task_killswitch/
lib.rs

1// Copyright (C) 2025, Cloudflare, Inc.
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without
5// modification, are permitted provided that the following conditions are
6// met:
7//
8//     * Redistributions of source code must retain the above copyright notice,
9//       this list of conditions and the following disclaimer.
10//
11//     * Redistributions in binary form must reproduce the above copyright
12//       notice, this list of conditions and the following disclaimer in the
13//       documentation and/or other materials provided with the distribution.
14//
15// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
16// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
17// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
19// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
27use tokio::sync::mpsc;
28use tokio::sync::watch;
29use tokio::task::JoinHandle;
30
31use std::collections::HashMap;
32
33use std::future::Future;
34use std::sync::atomic::AtomicU64;
35use std::sync::atomic::Ordering;
36use std::sync::LazyLock;
37
38enum ActiveTaskOp {
39    Add { id: u64, handle: JoinHandle<()> },
40    Remove { id: u64 },
41}
42
43/// Drop guard for task removal. If a task panics, this makes sure
44/// it is removed from [`ActiveTasks`] properly.
45struct RemoveOnDrop {
46    id: u64,
47    task_tx_weak: mpsc::WeakUnboundedSender<ActiveTaskOp>,
48}
49impl Drop for RemoveOnDrop {
50    fn drop(&mut self) {
51        if let Some(tx) = self.task_tx_weak.upgrade() {
52            let _ = tx.send(ActiveTaskOp::Remove { id: self.id });
53        }
54    }
55}
56
57/// A task killswitch that allows aborting all the tasks spawned with it at
58/// once. The implementation strives to not introduce any in-band locking, so
59/// spawning the future doesn't require acquiring a global lock, keeping the
60/// regular pace of operation.
61struct TaskKillswitch {
62    // NOTE: use a lock without poisoning here to not panic all the threads if
63    // one of the worker threads panic.
64    task_tx: parking_lot::RwLock<Option<mpsc::UnboundedSender<ActiveTaskOp>>>,
65    task_counter: AtomicU64,
66    all_killed: watch::Receiver<()>,
67}
68
69impl TaskKillswitch {
70    fn new() -> Self {
71        let (task_tx, task_rx) = mpsc::unbounded_channel();
72        let (signal_killed, all_killed) = watch::channel(());
73
74        let active_tasks = ActiveTasks {
75            task_rx,
76            tasks: Default::default(),
77            signal_killed,
78        };
79        tokio::spawn(active_tasks.collect());
80
81        Self {
82            task_tx: parking_lot::RwLock::new(Some(task_tx)),
83            task_counter: Default::default(),
84            all_killed,
85        }
86    }
87
88    fn spawn_task(&self, fut: impl Future<Output = ()> + Send + 'static) {
89        // NOTE: acquiring the lock here is very cheap, as unless the killswitch
90        // is activated, this one is always unlocked and this is just a
91        // few atomic operations.
92        let Some(task_tx) = self.task_tx.read().as_ref().cloned() else {
93            return;
94        };
95
96        let id = self.task_counter.fetch_add(1, Ordering::SeqCst);
97        let task_tx_weak = task_tx.downgrade();
98
99        let handle = tokio::spawn(async move {
100            // NOTE: we use a weak sender inside the spawned task - dropping
101            // all strong senders activates the killswitch. In that case,
102            // we don't need to remove anything from ActiveTasks anymore.
103            let _guard = RemoveOnDrop { task_tx_weak, id };
104            fut.await;
105        });
106
107        let _ = task_tx.send(ActiveTaskOp::Add { id, handle });
108    }
109
110    fn activate(&self) {
111        // take()ing the sender here drops it and thereby triggers the killswitch.
112        // Concurrent spawn_task calls may still hold strong senders, which
113        // ensures those tasks are added to ActiveTasks before the killing
114        // starts.
115        assert!(
116            self.task_tx.write().take().is_some(),
117            "killswitch can't be used twice"
118        );
119    }
120
121    fn killed(&self) -> impl Future<Output = ()> + Send + 'static {
122        let mut signal = self.all_killed.clone();
123        async move {
124            let _ = signal.changed().await;
125        }
126    }
127}
128
129struct ActiveTasks {
130    task_rx: mpsc::UnboundedReceiver<ActiveTaskOp>,
131    tasks: HashMap<u64, JoinHandle<()>>,
132    signal_killed: watch::Sender<()>,
133}
134
135impl ActiveTasks {
136    async fn collect(mut self) {
137        while let Some(op) = self.task_rx.recv().await {
138            self.handle_task_op(op);
139        }
140
141        for task in self.tasks.into_values() {
142            task.abort();
143        }
144        drop(self.signal_killed);
145    }
146
147    fn handle_task_op(&mut self, op: ActiveTaskOp) {
148        match op {
149            ActiveTaskOp::Add { id, handle } => {
150                self.tasks.insert(id, handle);
151            },
152            ActiveTaskOp::Remove { id } => {
153                self.tasks.remove(&id);
154            },
155        }
156    }
157}
158
159/// The global [`TaskKillswitch`] exposed publicly from the crate.
160static TASK_KILLSWITCH: LazyLock<TaskKillswitch> =
161    LazyLock::new(TaskKillswitch::new);
162
163/// Spawns a new asynchronous task and registers it in the crate's global
164/// killswitch.
165///
166/// Under the hood, [`tokio::spawn`] schedules the actual execution.
167#[inline]
168pub fn spawn_with_killswitch(fut: impl Future<Output = ()> + Send + 'static) {
169    TASK_KILLSWITCH.spawn_task(fut);
170}
171
172#[deprecated = "activate() was unnecessarily declared async. Use activate_now() instead."]
173pub async fn activate() {
174    TASK_KILLSWITCH.activate()
175}
176
177/// Triggers the killswitch, thereby scheduling all registered tasks to be
178/// killed.
179///
180/// Note: tasks are not killed synchronously in this function. This means
181/// `activate_now()` will return before all tasks have been stopped.
182#[inline]
183pub fn activate_now() {
184    TASK_KILLSWITCH.activate();
185}
186
187/// Returns a future that resolves when all registered tasks have been killed,
188/// after [`activate_now`] has been called.
189///
190/// Note: tokio does not kill a task until the next time it yields to the
191/// runtime. This means some killed tasks may still be running by the time this
192/// Future resolves.
193#[inline]
194pub fn killed_signal() -> impl Future<Output = ()> + Send + 'static {
195    TASK_KILLSWITCH.killed()
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use futures_util::future;
202    use std::time::Duration;
203    use tokio::sync::oneshot;
204
205    struct TaskAbortSignal(Option<oneshot::Sender<()>>);
206
207    impl TaskAbortSignal {
208        fn new() -> (Self, oneshot::Receiver<()>) {
209            let (tx, rx) = oneshot::channel();
210
211            (Self(Some(tx)), rx)
212        }
213    }
214
215    impl Drop for TaskAbortSignal {
216        fn drop(&mut self) {
217            let _ = self.0.take().unwrap().send(());
218        }
219    }
220
221    fn start_test_tasks(
222        killswitch: &TaskKillswitch,
223    ) -> Vec<oneshot::Receiver<()>> {
224        (0..1000)
225            .map(|_| {
226                let (tx, rx) = TaskAbortSignal::new();
227
228                killswitch.spawn_task(async move {
229                    tokio::time::sleep(tokio::time::Duration::from_secs(3600))
230                        .await;
231                    drop(tx);
232                });
233
234                rx
235            })
236            .collect()
237    }
238
239    #[tokio::test]
240    async fn activate_killswitch_early() {
241        let killswitch = TaskKillswitch::new();
242        let abort_signals = start_test_tasks(&killswitch);
243
244        killswitch.activate();
245
246        tokio::time::timeout(
247            Duration::from_secs(1),
248            future::join_all(abort_signals),
249        )
250        .await
251        .expect("tasks should be killed within given timeframe");
252    }
253
254    #[tokio::test]
255    async fn activate_killswitch_with_delay() {
256        let killswitch = TaskKillswitch::new();
257        let abort_signals = start_test_tasks(&killswitch);
258        let signal_handle = tokio::spawn(killswitch.killed());
259
260        // NOTE: give tasks time to start executing.
261        tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
262
263        assert!(!signal_handle.is_finished());
264        killswitch.activate();
265
266        tokio::time::timeout(
267            Duration::from_secs(1),
268            future::join_all(abort_signals),
269        )
270        .await
271        .expect("tasks should be killed within given timeframe");
272
273        tokio::time::timeout(Duration::from_secs(1), signal_handle)
274            .await
275            .expect("killed() signal should have resolved")
276            .expect("signal task should join successfully");
277    }
278}