//! Helper for shutdown signals use futures_util::{ future::{FusedFuture, Shared}, FutureExt, }; use std::{ future::Future, pin::Pin, sync::{atomic::AtomicUsize, Arc}, task::{ready, Context, Poll}, }; use tokio::sync::oneshot; /// A Future that resolves when the shutdown event has been fired. #[derive(Debug)] pub struct GracefulShutdown { shutdown: Shutdown, guard: Option, } impl GracefulShutdown { pub(crate) const fn new(shutdown: Shutdown, guard: GracefulShutdownGuard) -> Self { Self { shutdown, guard: Some(guard) } } /// Returns a new shutdown future that is ignores the returned [`GracefulShutdownGuard`]. /// /// This just maps the return value of the future to `()`, it does not drop the guard. pub fn ignore_guard(self) -> impl Future + Send + Sync + Unpin + 'static { self.map(drop) } } impl Future for GracefulShutdown { type Output = GracefulShutdownGuard; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { ready!(self.shutdown.poll_unpin(cx)); Poll::Ready(self.get_mut().guard.take().expect("Future polled after completion")) } } impl Clone for GracefulShutdown { fn clone(&self) -> Self { Self { shutdown: self.shutdown.clone(), guard: self.guard.as_ref().map(|g| GracefulShutdownGuard::new(Arc::clone(&g.0))), } } } /// A guard that fires once dropped to signal the [`TaskManager`](crate::TaskManager) that the /// [`GracefulShutdown`] has completed. #[derive(Debug)] #[must_use = "if unused the task will not be gracefully shutdown"] pub struct GracefulShutdownGuard(Arc); impl GracefulShutdownGuard { pub(crate) fn new(counter: Arc) -> Self { counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst); Self(counter) } } impl Drop for GracefulShutdownGuard { fn drop(&mut self) { self.0.fetch_sub(1, std::sync::atomic::Ordering::SeqCst); } } /// A Future that resolves when the shutdown event has been fired. #[derive(Debug, Clone)] pub struct Shutdown(Shared>); impl Future for Shutdown { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let pin = self.get_mut(); if pin.0.is_terminated() || pin.0.poll_unpin(cx).is_ready() { Poll::Ready(()) } else { Poll::Pending } } } /// Shutdown signal that fires either manually or on drop by closing the channel #[derive(Debug)] pub struct Signal(oneshot::Sender<()>); impl Signal { /// Fire the signal manually. pub fn fire(self) { let _ = self.0.send(()); } } /// Create a channel pair that's used to propagate shutdown event pub fn signal() -> (Signal, Shutdown) { let (sender, receiver) = oneshot::channel(); (Signal(sender), Shutdown(receiver.shared())) } #[cfg(test)] mod tests { use super::*; use futures_util::future::join_all; use std::time::Duration; #[tokio::test(flavor = "multi_thread")] async fn test_shutdown() { let (_signal, _shutdown) = signal(); } #[tokio::test(flavor = "multi_thread")] async fn test_drop_signal() { let (signal, shutdown) = signal(); tokio::task::spawn(async move { tokio::time::sleep(Duration::from_millis(500)).await; drop(signal) }); shutdown.await; } #[tokio::test(flavor = "multi_thread")] async fn test_multi_shutdowns() { let (signal, shutdown) = signal(); let mut tasks = Vec::with_capacity(100); for _ in 0..100 { let shutdown = shutdown.clone(); let task = tokio::task::spawn(async move { shutdown.await; }); tasks.push(task); } drop(signal); join_all(tasks).await; } #[tokio::test(flavor = "multi_thread")] async fn test_drop_signal_from_thread() { let (signal, shutdown) = signal(); let _thread = std::thread::spawn(|| { std::thread::sleep(Duration::from_millis(500)); drop(signal) }); shutdown.await; } }