summaryrefslogtreecommitdiff
path: root/src/utils.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/utils.rs')
-rw-r--r--src/utils.rs124
1 files changed, 124 insertions, 0 deletions
diff --git a/src/utils.rs b/src/utils.rs
new file mode 100644
index 0000000..5a2407a
--- /dev/null
+++ b/src/utils.rs
@@ -0,0 +1,124 @@
+use std::convert::Infallible;
+
+use futures::Future;
+use rocket::{
+ fairing::{self, Fairing, Info, Kind},
+ http::Status,
+ request::{FromRequest, Outcome},
+ tokio::{
+ self,
+ sync::broadcast::{channel, Sender},
+ task::JoinHandle,
+ },
+ Request, Response, Sentinel,
+};
+
+// does the same as try_outcome!(...).map_forward(|_| data),
+// but without moving data in non-failure cases.
+macro_rules! try_outcome_data {
+ ($data:expr, $e:expr) => {
+ match $e {
+ ::rocket::outcome::Outcome::Success(val) => val,
+ ::rocket::outcome::Outcome::Failure(e) => {
+ return ::rocket::outcome::Outcome::Failure(::std::convert::From::from(e))
+ },
+ ::rocket::outcome::Outcome::Forward(()) => {
+ return ::rocket::outcome::Outcome::Forward($data)
+ },
+ }
+ };
+}
+
+pub trait CanBeLogged: Send + 'static {
+ fn into_message(self) -> Option<String>;
+ fn not_run() -> Self;
+}
+
+impl CanBeLogged for Result<(), anyhow::Error> {
+ fn into_message(self) -> Option<String> {
+ self.err().map(|e| e.to_string())
+ }
+ fn not_run() -> Self {
+ Ok(())
+ }
+}
+
+pub fn spawn_logged<T>(context: &'static str, future: T) -> JoinHandle<()>
+where
+ T: Future + Send + 'static,
+ T::Output: CanBeLogged + Send + 'static,
+{
+ tokio::spawn(async move {
+ if let Some(msg) = future.await.into_message() {
+ warn!("{context}: {msg}");
+ }
+ })
+}
+
+pub struct DeferredActions;
+struct HasDeferredActions;
+
+pub struct DeferAction(Sender<()>);
+
+#[async_trait]
+impl Fairing for DeferredActions {
+ fn info(&self) -> Info {
+ Info { name: "deferred actions", kind: Kind::Ignite | Kind::Response }
+ }
+
+ async fn on_ignite(&self, rocket: rocket::Rocket<rocket::Build>) -> fairing::Result {
+ Ok(rocket.manage(HasDeferredActions))
+ }
+
+ async fn on_response<'r>(&self, req: &'r Request<'_>, res: &mut Response<'r>) {
+ if res.status() == Status::Ok {
+ if let Some(DeferAction(tx)) = req.local_cache(|| None) {
+ // could have no receivers, that's not an error here
+ tx.send(()).ok();
+ }
+ }
+ }
+}
+
+impl<'r> Sentinel for &'r DeferAction {
+ fn abort(rocket: &rocket::Rocket<rocket::Ignite>) -> bool {
+ rocket.state::<HasDeferredActions>().is_none()
+ }
+}
+
+#[async_trait]
+impl<'r> FromRequest<'r> for &'r DeferAction {
+ type Error = Infallible;
+
+ async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
+ Outcome::Success(
+ req.local_cache(|| {
+ let (tx, _) = channel(1);
+ Some(DeferAction(tx))
+ })
+ .as_ref()
+ .unwrap(),
+ )
+ }
+}
+
+impl DeferAction {
+ pub fn spawn_after_success<T>(&self, context: &'static str, future: T)
+ where
+ T: Future + Send + 'static,
+ T::Output: CanBeLogged + Send + 'static,
+ {
+ let mut r = self.0.subscribe();
+ spawn_logged(context, async move {
+ match r.recv().await {
+ Ok(_) => {
+ // the request finished with success, now wait for it to be dropped
+ // to ensure that all other fairings have run to completion.
+ r.recv().await.ok();
+ future.await
+ },
+ Err(_) => CanBeLogged::not_run(),
+ }
+ });
+ }
+}