use std::convert::Infallible; use futures::Future; use rocket::{ fairing::{self, Fairing, Info, Kind}, http::Status, request::{FromRequest, Outcome}, tokio::{ self, sync::broadcast::{channel, Sender}, task::JoinHandle, }, Request, Response, Sentinel, }; // does the same as try_outcome!(...).map_forward(|_| data), // but without moving data in non-failure cases. macro_rules! try_outcome_data { ($data:expr, $e:expr) => { match $e { ::rocket::outcome::Outcome::Success(val) => val, ::rocket::outcome::Outcome::Failure(e) => { return ::rocket::outcome::Outcome::Failure(::std::convert::From::from(e)) }, ::rocket::outcome::Outcome::Forward(()) => { return ::rocket::outcome::Outcome::Forward($data) }, } }; } pub fn spawn_logged(context: &'static str, future: F) -> JoinHandle<()> where F: Future> + Send + 'static, { tokio::spawn(async move { if let Err(e) = future.await { warn!("{context}: {e:?}"); } }) } pub struct DeferredActions; struct HasDeferredActions; pub struct DeferAction(Sender<()>); #[async_trait] impl Fairing for DeferredActions { fn info(&self) -> Info { Info { name: "deferred actions", kind: Kind::Ignite | Kind::Response } } async fn on_ignite(&self, rocket: rocket::Rocket) -> fairing::Result { Ok(rocket.manage(HasDeferredActions)) } async fn on_response<'r>(&self, req: &'r Request<'_>, res: &mut Response<'r>) { if res.status() == Status::Ok { if let Some(DeferAction(tx)) = req.local_cache(|| None) { // could have no receivers, that's not an error here tx.send(()).ok(); } } } } impl<'r> Sentinel for &'r DeferAction { fn abort(rocket: &rocket::Rocket) -> bool { rocket.state::().is_none() } } #[async_trait] impl<'r> FromRequest<'r> for &'r DeferAction { type Error = Infallible; async fn from_request(req: &'r Request<'_>) -> Outcome { Outcome::Success( req.local_cache(|| { let (tx, _) = channel(1); Some(DeferAction(tx)) }) .as_ref() .unwrap(), ) } } impl DeferAction { pub fn spawn_after_success(&self, context: &'static str, future: F) where F: Future> + Send + 'static, { let mut r = self.0.subscribe(); spawn_logged(context, async move { match r.recv().await { Ok(_) => { // the request finished with success, now wait for it to be dropped // to ensure that all other fairings have run to completion. r.recv().await.ok(); future.await }, Err(_) => Ok(()), } }); } }