From 2f8dce44d3f2be74b5c6ec0a2e7f4ceced715328 Mon Sep 17 00:00:00 2001 From: pennae Date: Wed, 13 Jul 2022 10:33:30 +0200 Subject: initial import --- src/utils.rs | 124 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 src/utils.rs (limited to 'src/utils.rs') diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..5a2407a --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,124 @@ +use std::convert::Infallible; + +use futures::Future; +use rocket::{ + fairing::{self, Fairing, Info, Kind}, + http::Status, + request::{FromRequest, Outcome}, + tokio::{ + self, + sync::broadcast::{channel, Sender}, + task::JoinHandle, + }, + Request, Response, Sentinel, +}; + +// does the same as try_outcome!(...).map_forward(|_| data), +// but without moving data in non-failure cases. +macro_rules! try_outcome_data { + ($data:expr, $e:expr) => { + match $e { + ::rocket::outcome::Outcome::Success(val) => val, + ::rocket::outcome::Outcome::Failure(e) => { + return ::rocket::outcome::Outcome::Failure(::std::convert::From::from(e)) + }, + ::rocket::outcome::Outcome::Forward(()) => { + return ::rocket::outcome::Outcome::Forward($data) + }, + } + }; +} + +pub trait CanBeLogged: Send + 'static { + fn into_message(self) -> Option; + fn not_run() -> Self; +} + +impl CanBeLogged for Result<(), anyhow::Error> { + fn into_message(self) -> Option { + self.err().map(|e| e.to_string()) + } + fn not_run() -> Self { + Ok(()) + } +} + +pub fn spawn_logged(context: &'static str, future: T) -> JoinHandle<()> +where + T: Future + Send + 'static, + T::Output: CanBeLogged + Send + 'static, +{ + tokio::spawn(async move { + if let Some(msg) = future.await.into_message() { + warn!("{context}: {msg}"); + } + }) +} + +pub struct DeferredActions; +struct HasDeferredActions; + +pub struct DeferAction(Sender<()>); + +#[async_trait] +impl Fairing for DeferredActions { + fn info(&self) -> Info { + Info { name: "deferred actions", kind: Kind::Ignite | Kind::Response } + } + + async fn on_ignite(&self, rocket: rocket::Rocket) -> fairing::Result { + Ok(rocket.manage(HasDeferredActions)) + } + + async fn on_response<'r>(&self, req: &'r Request<'_>, res: &mut Response<'r>) { + if res.status() == Status::Ok { + if let Some(DeferAction(tx)) = req.local_cache(|| None) { + // could have no receivers, that's not an error here + tx.send(()).ok(); + } + } + } +} + +impl<'r> Sentinel for &'r DeferAction { + fn abort(rocket: &rocket::Rocket) -> bool { + rocket.state::().is_none() + } +} + +#[async_trait] +impl<'r> FromRequest<'r> for &'r DeferAction { + type Error = Infallible; + + async fn from_request(req: &'r Request<'_>) -> Outcome { + Outcome::Success( + req.local_cache(|| { + let (tx, _) = channel(1); + Some(DeferAction(tx)) + }) + .as_ref() + .unwrap(), + ) + } +} + +impl DeferAction { + pub fn spawn_after_success(&self, context: &'static str, future: T) + where + T: Future + Send + 'static, + T::Output: CanBeLogged + Send + 'static, + { + let mut r = self.0.subscribe(); + spawn_logged(context, async move { + match r.recv().await { + Ok(_) => { + // the request finished with success, now wait for it to be dropped + // to ensure that all other fairings have run to completion. + r.recv().await.ok(); + future.await + }, + Err(_) => CanBeLogged::not_run(), + } + }); + } +} -- cgit v1.2.3