summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/api/auth/account.rs413
-rw-r--r--src/api/auth/device.rs455
-rw-r--r--src/api/auth/email.rs126
-rw-r--r--src/api/auth/invite.rs47
-rw-r--r--src/api/auth/mod.rs238
-rw-r--r--src/api/auth/oauth.rs433
-rw-r--r--src/api/auth/password.rs260
-rw-r--r--src/api/auth/session.rs107
-rw-r--r--src/api/mod.rs32
-rw-r--r--src/api/oauth.rs163
-rw-r--r--src/api/profile/mod.rs324
-rw-r--r--src/auth.rs241
-rw-r--r--src/bin/minorskulk.rs9
-rw-r--r--src/cache.rs42
-rw-r--r--src/crypto.rs408
-rw-r--r--src/db/mod.rs1026
-rw-r--r--src/js.rs53
-rw-r--r--src/lib.rs319
-rw-r--r--src/mailer.rs105
-rw-r--r--src/push.rs198
-rw-r--r--src/types.rs436
-rw-r--r--src/types/oauth.rs267
-rw-r--r--src/utils.rs124
23 files changed, 5826 insertions, 0 deletions
diff --git a/src/api/auth/account.rs b/src/api/auth/account.rs
new file mode 100644
index 0000000..51dd98e
--- /dev/null
+++ b/src/api/auth/account.rs
@@ -0,0 +1,413 @@
+use std::sync::Arc;
+
+use anyhow::Result;
+use chrono::{DateTime, Utc};
+use password_hash::SaltString;
+use rand::{thread_rng, Rng};
+use rocket::request::FromRequest;
+use rocket::State;
+use rocket::{serde::json::Json, Request};
+use serde::{Deserialize, Serialize};
+use validator::Validate;
+
+use crate::api::{Empty, EMPTY};
+use crate::db::{Db, DbConn};
+use crate::mailer::Mailer;
+use crate::push::PushClient;
+use crate::types::AccountResetID;
+use crate::utils::DeferAction;
+use crate::Config;
+use crate::{
+ api::{auth, serialize_dt},
+ auth::{AuthSource, Authenticated},
+ crypto::{AuthPW, KeyBundle, KeyFetchReq, SecretBytes, SessionCredentials},
+ types::{HawkKey, KeyFetchID, OauthToken, SecretKey, SessionID, User, UserID, VerifyHash},
+};
+
+// TODO better error handling
+
+// MISSING get /account/profile
+// MISSING get /account/status
+// MISSING post /account/status
+// MISSING post /account/reset
+
+#[derive(Deserialize, Debug, Validate)]
+#[serde(deny_unknown_fields)]
+#[allow(non_snake_case)]
+pub(crate) struct Create {
+ #[validate(email, length(min = 3, max = 256))]
+ email: String,
+ authPW: AuthPW,
+ // MISSING service
+ // MISSING redirectTo
+ // MISSING resume
+ // MISSING metricsContext
+ // NOTE we misuse style to communicate an invite token!
+ style: Option<String>,
+ // MISSING verificationMethod
+}
+
+#[derive(Serialize, Debug)]
+#[allow(non_snake_case)]
+#[serde(deny_unknown_fields)]
+pub(crate) struct CreateResp {
+ uid: UserID,
+ sessionToken: SecretBytes<32>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ keyFetchToken: Option<SecretBytes<32>>,
+ #[serde(serialize_with = "serialize_dt")]
+ authAt: DateTime<Utc>,
+ // MISSING verificationMethod
+}
+
+// MISSING arg: service
+#[post("/account/create?<keys>", data = "<data>")]
+pub(crate) async fn create(
+ db: &DbConn,
+ cfg: &State<Config>,
+ mailer: &State<Arc<Mailer>>,
+ keys: Option<bool>,
+ data: Json<Create>,
+) -> auth::Result<CreateResp> {
+ let keys = keys.unwrap_or(false);
+ let data = data.into_inner();
+ data.validate().map_err(|_| auth::Error::InvalidParameter)?;
+
+ if db.user_email_exists(&data.email).await? {
+ return Err(auth::Error::AccountExists);
+ }
+
+ match (cfg.invite_only, data.style) {
+ (false, Some(_)) => return Err(auth::Error::InvalidParameter),
+ (false, None) => (),
+ (true, None) => return Err(auth::Error::InviteOnly),
+ (true, Some(code)) => {
+ db.use_invite_code(&code).await.map_err(|e| match e {
+ sqlx::Error::RowNotFound => auth::Error::InviteNotFound,
+ e => auth::Error::Other(anyhow!(e)),
+ })?;
+ },
+ }
+
+ let ka = SecretBytes::generate();
+ let wrapwrap_kb = SecretBytes::generate();
+ let auth_salt = SaltString::generate(rand::rngs::OsRng);
+ let stretched = data.authPW.stretch(auth_salt.as_salt())?;
+ let verify_hash = stretched.verify_hash();
+ let session_token = SecretBytes::generate();
+ let session = SessionCredentials::derive(&session_token);
+ let key_fetch_token = if keys {
+ let key_fetch_token = SecretBytes::generate();
+ let req = KeyFetchReq::from_token(&key_fetch_token);
+ let wrapped = req.derive_resp().wrap_keys(&KeyBundle {
+ ka: ka.clone(),
+ wrap_kb: stretched.decrypt_wwkb(&wrapwrap_kb),
+ });
+ db.add_key_fetch(KeyFetchID(req.token_id.0), &HawkKey(req.req_hmac_key), &wrapped).await?;
+ Some(key_fetch_token)
+ } else {
+ None
+ };
+ let uid = db
+ .add_user(User {
+ auth_salt,
+ email: data.email.to_owned(),
+ ka: SecretKey(ka),
+ wrapwrap_kb: SecretKey(wrapwrap_kb),
+ verify_hash: VerifyHash(verify_hash),
+ display_name: None,
+ verified: false,
+ })
+ .await?;
+ let session_id = SessionID(session.token_id.0);
+ let auth_at = db
+ .add_session(session_id.clone(), &uid, HawkKey(session.req_hmac_key), false, None)
+ .await?;
+ let verify_code = hex::encode(&SecretBytes::<16>::generate().0);
+ db.add_verify_code(&uid, &session_id, &verify_code).await?;
+ // NOTE we send the email in this context rather than a spawn to signal
+ // send errors to the client.
+ mailer.send_account_verify(&uid, &data.email, &verify_code).await.map_err(|e| {
+ error!("failed to send email: {e}");
+ auth::Error::EmailFailed
+ })?;
+ Ok(Json(CreateResp {
+ uid,
+ sessionToken: session_token,
+ keyFetchToken: key_fetch_token,
+ authAt: auth_at,
+ }))
+}
+
+#[derive(Deserialize, Debug, Validate)]
+#[serde(deny_unknown_fields)]
+#[allow(non_snake_case)]
+pub(crate) struct Login {
+ #[validate(email, length(min = 3, max = 256))]
+ email: String,
+ authPW: AuthPW,
+ // MISSING service
+ // MISSING redirectTo
+ // MISSING resume
+ // MISSING reason
+ // MISSING unblockCode
+ // MISSING originalLoginEmail
+ // MISSING verificationMethod
+}
+
+#[derive(Serialize, Debug)]
+#[allow(non_snake_case)]
+#[serde(deny_unknown_fields)]
+pub(crate) struct LoginResp {
+ uid: UserID,
+ sessionToken: SecretBytes<32>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ keyFetchToken: Option<SecretBytes<32>>,
+ // MISSING verificationMethod
+ // MISSING verificationReason
+ // NOTE this is the *account* verified status, not the session status.
+ // the spec doesn't say.
+ verified: bool,
+ #[serde(serialize_with = "serialize_dt")]
+ authAt: DateTime<Utc>,
+ // MISSING metricsEnabled
+}
+
+// MISSING arg: service
+// MISSING arg: verificationMethod
+#[post("/account/login?<keys>", data = "<data>")]
+pub(crate) async fn login(
+ db: &DbConn,
+ mailer: &State<Arc<Mailer>>,
+ keys: Option<bool>,
+ data: Json<Login>,
+) -> auth::Result<LoginResp> {
+ let keys = keys.unwrap_or(false);
+ let data = data.into_inner();
+ data.validate().map_err(|_| auth::Error::InvalidParameter)?;
+
+ let (uid, user) = db.get_user(&data.email).await.map_err(|_| auth::Error::UnknownAccount)?;
+ if user.email != data.email {
+ return Err(auth::Error::IncorrectEmailCase);
+ }
+ if !user.verified {
+ return Err(auth::Error::UnverifiedAccount);
+ }
+
+ let stretched = data.authPW.stretch(user.auth_salt.as_salt())?;
+ if stretched.verify_hash() != user.verify_hash.0 {
+ return Err(auth::Error::IncorrectPassword);
+ }
+
+ let session_token = SecretBytes::generate();
+ let session = SessionCredentials::derive(&session_token);
+ let key_fetch_token = if keys {
+ let key_fetch_token = SecretBytes::generate();
+ let req = KeyFetchReq::from_token(&key_fetch_token);
+ let wrapped = req.derive_resp().wrap_keys(&KeyBundle {
+ ka: user.ka.0.clone(),
+ wrap_kb: stretched.decrypt_wwkb(&user.wrapwrap_kb.0),
+ });
+ db.add_key_fetch(KeyFetchID(req.token_id.0), &HawkKey(req.req_hmac_key), &wrapped).await?;
+ Some(key_fetch_token)
+ } else {
+ None
+ };
+
+ let session_id = SessionID(session.token_id.0);
+ let verify_code = format!("{:06}", thread_rng().gen_range(0..=999999));
+ let auth_at = db
+ .add_session(
+ session_id.clone(),
+ &uid,
+ HawkKey(session.req_hmac_key),
+ false,
+ Some(&verify_code),
+ )
+ .await?;
+ // NOTE we send the email in this context rather than a spawn to signal
+ // send errors to the client.
+ mailer.send_session_verify(&data.email, &verify_code).await.map_err(|e| {
+ error!("failed to send email: {e}");
+ auth::Error::EmailFailed
+ })?;
+ Ok(Json(LoginResp {
+ uid,
+ sessionToken: session_token,
+ keyFetchToken: key_fetch_token,
+ verified: true,
+ authAt: auth_at,
+ }))
+}
+
+#[derive(Deserialize, Debug, Validate)]
+#[serde(deny_unknown_fields)]
+#[allow(non_snake_case)]
+pub(crate) struct Destroy {
+ #[validate(email, length(min = 3, max = 256))]
+ email: String,
+ authPW: AuthPW,
+}
+
+// TODO may also be authenticated with a verified session
+#[post("/account/destroy", data = "<data>")]
+pub(crate) async fn destroy(
+ db: &DbConn,
+ db_pool: &Db,
+ defer: &DeferAction,
+ pc: &State<Arc<PushClient>>,
+ data: Json<Destroy>,
+) -> auth::Result<Empty> {
+ let data = data.into_inner();
+ data.validate().map_err(|_| auth::Error::InvalidParameter)?;
+
+ let (uid, user) = db.get_user(&data.email).await.map_err(|_| auth::Error::UnknownAccount)?;
+ if user.email != data.email {
+ return Err(auth::Error::IncorrectEmailCase);
+ }
+
+ let stretched = data.authPW.stretch(user.auth_salt.as_salt())?;
+ if stretched.verify_hash() != user.verify_hash.0 {
+ return Err(auth::Error::IncorrectPassword);
+ }
+
+ let devs = db.get_devices(&uid).await;
+ db.delete_user(&data.email).await?;
+ match devs {
+ Ok(devs) => defer.spawn_after_success("api::account/destroy(post)", {
+ let (pc, db) = (Arc::clone(pc), db_pool.clone());
+ async move {
+ let db = db.begin().await?;
+ pc.account_destroyed(&devs, &uid).await;
+ db.commit().await?;
+ Ok(())
+ }
+ }),
+ Err(e) => warn!("account_destroyed push failed: {e}"),
+ }
+
+ Ok(EMPTY)
+}
+
+#[derive(Deserialize, Serialize, Debug)]
+#[serde(deny_unknown_fields)]
+pub(crate) struct KeysResp {
+ bundle: String,
+}
+
+// NOTE the key fetch endpoint must delete a key fetch token from the database
+// once it has identified it, regardless of whether the request succeeds or
+// fails. we'll do this with a single-use auth source that sets the db to always
+// commit. the auth source must not be used for anything else. we can get away
+// with using a request guard because we'll always commit even if the guard
+// fails, but this is only allowable because this is the only handler for the path.
+
+#[derive(Debug)]
+pub(crate) struct WithKeyFetch;
+
+#[async_trait]
+impl AuthSource for WithKeyFetch {
+ type ID = KeyFetchID;
+ type Context = Vec<u8>;
+ async fn hawk(r: &Request<'_>, id: &KeyFetchID) -> Result<(SecretBytes<32>, Self::Context)> {
+ let db = Authenticated::<(), Self>::get_conn(r).await?;
+ db.always_commit().await?;
+ Ok(db.finish_key_fetch(id).await.map(|(h, ks)| (h.0, ks))?)
+ }
+ async fn bearer_token(_: &Request<'_>, _: &OauthToken) -> Result<(KeyFetchID, Self::Context)> {
+ // key fetch tokens are only valid in hawk requests
+ bail!("invalid key fetch authentication")
+ }
+}
+
+#[get("/account/keys")]
+pub(crate) async fn keys(auth: Authenticated<(), WithKeyFetch>) -> Json<KeysResp> {
+ // NOTE contrary to its own api spec fxa does not delete a key fetch if the
+ // associated session is not verified. we don't duplicate this special case
+ // because we control the clients, and requesting keys on an unverified session
+ // can be interpreted as a protocol violation anyway.
+ Json(KeysResp { bundle: hex::encode(&auth.context) })
+}
+
+#[derive(Debug)]
+pub(crate) struct WithResetToken;
+
+#[async_trait]
+impl AuthSource for WithResetToken {
+ type ID = AccountResetID;
+ type Context = UserID;
+ async fn hawk(
+ r: &Request<'_>,
+ id: &AccountResetID,
+ ) -> Result<(SecretBytes<32>, Self::Context)> {
+ // unlike key fetch we'll use a separate transaction here since the body of the
+ // handler can fail.
+ let pool = <&Db as FromRequest>::from_request(r)
+ .await
+ .success_or_else(|| anyhow!("could not open db connection"))?;
+ let db = pool.begin().await?;
+ let result = db.finish_account_reset(id).await.map(|(h, ctx)| (h.0, ctx))?;
+ db.commit().await?;
+ Ok(result)
+ }
+ async fn bearer_token(
+ _: &Request<'_>,
+ _: &OauthToken,
+ ) -> Result<(AccountResetID, Self::Context)> {
+ bail!("invalid password change authentication")
+ }
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(deny_unknown_fields)]
+#[allow(non_snake_case)]
+pub(crate) struct AccountResetReq {
+ authPW: AuthPW,
+ // MISSING wrapKb
+ // MISSING recoveryKeyId
+ // MISSING sessionToken
+}
+
+// NOTE resetting an account does not clear active sync data on the storage server,
+// so an account may be reported as disconnected for a while. this is not an error,
+// just an inconvenience we haven't found out how to fix yet.
+
+// MISSING arg: keys
+#[post("/account/reset", data = "<data>")]
+pub(crate) async fn reset(
+ db: &DbConn,
+ mailer: &State<Arc<Mailer>>,
+ client: &State<Arc<PushClient>>,
+ defer: &DeferAction,
+ data: Authenticated<AccountResetReq, WithResetToken>,
+) -> auth::Result<Empty> {
+ let user = db.get_user_by_id(&data.context).await?;
+
+ let notify_devs = db.get_devices(&data.context).await?;
+
+ let wrapwrap_kb = SecretBytes::generate();
+ let auth_salt = SaltString::generate(rand::rngs::OsRng);
+ let stretched = data.body.authPW.stretch(auth_salt.as_salt())?;
+ let verify_hash = stretched.verify_hash();
+
+ db.reset_user_auth(&data.context, auth_salt, SecretKey(wrapwrap_kb), VerifyHash(verify_hash))
+ .await?;
+
+ defer.spawn_after_success("api::auth/account/reset(post)", {
+ let client = Arc::clone(client);
+ async move {
+ client.password_reset(&notify_devs).await;
+ Ok(())
+ }
+ });
+
+ mailer
+ .send_account_reset(&user.email)
+ .await
+ .map_err(|e| {
+ warn!("account reset email send failed: {e}");
+ })
+ .ok();
+
+ Ok(EMPTY)
+}
diff --git a/src/api/auth/device.rs b/src/api/auth/device.rs
new file mode 100644
index 0000000..2b05e12
--- /dev/null
+++ b/src/api/auth/device.rs
@@ -0,0 +1,455 @@
+use std::time::Duration;
+use std::{collections::HashMap, sync::Arc};
+
+use chrono::{DateTime, Utc};
+use futures::future::join_all;
+use rocket::{serde::json::Json, State};
+use serde::{Deserialize, Serialize};
+use serde_json::Value;
+
+use crate::api::auth::{WithSession, WithVerifiedFxaLogin, WithVerifiedSession};
+use crate::api::{Empty, EMPTY};
+use crate::db::DbConn;
+use crate::push::PushClient;
+use crate::utils::DeferAction;
+use crate::{
+ api::{auth, serialize_dt_opt},
+ auth::Authenticated,
+ db::Db,
+ types::{
+ Device, DeviceCommand, DeviceCommands, DeviceID, DevicePush, DeviceUpdate, OauthTokenID,
+ SessionID,
+ },
+};
+
+fn map_error(e: sqlx::Error) -> auth::Error {
+ match &e {
+ // not-null violations can presumably only be caused by bad parameters
+ sqlx::Error::Database(de) if de.code().as_deref() == Some("23502") => {
+ auth::Error::MissingParameter
+ },
+ sqlx::Error::RowNotFound => auth::Error::UnknownDevice,
+ _ => auth::Error::Other(anyhow!(e)),
+ }
+}
+
+#[derive(Debug, Serialize, Deserialize, PartialEq)]
+#[allow(non_snake_case)]
+#[serde(deny_unknown_fields)]
+pub(crate) struct Info {
+ isCurrentDevice: bool,
+ id: DeviceID,
+ lastAccessTime: i64,
+ name: String,
+ r#type: String,
+ pushCallback: Option<String>,
+ pushPublicKey: Option<String>,
+ pushAuthKey: Option<String>,
+ pushEndpointExpired: bool,
+ availableCommands: HashMap<String, String>,
+ // NOTE location is optional per the spec, but fenix crashes if it isn't present
+ location: Value,
+ // MISSING lastAccessTimeFormatted
+ // MISSING approximateLastAccessTime
+ // MISSING approximateLastAccessTimeFormatted
+}
+
+fn device_to_json(current: Option<&DeviceID>, dev: Device) -> Info {
+ let (pcb, ppk, pak) = match dev.push {
+ Some(p) => (Some(p.callback), Some(p.public_key), Some(p.auth_key)),
+ None => (None, None, None),
+ };
+ Info {
+ isCurrentDevice: Some(&dev.device_id) == current,
+ id: dev.device_id,
+ lastAccessTime: dev.last_active.timestamp(),
+ name: dev.name,
+ r#type: dev.type_,
+ pushCallback: pcb,
+ pushPublicKey: ppk,
+ pushAuthKey: pak,
+ pushEndpointExpired: dev.push_expired,
+ availableCommands: dev.available_commands.into_map(),
+ location: dev.location,
+ }
+}
+
+#[derive(Serialize, Deserialize, PartialEq)]
+#[serde(transparent)]
+pub(crate) struct ListResp(Vec<Info>);
+
+#[get("/account/devices")]
+pub(crate) async fn devices(
+ db: &DbConn,
+ auth: Authenticated<(), WithVerifiedSession>,
+) -> auth::Result<ListResp> {
+ let devs = db.get_devices(&auth.context.uid).await?;
+ Ok(Json(ListResp(
+ devs.into_iter().map(|dev| device_to_json(auth.context.device_id.as_ref(), dev)).collect(),
+ )))
+}
+
+#[derive(Debug, Deserialize)]
+#[allow(non_snake_case)]
+#[serde(deny_unknown_fields)]
+pub(crate) struct DeviceReq {
+ id: Option<DeviceID>,
+ name: Option<String>,
+ r#type: Option<String>,
+ pushCallback: Option<String>,
+ pushPublicKey: Option<String>,
+ pushAuthKey: Option<String>,
+ availableCommands: Option<HashMap<String, String>>,
+ // present for legacy reasons, ignored
+ #[allow(dead_code)]
+ capabilities: Option<Vec<String>>,
+ location: Option<Value>,
+}
+
+#[post("/account/device", data = "<data>")]
+pub(crate) async fn device(
+ db: &DbConn,
+ db_pool: &Db,
+ defer: &DeferAction,
+ client: &State<Arc<PushClient>>,
+ // need to allow registrations to all sessions, otherwise the "now verified"
+ // notification can't be sent
+ data: Authenticated<DeviceReq, WithSession>,
+) -> auth::Result<Info> {
+ let dev = data.body;
+ if let (None, None, None) = (&dev.name, &dev.r#type, &dev.pushCallback) {
+ return Err(auth::Error::MissingParameter);
+ }
+
+ let push = dev.pushCallback.map(|pcb| DevicePush {
+ callback: pcb,
+ public_key: dev.pushPublicKey.unwrap_or_default(),
+ auth_key: dev.pushAuthKey.unwrap_or_default(),
+ });
+
+ let (own_id, changed_id, notify) = match (dev.id, data.context.device_id) {
+ (None, None) => {
+ let new = DeviceID::random();
+ (Some(new.clone()), new, true)
+ },
+ (None, Some(own)) => (Some(own.clone()), own, false),
+ (Some(other), own) => (own, other, false),
+ };
+ let result = db
+ .change_device(
+ &data.context.uid,
+ &changed_id,
+ DeviceUpdate {
+ name: dev.name.as_ref().map(AsRef::as_ref),
+ type_: dev.r#type.as_ref().map(AsRef::as_ref),
+ push,
+ available_commands: dev.availableCommands.map(DeviceCommands),
+ location: dev.location,
+ },
+ )
+ .await
+ .map_err(map_error)?;
+ if notify {
+ db.set_session_device(&data.session, Some(&changed_id)).await?;
+ match db.get_devices(&data.context.uid).await {
+ Err(e) => warn!("device_connected push failed: {e}"),
+ Ok(mut devs) => defer.spawn_after_success("api::auth/account/device(post)", {
+ devs.retain(|d| d.device_id != changed_id);
+ let (client, db) = (Arc::clone(client), db_pool.clone());
+ let name = result.name.clone();
+ async move {
+ let db = db.begin().await?;
+ client.device_connected(&db, &devs, &name).await;
+ db.commit().await?;
+ Ok(())
+ }
+ }),
+ };
+ }
+ Ok(Json(device_to_json(own_id.as_ref(), result)))
+}
+
+#[derive(Debug, Deserialize, Serialize)]
+#[serde(deny_unknown_fields)]
+pub(crate) struct Command {
+ target: DeviceID,
+ command: String,
+ payload: Value,
+ ttl: Option<u32>,
+}
+
+#[derive(Debug, Deserialize, Serialize)]
+#[allow(non_snake_case)]
+#[serde(deny_unknown_fields)]
+pub(crate) struct InvokeResp {
+ enqueued: bool,
+ notified: bool,
+ notifyError: Option<String>,
+}
+
+// NOTE fenix doesn't register a push callback for some reason, so receiving tabs
+// always requires opening the tab share menu or tab list first.
+#[post("/account/devices/invoke_command", data = "<cmd>")]
+pub(crate) async fn invoke(
+ client: &State<Arc<PushClient>>,
+ db: &DbConn,
+ cmd: Authenticated<Command, WithVerifiedSession>,
+) -> auth::Result<InvokeResp> {
+ let sender = cmd.context.device_id;
+ let dev = db.get_device(&cmd.context.uid, &cmd.body.target).await.map_err(map_error)?;
+ if dev.available_commands.get(&cmd.body.command).is_none() {
+ return Err(auth::Error::NoDeviceCommand);
+ }
+ let ttl = cmd.body.ttl.unwrap_or(30 * 86400).clamp(60, 30 * 86400);
+ let idx = db
+ .enqueue_command(&cmd.body.target, &sender, &cmd.body.command, &cmd.body.payload, ttl)
+ .await?;
+ let (notified, error) = client
+ .command_received(db, &dev, &cmd.body.command, idx, &sender)
+ .await
+ .map_or_else(|e| (false, Some(e.to_string())), |_| (true, None));
+ Ok(Json(InvokeResp { enqueued: true, notified, notifyError: error }))
+}
+
+#[derive(Debug, Serialize, Deserialize, PartialEq)]
+#[serde(deny_unknown_fields)]
+pub(crate) struct CommandData {
+ command: String,
+ payload: Value,
+ sender: Option<String>,
+}
+
+#[derive(Debug, Serialize, Deserialize, PartialEq)]
+#[serde(deny_unknown_fields)]
+pub(crate) struct CommandsEntry {
+ index: i64,
+ data: CommandData,
+}
+
+#[derive(Debug, Serialize, Deserialize, PartialEq)]
+#[serde(deny_unknown_fields)]
+pub(crate) struct CommandsResp {
+ index: i64,
+ last: bool,
+ messages: Vec<CommandsEntry>,
+}
+
+fn map_command(c: DeviceCommand) -> CommandsEntry {
+ CommandsEntry {
+ index: c.index,
+ data: CommandData { command: c.command, payload: c.payload, sender: c.sender },
+ }
+}
+
+#[get("/account/device/commands?<index>&<limit>")]
+pub(crate) async fn commands(
+ db: &DbConn,
+ index: i64,
+ limit: Option<i64>,
+ auth: Authenticated<(), WithVerifiedSession>,
+) -> auth::Result<CommandsResp> {
+ let dev = auth.context.device_id.as_ref().ok_or(auth::Error::UnknownDevice)?;
+ let (more, cmds) =
+ db.get_commands(&auth.context.uid, dev, index, limit.unwrap_or(100).clamp(0, 100)).await?;
+ Ok(Json(CommandsResp {
+ index: cmds.iter().map(|c| c.index).max().unwrap_or(0),
+ last: !more,
+ messages: cmds.into_iter().map(map_command).collect(),
+ }))
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(deny_unknown_fields)]
+pub(crate) struct DestroyReq {
+ id: DeviceID,
+}
+
+#[post("/account/device/destroy", data = "<req>")]
+pub(crate) async fn destroy(
+ db: &DbConn,
+ db_pool: &Db,
+ defer: &DeferAction,
+ client: &State<Arc<PushClient>>,
+ req: crate::auth::Authenticated<DestroyReq, WithVerifiedSession>,
+) -> auth::Result<Empty> {
+ db.delete_device(&req.context.uid, &req.body.id).await.map_err(map_error)?;
+ match db.get_devices(&req.context.uid).await {
+ Err(e) => warn!("device_disconnected push failed: {e}"),
+ Ok(devs) => defer.spawn_after_success("api::auth/account/device/destroy(post)", {
+ let (client, db) = (Arc::clone(client), db_pool.clone());
+ async move {
+ let db = db.begin().await?;
+ client.device_disconnected(&db, &devs, &req.body.id).await;
+ db.commit().await?;
+ Ok(())
+ }
+ }),
+ };
+ Ok(EMPTY)
+}
+
+#[derive(Debug, Deserialize)]
+pub(crate) enum NotifyTarget {
+ #[serde(rename = "all")]
+ All,
+}
+
+#[derive(Debug, Deserialize)]
+pub(crate) enum NotifyEPAction {
+ #[serde(rename = "accountVerify")]
+ AccountVerify,
+}
+
+#[derive(Debug, Deserialize)]
+#[allow(non_snake_case)]
+#[serde(untagged, deny_unknown_fields)]
+pub(crate) enum NotifyReq {
+ // deny_unknown_fields and flatten don't work together
+ All {
+ #[allow(dead_code)]
+ to: NotifyTarget,
+ _endpointAction: Option<NotifyEPAction>,
+ excluded: Option<Vec<DeviceID>>,
+ payload: Value,
+ TTL: Option<u32>,
+ },
+ Some {
+ to: Vec<DeviceID>,
+ _endpointAction: Option<NotifyEPAction>,
+ payload: Value,
+ TTL: Option<u32>,
+ },
+}
+
+#[post("/account/devices/notify", data = "<req>")]
+pub(crate) async fn notify(
+ db: &DbConn,
+ client: &State<Arc<PushClient>>,
+ req: Authenticated<NotifyReq, WithVerifiedSession>,
+) -> auth::Result<Empty> {
+ let (to, payload, ttl) = match req.body {
+ NotifyReq::All { excluded, payload, TTL: ttl, .. } => {
+ let excluded = excluded.unwrap_or_default();
+ let mut devs = db.get_devices(&req.context.uid).await?;
+ devs.retain(|d| !excluded.contains(&d.device_id));
+ (devs, payload, ttl)
+ },
+ NotifyReq::Some { to, payload, TTL: ttl, .. } => {
+ let to = join_all(to.iter().map(|id| db.get_device(&req.context.uid, id)))
+ .await
+ .into_iter()
+ .collect::<Result<Vec<_>, _>>()?;
+ (to, payload, ttl)
+ },
+ };
+ client.push_any(db, &to, Duration::from_secs(ttl.unwrap_or(0).into()), payload).await;
+ Ok(EMPTY)
+}
+
+#[derive(Debug, Serialize)]
+#[allow(non_snake_case)]
+pub(crate) struct AttachedClient {
+ clientId: Option<String>,
+ deviceId: Option<DeviceID>,
+ sessionTokenId: Option<SessionID>,
+ refreshTokenId: Option<OauthTokenID>,
+ isCurrentSession: bool,
+ deviceType: Option<String>,
+ name: Option<String>,
+ #[serde(serialize_with = "serialize_dt_opt")]
+ createdTime: Option<DateTime<Utc>>,
+ // MISSING createdTimeFormatted
+ #[serde(serialize_with = "serialize_dt_opt")]
+ lastAccessTime: Option<DateTime<Utc>>,
+ // MISSING lastAccessTimeFormatted
+ // MISSING approximateLastAccessTime
+ // MISSING approximateLastAccessTimeFormatted
+ scope: Option<String>,
+ // MISSING location
+ // MISSING userAgent
+ // MISSING os
+}
+
+// MISSING filterIdleDevicesTimestamp
+#[get("/account/attached_clients")]
+pub(crate) async fn attached_clients(
+ db: &DbConn,
+ auth: Authenticated<(), WithVerifiedFxaLogin>,
+) -> auth::Result<Vec<AttachedClient>> {
+ let clients = db.get_attached_clients(&auth.context.uid).await?;
+ Ok(Json(
+ clients
+ .into_iter()
+ .map(|dev| AttachedClient {
+ clientId: dev.client_id,
+ deviceId: dev.device_id,
+ refreshTokenId: dev.refresh_token_id,
+ isCurrentSession: dev.session_token_id.as_ref() == Some(&auth.session),
+ sessionTokenId: dev.session_token_id,
+ deviceType: dev.device_type,
+ name: dev.name,
+ createdTime: dev.created_time,
+ lastAccessTime: dev.last_access_time,
+ scope: dev.scope,
+ })
+ .collect::<Vec<_>>(),
+ ))
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(deny_unknown_fields)]
+#[allow(non_snake_case)]
+pub(crate) struct DestroyAttachedClientReq {
+ // NOTE should be used to verify token deletion, but since we allow only a fixed
+ // number of clients that makes little sense.
+ #[allow(dead_code)]
+ clientId: Option<String>,
+ sessionTokenId: Option<SessionID>,
+ refreshTokenId: Option<OauthTokenID>,
+ deviceId: Option<DeviceID>,
+}
+
+#[post("/account/attached_client/destroy", data = "<req>")]
+pub(crate) async fn destroy_attached_client(
+ db: &DbConn,
+ db_pool: &Db,
+ defer: &DeferAction,
+ client: &State<Arc<PushClient>>,
+ req: Authenticated<DestroyAttachedClientReq, WithVerifiedFxaLogin>,
+) -> auth::Result<Empty> {
+ // only one id may be given, otherwise deleting things properly is more work.
+ if (req.body.sessionTokenId.is_some() as u32)
+ + (req.body.refreshTokenId.is_some() as u32)
+ + (req.body.deviceId.is_some() as u32)
+ != 1
+ {
+ return Err(auth::Error::InvalidParameter);
+ }
+
+ if let Some(dev) = req.body.deviceId {
+ let devs = db.get_devices(&req.context.uid).await;
+ db.delete_device(&req.context.uid, &dev).await?;
+ match devs {
+ Err(e) => warn!("device_disconnected push failed: {e}"),
+ Ok(devs) => {
+ defer.spawn_after_success("api::auth/account/attached_client/destroy(post)", {
+ let (client, db) = (Arc::clone(client), db_pool.clone());
+ async move {
+ let db = db.begin().await?;
+ client.device_disconnected(&db, &devs, &dev).await;
+ db.commit().await?;
+ Ok(())
+ }
+ })
+ },
+ };
+ }
+ if let Some(id) = req.body.sessionTokenId {
+ db.delete_session(&req.context.uid, &id).await?;
+ }
+ if let Some(id) = req.body.refreshTokenId {
+ db.delete_refresh_token(&id).await?;
+ }
+
+ Ok(EMPTY)
+}
diff --git a/src/api/auth/email.rs b/src/api/auth/email.rs
new file mode 100644
index 0000000..f206759
--- /dev/null
+++ b/src/api/auth/email.rs
@@ -0,0 +1,126 @@
+use std::sync::Arc;
+
+use rocket::{serde::json::Json, State};
+use serde::{Deserialize, Serialize};
+
+use crate::{
+ api::{
+ auth::{self, WithFxaLogin},
+ Empty, EMPTY,
+ },
+ auth::Authenticated,
+ db::{Db, DbConn},
+ mailer::Mailer,
+ push::PushClient,
+ types::UserID,
+ utils::DeferAction,
+};
+
+// MISSING get /recovery_emails
+// MISSING post /recovery_email
+// MISSING post /recovery_email/destroy
+// MISSING post /recovery_email/resend_code
+// MISSING post /recovery_email/set_primary
+// MISSING post /emails/reminders/cad
+// MISSING post /recovery_email/secondary/resend_code
+// MISSING post /recovery_email/secondary/verify_code
+
+#[derive(Debug, Serialize)]
+#[allow(non_snake_case)]
+pub(crate) struct StatusResp {
+ email: String,
+ verified: bool,
+ sessionVerified: bool,
+ emailVerified: bool,
+}
+
+// MISSING arg: reason
+#[get("/recovery_email/status")]
+pub(crate) async fn status(
+ db: &DbConn,
+ req: Authenticated<(), WithFxaLogin>,
+) -> auth::Result<StatusResp> {
+ let user = db.get_user_by_id(&req.context.uid).await?;
+ Ok(Json(StatusResp {
+ email: user.email,
+ verified: user.verified,
+ sessionVerified: req.context.verified,
+ emailVerified: user.verified,
+ }))
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(deny_unknown_fields)]
+pub(crate) struct VerifyReq {
+ uid: UserID,
+ code: String,
+ // MISSING service
+ // MISSING reminder
+ // MISSING type
+ // MISSING style
+ // MISSING marketingOptIn
+ // MISSING newsletters
+}
+
+#[post("/recovery_email/verify_code", data = "<req>")]
+pub(crate) async fn verify_code(
+ db: &DbConn,
+ db_pool: &Db,
+ defer: &DeferAction,
+ pc: &State<Arc<PushClient>>,
+ req: Json<VerifyReq>,
+) -> auth::Result<Empty> {
+ let code = match db.try_use_verify_code(&req.uid, &req.code).await? {
+ Some(code) => code,
+ None => return Err(auth::Error::InvalidVerificationCode),
+ };
+ db.set_user_verified(&req.uid).await?;
+ if let Some(sid) = code.session_id {
+ db.set_session_verified(&sid).await?;
+ }
+ match db.get_devices(&req.uid).await {
+ Ok(devs) => defer.spawn_after_success("api::auth/recovery_email/verify_code(post)", {
+ let (pc, db) = (Arc::clone(pc), db_pool.clone());
+ async move {
+ let db = db.begin().await?;
+ pc.account_verified(&db, &devs).await;
+ db.commit().await?;
+ Ok(())
+ }
+ }),
+ Err(e) => warn!("account_verified push failed: {e}"),
+ }
+ Ok(EMPTY)
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(deny_unknown_fields)]
+pub(crate) struct ResendReq {
+ // MISSING email
+ // MISSING service
+ // MISSING redirectTo
+ // MISSING resume
+ // MISSING style
+ // MISSING type
+}
+
+// MISSING arg: service
+// MISSING arg: type
+#[post("/recovery_email/resend_code", data = "<req>")]
+pub(crate) async fn resend_code(
+ db: &DbConn,
+ mailer: &State<Arc<Mailer>>,
+ req: Authenticated<ResendReq, WithFxaLogin>,
+) -> auth::Result<Empty> {
+ let (email, code) = match db.get_verify_code(&req.context.uid).await {
+ Ok(v) => v,
+ Err(_) => return Err(auth::Error::InvalidVerificationCode),
+ };
+ // NOTE we send the email in this context rather than a spawn to signal
+ // send errors to the client.
+ mailer.send_account_verify(&req.context.uid, &email, &code.code).await.map_err(|e| {
+ error!("failed to send email: {e}");
+ auth::Error::EmailFailed
+ })?;
+ Ok(EMPTY)
+}
diff --git a/src/api/auth/invite.rs b/src/api/auth/invite.rs
new file mode 100644
index 0000000..dd81540
--- /dev/null
+++ b/src/api/auth/invite.rs
@@ -0,0 +1,47 @@
+use base64::URL_SAFE_NO_PAD;
+use chrono::{Duration, Utc};
+use rocket::{http::uri::Reference, serde::json::Json, State};
+use serde::{Deserialize, Serialize};
+
+use crate::{api::auth, auth::Authenticated, crypto::SecretBytes, db::DbConn, Config};
+
+use super::WithVerifiedFxaLogin;
+
+pub(crate) async fn generate_invite_link(
+ db: &DbConn,
+ cfg: &Config,
+ ttl: Duration,
+) -> anyhow::Result<Reference<'static>> {
+ let code = base64::encode_config(&SecretBytes::<32>::generate().0, URL_SAFE_NO_PAD);
+ db.add_invite_code(&code, Utc::now() + ttl).await?;
+ Ok(Reference::parse_owned(format!("{}/#/register/{}", cfg.location, code))
+ .map_err(|e| anyhow!("url building failed at {e}"))?)
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(deny_unknown_fields)]
+pub(crate) struct GenerateReq {
+ ttl_hours: u32,
+}
+
+#[derive(Debug, Serialize)]
+pub(crate) struct GenerateResp {
+ url: Reference<'static>,
+}
+
+#[post("/generate", data = "<req>")]
+pub(crate) async fn generate(
+ db: &DbConn,
+ cfg: &State<Config>,
+ req: Authenticated<GenerateReq, WithVerifiedFxaLogin>,
+) -> auth::Result<GenerateResp> {
+ if !req.context.verified {
+ return Err(auth::Error::UnverifiedSession);
+ }
+ let user = db.get_user_by_id(&req.context.uid).await?;
+ if user.email != cfg.invite_admin_address {
+ return Err(auth::Error::InvalidAuthToken);
+ }
+ let url = generate_invite_link(&db, &cfg, Duration::hours(req.body.ttl_hours as i64)).await?;
+ Ok(Json(GenerateResp { url }))
+}
diff --git a/src/api/auth/mod.rs b/src/api/auth/mod.rs
new file mode 100644
index 0000000..2c6d34d
--- /dev/null
+++ b/src/api/auth/mod.rs
@@ -0,0 +1,238 @@
+use rocket::{
+ http::Status,
+ response::{self, Responder},
+ serde::json::Json,
+ Request, Response,
+};
+use serde_json::json;
+
+use crate::{
+ auth::Authenticated,
+ crypto::SecretBytes,
+ types::{OauthToken, SessionID, UserSession},
+};
+
+pub(crate) mod account;
+pub(crate) mod device;
+pub(crate) mod email;
+pub(crate) mod invite;
+pub(crate) mod oauth;
+pub(crate) mod password;
+pub(crate) mod session;
+
+// we don't provide any additional fields. some we can't provide anyway (eg
+// invalid parameter `validation`), others are implied by the request body (eg
+// account exists `email`), and *our* client doesn't care about them anyway
+#[derive(Debug)]
+pub(crate) enum Error {
+ AccountExists,
+ UnknownAccount,
+ IncorrectPassword,
+ UnverifiedAccount,
+ InvalidVerificationCode,
+ InvalidBody,
+ InvalidParameter,
+ MissingParameter,
+ InvalidSignature,
+ InvalidAuthToken,
+ RequestTooLarge,
+ IncorrectEmailCase,
+ UnknownDevice,
+ UnverifiedSession,
+ EmailFailed,
+ NoDeviceCommand,
+ UnknownClientID,
+ ScopesNotAllowed,
+
+ InviteOnly,
+ InviteNotFound,
+
+ Other(anyhow::Error),
+ UnexpectedStatus(Status),
+}
+
+#[rustfmt::skip]
+impl<'r> Responder<'r, 'static> for Error {
+ fn respond_to(self, request: &'r Request<'_>) -> response::Result<'static> {
+ let (code, errno, msg) = match self {
+ Error::AccountExists => (Status::BadRequest, 101, "account already exists"),
+ Error::UnknownAccount => (Status::BadRequest, 102, "unknown account"),
+ Error::IncorrectPassword => (Status::BadRequest, 103, "incorrect password"),
+ Error::UnverifiedAccount => (Status::BadRequest, 104, "unverified account"),
+ Error::InvalidVerificationCode => (Status::BadRequest, 105, "invalid verification code"),
+ Error::InvalidBody => (Status::BadRequest, 106, "invalid json in request body"),
+ Error::InvalidParameter => (Status::BadRequest, 107, "invalid parameter in request body"),
+ Error::MissingParameter => (Status::BadRequest, 108, "missing parameter in request body"),
+ Error::InvalidSignature => (Status::Unauthorized, 109, "invalid request signature"),
+ Error::InvalidAuthToken => (Status::Unauthorized, 110, "invalid authentication token"),
+ Error::RequestTooLarge => (Status::PayloadTooLarge, 113, "request too large"),
+ Error::IncorrectEmailCase => (Status::BadRequest, 120, "incorrect email case"),
+ Error::UnknownDevice => (Status::BadRequest, 123, "unknown device"),
+ Error::UnverifiedSession => (Status::BadRequest, 138, "unverified session"),
+ Error::EmailFailed => (Status::UnprocessableEntity, 151, "failed to send email"),
+ Error::NoDeviceCommand => (Status::BadRequest, 157, "unavailable device command"),
+ Error::UnknownClientID => (Status::BadRequest, 162, "unknown client_id"),
+ Error::ScopesNotAllowed => (Status::BadRequest, 169, "requested scopes not allowed"),
+ Error::InviteOnly => (Status::BadRequest, -1, "invite code required"),
+ Error::InviteNotFound => (Status::BadRequest, -2, "invite code not found"),
+ Error::Other(e) => {
+ error!("non-api error during request: {:#?}", e);
+ (Status::InternalServerError, 999, "internal error")
+ },
+ Error::UnexpectedStatus(s) => (s, 999, ""),
+ };
+ let body = json!({
+ "code": code.code,
+ "errno": errno,
+ "error": code.reason_lossy(),
+ "message": msg
+ });
+ Response::build_from(Json(body).respond_to(request)?).status(code).ok()
+ }
+}
+
+impl From<sqlx::Error> for Error {
+ fn from(e: sqlx::Error) -> Self {
+ Error::Other(anyhow!(e))
+ }
+}
+
+impl From<anyhow::Error> for Error {
+ fn from(e: anyhow::Error) -> Self {
+ Error::Other(e)
+ }
+}
+
+pub(crate) type Result<T> = std::result::Result<Json<T>, Error>;
+
+// hack marker type to convey that auth failed due to an unverified session.
+// without this the catcher could convert the Unauthorized error we get from
+// auth failures into just one thing, even though we have multiple causes.
+#[derive(Clone, Copy, Debug)]
+struct UsedUnverifiedSession;
+
+#[catch(default)]
+pub(crate) fn catch_all(status: Status, req: &Request<'_>) -> Error {
+ match req.local_cache(|| None) {
+ Some(UsedUnverifiedSession) => Error::UnverifiedSession,
+ _ => {
+ match status.code {
+ 401 => Error::InvalidSignature,
+ // these three are caused by Json<T> errors
+ 400 => Error::InvalidBody,
+ 413 => Error::RequestTooLarge,
+ 422 => Error::InvalidParameter,
+ // generic unauthorized instead of 404 for eg wrong method or nonexistant endpoints
+ 404 => Error::InvalidSignature,
+ _ => {
+ error!("caught unexpected error {status}");
+ Error::UnexpectedStatus(status)
+ },
+ }
+ },
+ }
+}
+
+#[derive(Debug)]
+pub(crate) struct WithFxaLogin;
+
+#[async_trait]
+impl crate::auth::AuthSource for WithFxaLogin {
+ type ID = SessionID;
+ type Context = UserSession;
+ async fn hawk(
+ r: &Request<'_>,
+ id: &SessionID,
+ ) -> anyhow::Result<(SecretBytes<32>, Self::Context)> {
+ let db = Authenticated::<(), Self>::get_conn(r).await?;
+ let k = db.use_session(id).await?;
+ Ok((k.req_hmac_key.0.clone(), k))
+ }
+ async fn bearer_token(
+ _: &Request<'_>,
+ _: &OauthToken,
+ ) -> anyhow::Result<(SessionID, Self::Context)> {
+ bail!("refresh tokens not allowed here");
+ }
+}
+
+#[derive(Debug)]
+pub(crate) struct WithVerifiedFxaLogin;
+
+#[async_trait]
+impl crate::auth::AuthSource for WithVerifiedFxaLogin {
+ type ID = SessionID;
+ type Context = UserSession;
+ async fn hawk(
+ r: &Request<'_>,
+ id: &SessionID,
+ ) -> anyhow::Result<(SecretBytes<32>, Self::Context)> {
+ let res = WithFxaLogin::hawk(r, id).await?;
+ match res.1.verified {
+ true => Ok(res),
+ false => {
+ r.local_cache(|| Some(UsedUnverifiedSession));
+ bail!("session not verified");
+ },
+ }
+ }
+ async fn bearer_token(
+ _: &Request<'_>,
+ _: &OauthToken,
+ ) -> anyhow::Result<(SessionID, Self::Context)> {
+ bail!("refresh tokens not allowed here");
+ }
+}
+
+#[derive(Debug)]
+pub(crate) struct WithSession;
+
+#[rocket::async_trait]
+impl crate::auth::AuthSource for WithSession {
+ type ID = SessionID;
+ type Context = UserSession;
+ async fn hawk(
+ r: &Request<'_>,
+ id: &SessionID,
+ ) -> anyhow::Result<(SecretBytes<32>, Self::Context)> {
+ WithFxaLogin::hawk(r, id).await
+ }
+ async fn bearer_token(
+ r: &Request<'_>,
+ token: &OauthToken,
+ ) -> anyhow::Result<(SessionID, Self::Context)> {
+ let db = Authenticated::<(), Self>::get_conn(r).await?;
+ Ok(db.use_session_from_refresh(&token.hash()).await?)
+ }
+}
+
+#[derive(Debug)]
+pub(crate) struct WithVerifiedSession;
+
+#[rocket::async_trait]
+impl crate::auth::AuthSource for WithVerifiedSession {
+ type ID = SessionID;
+ type Context = UserSession;
+ async fn hawk(
+ r: &Request<'_>,
+ id: &SessionID,
+ ) -> anyhow::Result<(SecretBytes<32>, Self::Context)> {
+ WithVerifiedFxaLogin::hawk(r, id).await
+ }
+ async fn bearer_token(
+ r: &Request<'_>,
+ token: &OauthToken,
+ ) -> anyhow::Result<(SessionID, Self::Context)> {
+ let db = Authenticated::<(), Self>::get_conn(r).await?;
+ let res = db.use_session_from_refresh(&token.hash()).await?;
+ match res.1.verified {
+ true => Ok(res),
+ false => {
+ // technically unreachable because generating a refresh token requires a
+ // valid fxa session
+ r.local_cache(|| Some(UsedUnverifiedSession));
+ bail!("session not verified");
+ },
+ }
+ }
+}
diff --git a/src/api/auth/oauth.rs b/src/api/auth/oauth.rs
new file mode 100644
index 0000000..b0ed8ee
--- /dev/null
+++ b/src/api/auth/oauth.rs
@@ -0,0 +1,433 @@
+use std::collections::HashMap;
+
+use chrono::{DateTime, Duration, Local, Utc};
+use rocket::serde::json::Json;
+use serde::{Deserialize, Serialize};
+use serde_json::Value;
+use sha2::Digest;
+use subtle::ConstantTimeEq;
+
+use crate::api::auth::WithVerifiedFxaLogin;
+use crate::db::DbConn;
+use crate::types::oauth::{Scope, ScopeSet};
+use crate::{
+ api::{auth, serialize_dt},
+ auth::Authenticated,
+ crypto::{SecretBytes, SessionCredentials},
+ types::{
+ HawkKey, OauthAccessToken, OauthAccessType, OauthAuthorization, OauthAuthorizationID,
+ OauthRefreshToken, OauthToken, OauthTokenID, SessionID, UserID,
+ },
+};
+
+// MISSING get /oauth/client/{client_id}
+
+pub(crate) struct OauthClient {
+ pub(crate) id: &'static str,
+ // NOTE not read so far, but good to have
+ #[allow(dead_code)]
+ pub(crate) name: &'static str,
+ pub(crate) scopes: &'static [Scope<'static>],
+}
+
+const SESSION_SCOPE: Scope = Scope::borrowed("https://identity.mozilla.com/tokens/session");
+
+// NOTE the telemetry scopes don't seem to be needed. since we'd have to give
+// out keys for them (fxa does) we'll exclude them entirely.
+// see fxa-auth-server/config/dev.json for lists of predefined clients and permissions.
+pub(crate) const OAUTH_CLIENTS: [OauthClient; 2] = [
+ OauthClient {
+ id: "5882386c6d801776",
+ name: "Firefox",
+ scopes: &[
+ Scope::borrowed("profile:write"),
+ Scope::borrowed("https://identity.mozilla.com/apps/oldsync"),
+ Scope::borrowed("https://identity.mozilla.com/tokens/session"),
+ // "https://identity.mozilla.com/ids/ecosystem_telemetry",
+ ],
+ },
+ OauthClient {
+ id: "a2270f727f45f648",
+ name: "Fenix",
+ scopes: &[
+ Scope::borrowed("profile"),
+ Scope::borrowed("https://identity.mozilla.com/apps/oldsync"),
+ Scope::borrowed("https://identity.mozilla.com/tokens/session"),
+ // "https://identity.mozilla.com/ids/ecosystem_telemetry",
+ ],
+ },
+];
+
+// NOTE fxa dev config allows scoped keys only for:
+// - https://identity.mozilla.com/apps/notes
+// - https://identity.mozilla.com/apps/oldsync
+// - https://identity.mozilla.com/ids/ecosystem_telemetry
+// - https://identity.mozilla.com/apps/send
+// we only implement sync because notes and send are dead and
+// telemetry is of no use to us
+const SCOPES_WITH_KEYS: [Scope; 1] = [Scope::borrowed("https://identity.mozilla.com/apps/oldsync")];
+
+fn check_client_and_scopes(
+ client_id: &str,
+ scope: &ScopeSet,
+) -> Result<&'static OauthClient, auth::Error> {
+ let desc = match OAUTH_CLIENTS.iter().find(|&s| s.id == client_id) {
+ Some(d) => d,
+ None => return Err(auth::Error::UnknownClientID),
+ };
+ if !scope.is_allowed_by(desc.scopes) {
+ return Err(auth::Error::ScopesNotAllowed);
+ }
+ Ok(desc)
+}
+
+#[derive(Debug, Deserialize)]
+pub(crate) enum PkceChallengeType {
+ S256,
+}
+
+#[derive(Debug, Deserialize)]
+pub(crate) enum AuthResponseType {
+ #[serde(rename = "code")]
+ Code,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(deny_unknown_fields)]
+pub(crate) struct OauthAuthReq {
+ client_id: String,
+ state: String,
+ keys_jwe: Option<String>,
+ scope: ScopeSet,
+ access_type: OauthAccessType,
+ // NOTE we don't support confidential clients, so PKCE is mandatory
+ code_challenge: String,
+
+ // MISSING redirect_uri
+ // MISSING acr_value
+
+ // for validation during deserialization only
+ #[allow(dead_code)]
+ code_challenge_method: PkceChallengeType,
+ #[allow(dead_code)]
+ response_type: AuthResponseType,
+}
+
+#[derive(Debug, Serialize)]
+pub(crate) struct OauthAuthResp {
+ code: OauthAuthorizationID,
+ state: String,
+ // MISSING redirect
+}
+
+#[post("/oauth/authorization", data = "<req>")]
+pub(crate) async fn authorization(
+ db: &DbConn,
+ req: Authenticated<OauthAuthReq, WithVerifiedFxaLogin>,
+) -> auth::Result<OauthAuthResp> {
+ check_client_and_scopes(&req.body.client_id, &req.body.scope)?;
+ let id = OauthAuthorizationID::random();
+ db.add_oauth_authorization(
+ &id,
+ OauthAuthorization {
+ user_id: req.context.uid,
+ client_id: req.body.client_id,
+ scope: req.body.scope,
+ access_type: req.body.access_type,
+ code_challenge: req.body.code_challenge,
+ keys_jwe: req.body.keys_jwe,
+ auth_at: req.context.created_at,
+ },
+ )
+ .await?;
+ Ok(Json(OauthAuthResp { code: id, state: req.body.state }))
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(deny_unknown_fields)]
+pub(crate) struct ScopedKeysReq {
+ client_id: String,
+ scope: ScopeSet,
+}
+
+#[derive(Debug, Serialize)]
+#[allow(non_snake_case)]
+pub(crate) struct ScopedKey {
+ identifier: String,
+ keyRotationSecret: &'static str,
+ keyRotationTimestamp: u64,
+}
+
+#[post("/account/scoped-key-data", data = "<data>")]
+pub(crate) async fn scoped_key_data(
+ data: Authenticated<ScopedKeysReq, WithVerifiedFxaLogin>,
+) -> auth::Result<HashMap<String, ScopedKey>> {
+ check_client_and_scopes(&data.body.client_id, &data.body.scope)?;
+ // like fxa we'll stub out key rotation handling entirely and return the same constants.
+ Ok(Json(
+ data.body
+ .scope
+ .split()
+ .filter(|s| SCOPES_WITH_KEYS.contains(s))
+ .map(|scope| {
+ (
+ scope.to_string(),
+ ScopedKey {
+ identifier: scope.to_string(),
+ keyRotationSecret:
+ "0000000000000000000000000000000000000000000000000000000000000000",
+ keyRotationTimestamp: 0,
+ },
+ )
+ })
+ .collect(),
+ ))
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(deny_unknown_fields)]
+pub(crate) struct OauthDestroy {
+ client_id: String,
+ token: OauthToken,
+}
+
+#[post("/oauth/destroy", data = "<data>")]
+pub(crate) async fn destroy(db: &DbConn, data: Json<OauthDestroy>) -> auth::Result<()> {
+ // MISSING api spec allows an optional basic auth header, but what for?
+ // TODO fxa also checks the authorization header if present, but firefox doesn't send it
+ let client_id = if let Ok(t) = db.get_refresh_token(&data.token.hash()).await {
+ t.client_id
+ } else if let Ok(t) = db.get_access_token(&data.token.hash()).await {
+ t.client_id
+ } else {
+ return Err(auth::Error::InvalidParameter);
+ };
+ // fxa does constant-time checks for client_id, do that here too.
+ if client_id.as_bytes().ct_eq(data.client_id.as_bytes()).into() {
+ db.delete_oauth_token(&data.token.hash()).await?;
+ Ok(Json(()))
+ } else {
+ Err(auth::Error::InvalidParameter)
+ }
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(tag = "grant_type")]
+enum TokenReqDetails {
+ // we can't use deny_unknown_fields when flatten is involved, and multiple
+ // flattens in the same struct cause problems if one of them is greedy (like map).
+ // flatten an extra map into every variant instead and check each of them.
+ #[serde(rename = "authorization_code")]
+ AuthCode {
+ code: OauthAuthorizationID,
+ code_verifier: String,
+ // NOTE only useful with redirect flows, which we kinda don't support at all
+ #[allow(dead_code)]
+ redirect_uri: Option<String>,
+ #[serde(flatten)]
+ extra: HashMap<String, Value>,
+ },
+ #[serde(rename = "refresh_token")]
+ RefreshToken {
+ refresh_token: OauthToken,
+ scope: ScopeSet,
+ #[serde(flatten)]
+ extra: HashMap<String, Value>,
+ },
+ #[serde(rename = "fxa-credentials")]
+ FxaCreds {
+ scope: ScopeSet,
+ access_type: Option<OauthAccessType>,
+ #[serde(flatten)]
+ extra: HashMap<String, Value>,
+ },
+}
+
+impl TokenReqDetails {
+ fn extra_is_empty(&self) -> bool {
+ match self {
+ TokenReqDetails::AuthCode { extra, .. } => extra.is_empty(),
+ TokenReqDetails::RefreshToken { extra, .. } => extra.is_empty(),
+ TokenReqDetails::FxaCreds { extra, .. } => extra.is_empty(),
+ }
+ }
+}
+
+// TODO log errors in all the places
+
+#[derive(Debug, Deserialize)]
+pub(crate) struct TokenReq {
+ client_id: String,
+ ttl: Option<u32>,
+ #[serde(flatten)]
+ details: TokenReqDetails,
+ // MISSING client_secret
+ // MISSING redirect_uri
+ // MISSING ttl
+ // MISSING ppid_seed
+ // MISSING resource
+}
+
+#[derive(Debug, Serialize)]
+pub(crate) enum TokenType {
+ #[serde(rename = "bearer")]
+ Bearer,
+}
+
+#[derive(Debug, Serialize)]
+pub(crate) struct TokenResp {
+ access_token: OauthToken,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ refresh_token: Option<OauthToken>,
+ // MISSING id_token
+ #[serde(skip_serializing_if = "Option::is_none")]
+ session_token: Option<String>,
+ scope: ScopeSet,
+ token_type: TokenType,
+ expires_in: u32,
+ #[serde(serialize_with = "serialize_dt")]
+ auth_at: DateTime<Utc>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ keys_jwe: Option<String>,
+}
+
+#[post("/oauth/token", data = "<req>", rank = 1)]
+pub(crate) async fn token_authenticated(
+ db: &DbConn,
+ req: Authenticated<TokenReq, WithVerifiedFxaLogin>,
+) -> auth::Result<TokenResp> {
+ match &req.body.details {
+ TokenReqDetails::FxaCreds { .. } => (),
+ _ => return Err(auth::Error::InvalidParameter),
+ }
+ token_impl(
+ db,
+ Some(req.context.uid),
+ Some(req.context.created_at),
+ req.body,
+ None,
+ Some(req.session.clone()),
+ )
+ .await
+}
+
+#[post("/oauth/token", data = "<req>", rank = 2)]
+pub(crate) async fn token_unauthenticated(
+ db: &DbConn,
+ req: Json<TokenReq>,
+) -> auth::Result<TokenResp> {
+ let (parent_refresh, auth_at) = match &req.details {
+ TokenReqDetails::RefreshToken { refresh_token, .. } => {
+ let session = db.use_session_from_refresh(&refresh_token.hash()).await?;
+ (Some(refresh_token.hash()), Some(session.1.created_at))
+ },
+ TokenReqDetails::AuthCode { .. } => (None, None),
+ _ => return Err(auth::Error::InvalidParameter),
+ };
+ token_impl(db, None, auth_at, req.into_inner(), parent_refresh, None).await
+}
+
+async fn token_impl(
+ db: &DbConn,
+ user_id: Option<UserID>,
+ auth_at: Option<DateTime<Utc>>,
+ req: TokenReq,
+ parent_refresh: Option<OauthTokenID>,
+ parent_session: Option<SessionID>,
+) -> auth::Result<TokenResp> {
+ if !req.details.extra_is_empty() {
+ return Err(auth::Error::InvalidParameter);
+ }
+ let ttl = req.ttl.unwrap_or(3600).clamp(0, 7 * 86400);
+
+ let (auth_at, scope, keys_jwe, user_id, access_type) = match req.details {
+ TokenReqDetails::AuthCode { code, code_verifier, .. } => {
+ let auth = match db.take_oauth_authorization(&code).await {
+ Ok(a) => a,
+ Err(_) => return Err(auth::Error::InvalidAuthToken),
+ };
+ if !bool::from(auth.client_id.as_bytes().ct_eq(req.client_id.as_bytes())) {
+ return Err(auth::Error::UnknownClientID);
+ }
+ let mut sha = sha2::Sha256::new();
+ sha.update(code_verifier.as_bytes());
+ let challenge = base64::encode_config(&sha.finalize(), base64::URL_SAFE_NO_PAD);
+ if !bool::from(challenge.as_bytes().ct_eq(auth.code_challenge.as_bytes())) {
+ return Err(auth::Error::InvalidParameter);
+ }
+ (auth.auth_at, auth.scope, auth.keys_jwe, auth.user_id, Some(auth.access_type))
+ },
+ TokenReqDetails::RefreshToken { refresh_token, scope, .. } => {
+ let auth_at =
+ auth_at.expect("oauth token requests with refresh token must set auth_at");
+ let base = db.get_refresh_token(&refresh_token.hash()).await?;
+ if !bool::from(base.client_id.as_bytes().ct_eq(req.client_id.as_bytes())) {
+ return Err(auth::Error::UnknownClientID);
+ }
+ check_client_and_scopes(&req.client_id, &scope)?;
+ if !base.scope.implies_all(&scope) {
+ return Err(auth::Error::ScopesNotAllowed);
+ }
+ (auth_at, scope, None, base.user_id, None)
+ },
+ TokenReqDetails::FxaCreds { scope, access_type, .. } => {
+ let user_id = user_id.expect("oauth token requests with fxa must set user_id");
+ let auth_at = auth_at.expect("oauth token requests with fxa must set auth_at");
+ check_client_and_scopes(&req.client_id, &scope)?;
+ (auth_at, scope, None, user_id, access_type)
+ },
+ };
+
+ let access_token = OauthToken::random();
+ db.add_access_token(
+ &access_token.hash(),
+ OauthAccessToken {
+ user_id: user_id.clone(),
+ client_id: req.client_id.clone(),
+ scope: scope.clone(),
+ parent_refresh,
+ parent_session,
+ expires_at: (Local::now() + Duration::seconds(ttl.into())).into(),
+ },
+ )
+ .await?;
+
+ let (refresh_token, session_token) = if access_type == Some(OauthAccessType::Offline) {
+ let (session_token, session_id) = if scope.implies(&SESSION_SCOPE) {
+ let session_token = SecretBytes::generate();
+ let session = SessionCredentials::derive(&session_token);
+ let session_id = SessionID(session.token_id.0);
+ db.add_session(session_id.clone(), &user_id, HawkKey(session.req_hmac_key), true, None)
+ .await?;
+ (Some(session_token.0), Some(SessionID(session.token_id.0)))
+ } else {
+ (None, None)
+ };
+
+ let refresh_token = OauthToken::random();
+ db.add_refresh_token(
+ &refresh_token.hash(),
+ OauthRefreshToken {
+ user_id,
+ client_id: req.client_id,
+ scope: scope.remove(&SESSION_SCOPE),
+ session_id,
+ },
+ )
+ .await?;
+ (Some(refresh_token), session_token)
+ } else {
+ (None, None)
+ };
+
+ Ok(Json(TokenResp {
+ access_token,
+ refresh_token,
+ session_token: session_token.map(hex::encode),
+ scope: scope.remove(&SESSION_SCOPE),
+ token_type: TokenType::Bearer,
+ expires_in: ttl,
+ auth_at,
+ keys_jwe,
+ }))
+}
diff --git a/src/api/auth/password.rs b/src/api/auth/password.rs
new file mode 100644
index 0000000..0eeab4f
--- /dev/null
+++ b/src/api/auth/password.rs
@@ -0,0 +1,260 @@
+use std::sync::Arc;
+
+use anyhow::Result;
+use password_hash::SaltString;
+use rocket::{request::FromRequest, serde::json::Json, Request, State};
+use serde::{Deserialize, Serialize};
+use validator::Validate;
+
+use crate::{
+ api::auth,
+ auth::{AuthSource, Authenticated},
+ crypto::{AccountResetReq, AuthPW, KeyBundle, KeyFetchReq, PasswordChangeReq, SecretBytes},
+ db::{Db, DbConn},
+ mailer::Mailer,
+ types::{
+ AccountResetID, HawkKey, KeyFetchID, OauthToken, PasswordChangeID, SecretKey, UserID,
+ VerifyHash,
+ },
+};
+
+// MISSING get /password/forgot/status
+// MISSING post /password/create
+// MISSING post /password/forgot/resend_code
+
+#[derive(Debug, Deserialize, Validate)]
+#[serde(deny_unknown_fields)]
+#[allow(non_snake_case)]
+pub(crate) struct ChangeStartReq {
+ #[validate(email, length(min = 3, max = 256))]
+ email: String,
+ oldAuthPW: AuthPW,
+}
+
+#[derive(Debug, Serialize)]
+#[allow(non_snake_case)]
+pub(crate) struct ChangeStartResp {
+ keyFetchToken: SecretBytes<32>,
+ passwordChangeToken: SecretBytes<32>,
+}
+
+#[post("/password/change/start", data = "<data>")]
+pub(crate) async fn change_start(
+ db: &DbConn,
+ data: Json<ChangeStartReq>,
+) -> auth::Result<ChangeStartResp> {
+ let data = data.into_inner();
+ data.validate().map_err(|_| auth::Error::InvalidParameter)?;
+
+ let (uid, user) = db.get_user(&data.email).await.map_err(|_| auth::Error::UnknownAccount)?;
+ if user.email != data.email {
+ return Err(auth::Error::IncorrectEmailCase);
+ }
+ if !user.verified {
+ return Err(auth::Error::UnverifiedAccount);
+ }
+
+ let stretched = data.oldAuthPW.stretch(user.auth_salt.as_salt())?;
+ if stretched.verify_hash() != user.verify_hash.0 {
+ return Err(auth::Error::IncorrectPassword);
+ }
+
+ let change_token = SecretBytes::generate();
+ let change_req = PasswordChangeReq::from_change_token(&change_token);
+ let key_fetch_token = SecretBytes::generate();
+ let key_req = KeyFetchReq::from_token(&key_fetch_token);
+ let wrapped = key_req.derive_resp().wrap_keys(&KeyBundle {
+ ka: user.ka.0.clone(),
+ wrap_kb: stretched.decrypt_wwkb(&user.wrapwrap_kb.0),
+ });
+ db.add_key_fetch(KeyFetchID(key_req.token_id.0), &HawkKey(key_req.req_hmac_key), &wrapped)
+ .await?;
+ db.add_password_change(
+ &uid,
+ &PasswordChangeID(change_req.token_id.0),
+ &HawkKey(change_req.req_hmac_key),
+ None,
+ )
+ .await?;
+
+ Ok(Json(ChangeStartResp { keyFetchToken: key_fetch_token, passwordChangeToken: change_token }))
+}
+
+// NOTE we use a plain bool here and in the db instead of an enum because
+// enums aren't usable in const generics in stable.
+#[derive(Debug)]
+pub(crate) struct WithChangeToken<const IS_FORGOT: bool>;
+
+#[async_trait]
+impl<const IS_FORGOT: bool> AuthSource for WithChangeToken<IS_FORGOT> {
+ type ID = PasswordChangeID;
+ type Context = (UserID, Option<String>);
+ async fn hawk(
+ r: &Request<'_>,
+ id: &PasswordChangeID,
+ ) -> Result<(SecretBytes<32>, Self::Context)> {
+ // unlike key fetch we'll use a separate transaction here since the body of the
+ // handler can fail.
+ let pool = <&Db as FromRequest>::from_request(r)
+ .await
+ .success_or_else(|| anyhow!("could not open db connection"))?;
+ let db = pool.begin().await?;
+ let result = db.finish_password_change(id, IS_FORGOT).await.map(|(h, ctx)| (h.0, ctx))?;
+ db.commit().await?;
+ Ok(result)
+ }
+ async fn bearer_token(
+ _: &Request<'_>,
+ _: &OauthToken,
+ ) -> Result<(PasswordChangeID, Self::Context)> {
+ bail!("invalid password change authentication")
+ }
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(deny_unknown_fields)]
+#[allow(non_snake_case)]
+pub(crate) struct ChangeFinishReq {
+ authPW: AuthPW,
+ wrapKb: SecretBytes<32>,
+ // MISSING sessionToken
+}
+
+#[derive(Debug, Serialize)]
+#[allow(non_snake_case)]
+pub(crate) struct ChangeFinishResp {
+ // NOTE we intentionally deviate from mozilla here. mozilla creates a new
+ // session if sessionToken is set in the request, but we use the "legacy"
+ // password change mechanism that leaves the requesting session and its
+ // device and keys intact. as such this struct is intentionally empty.
+ //
+ // MISSING uid
+ // MISSING sessionToken
+ // MISSING verified
+ // MISSING authAt
+ // MISSING keyFetchToken
+}
+
+#[post("/password/change/finish", data = "<data>")]
+pub(crate) async fn change_finish(
+ db: &DbConn,
+ mailer: &State<Arc<Mailer>>,
+ data: Authenticated<ChangeFinishReq, WithChangeToken<false>>,
+) -> auth::Result<ChangeFinishResp> {
+ let user = db.get_user_by_id(&data.context.0).await?;
+
+ let auth_salt = SaltString::generate(rand::rngs::OsRng);
+ let stretched = data.body.authPW.stretch(auth_salt.as_salt())?;
+ let verify_hash = stretched.verify_hash();
+ let wrapwrap_kb = stretched.rewrap_wkb(&data.body.wrapKb);
+
+ db.change_user_auth(
+ &data.context.0,
+ auth_salt,
+ SecretKey(wrapwrap_kb),
+ VerifyHash(verify_hash),
+ )
+ .await?;
+
+ // NOTE password_changed/password_reset pushes seem to have no effect, so skip them.
+
+ mailer
+ .send_password_changed(&user.email)
+ .await
+ .map_err(|e| {
+ warn!("password change email send failed: {e}");
+ })
+ .ok();
+
+ Ok(Json(ChangeFinishResp {}))
+}
+
+#[derive(Debug, Deserialize, Validate)]
+#[serde(deny_unknown_fields)]
+#[allow(non_snake_case)]
+pub(crate) struct ForgotStartReq {
+ #[validate(email, length(min = 3, max = 256))]
+ email: String,
+}
+
+#[derive(Debug, Serialize)]
+#[allow(non_snake_case)]
+pub(crate) struct ForgotStartResp {
+ passwordForgotToken: SecretBytes<32>,
+ ttl: u32,
+ codeLength: u32,
+ tries: u32,
+}
+
+#[post("/password/forgot/send_code", data = "<data>")]
+pub(crate) async fn forgot_start(
+ db: &DbConn,
+ mailer: &State<Arc<Mailer>>,
+ data: Json<ForgotStartReq>,
+) -> auth::Result<ForgotStartResp> {
+ let data = data.into_inner();
+ data.validate().map_err(|_| auth::Error::InvalidParameter)?;
+
+ let (uid, user) = db.get_user(&data.email).await.map_err(|_| auth::Error::UnknownAccount)?;
+ if user.email != data.email {
+ return Err(auth::Error::IncorrectEmailCase);
+ }
+ if !user.verified {
+ return Err(auth::Error::UnverifiedAccount);
+ }
+
+ let forgot_code = hex::encode(SecretBytes::<16>::generate().0);
+ let forgot_token = SecretBytes::generate();
+ let forgot_req = PasswordChangeReq::from_forgot_token(&forgot_token);
+ db.add_password_change(
+ &uid,
+ &PasswordChangeID(forgot_req.token_id.0),
+ &HawkKey(forgot_req.req_hmac_key),
+ Some(&forgot_code),
+ )
+ .await?;
+
+ mailer.send_password_forgot(&user.email, &forgot_code).await?;
+
+ Ok(Json(ForgotStartResp {
+ passwordForgotToken: forgot_token,
+ ttl: 300,
+ codeLength: 16,
+ tries: 1,
+ }))
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(deny_unknown_fields)]
+#[allow(non_snake_case)]
+pub(crate) struct ForgotFinishReq {
+ code: String,
+ // MISSING accountResetWithRecoveryKey
+}
+
+#[derive(Debug, Serialize)]
+#[allow(non_snake_case)]
+pub(crate) struct ForgotFinishResp {
+ accountResetToken: SecretBytes<32>,
+}
+
+#[post("/password/forgot/verify_code", data = "<data>")]
+pub(crate) async fn forgot_finish(
+ db: &DbConn,
+ data: Authenticated<ForgotFinishReq, WithChangeToken<true>>,
+) -> auth::Result<ForgotFinishResp> {
+ if Some(data.body.code) != data.context.1 {
+ return Err(auth::Error::InvalidVerificationCode);
+ }
+
+ let reset_token = SecretBytes::generate();
+ let reset_req = AccountResetReq::from_token(&reset_token);
+ db.add_account_reset(
+ &data.context.0,
+ &AccountResetID(reset_req.token_id.0),
+ &HawkKey(reset_req.req_hmac_key),
+ )
+ .await?;
+
+ Ok(Json(ForgotFinishResp { accountResetToken: reset_token }))
+}
diff --git a/src/api/auth/session.rs b/src/api/auth/session.rs
new file mode 100644
index 0000000..5911b92
--- /dev/null
+++ b/src/api/auth/session.rs
@@ -0,0 +1,107 @@
+use std::sync::Arc;
+
+use rocket::serde::json::Json;
+use rocket::State;
+use serde::{Deserialize, Serialize};
+
+use crate::api::auth::WithFxaLogin;
+use crate::api::{auth, Empty, EMPTY};
+use crate::auth::Authenticated;
+use crate::db::Db;
+use crate::db::DbConn;
+use crate::mailer::Mailer;
+use crate::push::PushClient;
+use crate::types::{SessionID, UserID};
+use crate::utils::DeferAction;
+
+// MISSING post /session/duplicate
+// MISSING post /session/reauth
+// MISSING post /session/verify/send_push
+
+#[derive(Debug, Serialize)]
+pub(crate) struct StatusResp {
+ state: &'static str, // what does this *do*?
+ uid: UserID,
+}
+
+#[get("/session/status")]
+pub(crate) async fn status(req: Authenticated<(), WithFxaLogin>) -> auth::Result<StatusResp> {
+ Ok(Json(StatusResp { state: "", uid: req.context.uid }))
+}
+
+#[post("/session/resend_code", data = "<req>")]
+pub(crate) async fn resend_code(
+ db: &DbConn,
+ mailer: &State<Arc<Mailer>>,
+ req: Authenticated<Empty, WithFxaLogin>,
+) -> auth::Result<Empty> {
+ let code = match req.context.verify_code {
+ Some(code) => code,
+ _ => return Err(auth::Error::InvalidVerificationCode),
+ };
+
+ let user = db.get_user_by_id(&req.context.uid).await?;
+ mailer.send_session_verify(&user.email, &code).await.map_err(|e| {
+ error!("failed to send email: {e}");
+ auth::Error::EmailFailed
+ })?;
+ Ok(EMPTY)
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(deny_unknown_fields)]
+pub(crate) struct VerifyReq {
+ code: String,
+ // MISSING service
+ // MISSING scopes
+ // MISSING marketingOptIn
+ // MISSING newsletters
+}
+
+#[post("/session/verify_code", data = "<req>")]
+pub(crate) async fn verify_code(
+ db: &DbConn,
+ req: Authenticated<VerifyReq, WithFxaLogin>,
+) -> auth::Result<Empty> {
+ if req.context.verify_code.as_ref() != Some(&req.body.code) {
+ return Err(auth::Error::InvalidVerificationCode);
+ }
+ db.set_session_verified(&req.session).await?;
+ Ok(EMPTY)
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(deny_unknown_fields)]
+pub(crate) struct DestroyReq {
+ custom_session_id: Option<SessionID>,
+}
+
+#[post("/session/destroy", data = "<data>")]
+pub(crate) async fn destroy(
+ db: &DbConn,
+ db_pool: &Db,
+ defer: &DeferAction,
+ client: &State<Arc<PushClient>>,
+ data: Authenticated<DestroyReq, WithFxaLogin>,
+) -> auth::Result<Empty> {
+ if data.body.custom_session_id.is_some() && !data.context.verified {
+ return Err(auth::Error::UnverifiedSession);
+ }
+ let id = data.body.custom_session_id.as_ref().unwrap_or(&data.session);
+ db.delete_session(&data.context.uid, id).await.map_err(|_| auth::Error::UnknownDevice)?;
+ if let Some(id) = data.context.device_id {
+ match db.get_devices(&data.context.uid).await {
+ Err(e) => warn!("device_disconnected push failed: {e}"),
+ Ok(devs) => defer.spawn_after_success("api::auth/session/destroy(post)", {
+ let (client, db) = (Arc::clone(client), db_pool.clone());
+ async move {
+ let db = db.begin().await?;
+ client.device_disconnected(&db, &devs, &id).await;
+ db.commit().await?;
+ Ok(())
+ }
+ }),
+ };
+ }
+ Ok(EMPTY)
+}
diff --git a/src/api/mod.rs b/src/api/mod.rs
new file mode 100644
index 0000000..1831659
--- /dev/null
+++ b/src/api/mod.rs
@@ -0,0 +1,32 @@
+use chrono::{DateTime, TimeZone};
+use rocket::serde::json::Json;
+use serde::{Deserialize, Serialize, Serializer};
+
+pub(crate) mod auth;
+pub(crate) mod oauth;
+pub(crate) mod profile;
+
+pub fn serialize_dt<S, TZ>(dt: &DateTime<TZ>, ser: S) -> Result<S::Ok, S::Error>
+where
+ S: Serializer,
+ TZ: TimeZone,
+{
+ ser.serialize_i64(dt.timestamp())
+}
+
+pub fn serialize_dt_opt<S, TZ>(dt: &Option<DateTime<TZ>>, ser: S) -> Result<S::Ok, S::Error>
+where
+ S: Serializer,
+ TZ: TimeZone,
+{
+ match dt {
+ Some(dt) => serialize_dt(dt, ser),
+ None => ser.serialize_unit(),
+ }
+}
+
+#[derive(Clone, Copy, Serialize, Deserialize)]
+#[serde(deny_unknown_fields)]
+pub struct Empty {}
+
+pub const EMPTY: Json<Empty> = Json(Empty {});
diff --git a/src/api/oauth.rs b/src/api/oauth.rs
new file mode 100644
index 0000000..0519125
--- /dev/null
+++ b/src/api/oauth.rs
@@ -0,0 +1,163 @@
+use rocket::{
+ http::Status,
+ response::{self, Responder},
+ serde::json::Json,
+ Request, Response,
+};
+use serde::{Deserialize, Serialize};
+use serde_json::json;
+
+use crate::{
+ api::Empty,
+ types::{OauthToken, UserID},
+};
+use crate::{db::DbConn, types::oauth::Scope};
+
+use super::EMPTY;
+
+// we don't provide any additional fields. some we can't provide anyway (eg
+// invalid parameter `validation`), others are implied by the request body (eg
+// account exists `email`), and *our* client doesn't care about them anyway
+#[derive(Debug)]
+pub(crate) enum Error {
+ InvalidParameter,
+ Unauthorized,
+ PayloadTooLarge,
+
+ Other(anyhow::Error),
+ UnexpectedStatus(Status),
+}
+
+#[rustfmt::skip]
+impl<'r> Responder<'r, 'static> for Error {
+ fn respond_to(self, request: &'r Request<'_>) -> response::Result<'static> {
+ let (code, errno, msg) = match self {
+ Error::InvalidParameter => (Status::BadRequest, 109, "invalid request parameter"),
+ Error::Unauthorized => (Status::Forbidden, 111, "unauthorized"),
+ Error::PayloadTooLarge => (Status::PayloadTooLarge, 999, "payload too large"),
+ Error::Other(e) => {
+ error!("non-api error during request: {:?}", e);
+ (Status::InternalServerError, 999, "internal error")
+ },
+ Error::UnexpectedStatus(s) => (s, 999, ""),
+ };
+ let body = json!({
+ "code": code.code,
+ "errno": errno,
+ "error": code.reason_lossy(),
+ "message": msg
+ });
+ Response::build_from(Json(body).respond_to(request)?).status(code).ok()
+ }
+}
+
+impl From<sqlx::Error> for Error {
+ fn from(e: sqlx::Error) -> Self {
+ Error::Other(anyhow!(e))
+ }
+}
+
+impl From<anyhow::Error> for Error {
+ fn from(e: anyhow::Error) -> Self {
+ Error::Other(e)
+ }
+}
+
+pub(crate) type Result<T> = std::result::Result<Json<T>, Error>;
+
+#[catch(default)]
+pub(crate) fn catch_all(status: Status, _r: &Request<'_>) -> Error {
+ match status.code {
+ 401 => Error::Unauthorized,
+ // these three are caused by Json<T> errors
+ 400 => Error::InvalidParameter,
+ 413 => Error::PayloadTooLarge,
+ 422 => Error::InvalidParameter,
+ // generic unauthorized instead of 404 for eg wrong method or nonexistant endpoints
+ 404 => Error::Unauthorized,
+ _ => {
+ error!("caught unexpected error {status}");
+ Error::UnexpectedStatus(status)
+ },
+ }
+}
+
+fn map_error(e: sqlx::Error) -> Error {
+ match &e {
+ sqlx::Error::RowNotFound => Error::InvalidParameter,
+ _ => Error::Other(anyhow!(e)),
+ }
+}
+
+// MISSING GET /v1/authorization
+// MISSING POST /v1/authorization
+// MISSING POST /v1/authorized-clients
+// MISSING POST /v1/authorized-clients/destroy
+// MISSING GET /v1/client/:id
+// MISSING POST /v1/introspect
+// MISSING GET /v1/jwks
+// MISSING POST /v1/key-data
+// MISSING POST /v1/token
+// MISSING POST /v1/verify
+
+#[derive(Debug, Deserialize)]
+#[serde(deny_unknown_fields)]
+pub(crate) struct DestroyReq {
+ access_token: Option<OauthToken>,
+ refresh_token: Option<OauthToken>,
+ // NOTE this field does not exist in the spec, but fenix sends it
+ token: Option<OauthToken>,
+ // MISSING client_id
+ // MISSING client_secret
+ // MISSING refresh_token_id
+}
+
+#[post("/destroy", data = "<req>")]
+pub(crate) async fn destroy(
+ db: &DbConn,
+ req: Json<DestroyReq>,
+) -> std::result::Result<Json<Empty>, Error> {
+ // MISSING spec says basic auth is allowed, but nothing seems to use it
+ if let Some(t) = req.0.access_token {
+ db.delete_oauth_token(&t.hash()).await?;
+ }
+ if let Some(t) = req.0.refresh_token {
+ db.delete_oauth_token(&t.hash()).await?;
+ }
+ if let Some(t) = req.0.token {
+ db.delete_oauth_token(&t.hash()).await?;
+ }
+ Ok(EMPTY)
+}
+
+#[get("/jwks")]
+pub(crate) async fn jwks() -> Json<Empty> {
+ // HACK we need to return *something* for /jwks, otherwise PyFxA fails.
+ // since syncstorage-rs uses PyFxA to check oauth tokens this is bad.
+ EMPTY
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(deny_unknown_fields)]
+pub(crate) struct VerifyReq {
+ token: OauthToken,
+}
+
+#[derive(Debug, Serialize)]
+pub(crate) struct VerifyResp {
+ user: UserID,
+ client_id: String,
+ scope: Vec<Scope<'static>>,
+ // MISSING generation
+ // MISSING profile_changed_at
+}
+
+#[post("/verify", data = "<req>")]
+pub(crate) async fn verify(db: &DbConn, req: Json<VerifyReq>) -> Result<VerifyResp> {
+ let token = db.get_access_token(&req.token.hash()).await.map_err(map_error)?;
+ Ok(Json(VerifyResp {
+ user: token.user_id,
+ client_id: token.client_id,
+ scope: token.scope.split().map(|s| s.into_owned()).collect::<Vec<_>>(),
+ }))
+}
diff --git a/src/api/profile/mod.rs b/src/api/profile/mod.rs
new file mode 100644
index 0000000..28d1e03
--- /dev/null
+++ b/src/api/profile/mod.rs
@@ -0,0 +1,324 @@
+use std::sync::Arc;
+
+use either::Either;
+use rocket::{
+ data::ToByteUnit,
+ http::{uri::Absolute, ContentType, Status},
+ response::{self, Responder},
+ serde::json::Json,
+ Request, Response, State,
+};
+use serde::{Deserialize, Serialize};
+use serde_json::json;
+use sha2::{Digest, Sha256};
+use Either::{Left, Right};
+
+use crate::{
+ api::Empty,
+ auth::{Authenticated, WithBearer, AuthenticatedRequest},
+ cache::Immutable,
+ db::Db,
+ types::{oauth::Scope, UserID},
+ utils::DeferAction,
+};
+use crate::{db::DbConn, types::AvatarID, Config};
+use crate::{push::PushClient, types::Avatar};
+
+use super::EMPTY;
+
+// we don't provide any additional fields. some we can't provide anyway (eg
+// invalid parameter `validation`), others are implied by the request body (eg
+// account exists `email`), and *our* client doesn't care about them anyway
+#[derive(Debug)]
+pub(crate) enum Error {
+ Unauthorized,
+ InvalidParameter,
+ PayloadTooLarge,
+ NotFound,
+
+ // this is actually a response from the auth api (not the profile api),
+ // but firefox needs the *exact response* of this auth error to refresh
+ // profile fetch oauth tokens for its ui. :(
+ InvalidAuthToken,
+
+ Other(anyhow::Error),
+ UnexpectedStatus(Status),
+}
+
+#[rustfmt::skip]
+impl<'r> Responder<'r, 'static> for Error {
+ fn respond_to(self, request: &'r Request<'_>) -> response::Result<'static> {
+ let (code, errno, msg) = match self {
+ Error::Unauthorized => (Status::Forbidden, 100, "unauthorized"),
+ Error::InvalidParameter => (Status::BadRequest, 101, "invalid parameter in request body"),
+ Error::PayloadTooLarge => (Status::PayloadTooLarge, 999, "payload too large"),
+ Error::NotFound => (Status::NotFound, 999, "not found"),
+
+ Error::InvalidAuthToken => (Status::Unauthorized, 110, "invalid authentication token"),
+
+ Error::Other(e) => {
+ error!("non-api error during request: {:?}", e);
+ (Status::InternalServerError, 999, "internal error")
+ },
+ Error::UnexpectedStatus(s) => (s, 999, ""),
+ };
+ let body = json!({
+ "code": code.code,
+ "errno": errno,
+ "error": code.reason_lossy(),
+ "message": msg
+ });
+ Response::build_from(Json(body).respond_to(request)?).status(code).ok()
+ }
+}
+
+impl From<sqlx::Error> for Error {
+ fn from(e: sqlx::Error) -> Self {
+ Error::Other(anyhow!(e))
+ }
+}
+
+impl From<anyhow::Error> for Error {
+ fn from(e: anyhow::Error) -> Self {
+ Error::Other(e)
+ }
+}
+
+pub(crate) type Result<T> = std::result::Result<Json<T>, Error>;
+
+#[catch(default)]
+pub(crate) fn catch_all(status: Status, r: &Request<'_>) -> Error {
+ match status.code {
+ // these three are caused by Json<T> errors
+ 400 | 422 => Error::InvalidParameter,
+ 413 => Error::PayloadTooLarge,
+ // translate forbidden-because-token to the auth api error for firefox
+ 401 if r.invalid_token_used() => Error::InvalidAuthToken,
+ // generic unauthorized instead of 404 for eg wrong method or nonexistant endpoints
+ 401 | 404 => Error::Unauthorized,
+ _ => {
+ error!("caught unexpected error {status}");
+ Error::UnexpectedStatus(status)
+ },
+ }
+}
+
+// MISSING GET /v1/email
+// MISSING GET /v1/subscriptions
+// MISSING GET /v1/uid
+// MISSING GET /v1/display_name
+// MISSING DELETE /v1/cache/:uid
+// MISSING send profile:change webchannel event an avatar/name changes
+
+#[derive(Debug, Serialize)]
+#[allow(non_snake_case)]
+pub(crate) struct ProfileResp {
+ uid: Option<UserID>,
+ email: Option<String>,
+ locale: Option<String>,
+ amrValues: Option<Vec<String>>,
+ twoFactorAuthentication: bool,
+ displayName: Option<String>,
+ // NOTE spec does not exist, fxa-profile-server schema says this field is optional,
+ // but fenix exceptions if it's null.
+ // NOTE it also *must* be a valid url, or fenix crashes entirely.
+ avatar: Absolute<'static>,
+ avatarDefault: bool,
+ subscriptions: Option<Vec<String>>,
+}
+
+#[get("/profile")]
+pub(crate) async fn profile(
+ db: &DbConn,
+ cfg: &State<Config>,
+ auth: Authenticated<(), WithBearer>,
+) -> Result<ProfileResp> {
+ let has_scope = |s| auth.context.implies(&Scope::borrowed(s));
+
+ let user = db.get_user_by_id(&auth.session).await?;
+ let (avatar, avatar_default) = if has_scope("profile:avatar") {
+ match db.get_user_avatar_id(&auth.session).await? {
+ Some(id) => (uri!(cfg.avatars_prefix(), avatar_get_img(id = id.to_string())), false),
+ None => (
+ uri!(cfg.avatars_prefix(), avatar_get_img("00000000000000000000000000000000")),
+ true,
+ ),
+ }
+ } else {
+ (uri!(cfg.avatars_prefix(), avatar_get_img("00000000000000000000000000000000")), true)
+ };
+ Ok(Json(ProfileResp {
+ uid: if has_scope("profile:uid") { Some(auth.session) } else { None },
+ email: if has_scope("profile:email") { Some(user.email) } else { None },
+ locale: None,
+ amrValues: None,
+ twoFactorAuthentication: false,
+ displayName: if has_scope("profile:display_name") { user.display_name } else { None },
+ avatar,
+ avatarDefault: avatar_default,
+ subscriptions: None,
+ }))
+}
+
+#[derive(Debug, Deserialize)]
+#[allow(non_snake_case)]
+pub(crate) struct DisplayNameReq {
+ displayName: String,
+}
+
+#[post("/display_name", data = "<req>")]
+pub(crate) async fn display_name_post(
+ db: &DbConn,
+ db_pool: &Db,
+ pc: &State<Arc<PushClient>>,
+ defer: &DeferAction,
+ req: Authenticated<DisplayNameReq, WithBearer>,
+) -> Result<Empty> {
+ if !req.context.implies(&Scope::borrowed("profile:display_name:write")) {
+ return Err(Error::Unauthorized);
+ }
+
+ db.set_user_name(&req.session, &req.body.displayName).await?;
+ match db.get_devices(&req.session).await {
+ Ok(devs) => defer.spawn_after_success("api::profile/display_name(post)", {
+ let (pc, db) = (Arc::clone(pc), db_pool.clone());
+ async move {
+ let db = db.begin().await?;
+ pc.profile_updated(&db, &devs).await;
+ db.commit().await?;
+ Ok(())
+ }
+ }),
+ Err(e) => warn!("profile_updated push failed: {e}"),
+ }
+ Ok(EMPTY)
+}
+
+#[derive(Serialize)]
+#[allow(non_snake_case)]
+pub(crate) struct AvatarResp {
+ id: AvatarID,
+ avatarDefault: bool,
+ avatar: Absolute<'static>,
+}
+
+#[get("/avatar")]
+pub(crate) async fn avatar_get(
+ db: &DbConn,
+ cfg: &State<Config>,
+ req: Authenticated<(), WithBearer>,
+) -> Result<AvatarResp> {
+ if !req.context.implies(&Scope::borrowed("profile:avatar")) {
+ return Err(Error::Unauthorized);
+ }
+
+ let resp = match db.get_user_avatar_id(&req.session).await? {
+ Some(id) => {
+ let url = uri!(cfg.avatars_prefix(), avatar_get_img(id = id.to_string()));
+ AvatarResp { id, avatarDefault: false, avatar: url }
+ },
+ None => {
+ let url =
+ uri!(cfg.avatars_prefix(), avatar_get_img("00000000000000000000000000000000"));
+ AvatarResp { id: AvatarID([0; 16]), avatarDefault: true, avatar: url }
+ },
+ };
+ Ok(Json(resp))
+}
+
+#[get("/<id>")]
+pub(crate) async fn avatar_get_img(
+ db: &DbConn,
+ id: &str,
+) -> std::result::Result<(ContentType, Immutable<Either<Vec<u8>, &'static [u8]>>), Error> {
+ let id = id.parse().map_err(|_| Error::NotFound)?;
+
+ if id == AvatarID([0; 16]) {
+ return Ok((
+ ContentType::SVG,
+ Immutable(Right(include_bytes!("../../../Raven-Silhouette.svg"))),
+ ));
+ }
+
+ match db.get_user_avatar(&id).await? {
+ Some(avatar) => {
+ let ct = avatar.content_type.parse().expect("invalid content type in db");
+ Ok((ct, Immutable(Left(avatar.data))))
+ },
+ None => Err(Error::NotFound),
+ }
+}
+
+#[derive(Serialize)]
+#[allow(non_snake_case)]
+pub(crate) struct AvatarUploadResp {
+ url: Absolute<'static>,
+}
+
+#[post("/avatar/upload", data = "<data>")]
+pub(crate) async fn avatar_upload(
+ db: &DbConn,
+ db_pool: &Db,
+ pc: &State<Arc<PushClient>>,
+ defer: &DeferAction,
+ cfg: &State<Config>,
+ ct: &ContentType,
+ req: Authenticated<(), WithBearer>,
+ data: Vec<u8>,
+) -> Result<AvatarUploadResp> {
+ if !req.context.implies(&Scope::borrowed("profile:avatar:write")) {
+ return Err(Error::Unauthorized);
+ }
+ if data.len() >= 128.kibibytes() {
+ return Err(Error::PayloadTooLarge);
+ }
+
+ if !ct.is_png()
+ && !ct.is_gif()
+ && !ct.is_bmp()
+ && !ct.is_jpeg()
+ && !ct.is_webp()
+ && !ct.is_avif()
+ && !ct.is_svg()
+ {
+ return Err(Error::InvalidParameter);
+ }
+
+ let mut sha = Sha256::new();
+ sha.update(&req.session.0);
+ sha.update(&data);
+ let id = AvatarID(sha.finalize()[0..16].try_into().unwrap());
+
+ db.set_user_avatar(&req.session, Avatar { id: id.clone(), data, content_type: ct.to_string() })
+ .await?;
+ match db.get_devices(&req.session).await {
+ Ok(devs) => defer.spawn_after_success("api::profile/avatar/upload(post)", {
+ let (pc, db) = (Arc::clone(pc), db_pool.clone());
+ async move {
+ let db = db.begin().await?;
+ pc.profile_updated(&db, &devs).await;
+ db.commit().await?;
+ Ok(())
+ }
+ }),
+ Err(e) => warn!("profile_updated push failed: {e}"),
+ }
+
+ let url = uri!(cfg.avatars_prefix(), avatar_get_img(id = id.to_string()));
+ Ok(Json(AvatarUploadResp { url }))
+}
+
+#[delete("/avatar/<id>")]
+pub(crate) async fn avatar_delete(
+ db: &DbConn,
+ id: &str,
+ req: Authenticated<(), WithBearer>,
+) -> Result<Empty> {
+ if !req.context.implies(&Scope::borrowed("profile:avatar:write")) {
+ return Err(Error::Unauthorized);
+ }
+ let id = id.parse().map_err(|_| Error::NotFound)?;
+
+ db.delete_user_avatar(&req.session, &id).await?;
+ Ok(EMPTY)
+}
diff --git a/src/auth.rs b/src/auth.rs
new file mode 100644
index 0000000..f56c5e2
--- /dev/null
+++ b/src/auth.rs
@@ -0,0 +1,241 @@
+use std::str::FromStr;
+use std::time::Duration;
+
+use anyhow::Result;
+use hawk::{DigestAlgorithm, Header, Key, PayloadHasher, RequestBuilder};
+use rocket::data::{self, FromData, ToByteUnit};
+use rocket::http::Status;
+use rocket::outcome::{try_outcome, Outcome};
+use rocket::request::{local_cache, FromRequest, Request};
+use rocket::{request, Data, Ignite, Phase, Rocket, Sentinel};
+use serde::Deserialize;
+use serde_json::error::Category;
+
+use crate::crypto::SecretBytes;
+use crate::db::DbConn;
+use crate::types::oauth::ScopeSet;
+use crate::types::{OauthToken, UserID};
+use crate::Config;
+
+#[rocket::async_trait]
+pub(crate) trait AuthSource {
+ type ID: FromStr + Send + Sync + Clone;
+ type Context: Send + Sync;
+ async fn hawk(r: &Request<'_>, id: &Self::ID) -> Result<(SecretBytes<32>, Self::Context)>;
+ async fn bearer_token(r: &Request<'_>, id: &OauthToken) -> Result<(Self::ID, Self::Context)>;
+}
+
+// marker trait and type to communicate that authentication has failed with invalid
+// tokens used. this is needed to properly translate these error for the profile api.
+pub(crate) trait AuthenticatedRequest {
+ fn invalid_token_used(&self) -> bool;
+}
+
+struct InvalidTokenUsed;
+
+impl<'r> AuthenticatedRequest for Request<'r> {
+ fn invalid_token_used(&self) -> bool {
+ self.local_cache(|| None as Option<InvalidTokenUsed>).is_some()
+ }
+}
+
+#[derive(Debug)]
+pub(crate) struct Authenticated<T, Src: AuthSource> {
+ pub body: T,
+ pub session: Src::ID,
+ pub context: Src::Context,
+}
+
+enum AuthKind<'a> {
+ Hawk { header: Header },
+ Token { token: &'a str },
+}
+
+fn drop_auth_prefix<'a>(s: &'a str, prefix: &str) -> Option<&'a str> {
+ if prefix.len() <= s.len() && s[..prefix.len()].eq_ignore_ascii_case(prefix) {
+ Some(&s[prefix.len()..])
+ } else {
+ None
+ }
+}
+
+impl<T, S: AuthSource> Sentinel for Authenticated<T, S> {
+ fn abort(rocket: &Rocket<Ignite>) -> bool {
+ // NOTE data sentinels are broken in rocket 0.5-rc2
+ Self::try_get_state(rocket).is_none() || <&DbConn as Sentinel>::abort(rocket)
+ }
+}
+
+impl<T, Src: AuthSource> Authenticated<T, Src> {
+ fn try_get_state<S: Phase>(r: &Rocket<S>) -> Option<&Config> {
+ r.state::<Config>()
+ }
+
+ fn state<S: Phase>(r: &Rocket<S>) -> &Config {
+ Self::try_get_state(r).unwrap()
+ }
+
+ async fn parse_auth<'a>(
+ request: &'a Request<'_>,
+ ) -> Outcome<AuthKind<'a>, (Status, anyhow::Error), ()> {
+ let auth = match request.headers().get("authorization").take(2).enumerate().last() {
+ Some((0, h)) => h,
+ Some((_, _)) => {
+ return Outcome::Failure((
+ Status::BadRequest,
+ anyhow!("multiple authorization headers present"),
+ ))
+ },
+ None => return Outcome::Forward(()),
+ };
+ if let Some(hawk) = drop_auth_prefix(auth, "hawk ") {
+ match Header::from_str(hawk) {
+ Ok(header) => Outcome::Success(AuthKind::Hawk { header }),
+ Err(e) => Outcome::Failure((
+ Status::Unauthorized,
+ anyhow!(e).context("malformed hawk header"),
+ )),
+ }
+ } else if let Some(token) = drop_auth_prefix(auth, "bearer ") {
+ Outcome::Success(AuthKind::Token { token })
+ } else {
+ Outcome::Forward(())
+ }
+ }
+
+ pub async fn get_conn<'r>(req: &'r Request<'_>) -> Result<&'r DbConn> {
+ match <&'r DbConn as FromRequest<'r>>::from_request(req).await {
+ Outcome::Success(db) => Ok(db),
+ Outcome::Failure((_, e)) => Err(e.context("get db connection")),
+ _ => Err(anyhow!("could not get db connection")),
+ }
+ }
+
+ async fn verify_hawk(
+ request: &Request<'_>,
+ hawk: Header,
+ data: Option<&str>,
+ ) -> Result<(Src::ID, Src::Context)> {
+ let cfg = Self::state(request.rocket());
+ let url = format!("{}{}", cfg.location, request.uri());
+ let url = url::Url::parse(&url).unwrap();
+ let hash = data
+ .map(|d| PayloadHasher::hash("application/json", DigestAlgorithm::Sha256, d))
+ .transpose()?;
+ let hawk_req = RequestBuilder::from_url(request.method().as_str(), &url)?;
+ let hawk_req = match hash.as_ref() {
+ Some(h) => hawk_req.hash(Some(h.as_ref())).request(),
+ _ => hawk_req.request(),
+ };
+ let id: Src::ID =
+ match hawk.id.clone().ok_or_else(|| anyhow!("missing hawk key id"))?.parse() {
+ Ok(id) => id,
+ Err(_) => bail!("malformed hawk key id"),
+ };
+ let (key, context) = Src::hawk(request, &id).await?;
+ let key = Key::new(&key.0, DigestAlgorithm::Sha256)?;
+ // large skew was taken from fxa-auth-server, large clock skews seem to happen
+ if !hawk_req.validate_header(&hawk, &key, Duration::from_secs(20 * 365 * 86400)) {
+ bail!("bad hawk signature");
+ }
+ Ok((id, context))
+ }
+
+ async fn verify_bearer_token(
+ request: &Request<'_>,
+ token: &str,
+ ) -> Result<(Src::ID, Src::Context)> {
+ let token = match token.parse() {
+ Ok(token) => token,
+ Err(_) => bail!("malformed oauth token"),
+ };
+ Src::bearer_token(request, &token).await
+ }
+}
+
+#[rocket::async_trait]
+impl<'r, Src: AuthSource> FromRequest<'r> for Authenticated<(), Src> {
+ type Error = anyhow::Error;
+
+ async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
+ let auth = try_outcome!(Self::parse_auth(request).await);
+ let result = match auth {
+ AuthKind::Hawk { header } => Self::verify_hawk(request, header, None).await,
+ AuthKind::Token { token } => Self::verify_bearer_token(request, token).await,
+ };
+ match result {
+ Ok((session, context)) => {
+ Outcome::Success(Authenticated { body: (), session, context })
+ },
+ Err(e) => {
+ request.local_cache(|| Some(InvalidTokenUsed));
+ Outcome::Failure((Status::Unauthorized, anyhow!(e)))
+ },
+ }
+ }
+}
+
+#[rocket::async_trait]
+impl<'r, T: Deserialize<'r>, Src: AuthSource> FromData<'r> for Authenticated<T, Src> {
+ type Error = anyhow::Error;
+
+ async fn from_data(request: &'r Request<'_>, data: Data<'r>) -> data::Outcome<'r, Self> {
+ let auth = try_outcome_data!(data, Self::parse_auth(request).await);
+ let limit =
+ request.rocket().config().limits.get("json").unwrap_or_else(|| 1u32.mebibytes());
+ let raw_json = match data.open(limit).into_string().await {
+ Ok(r) if r.is_complete() => local_cache!(request, r.into_inner()),
+ Ok(_) => {
+ return data::Outcome::Failure((
+ Status::PayloadTooLarge,
+ anyhow!("request too large"),
+ ))
+ },
+ Err(e) => return data::Outcome::Failure((Status::InternalServerError, e.into())),
+ };
+ let verify_result = match auth {
+ AuthKind::Hawk { header } => Self::verify_hawk(request, header, Some(raw_json)).await,
+ AuthKind::Token { token } => Self::verify_bearer_token(request, token).await,
+ };
+ let result = match verify_result {
+ Ok((session, context)) => {
+ serde_json::from_str(raw_json).map(|body| Authenticated { body, session, context })
+ },
+ Err(e) => {
+ request.local_cache(|| Some(InvalidTokenUsed));
+ return Outcome::Failure((Status::Unauthorized, anyhow!(e)));
+ },
+ };
+ match result {
+ Ok(r) => Outcome::Success(r),
+ Err(e) => {
+ // match Json<T> here to keep catchers generic
+ let status = match e.classify() {
+ Category::Data => Status::UnprocessableEntity,
+ _ => Status::BadRequest,
+ };
+ Outcome::Failure((status, anyhow!(e)))
+ },
+ }
+ }
+}
+
+#[derive(Debug)]
+pub(crate) struct WithBearer;
+
+#[rocket::async_trait]
+impl crate::auth::AuthSource for WithBearer {
+ type ID = UserID;
+ type Context = ScopeSet;
+ async fn hawk(_r: &Request<'_>, _id: &Self::ID) -> Result<(SecretBytes<32>, Self::Context)> {
+ bail!("hawk signatures not allowed here")
+ }
+ async fn bearer_token(
+ r: &Request<'_>,
+ token: &OauthToken,
+ ) -> Result<(Self::ID, Self::Context)> {
+ let db = Authenticated::<(), Self>::get_conn(r).await?;
+ let t = db.get_access_token(&token.hash()).await?;
+ Ok((t.user_id, t.scope))
+ }
+}
diff --git a/src/bin/minorskulk.rs b/src/bin/minorskulk.rs
new file mode 100644
index 0000000..1609edb
--- /dev/null
+++ b/src/bin/minorskulk.rs
@@ -0,0 +1,9 @@
+use minor_skulk::build;
+
+#[rocket::main]
+async fn main() -> anyhow::Result<()> {
+ dotenv::dotenv().ok();
+
+ let _ = build().await?.launch().await?;
+ Ok(())
+}
diff --git a/src/cache.rs b/src/cache.rs
new file mode 100644
index 0000000..680d9da
--- /dev/null
+++ b/src/cache.rs
@@ -0,0 +1,42 @@
+use std::borrow::Cow;
+
+use rocket::{
+ http::Header,
+ request::{self, FromRequest},
+ response::{self, Responder},
+ Request,
+};
+
+pub(crate) struct Etagged<'r, T>(pub T, pub Cow<'r, str>);
+
+impl<'r, 'o: 'r, T: Responder<'r, 'o>> Responder<'r, 'o> for Etagged<'o, T> {
+ fn respond_to(self, r: &'r Request<'_>) -> response::Result<'o> {
+ let mut resp = self.0.respond_to(r)?;
+ resp.set_header(Header::new("etag", self.1));
+ Ok(resp)
+ }
+}
+
+pub(crate) struct Immutable<T>(pub T);
+
+impl<'r, 'o: 'r, T: Responder<'r, 'o>> Responder<'r, 'o> for Immutable<T> {
+ fn respond_to(self, r: &'r Request<'_>) -> response::Result<'o> {
+ let mut resp = self.0.respond_to(r)?;
+ resp.set_header(Header::new("cache-control", "public, max-age=604800, immutable"));
+ Ok(resp)
+ }
+}
+
+pub(crate) struct IfNoneMatch<'r>(pub &'r str);
+
+#[async_trait]
+impl<'r> FromRequest<'r> for IfNoneMatch<'r> {
+ type Error = ();
+
+ async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
+ match req.headers().get_one("if-none-match") {
+ Some(h) => request::Outcome::Success(Self(h)),
+ None => request::Outcome::Forward(()),
+ }
+ }
+}
diff --git a/src/crypto.rs b/src/crypto.rs
new file mode 100644
index 0000000..cf1044e
--- /dev/null
+++ b/src/crypto.rs
@@ -0,0 +1,408 @@
+#![deny(clippy::pedantic)]
+#![deny(clippy::restriction)]
+#![allow(clippy::blanket_clippy_restriction_lints)]
+#![allow(clippy::implicit_return)]
+#![allow(clippy::missing_docs_in_private_items)]
+#![allow(clippy::shadow_reuse)]
+
+use std::fmt::Debug;
+
+use hmac::{Hmac, Mac};
+use password_hash::{Output, Salt};
+use rand::RngCore;
+use scrypt::scrypt;
+use serde::{Deserialize, Serialize};
+use sha2::Sha256;
+
+const NAMESPACE: &[u8] = b"identity.mozilla.com/picl/v1/";
+
+#[derive(Clone, PartialEq, Eq, Zeroize, Serialize, Deserialize)]
+#[serde(try_from = "String", into = "String")]
+pub struct SecretBytes<const N: usize>(pub [u8; N]);
+
+impl<const N: usize> Drop for SecretBytes<N> {
+ fn drop(&mut self) {
+ self.zeroize();
+ }
+}
+
+#[derive(Clone, PartialEq, Eq)]
+pub struct TokenID(pub [u8; 32]);
+
+impl<const N: usize> Debug for SecretBytes<N> {
+ fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
+ fmt.write_fmt(format_args!("SecretBytes {{ raw: {} }}", hex::encode(&self.0)))
+ }
+}
+
+impl<const N: usize> SecretBytes<N> {
+ fn xor(&self, other: &Self) -> Self {
+ let mut result = self.clone();
+ for (a, b) in result.0.iter_mut().zip(other.0.iter()) {
+ *a ^= b;
+ }
+ result
+ }
+}
+
+impl<const N: usize> From<SecretBytes<N>> for String {
+ fn from(sb: SecretBytes<N>) -> Self {
+ hex::encode(&sb.0)
+ }
+}
+
+impl<const N: usize> TryFrom<String> for SecretBytes<N> {
+ type Error = hex::FromHexError;
+
+ fn try_from(value: String) -> Result<Self, Self::Error> {
+ let mut result = Self([0; N]);
+ hex::decode_to_slice(value, &mut result.0)?;
+ Ok(result)
+ }
+}
+
+impl From<SecretBytes<32>> for Output {
+ fn from(s: SecretBytes<32>) -> Output {
+ #[allow(clippy::unwrap_used)]
+ Output::new(&s.0).unwrap()
+ }
+}
+
+mod from_hkdf {
+ use hkdf::Hkdf;
+ use sha2::Sha256;
+
+ // sealing lets us guarantee that SIZE is always correct,
+ // which means that from_hkdf always receives correctly sized slices
+ // and copies never fail
+ mod private {
+ pub trait Seal {}
+ impl<const N: usize> Seal for super::super::SecretBytes<N> {}
+ impl Seal for super::super::TokenID {}
+ impl<L: Seal, R: Seal> Seal for (L, R) {}
+ }
+
+ pub trait FromHkdf: private::Seal {
+ const SIZE: usize;
+ fn from_hkdf(bytes: &[u8]) -> Self;
+ }
+
+ impl<const N: usize> FromHkdf for super::SecretBytes<N> {
+ const SIZE: usize = N;
+ fn from_hkdf(bytes: &[u8]) -> Self {
+ #[allow(clippy::unwrap_used)]
+ Self(bytes.try_into().unwrap())
+ }
+ }
+
+ impl FromHkdf for super::TokenID {
+ const SIZE: usize = 32;
+ fn from_hkdf(bytes: &[u8]) -> Self {
+ #[allow(clippy::expect_used)]
+ Self(bytes.try_into().expect("hkdf failed"))
+ }
+ }
+
+ impl<L: FromHkdf, R: FromHkdf> FromHkdf for (L, R) {
+ const SIZE: usize = L::SIZE + R::SIZE;
+ #[allow(clippy::indexing_slicing)]
+ fn from_hkdf(bytes: &[u8]) -> Self {
+ (L::from_hkdf(&bytes[0..L::SIZE]), R::from_hkdf(&bytes[L::SIZE..]))
+ }
+ }
+
+ pub fn from_hkdf<O: FromHkdf>(key: &[u8], info: &[&[u8]]) -> O {
+ let hk = Hkdf::<Sha256>::new(None, key);
+ let mut buf = vec![0; O::SIZE];
+ #[allow(clippy::expect_used)]
+ // worth keeping an eye out for very large results (>255*32 bytes)
+ hk.expand_multi_info(info, buf.as_mut_slice()).expect("hkdf failed");
+ O::from_hkdf(&buf)
+ }
+}
+
+use from_hkdf::from_hkdf;
+use zeroize::Zeroize;
+
+impl<const N: usize> SecretBytes<N> {
+ pub fn generate() -> Self {
+ let mut result = Self([0; N]);
+ rand::rngs::OsRng.fill_bytes(&mut result.0);
+ result
+ }
+}
+
+#[derive(Debug, Deserialize, Serialize)]
+#[serde(transparent)]
+pub struct AuthPW {
+ pub pw: SecretBytes<32>,
+}
+
+pub struct StretchedPW {
+ pub pw: Output,
+}
+
+impl AuthPW {
+ pub fn stretch(&self, salt: Salt) -> anyhow::Result<StretchedPW> {
+ let mut result = [0; 32];
+ let mut buf = [0; Salt::MAX_LENGTH];
+ let params = scrypt::Params::new(16, 8, 1)?;
+ let salt = salt.b64_decode(&mut buf)?;
+ scrypt(&self.pw.0, salt, &params, &mut result)?;
+ Ok(StretchedPW { pw: Output::new(&result)? })
+ }
+}
+
+impl StretchedPW {
+ pub fn verify_hash(&self) -> Output {
+ let raw: SecretBytes<32> = from_hkdf(self.pw.as_bytes(), &[NAMESPACE, b"verifyHash"]);
+ raw.into()
+ }
+
+ fn wrap_wrap_key(&self) -> SecretBytes<32> {
+ from_hkdf(self.pw.as_bytes(), &[NAMESPACE, b"wrapwrapKey"])
+ }
+
+ pub fn decrypt_wwkb(&self, wwkb: &SecretBytes<32>) -> SecretBytes<32> {
+ wwkb.xor(&self.wrap_wrap_key())
+ }
+
+ pub fn rewrap_wkb(&self, wkb: &SecretBytes<32>) -> SecretBytes<32> {
+ wkb.xor(&self.wrap_wrap_key())
+ }
+}
+
+pub struct SessionCredentials {
+ pub token_id: TokenID,
+ pub req_hmac_key: SecretBytes<32>,
+}
+
+impl SessionCredentials {
+ pub fn derive(seed: &SecretBytes<32>) -> Self {
+ let (token_id, req_hmac_key) = from_hkdf(&seed.0, &[NAMESPACE, b"sessionToken"]);
+ Self { token_id, req_hmac_key }
+ }
+}
+
+pub struct KeyFetchReq {
+ pub token_id: TokenID,
+ pub req_hmac_key: SecretBytes<32>,
+ pub key_request_key: SecretBytes<32>,
+}
+
+impl KeyFetchReq {
+ pub fn from_token(key_fetch_token: &SecretBytes<32>) -> Self {
+ let (token_id, (req_hmac_key, key_request_key)) =
+ from_hkdf(&key_fetch_token.0, &[NAMESPACE, b"keyFetchToken"]);
+ Self { token_id, req_hmac_key, key_request_key }
+ }
+
+ pub fn derive_resp(&self) -> KeyFetchResp {
+ let (resp_hmac_key, resp_xor_key) =
+ from_hkdf(&self.key_request_key.0, &[NAMESPACE, b"account/keys"]);
+ KeyFetchResp { resp_hmac_key, resp_xor_key }
+ }
+}
+
+pub struct KeyFetchResp {
+ pub resp_hmac_key: SecretBytes<32>,
+ pub resp_xor_key: SecretBytes<64>,
+}
+
+impl KeyFetchResp {
+ pub fn wrap_keys(&self, keys: &KeyBundle) -> WrappedKeyBundle {
+ let ciphertext = self.resp_xor_key.xor(&keys.to_bytes());
+ #[allow(clippy::unwrap_used)]
+ let mut hmac = Hmac::<Sha256>::new_from_slice(&self.resp_hmac_key.0).unwrap();
+ hmac.update(&ciphertext.0);
+ let hmac = *hmac.finalize().into_bytes().as_ref();
+ WrappedKeyBundle { ciphertext, hmac }
+ }
+}
+
+pub struct KeyBundle {
+ pub ka: SecretBytes<32>,
+ pub wrap_kb: SecretBytes<32>,
+}
+
+impl KeyBundle {
+ pub fn to_bytes(&self) -> SecretBytes<64> {
+ let mut result = SecretBytes([0; 64]);
+ result.0[0..32].copy_from_slice(&self.ka.0);
+ result.0[32..].copy_from_slice(&self.wrap_kb.0);
+ result
+ }
+}
+
+#[derive(Debug)]
+pub struct WrappedKeyBundle {
+ pub ciphertext: SecretBytes<64>,
+ pub hmac: [u8; 32],
+}
+
+impl WrappedKeyBundle {
+ pub fn to_bytes(&self) -> [u8; 96] {
+ let mut result = [0; 96];
+ result[0..64].copy_from_slice(&self.ciphertext.0);
+ result[64..].copy_from_slice(&self.hmac);
+ result
+ }
+}
+
+pub struct PasswordChangeReq {
+ pub token_id: TokenID,
+ pub req_hmac_key: SecretBytes<32>,
+}
+
+impl PasswordChangeReq {
+ pub fn from_change_token(token: &SecretBytes<32>) -> Self {
+ let (token_id, req_hmac_key) = from_hkdf(&token.0, &[NAMESPACE, b"passwordChangeToken"]);
+ Self { token_id, req_hmac_key }
+ }
+
+ pub fn from_forgot_token(token: &SecretBytes<32>) -> Self {
+ let (token_id, req_hmac_key) = from_hkdf(&token.0, &[NAMESPACE, b"passwordForgotToken"]);
+ Self { token_id, req_hmac_key }
+ }
+}
+
+pub struct AccountResetReq {
+ pub token_id: TokenID,
+ pub req_hmac_key: SecretBytes<32>,
+}
+
+impl AccountResetReq {
+ pub fn from_token(token: &SecretBytes<32>) -> Self {
+ let (token_id, req_hmac_key) = from_hkdf(&token.0, &[NAMESPACE, b"accountResetToken"]);
+ Self { token_id, req_hmac_key }
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use hex_literal::hex;
+ use password_hash::{Output, SaltString};
+
+ use crate::crypto::{KeyBundle, KeyFetchReq, SessionCredentials};
+
+ use super::{AuthPW, SecretBytes};
+
+ macro_rules! shex {
+ ( $s: literal ) => {
+ SecretBytes(hex!($s))
+ };
+ }
+
+ #[test]
+ fn test_derive_session() {
+ let creds = SessionCredentials::derive(&SecretBytes(hex!(
+ "a0a1a2a3a4a5a6a7 a8a9aaabacadaeaf b0b1b2b3b4b5b6b7 b8b9babbbcbdbebf"
+ )));
+ assert_eq!(
+ creds.token_id.0,
+ hex!("c0a29dcf46174973da1378696e4c82ae10f723cf4f4d9f75e39f4ae3851595ab")
+ );
+ assert_eq!(
+ creds.req_hmac_key.0,
+ hex!("9d8f22998ee7f579 8b887042466b72d5 3e56ab0c094388bf 65831f702d2febc0")
+ );
+ }
+
+ #[test]
+ fn test_key_fetch() {
+ let key_fetch = KeyFetchReq::from_token(&shex!(
+ "8081828384858687 88898a8b8c8d8e8f 9091929394959697 98999a9b9c9d9e9f"
+ ));
+ assert_eq!(
+ key_fetch.token_id.0,
+ hex!("3d0a7c02a15a62a2882f76e39b6494b500c022a8816e048625a495718998ba60")
+ );
+ assert_eq!(
+ key_fetch.req_hmac_key.0,
+ hex!("87b8937f61d38d0e 29cd2d5600b3f4da 0aa48ac41de36a0e fe84bb4a9872ceb7")
+ );
+ assert_eq!(
+ key_fetch.key_request_key.0,
+ hex!("14f338a9e8c6324d 9e102d4e6ee83b20 9796d5c74bb734a4 10e729e014a4a546")
+ );
+
+ let resp = key_fetch.derive_resp();
+ assert_eq!(
+ resp.resp_hmac_key.0,
+ hex!("f824d2953aab9faf 51a1cb65ba9e7f9e 5bf91c8d8fd1ac1c 8c2d31853a8a1210")
+ );
+ assert_eq!(
+ resp.resp_xor_key.0,
+ hex!(
+ "ce7d7aa77859b235 9932970bbe2101f2 e80d01faf9191bd5 ee52181d2f0b7809
+ 8281ba8cff392543 3a89f7c3095e0c89 900a469d60790c83 3281c4df1a11c763"
+ )
+ );
+
+ let bundle = KeyBundle {
+ ka: shex!("2021222324252627 28292a2b2c2d2e2f 3031323334353637 38393a3b3c3d3e3f"),
+ wrap_kb: shex!("7effe354abecbcb2 34a8dfc2d7644b4a d339b525589738f2 d27341bb8622ecd8"),
+ };
+ assert_eq!(
+ bundle.to_bytes().0,
+ hex!(
+ "2021222324252627 28292a2b2c2d2e2f 3031323334353637 38393a3b3c3d3e3f
+ 7effe354abecbcb2 34a8dfc2d7644b4a d339b525589738f2 d27341bb8622ecd8"
+ )
+ );
+
+ let wrapped = resp.wrap_keys(&bundle);
+ assert_eq!(
+ wrapped.ciphertext.0,
+ hex!(
+ "ee5c58845c7c9412 b11bbd20920c2fdd d83c33c9cd2c2de2 d66b222613364636
+ fc7e59d854d599f1 0e212801de3a47c3 4333f3b838ee3471 e0f285649c332bbb"
+ )
+ );
+ assert_eq!(
+ wrapped.hmac,
+ hex!("4c17f42a0b319bbb a327d2b326ad23e9 37219b4de32e3ec7 b3e3f740522ad6ef")
+ );
+ assert_eq!(
+ wrapped.to_bytes(),
+ hex!(
+ "ee5c58845c7c9412 b11bbd20920c2fdd d83c33c9cd2c2de2 d66b222613364636
+ fc7e59d854d599f1 0e212801de3a47c3 4333f3b838ee3471 e0f285649c332bbb
+ 4c17f42a0b319bbb a327d2b326ad23e9 37219b4de32e3ec7 b3e3f740522ad6ef"
+ )
+ );
+ }
+
+ #[test]
+ fn test_stretch() -> anyhow::Result<()> {
+ let auth_pw = AuthPW {
+ pw: SecretBytes(hex!(
+ "247b675ffb4c4631 0bc87e26d712153a be5e1c90ef00a478 4594f97ef54f2375"
+ )),
+ };
+
+ let stretched = auth_pw.stretch(
+ SaltString::b64_encode(&hex!(
+ "00f0000000000000 0000000000000000 0000000000000000 0000000000000000"
+ ))?
+ .as_salt(),
+ )?;
+ assert_eq!(
+ stretched.pw,
+ Output::new(&hex!(
+ "441509e25c92ee10 3d5a1a874e6f155d f25a44d06e61c894 616c9e85181dba97"
+ ))?
+ );
+
+ assert_eq!(
+ stretched.verify_hash().as_bytes(),
+ hex!("a4765bf103dc057f 4cf4bc2c131ddb67 16e8a4333cc55e1d 3c449f31f0eec4f1")
+ );
+
+ assert_eq!(
+ stretched.wrap_wrap_key().0,
+ hex!("3ebea117efa9faf5 7ce195899b290505 8368e7760cc26ea5 8a2a1be0da7fb287")
+ );
+ Ok(())
+ }
+}
diff --git a/src/db/mod.rs b/src/db/mod.rs
new file mode 100644
index 0000000..040507f
--- /dev/null
+++ b/src/db/mod.rs
@@ -0,0 +1,1026 @@
+use std::{error::Error, mem::replace, sync::Arc};
+
+use anyhow::Result;
+use chrono::{DateTime, Duration, Utc};
+use password_hash::SaltString;
+use rocket::{
+ fairing::{self, Fairing},
+ futures::lock::{MappedMutexGuard, Mutex, MutexGuard},
+ http::Status,
+ request::{FromRequest, Outcome},
+ Request, Response, Sentinel,
+};
+use serde_json::Value;
+use sqlx::{query, query_as, query_scalar, PgPool, Postgres, Transaction};
+
+use crate::{
+ crypto::WrappedKeyBundle,
+ types::{
+ oauth::ScopeSet, AccountResetID, AttachedClient, Avatar, AvatarID, Device, DeviceCommand,
+ DeviceCommands, DeviceID, DevicePush, DeviceUpdate, HawkKey, KeyFetchID, OauthAccessToken,
+ OauthAccessType, OauthAuthorization, OauthAuthorizationID, OauthRefreshToken, OauthTokenID,
+ PasswordChangeID, SecretKey, SessionID, User, UserID, UserSession, VerifyCode, VerifyHash,
+ },
+};
+
+// we implement a completely custom db type set instead of using rocket_db_pools for two reasons:
+// 1. rocket_db_pools doesn't support sqlx 0.6
+// 2. we want one transaction per *request*, including all guards
+
+#[derive(Clone)]
+pub struct Db {
+ db: Arc<PgPool>,
+}
+
+impl Db {
+ pub async fn connect(db: &str) -> Result<Self> {
+ Ok(Self { db: Arc::new(PgPool::connect(db).await?) })
+ }
+
+ pub async fn begin(&self) -> Result<DbConn> {
+ Ok(DbConn(Mutex::new(ConnState::Capable(self.clone()))))
+ }
+
+ pub async fn migrate(&self) -> Result<()> {
+ sqlx::migrate!().run(&*self.db).await?;
+ Ok(())
+ }
+}
+
+struct ActiveConn {
+ tx: Transaction<'static, Postgres>,
+ always_commit: bool,
+}
+
+#[allow(clippy::large_enum_variant)]
+enum ConnState {
+ None,
+ Capable(Db),
+ Active(ActiveConn),
+ Done,
+}
+
+pub struct DbConn(Mutex<ConnState>);
+
+struct DbWrap(Db);
+struct DbConnWrap(DbConn);
+type DbLock<'c> = MappedMutexGuard<'c, ConnState, ActiveConn>;
+
+#[rocket::async_trait]
+impl Fairing for Db {
+ fn info(&self) -> fairing::Info {
+ fairing::Info {
+ name: "db access",
+ kind: fairing::Kind::Ignite | fairing::Kind::Response | fairing::Kind::Singleton,
+ }
+ }
+
+ async fn on_ignite(&self, rocket: rocket::Rocket<rocket::Build>) -> fairing::Result {
+ Ok(rocket.manage(DbWrap(self.clone())).manage(self.clone()))
+ }
+
+ async fn on_response<'r>(&self, req: &'r Request<'_>, resp: &mut Response<'r>) {
+ let conn = req.local_cache(|| DbConnWrap(DbConn(Mutex::new(ConnState::None))));
+ let s = replace(&mut *conn.0 .0.lock().await, ConnState::Done);
+ if let ConnState::Active(ActiveConn { tx, always_commit }) = s {
+ // don't commit if the request failed, unless explicitly asked to.
+ // this is used by key fetch to invalidate tokens even if the header
+ // signature of the request is incorrect.
+ if !always_commit && resp.status() != Status::Ok {
+ return;
+ }
+ if let Err(e) = tx.commit().await {
+ resp.set_status(Status::InternalServerError);
+ error!("commit failed: {}", e);
+ }
+ }
+ }
+}
+
+impl DbConn {
+ pub async fn commit(self) -> Result<(), sqlx::Error> {
+ match self.0.into_inner() {
+ ConnState::None => Ok(()),
+ ConnState::Capable(_) => Ok(()),
+ ConnState::Active(ac) => ac.tx.commit().await,
+ ConnState::Done => Ok(()),
+ }
+ }
+
+ fn register<'r>(req: &'r Request<'_>) -> &'r DbConn {
+ &req.local_cache(|| {
+ let db = req.rocket().state::<DbWrap>().unwrap().0.clone();
+ DbConnWrap(DbConn(Mutex::new(ConnState::Capable(db))))
+ })
+ .0
+ }
+
+ // defer opening a transaction and thus locking the mutex holding it.
+ // without deferral the order of arguments in route signatures is important,
+ // which may be surprising: placing a DbConn before a guard that also uses
+ // the database causes a concurrent transaction error.
+ // HACK maybe we should return errors instead of panicking, but there's no
+ // way an error here is not a severe bug
+ async fn get(&self) -> sqlx::Result<DbLock<'_>> {
+ let mut m = match self.0.try_lock() {
+ Some(m) => m,
+ None => panic!("attempted to open concurrent transactions"),
+ };
+ match &*m {
+ ConnState::Capable(db) => {
+ *m = ConnState::Active(ActiveConn {
+ tx: db.db.begin().await?,
+ always_commit: false,
+ });
+ },
+ ConnState::None | ConnState::Done => panic!("db connection requested after teardown"),
+ _ => (),
+ }
+ Ok(MutexGuard::map(m, |g| match g {
+ ConnState::Active(ref mut tx) => tx,
+ _ => unreachable!(),
+ }))
+ }
+}
+
+impl<'r> Sentinel for &'r Db {
+ fn abort(rocket: &rocket::Rocket<rocket::Ignite>) -> bool {
+ rocket.state::<DbWrap>().is_none()
+ }
+}
+
+#[rocket::async_trait]
+impl<'r> FromRequest<'r> for &'r Db {
+ type Error = anyhow::Error;
+
+ async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
+ Outcome::Success(&req.rocket().state::<DbWrap>().unwrap().0)
+ }
+}
+
+impl<'r> Sentinel for &'r DbConn {
+ fn abort(rocket: &rocket::Rocket<rocket::Ignite>) -> bool {
+ rocket.state::<DbWrap>().is_none()
+ }
+}
+
+#[rocket::async_trait]
+impl<'r> FromRequest<'r> for &'r DbConn {
+ type Error = anyhow::Error;
+
+ async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
+ Outcome::Success(DbConn::register(req))
+ }
+}
+
+//
+//
+//
+
+//
+//
+//
+
+impl DbConn {
+ pub(crate) async fn always_commit(&self) -> Result<()> {
+ self.get().await?.always_commit = true;
+ Ok(())
+ }
+
+ //
+ //
+ //
+
+ pub(crate) async fn add_session(
+ &self,
+ id: SessionID,
+ user_id: &UserID,
+ key: HawkKey,
+ verified: bool,
+ verify_code: Option<&str>,
+ ) -> sqlx::Result<DateTime<Utc>> {
+ query_scalar!(
+ r#"insert into user_session (session_id, user_id, req_hmac_key, device_id, verified,
+ verify_code)
+ values ($1, $2, $3, null, $4, $5)
+ returning created_at"#,
+ id as _,
+ user_id as _,
+ key as _,
+ verified,
+ verify_code,
+ )
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn use_session(&self, id: &SessionID) -> sqlx::Result<UserSession> {
+ query_as!(
+ UserSession,
+ r#"update user_session
+ set last_active = now()
+ where session_id = $1
+ returning user_id as "uid: UserID", req_hmac_key as "req_hmac_key: HawkKey",
+ device_id as "device_id: DeviceID", created_at, verified, verify_code"#,
+ id as _
+ )
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn use_session_from_refresh(
+ &self,
+ id: &OauthTokenID,
+ ) -> sqlx::Result<(SessionID, UserSession)> {
+ query!(
+ r#"update user_session
+ set last_active = now()
+ where session_id = (
+ select session_id from oauth_token where kind = 'refresh' and id = $1
+ )
+ returning user_id as "uid: UserID", req_hmac_key as "req_hmac_key: HawkKey",
+ device_id as "device_id: DeviceID", session_id as "session_id: SessionID",
+ created_at, verified, verify_code"#,
+ id as _
+ )
+ .map(|r| {
+ (
+ r.session_id.clone(),
+ UserSession {
+ uid: r.uid,
+ req_hmac_key: r.req_hmac_key,
+ device_id: r.device_id,
+ created_at: r.created_at,
+ verified: r.verified,
+ verify_code: r.verify_code,
+ },
+ )
+ })
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn delete_session(&self, user: &UserID, id: &SessionID) -> sqlx::Result<()> {
+ query_scalar!(
+ r#"delete from user_session
+ where user_id = $1 and session_id = $2
+ returning 1"#,
+ user as _,
+ id as _
+ )
+ .fetch_one(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn set_session_device<'d>(
+ &self,
+ id: &SessionID,
+ dev: Option<&'d DeviceID>,
+ ) -> sqlx::Result<()> {
+ query!(
+ r#"update user_session set device_id = $1 where session_id = $2"#,
+ dev as _,
+ id as _
+ )
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn set_session_verified(&self, id: &SessionID) -> sqlx::Result<()> {
+ query_scalar!(
+ r#"update user_session
+ set verified = true, verify_code = null
+ where session_id = $1
+ returning 1"#,
+ id as _
+ )
+ .fetch_one(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ //
+ //
+ //
+
+ pub(crate) async fn enqueue_command(
+ &self,
+ target: &DeviceID,
+ sender: &Option<DeviceID>,
+ command: &str,
+ payload: &Value,
+ ttl: u32,
+ ) -> sqlx::Result<i64> {
+ let expires = Utc::now() + Duration::seconds(ttl as i64);
+ query!(
+ r#"insert into device_commands (device_id, command, payload, expires, sender)
+ values ($1, $2, $3, $4, $5)
+ returning index"#,
+ target as _,
+ command,
+ payload,
+ expires,
+ sender.as_ref().map(ToString::to_string)
+ )
+ .map(|x| x.index)
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn get_commands(
+ &self,
+ user: &UserID,
+ dev: &DeviceID,
+ min_index: i64,
+ limit: i64,
+ ) -> sqlx::Result<(bool, Vec<DeviceCommand>)> {
+ // NOTE while fxa api docs state that command queries return only commands enqueued
+ // *after* index, what pushbox actually does is return commands *starting at* index!
+ let mut results = query_as!(
+ DeviceCommand,
+ r#"select index, command, payload, expires, sender
+ from device_commands join device using (device_id)
+ where index >= $1 and device_id = $2 and user_id = $3
+ order by index
+ limit $4"#,
+ min_index,
+ dev as _,
+ user as _,
+ limit + 1
+ )
+ .fetch_all(&mut self.get().await?.tx)
+ .await?;
+ let more = results.len() > limit as usize;
+ results.truncate(limit as usize);
+ Ok((more, results))
+ }
+
+ //
+ //
+ //
+
+ pub(crate) async fn get_devices(&self, user: &UserID) -> sqlx::Result<Vec<Device>> {
+ query_as!(
+ Device,
+ r#"select d.device_id as "device_id: DeviceID", d.name, d.type as type_,
+ d.push as "push: DevicePush",
+ d.available_commands as "available_commands: DeviceCommands",
+ d.push_expired, d.location, coalesce(us.last_active, to_timestamp(0)) as "last_active!"
+ from device d left join user_session us using (device_id)
+ where d.user_id = $1"#,
+ user as _
+ )
+ .fetch_all(&mut self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn get_device(&self, user: &UserID, dev: &DeviceID) -> sqlx::Result<Device> {
+ query_as!(
+ Device,
+ r#"select d.device_id as "device_id: DeviceID", d.name, d.type as type_,
+ d.push as "push: DevicePush",
+ d.available_commands as "available_commands: DeviceCommands",
+ d.push_expired, d.location, coalesce(us.last_active, to_timestamp(0)) as "last_active!"
+ from device d left join user_session us using (device_id)
+ where d.user_id = $1 and d.device_id = $2"#,
+ user as _,
+ dev as _
+ )
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn change_device<'d>(
+ &self,
+ user: &UserID,
+ id: &DeviceID,
+ dev: DeviceUpdate<'d>,
+ ) -> sqlx::Result<Device> {
+ query_as!(
+ Device,
+ r#"select device_id as "device_id!: DeviceID", name as "name!", type as "type_!",
+ push as "push: DevicePush",
+ available_commands as "available_commands!: DeviceCommands",
+ push_expired as "push_expired!", location as "location!",
+ coalesce(last_active, to_timestamp(0)) as "last_active!"
+ from insert_or_update_device($1, $2, $3, $4, $5, $6, $7) as iud
+ left join user_session using (device_id)"#,
+ id as _,
+ user as _,
+ // these two are not optional but are Option anyway. the db will
+ // refuse insertions that don't have them set.
+ dev.name,
+ dev.type_,
+ dev.push as _,
+ dev.available_commands as _,
+ dev.location
+ )
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn set_push_expired(&self, dev: &DeviceID) -> sqlx::Result<()> {
+ query!(
+ r#"update device
+ set push_expired = true
+ where device_id = $1"#,
+ dev as _
+ )
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn delete_device(&self, user: &UserID, dev: &DeviceID) -> sqlx::Result<()> {
+ query_scalar!(
+ r#"delete from device where user_id = $1 and device_id = $2 returning 1"#,
+ user as _,
+ dev as _
+ )
+ .fetch_one(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ //
+ //
+ //
+
+ pub(crate) async fn add_key_fetch(
+ &self,
+ id: KeyFetchID,
+ hmac_key: &HawkKey,
+ keys: &WrappedKeyBundle,
+ ) -> sqlx::Result<()> {
+ query!(
+ r#"insert into key_fetch (id, hmac_key, keys) values ($1, $2, $3)"#,
+ id as _,
+ hmac_key as _,
+ &keys.to_bytes()[..]
+ )
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn finish_key_fetch(
+ &self,
+ id: &KeyFetchID,
+ ) -> sqlx::Result<(HawkKey, Vec<u8>)> {
+ query!(
+ r#"delete from key_fetch
+ where id = $1 and expires_at > now()
+ returning hmac_key as "hmac_key: HawkKey", keys"#,
+ id as _
+ )
+ .map(|r| (r.hmac_key, r.keys))
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ //
+ //
+ //
+
+ pub(crate) async fn get_refresh_token(
+ &self,
+ id: &OauthTokenID,
+ ) -> sqlx::Result<OauthRefreshToken> {
+ query_as!(
+ OauthRefreshToken,
+ r#"select user_id as "user_id: UserID", client_id, scope as "scope: ScopeSet",
+ session_id as "session_id: SessionID"
+ from oauth_token
+ where id = $1 and kind = 'refresh'"#,
+ id as _
+ )
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn get_access_token(
+ &self,
+ id: &OauthTokenID,
+ ) -> sqlx::Result<OauthAccessToken> {
+ query_as!(
+ OauthAccessToken,
+ r#"select user_id as "user_id: UserID", client_id, scope as "scope: ScopeSet",
+ parent_refresh as "parent_refresh: OauthTokenID",
+ parent_session as "parent_session: SessionID",
+ expires_at as "expires_at!"
+ from oauth_token
+ where id = $1 and kind = 'access' and expires_at > now()"#,
+ id as _
+ )
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn add_refresh_token(
+ &self,
+ id: &OauthTokenID,
+ token: OauthRefreshToken,
+ ) -> sqlx::Result<()> {
+ query!(
+ r#"insert into oauth_token (id, kind, user_id, client_id, scope, session_id)
+ values ($1, 'refresh', $2, $3, $4, $5)"#,
+ id as _,
+ token.user_id as _,
+ token.client_id,
+ token.scope as _,
+ token.session_id as _
+ )
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn add_access_token(
+ &self,
+ id: &OauthTokenID,
+ token: OauthAccessToken,
+ ) -> sqlx::Result<()> {
+ query!(
+ r#"insert into oauth_token (id, kind, user_id, client_id, scope, session_id,
+ parent_refresh, parent_session, expires_at)
+ values ($1, 'access', $2, $3, $4, null, $5, $6, $7)"#,
+ id as _,
+ token.user_id as _,
+ token.client_id,
+ token.scope as _,
+ token.parent_refresh as _,
+ token.parent_session as _,
+ token.expires_at,
+ )
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn delete_oauth_token(&self, id: &OauthTokenID) -> sqlx::Result<()> {
+ query!(r#"delete from oauth_token where id = $1"#, id as _)
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn delete_refresh_token(&self, id: &OauthTokenID) -> sqlx::Result<()> {
+ query!(r#"delete from oauth_token where id = $1 and kind = 'refresh'"#, id as _)
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ //
+ //
+ //
+
+ pub(crate) async fn add_oauth_authorization(
+ &self,
+ id: &OauthAuthorizationID,
+ auth: OauthAuthorization,
+ ) -> sqlx::Result<()> {
+ query!(
+ r#"insert into oauth_authorization (id, user_id, client_id, scope, access_type,
+ code_challenge, keys_jwe, auth_at)
+ values ($1, $2, $3, $4, $5, $6, $7, $8)"#,
+ id as _,
+ auth.user_id as _,
+ auth.client_id,
+ auth.scope as _,
+ auth.access_type as _,
+ auth.code_challenge,
+ auth.keys_jwe,
+ auth.auth_at,
+ )
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn take_oauth_authorization(
+ &self,
+ id: &OauthAuthorizationID,
+ ) -> sqlx::Result<OauthAuthorization> {
+ query_as!(
+ OauthAuthorization,
+ r#"delete from oauth_authorization
+ where id = $1 and expires_at > now()
+ returning user_id as "user_id: UserID", client_id, scope as "scope: ScopeSet",
+ access_type as "access_type: OauthAccessType",
+ code_challenge, keys_jwe, auth_at"#,
+ id as _
+ )
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ //
+ //
+ //
+
+ pub(crate) async fn user_email_exists(&self, email: &str) -> sqlx::Result<bool> {
+ Ok(query_scalar!(r#"select 1 from users where email = lower($1)"#, email)
+ .fetch_optional(&mut self.get().await?.tx)
+ .await?
+ .is_some())
+ }
+
+ pub(crate) async fn add_user(&self, user: User) -> sqlx::Result<UserID> {
+ let id = UserID::random();
+ query_scalar!(
+ r#"insert into users (user_id, auth_salt, email, ka, wrapwrap_kb, verify_hash,
+ display_name)
+ values ($1, $2, $3, $4, $5, $6, $7)"#,
+ id as _,
+ user.auth_salt.as_str(),
+ user.email,
+ user.ka as _,
+ user.wrapwrap_kb as _,
+ user.verify_hash as _,
+ user.display_name,
+ )
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(id)
+ }
+
+ pub(crate) async fn get_user(&self, email: &str) -> sqlx::Result<(UserID, User)> {
+ query!(
+ r#"select user_id as "id: UserID", auth_salt as "auth_salt: String", email,
+ ka as "ka: SecretKey", wrapwrap_kb as "wrapwrap_kb: SecretKey",
+ verify_hash as "verify_hash: VerifyHash", display_name, verified
+ from users
+ where email = lower($1)"#,
+ email
+ )
+ .try_map(|r| {
+ Ok((
+ r.id,
+ User {
+ auth_salt: SaltString::new(&r.auth_salt).map_err(decode_err("auth_salt"))?,
+ email: r.email,
+ ka: r.ka,
+ wrapwrap_kb: r.wrapwrap_kb,
+ verify_hash: r.verify_hash,
+ display_name: r.display_name,
+ verified: r.verified,
+ },
+ ))
+ })
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn get_user_by_id(&self, id: &UserID) -> sqlx::Result<User> {
+ query!(
+ r#"select auth_salt as "auth_salt: String", email,
+ ka as "ka: SecretKey", wrapwrap_kb as "wrapwrap_kb: SecretKey",
+ verify_hash as "verify_hash: VerifyHash", display_name, verified
+ from users
+ where user_id = $1"#,
+ id as _
+ )
+ .try_map(|r| {
+ Ok(User {
+ auth_salt: SaltString::new(&r.auth_salt).map_err(decode_err("auth_salt"))?,
+ email: r.email,
+ ka: r.ka,
+ wrapwrap_kb: r.wrapwrap_kb,
+ verify_hash: r.verify_hash,
+ display_name: r.display_name,
+ verified: r.verified,
+ })
+ })
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn set_user_name(&self, id: &UserID, name: &str) -> sqlx::Result<()> {
+ query!(
+ "update users
+ set display_name = $2
+ where user_id = $1",
+ id as _,
+ name,
+ )
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn set_user_verified(&self, id: &UserID) -> sqlx::Result<()> {
+ query_scalar!("update users set verified = true where user_id = $1 returning 1", id as _)
+ .fetch_one(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn delete_user(&self, email: &str) -> sqlx::Result<()> {
+ query_scalar!(r#"delete from users where email = lower($1)"#, email)
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn change_user_auth(
+ &self,
+ uid: &UserID,
+ salt: SaltString,
+ wwkb: SecretKey,
+ verify_hash: VerifyHash,
+ ) -> sqlx::Result<()> {
+ query!(
+ r#"update users
+ set auth_salt = $2, wrapwrap_kb = $3, verify_hash = $4
+ where user_id = $1"#,
+ uid as _,
+ salt.to_string(),
+ wwkb as _,
+ verify_hash as _,
+ )
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn reset_user_auth(
+ &self,
+ uid: &UserID,
+ salt: SaltString,
+ wwkb: SecretKey,
+ verify_hash: VerifyHash,
+ ) -> sqlx::Result<()> {
+ query!(
+ r#"call reset_user_auth($1, $2, $3, $4)"#,
+ uid as _,
+ salt.to_string(),
+ wwkb as _,
+ verify_hash as _,
+ )
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ //
+ //
+ //
+
+ pub(crate) async fn get_attached_clients(
+ &self,
+ id: &UserID,
+ ) -> sqlx::Result<Vec<AttachedClient>> {
+ query_as!(
+ AttachedClient,
+ r#"select
+ ot.client_id as "client_id?",
+ d.device_id as "device_id?: DeviceID",
+ us.session_id as "session_token_id?: SessionID",
+ ot.id as "refresh_token_id?: OauthTokenID",
+ d.type as "device_type?",
+ d.name as "name?",
+ coalesce(d.created_at, us.created_at, ot.created_at) as "created_time?",
+ us.last_active as "last_access_time?",
+ ot.scope as "scope?"
+ from device d
+ full outer join user_session us on (d.device_id = us.device_id)
+ full outer join oauth_token ot on (us.session_id = ot.session_id)
+ where
+ (ot.kind is null or ot.kind = 'refresh')
+ and $1 in (d.user_id, us.user_id, ot.user_id)
+ order by d.device_id"#,
+ id as _,
+ )
+ .fetch_all(&mut self.get().await?.tx)
+ .await
+ }
+
+ //
+ //
+ //
+
+ pub(crate) async fn get_user_avatar_id(&self, id: &UserID) -> sqlx::Result<Option<AvatarID>> {
+ query!(r#"select id as "id: AvatarID" from user_avatars where user_id = $1"#, id as _,)
+ .map(|r| r.id)
+ .fetch_optional(&mut *self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn get_user_avatar(&self, id: &AvatarID) -> sqlx::Result<Option<Avatar>> {
+ query_as!(
+ Avatar,
+ r#"select id as "id: AvatarID", data, content_type
+ from user_avatars
+ where id = $1"#,
+ id as _,
+ )
+ .fetch_optional(&mut *self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn set_user_avatar(&self, id: &UserID, avatar: Avatar) -> sqlx::Result<()> {
+ query!(
+ r#"insert into user_avatars (user_id, id, data, content_type)
+ values ($1, $2, $3, $4)
+ on conflict (user_id) do update set
+ id = $2, data = $3, content_type = $4"#,
+ id as _,
+ avatar.id as _,
+ avatar.data,
+ avatar.content_type,
+ )
+ .execute(&mut *self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn delete_user_avatar(
+ &self,
+ user: &UserID,
+ id: &AvatarID,
+ ) -> sqlx::Result<()> {
+ query!(r#"delete from user_avatars where user_id = $1 and id = $2"#, user as _, id as _,)
+ .execute(&mut *self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ //
+ //
+ //
+
+ pub(crate) async fn add_verify_code(
+ &self,
+ user: &UserID,
+ session: &SessionID,
+ code: &str,
+ ) -> sqlx::Result<()> {
+ query!(
+ r#"insert into verify_codes (user_id, session_id, code)
+ values ($1, $2, $3)"#,
+ user as _,
+ session as _,
+ code,
+ )
+ .execute(&mut *self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn get_verify_code(
+ &self,
+ user: &UserID,
+ ) -> sqlx::Result<(String, VerifyCode)> {
+ query!(
+ r#"select user_id as "user_id: UserID", session_id as "session_id: SessionID", code,
+ email
+ from verify_codes join users using (user_id)
+ where user_id = $1"#,
+ user as _,
+ )
+ .map(|r| {
+ (r.email, VerifyCode { user_id: r.user_id, session_id: r.session_id, code: r.code })
+ })
+ .fetch_one(&mut *self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn try_use_verify_code(
+ &self,
+ user: &UserID,
+ code: &str,
+ ) -> sqlx::Result<Option<VerifyCode>> {
+ query_as!(
+ VerifyCode,
+ r#"delete from verify_codes
+ where user_id = $1 and code = $2
+ returning user_id as "user_id: UserID", session_id as "session_id: SessionID",
+ code"#,
+ user as _,
+ code,
+ )
+ .fetch_optional(&mut *self.get().await?.tx)
+ .await
+ }
+
+ //
+ //
+ //
+
+ pub(crate) async fn add_password_change(
+ &self,
+ user: &UserID,
+ id: &PasswordChangeID,
+ key: &HawkKey,
+ forgot_code: Option<&str>,
+ ) -> sqlx::Result<()> {
+ query!(
+ r#"insert into password_change_tokens (id, user_id, hmac_key, forgot_code)
+ values ($1, $2, $3, $4)
+ on conflict (user_id) do update set id = $1, hmac_key = $3, forgot_code = $4,
+ expires_at = default"#,
+ id as _,
+ user as _,
+ key as _,
+ forgot_code,
+ )
+ .execute(&mut *self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn finish_password_change(
+ &self,
+ id: &PasswordChangeID,
+ is_forgot: bool,
+ ) -> sqlx::Result<(HawkKey, (UserID, Option<String>))> {
+ query!(
+ r#"delete from password_change_tokens
+ where id = $1 and expires_at > now() and (forgot_code is not null) = $2
+ returning hmac_key as "hmac_key: HawkKey", user_id as "user_id: UserID",
+ forgot_code"#,
+ id as _,
+ is_forgot,
+ )
+ .map(|r| (r.hmac_key, (r.user_id, r.forgot_code)))
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn add_account_reset(
+ &self,
+ user: &UserID,
+ id: &AccountResetID,
+ key: &HawkKey,
+ ) -> sqlx::Result<()> {
+ query!(
+ r#"insert into account_reset_tokens (id, user_id, hmac_key)
+ values ($1, $2, $3)
+ on conflict (user_id) do update set id = $1, hmac_key = $3, expires_at = default"#,
+ id as _,
+ user as _,
+ key as _,
+ )
+ .execute(&mut *self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn finish_account_reset(
+ &self,
+ id: &AccountResetID,
+ ) -> sqlx::Result<(HawkKey, UserID)> {
+ query!(
+ r#"delete from account_reset_tokens
+ where id = $1 and expires_at > now()
+ returning hmac_key as "hmac_key: HawkKey", user_id as "user_id: UserID""#,
+ id as _,
+ )
+ .map(|r| (r.hmac_key, r.user_id))
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ //
+ //
+ //
+
+ pub async fn add_invite_code(&self, code: &str, expires: DateTime<Utc>) -> sqlx::Result<()> {
+ query!(r#"insert into invite_codes (code, expires_at) values ($1, $2)"#, code, expires,)
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn use_invite_code(&self, code: &str) -> sqlx::Result<()> {
+ query_scalar!(
+ r#"delete from invite_codes where code = $1 and expires_at > now() returning 1"#,
+ code,
+ )
+ .fetch_one(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ //
+ //
+ //
+
+ pub(crate) async fn prune_expired_tokens(&self) -> sqlx::Result<()> {
+ query!("call prune_expired_tokens()").execute(&mut self.get().await?.tx).await?;
+ Ok(())
+ }
+
+ pub(crate) async fn prune_expired_verify_codes(&self) -> sqlx::Result<()> {
+ query!("call prune_expired_verify_codes()").execute(&mut self.get().await?.tx).await?;
+ Ok(())
+ }
+}
+
+fn decode_err<E: Error + Send + Sync + 'static>(c: &str) -> impl FnOnce(E) -> sqlx::Error {
+ let index = c.to_string();
+ move |e| sqlx::Error::ColumnDecode { index, source: Box::new(e) }
+}
diff --git a/src/js.rs b/src/js.rs
new file mode 100644
index 0000000..f3d662c
--- /dev/null
+++ b/src/js.rs
@@ -0,0 +1,53 @@
+use std::{collections::HashMap, path::PathBuf};
+
+use rocket::http::{ContentType, Status};
+use sha2::{Digest, Sha256};
+
+use crate::cache::{Etagged, IfNoneMatch};
+
+struct Entry {
+ data: &'static str,
+ hash: String,
+}
+
+fn enter(data: &'static str) -> Entry {
+ let mut sha = Sha256::new();
+ sha.update(data.as_bytes());
+ let hash = base64::encode(sha.finalize());
+ Entry { data, hash }
+}
+
+lazy_static! {
+ static ref JS: HashMap<&'static str, Entry> = {
+ let mut m = HashMap::new();
+ m.insert("main", enter(include_str!("../web/js//main.js")));
+ m.insert("crypto", enter(include_str!("../web/js//crypto.js")));
+ m.insert("auth-client/browser", enter(include_str!("../web/js//browser/browser.js")));
+ m.insert("auth-client/lib/client", enter(include_str!("../web/js//browser/lib/client.js")));
+ m.insert("auth-client/lib/crypto", enter(include_str!("../web/js//browser/lib/crypto.js")));
+ m.insert("auth-client/lib/hawk", enter(include_str!("../web/js//browser/lib/hawk.js")));
+ m.insert(
+ "auth-client/lib/recoveryKey",
+ enter(include_str!("../web/js//browser/lib/recoveryKey.js")),
+ );
+ m.insert("auth-client/lib/utils", enter(include_str!("../web/js//browser/lib/utils.js")));
+ m
+ };
+}
+
+#[get("/<name..>")]
+pub(crate) async fn static_js(
+ name: PathBuf,
+ inm: Option<IfNoneMatch<'_>>,
+) -> (Status, Result<(ContentType, Etagged<'_, &'static str>), ()>) {
+ let entry = JS.get(name.to_string_lossy().as_ref());
+ match entry {
+ Some(e) => match inm {
+ Some(h) if h.0 == e.hash => (Status::NotModified, Err(())),
+ _ => {
+ (Status::Ok, Ok((ContentType::JavaScript, Etagged(e.data, e.hash.as_str().into()))))
+ },
+ },
+ _ => (Status::NotFound, Err(())),
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 0000000..1e6fa31
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,319 @@
+use std::{
+ path::PathBuf,
+ sync::Arc,
+ time::{Duration as StdDuration, SystemTime, UNIX_EPOCH},
+};
+
+use anyhow::Context;
+use chrono::Duration;
+use db::Db;
+use futures::Future;
+use lettre::message::Mailbox;
+use mailer::Mailer;
+use push::PushClient;
+use rocket::{
+ fairing::AdHoc,
+ http::{uri::Absolute, ContentType, Header},
+ request::{self, FromRequest},
+ response::Redirect,
+ tokio::{
+ spawn,
+ time::{interval_at, Instant, MissedTickBehavior},
+ },
+ Request, State,
+};
+use serde_json::{json, Value};
+use utils::DeferredActions;
+
+use crate::api::auth::invite::generate_invite_link;
+
+#[macro_use]
+extern crate rocket;
+#[macro_use]
+extern crate anyhow;
+#[macro_use]
+extern crate lazy_static;
+
+#[macro_use]
+pub(crate) mod utils;
+pub(crate) mod api;
+mod auth;
+mod cache;
+mod crypto;
+pub mod db;
+mod js;
+mod mailer;
+mod push;
+mod types;
+
+fn default_push_ttl() -> std::time::Duration {
+ std::time::Duration::from_secs(2 * 86400)
+}
+
+fn default_task_interval() -> std::time::Duration {
+ std::time::Duration::from_secs(5 * 60)
+}
+
+#[derive(serde::Deserialize)]
+struct Config {
+ database_url: String,
+ location: Absolute<'static>,
+ token_server_location: Absolute<'static>,
+ vapid_key: PathBuf,
+ vapid_subject: String,
+ #[serde(default = "default_push_ttl", with = "humantime_serde")]
+ default_push_ttl: std::time::Duration,
+ #[serde(default = "default_task_interval", with = "humantime_serde")]
+ prune_expired_interval: std::time::Duration,
+
+ mail_from: Mailbox,
+ mail_host: Option<String>,
+ mail_port: Option<u16>,
+
+ #[serde(default)]
+ invite_only: bool,
+ #[serde(default)]
+ invite_admin_address: String,
+}
+
+impl Config {
+ pub fn avatars_prefix(&self) -> Absolute<'static> {
+ Absolute::parse_owned(format!("{}/avatars", self.location)).unwrap()
+ }
+}
+
+#[get("/")]
+async fn root() -> (ContentType, &'static str) {
+ (ContentType::HTML, include_str!("../web/index.html"))
+}
+
+#[get("/settings/<_..>")]
+async fn settings() -> Redirect {
+ Redirect::to(uri!("/#/settings"))
+}
+
+#[get("/auth/v1/authorization")]
+async fn auth_auth() -> (ContentType, &'static str) {
+ root().await
+}
+
+#[get("/force_auth")]
+async fn force_auth() -> Redirect {
+ Redirect::to(uri!("/#/force_auth"))
+}
+
+#[derive(Debug)]
+struct IsFenix(bool);
+
+#[rocket::async_trait]
+impl<'r> FromRequest<'r> for IsFenix {
+ type Error = std::convert::Infallible;
+
+ async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
+ let ua = request.headers().get_one("user-agent");
+ request::Outcome::Success(IsFenix(
+ ua.map(|ua| ua.contains("Firefox") && ua.contains("Android")).unwrap_or(false),
+ ))
+ }
+}
+
+#[get("/.well-known/fxa-client-configuration")]
+async fn fxa_client_configuration(cfg: &State<Config>, is_fenix: IsFenix) -> Value {
+ let base = &cfg.location;
+ json!({
+ "auth_server_base_url": format!("{base}/auth"),
+ "oauth_server_base_url": format!("{base}/oauth"),
+ "pairing_server_base_uri": format!("{base}/pairing"),
+ "profile_server_base_url": format!("{base}/profile"),
+ // NOTE trailing slash is *essential*, otherwise fenix will refuse to sync.
+ // likewise firefox desktop seems to misbehave if there *is* a trailing slash.
+ "sync_tokenserver_base_url": format!("{}{}", cfg.token_server_location, if is_fenix.0 { "/" } else { "" })
+ })
+}
+
+// NOTE it looks like firefox does not implement refresh token rotation.
+// since it also looks like it doesn't implement MTLS we can't secure
+// refresh tokens against being stolen as advised by
+// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-security-topics
+// section 2.2.2
+
+// NOTE firefox "oldsync" scope is the current version?
+// https://github.com/mozilla/fxa/blob/main/packages/fxa-auth-server/docs/oauth/scopes.md
+// https://mozilla.github.io/ecosystem-platform/explanation/onepw-protocol
+// https://mozilla.github.io/ecosystem-platform/api
+// https://github.com/mozilla/fxa/blob/main/packages/fxa-auth-server/docs/device_registration.md
+// -> push for everything
+// https://mozilla.github.io/ecosystem-platform/explanation/scoped-keys
+
+#[get("/.well-known/openid-configuration")]
+fn oid(cfg: &State<Config>) -> Value {
+ let base = &cfg.location;
+ json!({
+ "authorization_endpoint": format!("{base}/auth/v1/authorization"),
+ "introspection_endpoint": format!("{base}/oauth/v1/introspect"),
+ "issuer": base.to_string(),
+ "jwks_uri": format!("{base}/oauth/v1/jwks"),
+ "revocation_endpoint": format!("{base}/oauth/v1/destroy"),
+ "token_endpoint": format!("{base}/auth/v1/oauth/token"),
+ "userinfo_endpoint": format!("{base}/profile/v1/profile"),
+ "claims_supported": ["aud","exp","iat","iss","sub"],
+ "id_token_signing_alg_values_supported": ["RS256"],
+ "response_types_supported": ["code","token"],
+ "scopes_supported": ["openid","profile","email"],
+ "subject_types_supported": ["public"],
+ "token_endpoint_auth_methods_supported": ["client_secret_post"],
+ })
+}
+
+fn spawn_periodic<A, P, F>(context: &'static str, t: StdDuration, p: P, f: A)
+where
+ A: Fn(P) -> F + Send + Sync + Sized + 'static,
+ P: Clone + Send + Sync + 'static,
+ F: Future<Output = anyhow::Result<()>> + Send + Sized,
+{
+ let mut interval = interval_at(Instant::now() + t, t);
+ interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
+
+ spawn(async move {
+ loop {
+ interval.tick().await;
+ info!("starting periodic {context}");
+ if let Err(e) = f(p.clone()).await {
+ error!("periodic {context} failed: {e}");
+ }
+ }
+ });
+}
+
+async fn ensure_invite_admin(db: &Db, cfg: &Config) -> anyhow::Result<()> {
+ if !cfg.invite_only {
+ return Ok(());
+ }
+
+ let tx = db.begin().await?;
+ match tx.get_user(&cfg.invite_admin_address).await {
+ Err(sqlx::Error::RowNotFound) => {
+ let url = generate_invite_link(&tx, cfg, Duration::hours(1)).await?;
+ tx.commit().await?;
+ warn!("admin user {} does not exist, register at {url}", cfg.invite_admin_address);
+ Ok(())
+ },
+ Err(e) => Err(anyhow!(e)),
+ Ok(_) => Ok(()),
+ }
+}
+
+pub async fn build() -> anyhow::Result<rocket::Rocket<rocket::Build>> {
+ let rocket = rocket::build();
+ let config = rocket.figment().extract::<Config>().context("reading config")?;
+ let db = Arc::new(Db::connect(&config.database_url).await.unwrap());
+
+ db.migrate().await.context("running db migrations")?;
+
+ ensure_invite_admin(&db, &config).await?;
+ let push = Arc::new(
+ PushClient::new(
+ &config.vapid_key,
+ &config.vapid_subject,
+ config.location.clone(),
+ config.default_push_ttl,
+ )
+ .context("setting up push notifications")?,
+ );
+ let mailer = Arc::new(
+ Mailer::new(
+ config.mail_from.clone(),
+ config.mail_host.as_deref().unwrap_or("localhost"),
+ config.mail_port.unwrap_or(25),
+ config.location.clone(),
+ )
+ .context("setting up mail notifications")?,
+ );
+ spawn_periodic("verify code prune", StdDuration::from_secs(5 * 60), Arc::clone(&db), {
+ |db| async move {
+ let tx = db.begin().await?;
+ tx.prune_expired_verify_codes().await?;
+ tx.commit().await?;
+ Ok(())
+ }
+ });
+ spawn_periodic("expired token prune", config.prune_expired_interval, Arc::clone(&db), {
+ |db| async move {
+ let tx = db.begin().await?;
+ tx.prune_expired_tokens().await?;
+ tx.commit().await?;
+ Ok(())
+ }
+ });
+ let rocket = rocket
+ .manage(config)
+ .manage(push)
+ .manage(mailer)
+ .attach(db)
+ .attach(DeferredActions)
+ .mount("/", routes![root, settings, oid, auth_auth, force_auth, fxa_client_configuration,])
+ .register("/auth/v1", catchers![api::auth::catch_all,])
+ .mount(
+ "/auth/v1",
+ routes![
+ api::auth::account::create,
+ api::auth::account::login,
+ api::auth::account::destroy,
+ api::auth::account::keys,
+ api::auth::account::reset,
+ api::auth::oauth::token_authenticated,
+ api::auth::oauth::token_unauthenticated,
+ api::auth::oauth::destroy,
+ api::auth::oauth::scoped_key_data,
+ api::auth::device::devices,
+ api::auth::device::device,
+ api::auth::device::invoke,
+ api::auth::device::commands,
+ api::auth::session::status,
+ api::auth::session::resend_code,
+ api::auth::session::verify_code,
+ api::auth::session::destroy,
+ api::auth::oauth::authorization,
+ api::auth::device::destroy,
+ api::auth::device::notify,
+ api::auth::device::attached_clients,
+ api::auth::device::destroy_attached_client,
+ api::auth::email::status,
+ api::auth::email::verify_code,
+ api::auth::email::resend_code,
+ api::auth::password::change_start,
+ api::auth::password::change_finish,
+ api::auth::password::forgot_start,
+ api::auth::password::forgot_finish,
+ ],
+ )
+ // slight hack to allow the js auth client to "just work"
+ .register("/_invite/v1", catchers![api::auth::catch_all,])
+ .mount("/_invite/v1", routes![api::auth::invite::generate,])
+ .attach(AdHoc::on_response("/auth Timestamp", |req, resp| {
+ Box::pin(async move {
+ if req.uri().path().as_str().starts_with("/auth/v1/") {
+ if let Ok(ts) = SystemTime::now().duration_since(UNIX_EPOCH) {
+ resp.set_header(Header::new("timestamp", ts.as_secs().to_string()));
+ }
+ }
+ })
+ }))
+ .register("/profile", catchers![api::profile::catch_all,])
+ .mount(
+ "/profile/v1",
+ routes![
+ api::profile::profile,
+ api::profile::display_name_post,
+ api::profile::avatar_get,
+ api::profile::avatar_upload,
+ api::profile::avatar_delete,
+ ],
+ )
+ .register("/avatars", catchers![api::profile::catch_all,])
+ .mount("/avatars", routes![api::profile::avatar_get_img])
+ .register("/oauth/v1", catchers![api::oauth::catch_all,])
+ .mount("/oauth/v1", routes![api::oauth::destroy, api::oauth::jwks, api::oauth::verify,])
+ .mount("/js", routes![js::static_js]);
+ Ok(rocket)
+}
diff --git a/src/mailer.rs b/src/mailer.rs
new file mode 100644
index 0000000..1ea1a8b
--- /dev/null
+++ b/src/mailer.rs
@@ -0,0 +1,105 @@
+use std::time::Duration;
+
+use lettre::{
+ message::Mailbox,
+ transport::smtp::client::{Tls, TlsParameters},
+ AsyncSmtpTransport, Message, Tokio1Executor,
+};
+use rocket::http::uri::Absolute;
+use serde_json::json;
+
+use crate::types::UserID;
+
+pub struct Mailer {
+ from: Mailbox,
+ verify_base: Absolute<'static>,
+ transport: AsyncSmtpTransport<Tokio1Executor>,
+}
+
+impl Mailer {
+ pub fn new(
+ from: Mailbox,
+ host: &str,
+ port: u16,
+ verify_base: Absolute<'static>,
+ ) -> anyhow::Result<Self> {
+ Ok(Mailer {
+ from,
+ verify_base,
+ transport: AsyncSmtpTransport::<Tokio1Executor>::builder_dangerous(host)
+ .port(port)
+ .tls(Tls::Opportunistic(TlsParameters::new(host.to_string())?))
+ .timeout(Some(Duration::from_secs(5)))
+ .build(),
+ })
+ }
+
+ pub(crate) async fn send_account_verify(
+ &self,
+ uid: &UserID,
+ to: &str,
+ code: &str,
+ ) -> anyhow::Result<()> {
+ let fragment = base64::encode_config(
+ serde_json::to_string(&json!({
+ "uid": uid,
+ "email": to,
+ "code": code,
+ }))?,
+ base64::URL_SAFE,
+ );
+ let email = Message::builder()
+ .from(self.from.clone())
+ .to(to.parse()?)
+ .subject("account verify code")
+ .body(format!("{}/#/verify/{fragment}", self.verify_base))?;
+ lettre::AsyncTransport::send(&self.transport, email).await?;
+ Ok(())
+ }
+
+ pub(crate) async fn send_session_verify(&self, to: &str, code: &str) -> anyhow::Result<()> {
+ let email = Message::builder()
+ .from(self.from.clone())
+ .to(to.parse()?)
+ .subject("session verify code")
+ .body(format!("{code}"))?;
+ lettre::AsyncTransport::send(&self.transport, email).await?;
+ Ok(())
+ }
+
+ pub(crate) async fn send_password_changed(&self, to: &str) -> anyhow::Result<()> {
+ let email = Message::builder()
+ .from(self.from.clone())
+ .to(to.parse()?)
+ .subject("account password has been changed")
+ .body(String::from(
+ "your account password has been changed. if you haven't done this, \
+ you're probably in trouble now.",
+ ))?;
+ lettre::AsyncTransport::send(&self.transport, email).await?;
+ Ok(())
+ }
+
+ pub(crate) async fn send_password_forgot(&self, to: &str, code: &str) -> anyhow::Result<()> {
+ let email = Message::builder()
+ .from(self.from.clone())
+ .to(to.parse()?)
+ .subject("account reset code")
+ .body(code.to_string())?;
+ lettre::AsyncTransport::send(&self.transport, email).await?;
+ Ok(())
+ }
+
+ pub(crate) async fn send_account_reset(&self, to: &str) -> anyhow::Result<()> {
+ let email = Message::builder()
+ .from(self.from.clone())
+ .to(to.parse()?)
+ .subject("account has been reset")
+ .body(String::from(
+ "your account has been reset. if you haven't done this, \
+ you're probably in trouble now.",
+ ))?;
+ lettre::AsyncTransport::send(&self.transport, email).await?;
+ Ok(())
+ }
+}
diff --git a/src/push.rs b/src/push.rs
new file mode 100644
index 0000000..6ee1afb
--- /dev/null
+++ b/src/push.rs
@@ -0,0 +1,198 @@
+use anyhow::Result;
+use rocket::http::uri::Absolute;
+use serde_json::{json, Value};
+use std::time::Duration;
+use std::{fs::File, io::Read, path::Path};
+
+use serde::Serialize;
+use web_push::{
+ ContentEncoding, SubscriptionInfo, VapidSignatureBuilder, WebPushClient, WebPushMessageBuilder,
+};
+
+use crate::db::DbConn;
+use crate::types::{Device, DeviceID, DevicePush, UserID};
+
+pub(crate) struct PushClient {
+ key: Box<[u8]>,
+ client: WebPushClient,
+ subject: String,
+ base_uri: Absolute<'static>,
+ default_ttl: Duration,
+}
+
+impl PushClient {
+ pub(crate) fn new<P: AsRef<Path>>(
+ key: P,
+ subject: &str,
+ base_uri: Absolute<'static>,
+ default_ttl: Duration,
+ ) -> Result<Self> {
+ let mut key_bytes = vec![];
+ File::open(key).and_then(|mut f| f.read_to_end(&mut key_bytes))?;
+ Ok(PushClient {
+ key: key_bytes.into_boxed_slice(),
+ client: WebPushClient::new()?,
+ subject: subject.to_string(),
+ base_uri,
+ default_ttl,
+ })
+ }
+
+ async fn push_raw(&self, to: &DevicePush, ttl: Duration, data: Option<&[u8]>) -> Result<()> {
+ let sub = SubscriptionInfo::new(&to.callback, &to.public_key, &to.auth_key);
+ let mut sig = VapidSignatureBuilder::from_pem(&*self.key, &sub)?;
+ // mozilla requires {aud,exp,sub} or message will get a 401 unauthorized.
+ // {aud,exp} are added automatically
+ sig.add_claim("sub", self.subject.as_str());
+ let mut builder = WebPushMessageBuilder::new(&sub)?;
+ if let Some(data) = data {
+ builder.set_payload(ContentEncoding::Aes128Gcm, data);
+ }
+ builder.set_vapid_signature(sig.build()?);
+ builder.set_ttl(ttl.as_secs().min(u32::MAX as u64) as u32);
+ Ok(self.client.send(builder.build()?).await?)
+ }
+
+ async fn push_one(
+ &self,
+ context: &str,
+ db: Option<&DbConn>,
+ to: &Device,
+ ttl: Duration,
+ data: Option<&[u8]>,
+ ) -> Result<()> {
+ match (to.push_expired, to.push.as_ref()) {
+ (false, Some(ep)) => match self.push_raw(ep, ttl, data).await {
+ Ok(()) => Ok(()),
+ Err(e) => {
+ warn!("{} push to {} failed: {}", context, &to.device_id, e);
+ if let Some(db) = db {
+ if let Err(e) = db.set_push_expired(&to.device_id).await {
+ warn!("failed to set {} push_endpoint_expired: {}", &to.device_id, e);
+ }
+ }
+ Err(e)
+ },
+ },
+ (_, None) => Err(anyhow!("no push callback")),
+ (true, _) => Err(anyhow!("push endpoint expired")),
+ }
+ }
+
+ async fn push_all(
+ &self,
+ context: &str,
+ db: Option<&DbConn>,
+ to: &[Device],
+ ttl: Duration,
+ msg: impl Serialize,
+ ) {
+ let msg = serde_json::to_vec(&msg).expect("push message serialization failed");
+ for dev in to {
+ // ignore errors here, except by logging them. we can't notify the client
+ // about anything and failing isn't an option either.
+ let _ = self.push_one(context, db, dev, ttl, Some(&msg)).await;
+ }
+ }
+
+ pub(crate) async fn command_received(
+ &self,
+ db: &DbConn,
+ to: &Device,
+ command: &str,
+ index: i64,
+ sender: &Option<DeviceID>,
+ ) -> Result<()> {
+ let url =
+ format!("{}/auth/v1/account/device/commands?index={}&limit=1", self.base_uri, index);
+ let msg = json!({
+ "version": 1,
+ "command": "fxaccounts:command_received",
+ "data": {
+ "command": command,
+ "index": index,
+ "sender": sender,
+ "url": url,
+ },
+ });
+ let msg = serde_json::to_vec(&msg)?;
+ self.push_one("command_received", Some(db), to, self.default_ttl, Some(&msg)).await
+ }
+
+ pub(crate) async fn device_connected(&self, db: &DbConn, to: &[Device], name: &str) {
+ let msg = json!({
+ "version": 1,
+ "command": "fxaccounts:device_connected",
+ "data": {
+ "deviceName": name,
+ },
+ });
+ self.push_all("device_connected", Some(db), to, self.default_ttl, &msg).await;
+ }
+
+ pub(crate) async fn device_disconnected(&self, db: &DbConn, to: &[Device], id: &DeviceID) {
+ let msg = json!({
+ "version": 1,
+ "command": "fxaccounts:device_disconnected",
+ "data": {
+ "id": id,
+ },
+ });
+ self.push_all("device_disconnected", Some(db), to, self.default_ttl, &msg).await;
+ }
+
+ pub(crate) async fn profile_updated(&self, db: &DbConn, to: &[Device]) {
+ let msg = json!({
+ "version": 1,
+ "command": "fxaccounts:profile_updated",
+ });
+ self.push_all("profile_updated", Some(db), to, self.default_ttl, &msg).await;
+ }
+
+ pub(crate) async fn account_verified(&self, db: &DbConn, to: &[Device]) {
+ for dev in to {
+ // ignore errors here, except by logging them. we can't notify the client
+ // about anything and failing isn't an option either.
+ let _ = self.push_one("account_verified", Some(db), dev, Duration::ZERO, None).await;
+ }
+ }
+
+ pub(crate) async fn account_destroyed(&self, to: &[Device], uid: &UserID) {
+ let msg = json!({
+ "version": 1,
+ "command": "fxaccounts:account_destroyed",
+ "data": {
+ "uid": uid,
+ },
+ });
+ self.push_all("account_destroyed", None, to, self.default_ttl, &msg).await;
+ }
+
+ pub(crate) async fn password_reset(&self, to: &[Device]) {
+ let msg = serde_json::to_vec(&json!({
+ "version": 1u32,
+ "command": "fxaccounts:password_reset",
+ }))
+ .expect("serde failed");
+ for dev in to {
+ // ignore errors here, except by logging them. we can't notify the client
+ // about anything and failing isn't an option either.
+ let _ = self.push_one("password_reset", None, dev, self.default_ttl, Some(&msg)).await;
+ // NOTE password_reset alone doesn't seem to do much, se we also disconnect
+ // each device explicitly.
+ let msg = serde_json::to_vec(&json!({
+ "version": 1,
+ "command": "fxaccounts:device_disconnected",
+ "data": {
+ "id": dev.device_id,
+ },
+ }))
+ .expect("serde failed");
+ let _ = self.push_one("password_reset", None, dev, self.default_ttl, Some(&msg)).await;
+ }
+ }
+
+ pub(crate) async fn push_any(&self, db: &DbConn, to: &[Device], ttl: Duration, payload: Value) {
+ self.push_all("push_any", Some(db), to, ttl, &payload).await;
+ }
+}
diff --git a/src/types.rs b/src/types.rs
new file mode 100644
index 0000000..c0c5dfe
--- /dev/null
+++ b/src/types.rs
@@ -0,0 +1,436 @@
+use crate::crypto::SecretBytes;
+use chrono::{DateTime, Utc};
+use password_hash::{rand_core::OsRng, Output, SaltString};
+use rand::RngCore;
+use serde::{Deserialize, Serialize};
+use serde_json::Value;
+use sha2::{Digest, Sha256};
+use sqlx::{
+ postgres::{PgArgumentBuffer, PgTypeInfo, PgValueRef},
+ Decode, Encode, Postgres, Type,
+};
+use std::{
+ collections::HashMap,
+ fmt::{Debug, Display},
+ ops::Deref,
+ str::FromStr,
+};
+
+use self::oauth::ScopeSet;
+
+pub(crate) mod oauth;
+
+macro_rules! array_type {
+ (
+ $( #[ $attr:meta ] )*
+ $name:ident($inner:ty) as $sql_name:ident {
+ $( $body:tt )*
+ }
+ ) => {
+ $( #[ $attr ] )*
+ pub(crate) struct $name(pub(crate) $inner);
+
+ impl $name {
+ $( $body )*
+ }
+
+ impl Type<Postgres> for $name {
+ fn type_info() -> PgTypeInfo {
+ PgTypeInfo::with_name(stringify!($sql_name))
+ }
+ }
+
+ impl Encode<'_, Postgres> for $name {
+ fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> sqlx::encode::IsNull {
+ let raw = self.0.iter().map(Self::encode_elem).collect::<Vec<_>>();
+ Encode::<'_, Postgres>::encode_by_ref(&raw, buf)
+ }
+ }
+
+ impl Decode<'_, Postgres> for $name {
+ fn decode(value: PgValueRef) -> Result<Self, sqlx::error::BoxDynError> {
+ Ok(Self::decode_elems(Decode::<'_, Postgres>::decode(value)?)?)
+ }
+ }
+ }
+}
+
+macro_rules! bytea_types {
+ () => {};
+ (
+ #[simple_array]
+ struct $name:ident($inner:ty) as $sql_name:ident;
+
+ $( $rest:tt )*
+ ) => {
+ bytea_types!{
+ #[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
+ #[serde(try_from = "String", into = "String")]
+ struct $name($inner) as $sql_name {
+ fn decode(v) -> _ { &v.0[..] }
+ fn encode(v) -> _ { v }
+ }
+
+ impl FromStr for $name {}
+ impl ToString for $name {}
+ impl Debug for $name {}
+
+ $( $rest )*
+ }
+ };
+ (
+ $( #[ $attr:meta ] )*
+ struct $name:ident($inner:ty) as $sql_name:ident {
+ $( fn arbitrary($a:ident) -> _ { $ae:expr } )?
+ fn decode($d:ident) -> _ { $de:expr }
+ fn encode($e:ident) -> _ { $ee:expr }
+
+ $( $impls:tt )*
+ }
+
+ $( $rest:tt )*
+ ) => {
+ $( #[ $attr ] )*
+ pub(crate) struct $name(pub(crate) $inner);
+
+ impl $name {
+ $( $impls )*
+ }
+
+ impl Type<Postgres> for $name {
+ fn type_info() -> PgTypeInfo {
+ PgTypeInfo::with_name(stringify!($sql_name))
+ }
+ }
+
+ impl Encode<'_, Postgres> for $name {
+ fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> sqlx::encode::IsNull {
+ let $d = self;
+ <&[u8] as Encode<'_, Postgres>>::encode_by_ref(&$de, buf)
+ }
+ }
+
+ impl Decode<'_, Postgres> for $name {
+ fn decode(value: PgValueRef) -> Result<Self, sqlx::error::BoxDynError> {
+ let $e = <&[u8] as Decode<'_, Postgres>>::decode(value)?.try_into()?;
+ Ok($name($ee))
+ }
+ }
+
+ bytea_types!{ $( $rest )* }
+ };
+ ( impl ToString for $name:ident {} $( $rest:tt )* ) => {
+ impl Display for $name {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
+ f.write_str(&hex::encode(&self.0))
+ }
+ }
+ impl From<$name> for String {
+ fn from(s: $name) -> String {
+ format!("{}", s)
+ }
+ }
+ bytea_types!{ $( $rest )* }
+ };
+ ( impl Debug for $name:ident {} $( $rest:tt )* ) => {
+ impl Debug for $name {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_tuple(stringify!($name)).field(&self.to_string()).finish()
+ }
+ }
+ bytea_types!{ $( $rest )* }
+ };
+ ( impl FromStr for $name:ident {} $( $rest:tt )* ) => {
+ impl FromStr for $name {
+ type Err = anyhow::Error;
+
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ Ok(Self(hex::decode(s)?.as_slice().try_into()?))
+ }
+ }
+ impl TryFrom<String> for $name {
+ type Error = anyhow::Error;
+
+ fn try_from(s: String) -> Result<Self, Self::Error> {
+ s.parse()
+ }
+ }
+ bytea_types!{ $( $rest )* }
+ }
+}
+
+//
+//
+//
+
+bytea_types! {
+ #[derive(Clone, Debug, PartialEq, Eq)]
+ struct HawkKey(SecretBytes<32>) as hawk_key {
+ fn arbitrary(a) -> _ { SecretBytes(a) }
+ fn decode(v) -> _ { v.0.0.as_ref() }
+ fn encode(v) -> _ { SecretBytes(v) }
+ }
+
+ #[simple_array]
+ struct SessionID([u8; 32]) as session_id;
+
+ #[simple_array]
+ struct DeviceID([u8; 16]) as device_id;
+
+ #[simple_array]
+ struct UserID([u8; 16]) as user_id;
+
+ #[simple_array]
+ struct KeyFetchID([u8; 32]) as key_fetch_id;
+
+ #[simple_array]
+ struct OauthTokenID([u8; 32]) as oauth_token_id;
+
+ #[simple_array]
+ struct OauthAuthorizationID([u8; 32]) as oauth_auth_id;
+
+ #[simple_array]
+ struct PasswordChangeID([u8; 32]) as password_change_id;
+
+ #[simple_array]
+ struct AccountResetID([u8; 32]) as account_reset_id;
+
+ #[simple_array]
+ struct AvatarID([u8; 16]) as avatar_id;
+
+ #[derive(Clone, Debug, PartialEq, Eq)]
+ struct SecretKey(SecretBytes<32>) as secret_key {
+ fn arbitrary(a) -> _ { SecretBytes(a) }
+ fn decode(v) -> _ { v.0.0.as_ref() }
+ fn encode(v) -> _ { SecretBytes(v) }
+ }
+
+ #[derive(Clone, Debug, PartialEq, Eq)]
+ struct VerifyHash(Output) as verify_hash {
+ fn arbitrary(a) -> _ { Output::new(<[u8; 32]>::as_ref(&a)).unwrap() }
+ fn decode(v) -> _ { v.0.as_ref() }
+ fn encode(v) -> _ { v }
+ }
+}
+
+impl DeviceID {
+ pub fn random() -> Self {
+ let mut result = Self([0; 16]);
+ OsRng.fill_bytes(&mut result.0);
+ result
+ }
+}
+
+impl UserID {
+ pub fn random() -> Self {
+ let mut result = Self([0; 16]);
+ OsRng.fill_bytes(&mut result.0);
+ result
+ }
+}
+
+impl OauthAuthorizationID {
+ pub fn random() -> Self {
+ let mut result = Self([0; 32]);
+ OsRng.fill_bytes(&mut result.0);
+ result
+ }
+}
+
+#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
+#[serde(try_from = "String", into = "String")]
+pub(crate) struct OauthToken([u8; 32]);
+
+impl OauthToken {
+ pub fn random() -> Self {
+ let mut result = Self([0; 32]);
+ OsRng.fill_bytes(&mut result.0);
+ result
+ }
+
+ pub fn hash(&self) -> OauthTokenID {
+ let mut sha = Sha256::new();
+ sha.update(&self.0);
+ OauthTokenID(*sha.finalize().as_ref())
+ }
+}
+
+bytea_types! {
+ impl Debug for OauthToken {}
+ impl FromStr for OauthToken {}
+ impl ToString for OauthToken {}
+}
+
+#[derive(Debug, Deserialize, PartialEq, Eq, Type)]
+#[sqlx(type_name = "oauth_access_type", rename_all = "lowercase")]
+#[serde(rename_all = "lowercase")]
+pub enum OauthAccessType {
+ Online,
+ Offline,
+}
+
+#[derive(Debug)]
+pub(crate) struct UserSession {
+ pub(crate) uid: UserID,
+ pub(crate) req_hmac_key: HawkKey,
+ pub(crate) device_id: Option<DeviceID>,
+ pub(crate) created_at: DateTime<Utc>,
+ pub(crate) verified: bool,
+ pub(crate) verify_code: Option<String>,
+}
+
+#[derive(Clone, Debug)]
+pub(crate) struct DeviceCommand {
+ pub(crate) index: i64,
+ pub(crate) command: String,
+ pub(crate) payload: Value,
+ #[allow(dead_code)]
+ pub(crate) expires: DateTime<Utc>,
+ // NOTE this is a device ID, but we don't link it to the actual sender device
+ // because removing a device would also remove its queued commands. this mirrors
+ // what fxa does.
+ pub(crate) sender: Option<String>,
+}
+
+#[derive(Clone, Debug, PartialEq, sqlx::Type)]
+#[sqlx(type_name = "device_push_info")]
+pub(crate) struct DevicePush {
+ pub(crate) callback: String,
+ pub(crate) public_key: String,
+ pub(crate) auth_key: String,
+}
+
+#[derive(Clone, Debug, PartialEq, sqlx::Type)]
+#[sqlx(type_name = "device_command")]
+struct DeviceCommandsEntry {
+ name: String,
+ body: String,
+}
+
+array_type! {
+ #[derive(Clone, Debug, PartialEq)]
+ DeviceCommands(HashMap<String, String>) as _device_command {
+ fn encode_elem(e: (&String, &String)) -> DeviceCommandsEntry {
+ DeviceCommandsEntry { name: e.0.clone(), body: e.1.clone() }
+ }
+ fn decode_elems(e: Vec<DeviceCommandsEntry>) -> anyhow::Result<Self> {
+ Ok(Self(e.into_iter().map(|e| (e.name, e.body)).collect()))
+ }
+
+ pub(crate) fn into_map(self) -> HashMap<String, String> {
+ self.0
+ }
+ }
+}
+
+impl Deref for DeviceCommands {
+ type Target = HashMap<String, String>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+#[derive(Clone, Debug)]
+pub(crate) struct Device {
+ pub(crate) device_id: DeviceID,
+ // taken from session, otherwise UNIX_EPOCH
+ pub(crate) last_active: DateTime<Utc>,
+ pub(crate) name: String,
+ pub(crate) type_: String,
+ pub(crate) push: Option<DevicePush>,
+ pub(crate) available_commands: DeviceCommands,
+ pub(crate) push_expired: bool,
+ // actually a str->str map, but we treat it as opaque for simplicity.
+ // writing a HashMap<String, String> to the db through sqlx is an immense pain,
+ // and we don't care about the value anyway—it only has to exist for fenix.
+ pub(crate) location: Value,
+}
+
+#[derive(Clone, Debug)]
+pub(crate) struct DeviceUpdate<'a> {
+ pub(crate) name: Option<&'a str>,
+ pub(crate) type_: Option<&'a str>,
+ pub(crate) push: Option<DevicePush>,
+ pub(crate) available_commands: Option<DeviceCommands>,
+ pub(crate) location: Option<Value>,
+}
+
+#[derive(Debug, sqlx::Type)]
+#[sqlx(type_name = "oauth_token_kind", rename_all = "lowercase")]
+pub(crate) enum OauthTokenKind {
+ Access,
+ Refresh,
+}
+
+#[derive(Debug)]
+pub(crate) struct OauthAccessToken {
+ pub(crate) user_id: UserID,
+ pub(crate) client_id: String,
+ pub(crate) scope: ScopeSet,
+ pub(crate) parent_refresh: Option<OauthTokenID>,
+ pub(crate) parent_session: Option<SessionID>,
+ pub(crate) expires_at: DateTime<Utc>,
+}
+
+#[derive(Debug)]
+pub(crate) struct OauthRefreshToken {
+ pub(crate) user_id: UserID,
+ pub(crate) client_id: String,
+ pub(crate) scope: ScopeSet,
+ pub(crate) session_id: Option<SessionID>,
+}
+
+#[derive(Debug)]
+pub(crate) struct OauthAuthorization {
+ pub(crate) user_id: UserID,
+ pub(crate) client_id: String,
+ pub(crate) scope: ScopeSet,
+ pub(crate) access_type: OauthAccessType,
+ pub(crate) code_challenge: String,
+ pub(crate) keys_jwe: Option<String>,
+ pub(crate) auth_at: DateTime<Utc>,
+}
+
+#[derive(Debug)]
+#[cfg_attr(test, derive(Clone))]
+pub(crate) struct User {
+ pub(crate) auth_salt: SaltString,
+ pub(crate) email: String,
+ pub(crate) display_name: Option<String>,
+ pub(crate) ka: SecretKey,
+ pub(crate) wrapwrap_kb: SecretKey,
+ pub(crate) verify_hash: VerifyHash,
+ pub(crate) verified: bool,
+}
+
+// MISSING user secondary email addresses
+
+#[derive(Debug)]
+pub(crate) struct Avatar {
+ pub(crate) id: AvatarID,
+ pub(crate) data: Vec<u8>,
+ pub(crate) content_type: String,
+}
+
+#[derive(Debug)]
+pub(crate) struct AttachedClient {
+ pub(crate) client_id: Option<String>,
+ pub(crate) device_id: Option<DeviceID>,
+ pub(crate) session_token_id: Option<SessionID>,
+ pub(crate) refresh_token_id: Option<OauthTokenID>,
+ pub(crate) device_type: Option<String>,
+ pub(crate) name: Option<String>,
+ pub(crate) created_time: Option<DateTime<Utc>>,
+ pub(crate) last_access_time: Option<DateTime<Utc>>,
+ pub(crate) scope: Option<String>,
+}
+
+#[derive(Debug)]
+pub(crate) struct VerifyCode {
+ #[allow(dead_code)]
+ pub(crate) user_id: UserID,
+ pub(crate) session_id: Option<SessionID>,
+ #[allow(dead_code)]
+ pub(crate) code: String,
+}
diff --git a/src/types/oauth.rs b/src/types/oauth.rs
new file mode 100644
index 0000000..222c567
--- /dev/null
+++ b/src/types/oauth.rs
@@ -0,0 +1,267 @@
+use std::{borrow::Cow, fmt::Display};
+
+use serde::{Deserialize, Serialize};
+
+#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
+#[serde(transparent)]
+pub(crate) struct Scope<'a>(pub Cow<'a, str>);
+
+impl<'a> Scope<'a> {
+ pub const fn borrowed(s: &'a str) -> Self {
+ Self(Cow::Borrowed(s))
+ }
+
+ pub fn into_owned(self) -> Scope<'static> {
+ Scope(Cow::Owned(self.0.into_owned()))
+ }
+
+ pub fn implies(&self, other: &Scope) -> bool {
+ let (a, b) = (&*self.0, &*other.0);
+ match (a.strip_prefix("https://"), b.strip_prefix("https://")) {
+ (Some(a), Some(b)) => {
+ let (a_origin, a_path) = a.split_once('/').unwrap_or((a, ""));
+ let (b_origin, b_path) = b.split_once('/').unwrap_or((b, ""));
+ if a_origin != b_origin {
+ false
+ } else {
+ let (a_path, a_frag) = match a_path.split_once('#') {
+ Some((p, f)) => (p, Some(f)),
+ None => (a_path, None),
+ };
+ let (b_path, b_frag) = match b_path.split_once('#') {
+ Some((p, f)) => (p, Some(f)),
+ None => (b_path, None),
+ };
+ if b_path
+ .strip_prefix(a_path)
+ .map_or(false, |br| br.is_empty() || br.starts_with('/'))
+ {
+ match (a_frag, b_frag) {
+ (Some(af), Some(bf)) => af == bf,
+ (Some(_), None) => false,
+ _ => true,
+ }
+ } else {
+ false
+ }
+ }
+ },
+ (None, None) => {
+ let (a, a_write) =
+ a.strip_suffix(":write").map(|s| (s, true)).unwrap_or((a, false));
+ let (b, b_write) =
+ b.strip_suffix(":write").map(|s| (s, true)).unwrap_or((b, false));
+ if b_write && !a_write {
+ false
+ } else {
+ b.strip_prefix(a).map_or(false, |br| br.is_empty() || br.starts_with(':'))
+ }
+ },
+ _ => false,
+ }
+ }
+}
+
+impl<'a> Display for Scope<'a> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ self.0.fmt(f)
+ }
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize, sqlx::Type)]
+#[serde(transparent)]
+#[sqlx(transparent)]
+pub(crate) struct ScopeSet(String);
+
+impl ScopeSet {
+ pub fn split(&self) -> impl Iterator<Item = Scope> {
+ // not using split_whitespace because the oauth spec explicitly says to split on SP
+ self.0.split(' ').filter(|s| !s.is_empty()).map(Scope::borrowed)
+ }
+
+ pub fn implies(&self, scope: &Scope) -> bool {
+ self.split().any(|a| a.implies(scope))
+ }
+
+ pub fn implies_all(&self, scopes: &ScopeSet) -> bool {
+ scopes.split().all(|b| self.implies(&b))
+ }
+
+ pub fn is_allowed_by(&self, allowed: &[Scope]) -> bool {
+ self.split().all(|scope| allowed.iter().any(|perm| perm.implies(&scope)))
+ }
+
+ pub fn remove(&self, remove: &Scope) -> ScopeSet {
+ let remaining = self.split().filter(|s| !remove.implies(s));
+ ScopeSet(remaining.map(|s| s.0).collect::<Vec<_>>().join(" "))
+ }
+}
+
+impl PartialEq for ScopeSet {
+ fn eq(&self, other: &Self) -> bool {
+ let (mut a, mut b) = (self.split().collect::<Vec<_>>(), other.split().collect::<Vec<_>>());
+ a.sort_by(|a, b| a.0.cmp(&b.0));
+ b.sort_by(|a, b| a.0.cmp(&b.0));
+ a.eq(&b)
+ }
+}
+
+impl Eq for ScopeSet {}
+
+impl Display for ScopeSet {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ self.0.fmt(f)
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::{Scope, ScopeSet};
+
+ #[test]
+ fn test_scope_implies() {
+ assert!(ScopeSet("profile:write".to_string()).implies(&Scope::borrowed("profile")));
+ assert!(ScopeSet("profile".to_string()).implies(&Scope::borrowed("profile:email")));
+ assert!(ScopeSet("profile:write".to_string()).implies(&Scope::borrowed("profile:email")));
+ assert!(
+ ScopeSet("profile:write".to_string()).implies(&Scope::borrowed("profile:email:write"))
+ );
+ assert!(
+ ScopeSet("profile:email:write".to_string()).implies(&Scope::borrowed("profile:email"))
+ );
+ assert!(ScopeSet("profile profile:email:write".to_string())
+ .implies(&Scope::borrowed("profile:email")));
+ assert!(ScopeSet("profile profile:email:write".to_string())
+ .implies(&Scope::borrowed("profile:display_name")));
+ assert!(ScopeSet("profile https://identity.mozilla.com/apps/oldsync".to_string())
+ .implies(&Scope::borrowed("profile")));
+ assert!(ScopeSet("profile https://identity.mozilla.com/apps/oldsync".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync")));
+ assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync#read")));
+ assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/bookmarks")));
+ assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/bookmarks#read")));
+ assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync#read".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/bookmarks#read")));
+ assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync#read profile".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/bookmarks#read")));
+
+ assert!(!ScopeSet("profile:email:write".to_string()).implies(&Scope::borrowed("profile")));
+ assert!(
+ !ScopeSet("profile:email:write".to_string()).implies(&Scope::borrowed("profile:write"))
+ );
+ assert!(!ScopeSet("profile:email".to_string())
+ .implies(&Scope::borrowed("profile:display_name")));
+ assert!(!ScopeSet("profilebogey".to_string()).implies(&Scope::borrowed("profile")));
+ assert!(!ScopeSet("profile:write".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync")));
+ assert!(!ScopeSet("profile profile:email:write".to_string())
+ .implies(&Scope::borrowed("profile:write")));
+ assert!(!ScopeSet("https".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync")));
+ assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsync".to_string())
+ .implies(&Scope::borrowed("profile")));
+ assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsync#read".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/bookmarks")));
+ assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsync#write".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/bookmarks#read")));
+ assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsync/bookmarks".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync")));
+ assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsync/bookmarks".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/passwords")));
+ assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsyncer".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync")));
+ assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsync".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsyncer")));
+ assert!(!ScopeSet("https://identity.mozilla.org/apps/oldsync".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync")));
+ }
+
+ #[test]
+ fn test_scopes_allowed_by() {
+ const ALLOWED: [Scope; 2] = [
+ Scope::borrowed("profile:write"),
+ Scope::borrowed("https://identity.mozilla.com/apps/oldsync"),
+ ];
+
+ assert!(ScopeSet("profile".to_string()).is_allowed_by(&ALLOWED));
+ assert!(ScopeSet("profile:write".to_string()).is_allowed_by(&ALLOWED));
+ assert!(ScopeSet("profile:email".to_string()).is_allowed_by(&ALLOWED));
+ assert!(ScopeSet("profile:email:write".to_string()).is_allowed_by(&ALLOWED));
+ assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync".to_string())
+ .is_allowed_by(&ALLOWED));
+ assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync#read".to_string())
+ .is_allowed_by(&ALLOWED));
+ assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync/bookmarks".to_string())
+ .is_allowed_by(&ALLOWED));
+ assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync/bookmarks#read".to_string())
+ .is_allowed_by(&ALLOWED));
+ assert!(ScopeSet("profile https://identity.mozilla.com/apps/oldsync".to_string())
+ .is_allowed_by(&ALLOWED));
+
+ assert!(!ScopeSet("storage".to_string()).is_allowed_by(&ALLOWED));
+ assert!(!ScopeSet("storage:write".to_string()).is_allowed_by(&ALLOWED));
+ assert!(!ScopeSet("storage:email".to_string()).is_allowed_by(&ALLOWED));
+ assert!(!ScopeSet("storage:email:write".to_string()).is_allowed_by(&ALLOWED));
+ assert!(!ScopeSet("https://identity.mozilla.com/apps/newsync".to_string())
+ .is_allowed_by(&ALLOWED));
+ assert!(!ScopeSet("https://identity.mozilla.com/apps/newsync#read".to_string())
+ .is_allowed_by(&ALLOWED));
+ assert!(!ScopeSet("https://identity.mozilla.com/apps/newsync/bookmarks".to_string())
+ .is_allowed_by(&ALLOWED));
+ assert!(!ScopeSet("https://identity.mozilla.com/apps/newsync/bookmarks#read".to_string())
+ .is_allowed_by(&ALLOWED));
+ assert!(!ScopeSet("storage https://identity.mozilla.com/apps/newsync".to_string())
+ .is_allowed_by(&ALLOWED));
+ }
+
+ #[test]
+ fn test_scopes_remove() {
+ assert_eq!(
+ ScopeSet("profile foo".to_string()).remove(&Scope::borrowed("profile")),
+ ScopeSet("foo".to_string())
+ );
+ assert_ne!(
+ ScopeSet("profile:write foo".to_string()).remove(&Scope::borrowed("profile")),
+ ScopeSet("foo".to_string())
+ );
+ assert_eq!(
+ ScopeSet("profile:write foo".to_string()).remove(&Scope::borrowed("profile:write")),
+ ScopeSet("foo".to_string())
+ );
+ assert_eq!(
+ ScopeSet("profile:x foo".to_string()).remove(&Scope::borrowed("profile")),
+ ScopeSet("foo".to_string())
+ );
+ assert_ne!(
+ ScopeSet("profile:x:write foo".to_string()).remove(&Scope::borrowed("profile")),
+ ScopeSet("foo".to_string())
+ );
+ assert_eq!(
+ ScopeSet("profile:x:write foo".to_string()).remove(&Scope::borrowed("profile:write")),
+ ScopeSet("foo".to_string())
+ );
+
+ assert_eq!(
+ ScopeSet("https://foo/bar foo".to_string()).remove(&Scope::borrowed("https://foo/bar")),
+ ScopeSet("foo".to_string())
+ );
+ assert_eq!(
+ ScopeSet("https://foo/bar/baz foo".to_string())
+ .remove(&Scope::borrowed("https://foo/bar")),
+ ScopeSet("foo".to_string())
+ );
+ assert_eq!(
+ ScopeSet("https://foo/bar#read foo".to_string())
+ .remove(&Scope::borrowed("https://foo/bar")),
+ ScopeSet("foo".to_string())
+ );
+ assert_eq!(
+ ScopeSet("https://foo/bar/baz#read foo".to_string())
+ .remove(&Scope::borrowed("https://foo/bar")),
+ ScopeSet("foo".to_string())
+ );
+ }
+}
diff --git a/src/utils.rs b/src/utils.rs
new file mode 100644
index 0000000..5a2407a
--- /dev/null
+++ b/src/utils.rs
@@ -0,0 +1,124 @@
+use std::convert::Infallible;
+
+use futures::Future;
+use rocket::{
+ fairing::{self, Fairing, Info, Kind},
+ http::Status,
+ request::{FromRequest, Outcome},
+ tokio::{
+ self,
+ sync::broadcast::{channel, Sender},
+ task::JoinHandle,
+ },
+ Request, Response, Sentinel,
+};
+
+// does the same as try_outcome!(...).map_forward(|_| data),
+// but without moving data in non-failure cases.
+macro_rules! try_outcome_data {
+ ($data:expr, $e:expr) => {
+ match $e {
+ ::rocket::outcome::Outcome::Success(val) => val,
+ ::rocket::outcome::Outcome::Failure(e) => {
+ return ::rocket::outcome::Outcome::Failure(::std::convert::From::from(e))
+ },
+ ::rocket::outcome::Outcome::Forward(()) => {
+ return ::rocket::outcome::Outcome::Forward($data)
+ },
+ }
+ };
+}
+
+pub trait CanBeLogged: Send + 'static {
+ fn into_message(self) -> Option<String>;
+ fn not_run() -> Self;
+}
+
+impl CanBeLogged for Result<(), anyhow::Error> {
+ fn into_message(self) -> Option<String> {
+ self.err().map(|e| e.to_string())
+ }
+ fn not_run() -> Self {
+ Ok(())
+ }
+}
+
+pub fn spawn_logged<T>(context: &'static str, future: T) -> JoinHandle<()>
+where
+ T: Future + Send + 'static,
+ T::Output: CanBeLogged + Send + 'static,
+{
+ tokio::spawn(async move {
+ if let Some(msg) = future.await.into_message() {
+ warn!("{context}: {msg}");
+ }
+ })
+}
+
+pub struct DeferredActions;
+struct HasDeferredActions;
+
+pub struct DeferAction(Sender<()>);
+
+#[async_trait]
+impl Fairing for DeferredActions {
+ fn info(&self) -> Info {
+ Info { name: "deferred actions", kind: Kind::Ignite | Kind::Response }
+ }
+
+ async fn on_ignite(&self, rocket: rocket::Rocket<rocket::Build>) -> fairing::Result {
+ Ok(rocket.manage(HasDeferredActions))
+ }
+
+ async fn on_response<'r>(&self, req: &'r Request<'_>, res: &mut Response<'r>) {
+ if res.status() == Status::Ok {
+ if let Some(DeferAction(tx)) = req.local_cache(|| None) {
+ // could have no receivers, that's not an error here
+ tx.send(()).ok();
+ }
+ }
+ }
+}
+
+impl<'r> Sentinel for &'r DeferAction {
+ fn abort(rocket: &rocket::Rocket<rocket::Ignite>) -> bool {
+ rocket.state::<HasDeferredActions>().is_none()
+ }
+}
+
+#[async_trait]
+impl<'r> FromRequest<'r> for &'r DeferAction {
+ type Error = Infallible;
+
+ async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
+ Outcome::Success(
+ req.local_cache(|| {
+ let (tx, _) = channel(1);
+ Some(DeferAction(tx))
+ })
+ .as_ref()
+ .unwrap(),
+ )
+ }
+}
+
+impl DeferAction {
+ pub fn spawn_after_success<T>(&self, context: &'static str, future: T)
+ where
+ T: Future + Send + 'static,
+ T::Output: CanBeLogged + Send + 'static,
+ {
+ let mut r = self.0.subscribe();
+ spawn_logged(context, async move {
+ match r.recv().await {
+ Ok(_) => {
+ // the request finished with success, now wait for it to be dropped
+ // to ensure that all other fairings have run to completion.
+ r.recv().await.ok();
+ future.await
+ },
+ Err(_) => CanBeLogged::not_run(),
+ }
+ });
+ }
+}