From 2f8dce44d3f2be74b5c6ec0a2e7f4ceced715328 Mon Sep 17 00:00:00 2001 From: pennae Date: Wed, 13 Jul 2022 10:33:30 +0200 Subject: initial import --- src/push.rs | 198 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 198 insertions(+) create mode 100644 src/push.rs (limited to 'src/push.rs') diff --git a/src/push.rs b/src/push.rs new file mode 100644 index 0000000..6ee1afb --- /dev/null +++ b/src/push.rs @@ -0,0 +1,198 @@ +use anyhow::Result; +use rocket::http::uri::Absolute; +use serde_json::{json, Value}; +use std::time::Duration; +use std::{fs::File, io::Read, path::Path}; + +use serde::Serialize; +use web_push::{ + ContentEncoding, SubscriptionInfo, VapidSignatureBuilder, WebPushClient, WebPushMessageBuilder, +}; + +use crate::db::DbConn; +use crate::types::{Device, DeviceID, DevicePush, UserID}; + +pub(crate) struct PushClient { + key: Box<[u8]>, + client: WebPushClient, + subject: String, + base_uri: Absolute<'static>, + default_ttl: Duration, +} + +impl PushClient { + pub(crate) fn new>( + key: P, + subject: &str, + base_uri: Absolute<'static>, + default_ttl: Duration, + ) -> Result { + let mut key_bytes = vec![]; + File::open(key).and_then(|mut f| f.read_to_end(&mut key_bytes))?; + Ok(PushClient { + key: key_bytes.into_boxed_slice(), + client: WebPushClient::new()?, + subject: subject.to_string(), + base_uri, + default_ttl, + }) + } + + async fn push_raw(&self, to: &DevicePush, ttl: Duration, data: Option<&[u8]>) -> Result<()> { + let sub = SubscriptionInfo::new(&to.callback, &to.public_key, &to.auth_key); + let mut sig = VapidSignatureBuilder::from_pem(&*self.key, &sub)?; + // mozilla requires {aud,exp,sub} or message will get a 401 unauthorized. + // {aud,exp} are added automatically + sig.add_claim("sub", self.subject.as_str()); + let mut builder = WebPushMessageBuilder::new(&sub)?; + if let Some(data) = data { + builder.set_payload(ContentEncoding::Aes128Gcm, data); + } + builder.set_vapid_signature(sig.build()?); + builder.set_ttl(ttl.as_secs().min(u32::MAX as u64) as u32); + Ok(self.client.send(builder.build()?).await?) + } + + async fn push_one( + &self, + context: &str, + db: Option<&DbConn>, + to: &Device, + ttl: Duration, + data: Option<&[u8]>, + ) -> Result<()> { + match (to.push_expired, to.push.as_ref()) { + (false, Some(ep)) => match self.push_raw(ep, ttl, data).await { + Ok(()) => Ok(()), + Err(e) => { + warn!("{} push to {} failed: {}", context, &to.device_id, e); + if let Some(db) = db { + if let Err(e) = db.set_push_expired(&to.device_id).await { + warn!("failed to set {} push_endpoint_expired: {}", &to.device_id, e); + } + } + Err(e) + }, + }, + (_, None) => Err(anyhow!("no push callback")), + (true, _) => Err(anyhow!("push endpoint expired")), + } + } + + async fn push_all( + &self, + context: &str, + db: Option<&DbConn>, + to: &[Device], + ttl: Duration, + msg: impl Serialize, + ) { + let msg = serde_json::to_vec(&msg).expect("push message serialization failed"); + for dev in to { + // ignore errors here, except by logging them. we can't notify the client + // about anything and failing isn't an option either. + let _ = self.push_one(context, db, dev, ttl, Some(&msg)).await; + } + } + + pub(crate) async fn command_received( + &self, + db: &DbConn, + to: &Device, + command: &str, + index: i64, + sender: &Option, + ) -> Result<()> { + let url = + format!("{}/auth/v1/account/device/commands?index={}&limit=1", self.base_uri, index); + let msg = json!({ + "version": 1, + "command": "fxaccounts:command_received", + "data": { + "command": command, + "index": index, + "sender": sender, + "url": url, + }, + }); + let msg = serde_json::to_vec(&msg)?; + self.push_one("command_received", Some(db), to, self.default_ttl, Some(&msg)).await + } + + pub(crate) async fn device_connected(&self, db: &DbConn, to: &[Device], name: &str) { + let msg = json!({ + "version": 1, + "command": "fxaccounts:device_connected", + "data": { + "deviceName": name, + }, + }); + self.push_all("device_connected", Some(db), to, self.default_ttl, &msg).await; + } + + pub(crate) async fn device_disconnected(&self, db: &DbConn, to: &[Device], id: &DeviceID) { + let msg = json!({ + "version": 1, + "command": "fxaccounts:device_disconnected", + "data": { + "id": id, + }, + }); + self.push_all("device_disconnected", Some(db), to, self.default_ttl, &msg).await; + } + + pub(crate) async fn profile_updated(&self, db: &DbConn, to: &[Device]) { + let msg = json!({ + "version": 1, + "command": "fxaccounts:profile_updated", + }); + self.push_all("profile_updated", Some(db), to, self.default_ttl, &msg).await; + } + + pub(crate) async fn account_verified(&self, db: &DbConn, to: &[Device]) { + for dev in to { + // ignore errors here, except by logging them. we can't notify the client + // about anything and failing isn't an option either. + let _ = self.push_one("account_verified", Some(db), dev, Duration::ZERO, None).await; + } + } + + pub(crate) async fn account_destroyed(&self, to: &[Device], uid: &UserID) { + let msg = json!({ + "version": 1, + "command": "fxaccounts:account_destroyed", + "data": { + "uid": uid, + }, + }); + self.push_all("account_destroyed", None, to, self.default_ttl, &msg).await; + } + + pub(crate) async fn password_reset(&self, to: &[Device]) { + let msg = serde_json::to_vec(&json!({ + "version": 1u32, + "command": "fxaccounts:password_reset", + })) + .expect("serde failed"); + for dev in to { + // ignore errors here, except by logging them. we can't notify the client + // about anything and failing isn't an option either. + let _ = self.push_one("password_reset", None, dev, self.default_ttl, Some(&msg)).await; + // NOTE password_reset alone doesn't seem to do much, se we also disconnect + // each device explicitly. + let msg = serde_json::to_vec(&json!({ + "version": 1, + "command": "fxaccounts:device_disconnected", + "data": { + "id": dev.device_id, + }, + })) + .expect("serde failed"); + let _ = self.push_one("password_reset", None, dev, self.default_ttl, Some(&msg)).await; + } + } + + pub(crate) async fn push_any(&self, db: &DbConn, to: &[Device], ttl: Duration, payload: Value) { + self.push_all("push_any", Some(db), to, ttl, &payload).await; + } +} -- cgit v1.2.3