1use tokio::sync::mpsc;
28use tokio::sync::watch;
29use tokio::task::JoinHandle;
30
31use std::collections::HashMap;
32
33use std::future::Future;
34use std::sync::atomic::AtomicU64;
35use std::sync::atomic::Ordering;
36use std::sync::LazyLock;
37
38enum ActiveTaskOp {
39 Add { id: u64, handle: JoinHandle<()> },
40 Remove { id: u64 },
41}
42
43struct RemoveOnDrop {
46 id: u64,
47 task_tx_weak: mpsc::WeakUnboundedSender<ActiveTaskOp>,
48}
49impl Drop for RemoveOnDrop {
50 fn drop(&mut self) {
51 if let Some(tx) = self.task_tx_weak.upgrade() {
52 let _ = tx.send(ActiveTaskOp::Remove { id: self.id });
53 }
54 }
55}
56
57struct TaskKillswitch {
62 task_tx: parking_lot::RwLock<Option<mpsc::UnboundedSender<ActiveTaskOp>>>,
65 task_counter: AtomicU64,
66 all_killed: watch::Receiver<()>,
67}
68
69impl TaskKillswitch {
70 fn new() -> Self {
71 let (task_tx, task_rx) = mpsc::unbounded_channel();
72 let (signal_killed, all_killed) = watch::channel(());
73
74 let active_tasks = ActiveTasks {
75 task_rx,
76 tasks: Default::default(),
77 signal_killed,
78 };
79 tokio::spawn(active_tasks.collect());
80
81 Self {
82 task_tx: parking_lot::RwLock::new(Some(task_tx)),
83 task_counter: Default::default(),
84 all_killed,
85 }
86 }
87
88 fn spawn_task(&self, fut: impl Future<Output = ()> + Send + 'static) {
89 let Some(task_tx) = self.task_tx.read().as_ref().cloned() else {
93 return;
94 };
95
96 let id = self.task_counter.fetch_add(1, Ordering::SeqCst);
97 let task_tx_weak = task_tx.downgrade();
98
99 let handle = tokio::spawn(async move {
100 let _guard = RemoveOnDrop { task_tx_weak, id };
104 fut.await;
105 });
106
107 let _ = task_tx.send(ActiveTaskOp::Add { id, handle });
108 }
109
110 fn activate(&self) {
111 assert!(
116 self.task_tx.write().take().is_some(),
117 "killswitch can't be used twice"
118 );
119 }
120
121 fn killed(&self) -> impl Future<Output = ()> + Send + 'static {
122 let mut signal = self.all_killed.clone();
123 async move {
124 let _ = signal.changed().await;
125 }
126 }
127}
128
129struct ActiveTasks {
130 task_rx: mpsc::UnboundedReceiver<ActiveTaskOp>,
131 tasks: HashMap<u64, JoinHandle<()>>,
132 signal_killed: watch::Sender<()>,
133}
134
135impl ActiveTasks {
136 async fn collect(mut self) {
137 while let Some(op) = self.task_rx.recv().await {
138 self.handle_task_op(op);
139 }
140
141 for task in self.tasks.into_values() {
142 task.abort();
143 }
144 drop(self.signal_killed);
145 }
146
147 fn handle_task_op(&mut self, op: ActiveTaskOp) {
148 match op {
149 ActiveTaskOp::Add { id, handle } => {
150 self.tasks.insert(id, handle);
151 },
152 ActiveTaskOp::Remove { id } => {
153 self.tasks.remove(&id);
154 },
155 }
156 }
157}
158
159static TASK_KILLSWITCH: LazyLock<TaskKillswitch> =
161 LazyLock::new(TaskKillswitch::new);
162
163#[inline]
168pub fn spawn_with_killswitch(fut: impl Future<Output = ()> + Send + 'static) {
169 TASK_KILLSWITCH.spawn_task(fut);
170}
171
172#[deprecated = "activate() was unnecessarily declared async. Use activate_now() instead."]
173pub async fn activate() {
174 TASK_KILLSWITCH.activate()
175}
176
177#[inline]
183pub fn activate_now() {
184 TASK_KILLSWITCH.activate();
185}
186
187#[inline]
194pub fn killed_signal() -> impl Future<Output = ()> + Send + 'static {
195 TASK_KILLSWITCH.killed()
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use futures_util::future;
202 use std::time::Duration;
203 use tokio::sync::oneshot;
204
205 struct TaskAbortSignal(Option<oneshot::Sender<()>>);
206
207 impl TaskAbortSignal {
208 fn new() -> (Self, oneshot::Receiver<()>) {
209 let (tx, rx) = oneshot::channel();
210
211 (Self(Some(tx)), rx)
212 }
213 }
214
215 impl Drop for TaskAbortSignal {
216 fn drop(&mut self) {
217 let _ = self.0.take().unwrap().send(());
218 }
219 }
220
221 fn start_test_tasks(
222 killswitch: &TaskKillswitch,
223 ) -> Vec<oneshot::Receiver<()>> {
224 (0..1000)
225 .map(|_| {
226 let (tx, rx) = TaskAbortSignal::new();
227
228 killswitch.spawn_task(async move {
229 tokio::time::sleep(tokio::time::Duration::from_secs(3600))
230 .await;
231 drop(tx);
232 });
233
234 rx
235 })
236 .collect()
237 }
238
239 #[tokio::test]
240 async fn activate_killswitch_early() {
241 let killswitch = TaskKillswitch::new();
242 let abort_signals = start_test_tasks(&killswitch);
243
244 killswitch.activate();
245
246 tokio::time::timeout(
247 Duration::from_secs(1),
248 future::join_all(abort_signals),
249 )
250 .await
251 .expect("tasks should be killed within given timeframe");
252 }
253
254 #[tokio::test]
255 async fn activate_killswitch_with_delay() {
256 let killswitch = TaskKillswitch::new();
257 let abort_signals = start_test_tasks(&killswitch);
258 let signal_handle = tokio::spawn(killswitch.killed());
259
260 tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
262
263 assert!(!signal_handle.is_finished());
264 killswitch.activate();
265
266 tokio::time::timeout(
267 Duration::from_secs(1),
268 future::join_all(abort_signals),
269 )
270 .await
271 .expect("tasks should be killed within given timeframe");
272
273 tokio::time::timeout(Duration::from_secs(1), signal_handle)
274 .await
275 .expect("killed() signal should have resolved")
276 .expect("signal task should join successfully");
277 }
278}