From 2f8dce44d3f2be74b5c6ec0a2e7f4ceced715328 Mon Sep 17 00:00:00 2001 From: pennae Date: Wed, 13 Jul 2022 10:33:30 +0200 Subject: initial import --- src/api/auth/account.rs | 413 +++++++++++++++++++ src/api/auth/device.rs | 455 ++++++++++++++++++++ src/api/auth/email.rs | 126 ++++++ src/api/auth/invite.rs | 47 +++ src/api/auth/mod.rs | 238 +++++++++++ src/api/auth/oauth.rs | 433 +++++++++++++++++++ src/api/auth/password.rs | 260 ++++++++++++ src/api/auth/session.rs | 107 +++++ src/api/mod.rs | 32 ++ src/api/oauth.rs | 163 ++++++++ src/api/profile/mod.rs | 324 +++++++++++++++ src/auth.rs | 241 +++++++++++ src/bin/minorskulk.rs | 9 + src/cache.rs | 42 ++ src/crypto.rs | 408 ++++++++++++++++++ src/db/mod.rs | 1026 ++++++++++++++++++++++++++++++++++++++++++++++ src/js.rs | 53 +++ src/lib.rs | 319 ++++++++++++++ src/mailer.rs | 105 +++++ src/push.rs | 198 +++++++++ src/types.rs | 436 ++++++++++++++++++++ src/types/oauth.rs | 267 ++++++++++++ src/utils.rs | 124 ++++++ 23 files changed, 5826 insertions(+) create mode 100644 src/api/auth/account.rs create mode 100644 src/api/auth/device.rs create mode 100644 src/api/auth/email.rs create mode 100644 src/api/auth/invite.rs create mode 100644 src/api/auth/mod.rs create mode 100644 src/api/auth/oauth.rs create mode 100644 src/api/auth/password.rs create mode 100644 src/api/auth/session.rs create mode 100644 src/api/mod.rs create mode 100644 src/api/oauth.rs create mode 100644 src/api/profile/mod.rs create mode 100644 src/auth.rs create mode 100644 src/bin/minorskulk.rs create mode 100644 src/cache.rs create mode 100644 src/crypto.rs create mode 100644 src/db/mod.rs create mode 100644 src/js.rs create mode 100644 src/lib.rs create mode 100644 src/mailer.rs create mode 100644 src/push.rs create mode 100644 src/types.rs create mode 100644 src/types/oauth.rs create mode 100644 src/utils.rs (limited to 'src') diff --git a/src/api/auth/account.rs b/src/api/auth/account.rs new file mode 100644 index 0000000..51dd98e --- /dev/null +++ b/src/api/auth/account.rs @@ -0,0 +1,413 @@ +use std::sync::Arc; + +use anyhow::Result; +use chrono::{DateTime, Utc}; +use password_hash::SaltString; +use rand::{thread_rng, Rng}; +use rocket::request::FromRequest; +use rocket::State; +use rocket::{serde::json::Json, Request}; +use serde::{Deserialize, Serialize}; +use validator::Validate; + +use crate::api::{Empty, EMPTY}; +use crate::db::{Db, DbConn}; +use crate::mailer::Mailer; +use crate::push::PushClient; +use crate::types::AccountResetID; +use crate::utils::DeferAction; +use crate::Config; +use crate::{ + api::{auth, serialize_dt}, + auth::{AuthSource, Authenticated}, + crypto::{AuthPW, KeyBundle, KeyFetchReq, SecretBytes, SessionCredentials}, + types::{HawkKey, KeyFetchID, OauthToken, SecretKey, SessionID, User, UserID, VerifyHash}, +}; + +// TODO better error handling + +// MISSING get /account/profile +// MISSING get /account/status +// MISSING post /account/status +// MISSING post /account/reset + +#[derive(Deserialize, Debug, Validate)] +#[serde(deny_unknown_fields)] +#[allow(non_snake_case)] +pub(crate) struct Create { + #[validate(email, length(min = 3, max = 256))] + email: String, + authPW: AuthPW, + // MISSING service + // MISSING redirectTo + // MISSING resume + // MISSING metricsContext + // NOTE we misuse style to communicate an invite token! + style: Option, + // MISSING verificationMethod +} + +#[derive(Serialize, Debug)] +#[allow(non_snake_case)] +#[serde(deny_unknown_fields)] +pub(crate) struct CreateResp { + uid: UserID, + sessionToken: SecretBytes<32>, + #[serde(skip_serializing_if = "Option::is_none")] + keyFetchToken: Option>, + #[serde(serialize_with = "serialize_dt")] + authAt: DateTime, + // MISSING verificationMethod +} + +// MISSING arg: service +#[post("/account/create?", data = "")] +pub(crate) async fn create( + db: &DbConn, + cfg: &State, + mailer: &State>, + keys: Option, + data: Json, +) -> auth::Result { + let keys = keys.unwrap_or(false); + let data = data.into_inner(); + data.validate().map_err(|_| auth::Error::InvalidParameter)?; + + if db.user_email_exists(&data.email).await? { + return Err(auth::Error::AccountExists); + } + + match (cfg.invite_only, data.style) { + (false, Some(_)) => return Err(auth::Error::InvalidParameter), + (false, None) => (), + (true, None) => return Err(auth::Error::InviteOnly), + (true, Some(code)) => { + db.use_invite_code(&code).await.map_err(|e| match e { + sqlx::Error::RowNotFound => auth::Error::InviteNotFound, + e => auth::Error::Other(anyhow!(e)), + })?; + }, + } + + let ka = SecretBytes::generate(); + let wrapwrap_kb = SecretBytes::generate(); + let auth_salt = SaltString::generate(rand::rngs::OsRng); + let stretched = data.authPW.stretch(auth_salt.as_salt())?; + let verify_hash = stretched.verify_hash(); + let session_token = SecretBytes::generate(); + let session = SessionCredentials::derive(&session_token); + let key_fetch_token = if keys { + let key_fetch_token = SecretBytes::generate(); + let req = KeyFetchReq::from_token(&key_fetch_token); + let wrapped = req.derive_resp().wrap_keys(&KeyBundle { + ka: ka.clone(), + wrap_kb: stretched.decrypt_wwkb(&wrapwrap_kb), + }); + db.add_key_fetch(KeyFetchID(req.token_id.0), &HawkKey(req.req_hmac_key), &wrapped).await?; + Some(key_fetch_token) + } else { + None + }; + let uid = db + .add_user(User { + auth_salt, + email: data.email.to_owned(), + ka: SecretKey(ka), + wrapwrap_kb: SecretKey(wrapwrap_kb), + verify_hash: VerifyHash(verify_hash), + display_name: None, + verified: false, + }) + .await?; + let session_id = SessionID(session.token_id.0); + let auth_at = db + .add_session(session_id.clone(), &uid, HawkKey(session.req_hmac_key), false, None) + .await?; + let verify_code = hex::encode(&SecretBytes::<16>::generate().0); + db.add_verify_code(&uid, &session_id, &verify_code).await?; + // NOTE we send the email in this context rather than a spawn to signal + // send errors to the client. + mailer.send_account_verify(&uid, &data.email, &verify_code).await.map_err(|e| { + error!("failed to send email: {e}"); + auth::Error::EmailFailed + })?; + Ok(Json(CreateResp { + uid, + sessionToken: session_token, + keyFetchToken: key_fetch_token, + authAt: auth_at, + })) +} + +#[derive(Deserialize, Debug, Validate)] +#[serde(deny_unknown_fields)] +#[allow(non_snake_case)] +pub(crate) struct Login { + #[validate(email, length(min = 3, max = 256))] + email: String, + authPW: AuthPW, + // MISSING service + // MISSING redirectTo + // MISSING resume + // MISSING reason + // MISSING unblockCode + // MISSING originalLoginEmail + // MISSING verificationMethod +} + +#[derive(Serialize, Debug)] +#[allow(non_snake_case)] +#[serde(deny_unknown_fields)] +pub(crate) struct LoginResp { + uid: UserID, + sessionToken: SecretBytes<32>, + #[serde(skip_serializing_if = "Option::is_none")] + keyFetchToken: Option>, + // MISSING verificationMethod + // MISSING verificationReason + // NOTE this is the *account* verified status, not the session status. + // the spec doesn't say. + verified: bool, + #[serde(serialize_with = "serialize_dt")] + authAt: DateTime, + // MISSING metricsEnabled +} + +// MISSING arg: service +// MISSING arg: verificationMethod +#[post("/account/login?", data = "")] +pub(crate) async fn login( + db: &DbConn, + mailer: &State>, + keys: Option, + data: Json, +) -> auth::Result { + let keys = keys.unwrap_or(false); + let data = data.into_inner(); + data.validate().map_err(|_| auth::Error::InvalidParameter)?; + + let (uid, user) = db.get_user(&data.email).await.map_err(|_| auth::Error::UnknownAccount)?; + if user.email != data.email { + return Err(auth::Error::IncorrectEmailCase); + } + if !user.verified { + return Err(auth::Error::UnverifiedAccount); + } + + let stretched = data.authPW.stretch(user.auth_salt.as_salt())?; + if stretched.verify_hash() != user.verify_hash.0 { + return Err(auth::Error::IncorrectPassword); + } + + let session_token = SecretBytes::generate(); + let session = SessionCredentials::derive(&session_token); + let key_fetch_token = if keys { + let key_fetch_token = SecretBytes::generate(); + let req = KeyFetchReq::from_token(&key_fetch_token); + let wrapped = req.derive_resp().wrap_keys(&KeyBundle { + ka: user.ka.0.clone(), + wrap_kb: stretched.decrypt_wwkb(&user.wrapwrap_kb.0), + }); + db.add_key_fetch(KeyFetchID(req.token_id.0), &HawkKey(req.req_hmac_key), &wrapped).await?; + Some(key_fetch_token) + } else { + None + }; + + let session_id = SessionID(session.token_id.0); + let verify_code = format!("{:06}", thread_rng().gen_range(0..=999999)); + let auth_at = db + .add_session( + session_id.clone(), + &uid, + HawkKey(session.req_hmac_key), + false, + Some(&verify_code), + ) + .await?; + // NOTE we send the email in this context rather than a spawn to signal + // send errors to the client. + mailer.send_session_verify(&data.email, &verify_code).await.map_err(|e| { + error!("failed to send email: {e}"); + auth::Error::EmailFailed + })?; + Ok(Json(LoginResp { + uid, + sessionToken: session_token, + keyFetchToken: key_fetch_token, + verified: true, + authAt: auth_at, + })) +} + +#[derive(Deserialize, Debug, Validate)] +#[serde(deny_unknown_fields)] +#[allow(non_snake_case)] +pub(crate) struct Destroy { + #[validate(email, length(min = 3, max = 256))] + email: String, + authPW: AuthPW, +} + +// TODO may also be authenticated with a verified session +#[post("/account/destroy", data = "")] +pub(crate) async fn destroy( + db: &DbConn, + db_pool: &Db, + defer: &DeferAction, + pc: &State>, + data: Json, +) -> auth::Result { + let data = data.into_inner(); + data.validate().map_err(|_| auth::Error::InvalidParameter)?; + + let (uid, user) = db.get_user(&data.email).await.map_err(|_| auth::Error::UnknownAccount)?; + if user.email != data.email { + return Err(auth::Error::IncorrectEmailCase); + } + + let stretched = data.authPW.stretch(user.auth_salt.as_salt())?; + if stretched.verify_hash() != user.verify_hash.0 { + return Err(auth::Error::IncorrectPassword); + } + + let devs = db.get_devices(&uid).await; + db.delete_user(&data.email).await?; + match devs { + Ok(devs) => defer.spawn_after_success("api::account/destroy(post)", { + let (pc, db) = (Arc::clone(pc), db_pool.clone()); + async move { + let db = db.begin().await?; + pc.account_destroyed(&devs, &uid).await; + db.commit().await?; + Ok(()) + } + }), + Err(e) => warn!("account_destroyed push failed: {e}"), + } + + Ok(EMPTY) +} + +#[derive(Deserialize, Serialize, Debug)] +#[serde(deny_unknown_fields)] +pub(crate) struct KeysResp { + bundle: String, +} + +// NOTE the key fetch endpoint must delete a key fetch token from the database +// once it has identified it, regardless of whether the request succeeds or +// fails. we'll do this with a single-use auth source that sets the db to always +// commit. the auth source must not be used for anything else. we can get away +// with using a request guard because we'll always commit even if the guard +// fails, but this is only allowable because this is the only handler for the path. + +#[derive(Debug)] +pub(crate) struct WithKeyFetch; + +#[async_trait] +impl AuthSource for WithKeyFetch { + type ID = KeyFetchID; + type Context = Vec; + async fn hawk(r: &Request<'_>, id: &KeyFetchID) -> Result<(SecretBytes<32>, Self::Context)> { + let db = Authenticated::<(), Self>::get_conn(r).await?; + db.always_commit().await?; + Ok(db.finish_key_fetch(id).await.map(|(h, ks)| (h.0, ks))?) + } + async fn bearer_token(_: &Request<'_>, _: &OauthToken) -> Result<(KeyFetchID, Self::Context)> { + // key fetch tokens are only valid in hawk requests + bail!("invalid key fetch authentication") + } +} + +#[get("/account/keys")] +pub(crate) async fn keys(auth: Authenticated<(), WithKeyFetch>) -> Json { + // NOTE contrary to its own api spec fxa does not delete a key fetch if the + // associated session is not verified. we don't duplicate this special case + // because we control the clients, and requesting keys on an unverified session + // can be interpreted as a protocol violation anyway. + Json(KeysResp { bundle: hex::encode(&auth.context) }) +} + +#[derive(Debug)] +pub(crate) struct WithResetToken; + +#[async_trait] +impl AuthSource for WithResetToken { + type ID = AccountResetID; + type Context = UserID; + async fn hawk( + r: &Request<'_>, + id: &AccountResetID, + ) -> Result<(SecretBytes<32>, Self::Context)> { + // unlike key fetch we'll use a separate transaction here since the body of the + // handler can fail. + let pool = <&Db as FromRequest>::from_request(r) + .await + .success_or_else(|| anyhow!("could not open db connection"))?; + let db = pool.begin().await?; + let result = db.finish_account_reset(id).await.map(|(h, ctx)| (h.0, ctx))?; + db.commit().await?; + Ok(result) + } + async fn bearer_token( + _: &Request<'_>, + _: &OauthToken, + ) -> Result<(AccountResetID, Self::Context)> { + bail!("invalid password change authentication") + } +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +#[allow(non_snake_case)] +pub(crate) struct AccountResetReq { + authPW: AuthPW, + // MISSING wrapKb + // MISSING recoveryKeyId + // MISSING sessionToken +} + +// NOTE resetting an account does not clear active sync data on the storage server, +// so an account may be reported as disconnected for a while. this is not an error, +// just an inconvenience we haven't found out how to fix yet. + +// MISSING arg: keys +#[post("/account/reset", data = "")] +pub(crate) async fn reset( + db: &DbConn, + mailer: &State>, + client: &State>, + defer: &DeferAction, + data: Authenticated, +) -> auth::Result { + let user = db.get_user_by_id(&data.context).await?; + + let notify_devs = db.get_devices(&data.context).await?; + + let wrapwrap_kb = SecretBytes::generate(); + let auth_salt = SaltString::generate(rand::rngs::OsRng); + let stretched = data.body.authPW.stretch(auth_salt.as_salt())?; + let verify_hash = stretched.verify_hash(); + + db.reset_user_auth(&data.context, auth_salt, SecretKey(wrapwrap_kb), VerifyHash(verify_hash)) + .await?; + + defer.spawn_after_success("api::auth/account/reset(post)", { + let client = Arc::clone(client); + async move { + client.password_reset(¬ify_devs).await; + Ok(()) + } + }); + + mailer + .send_account_reset(&user.email) + .await + .map_err(|e| { + warn!("account reset email send failed: {e}"); + }) + .ok(); + + Ok(EMPTY) +} diff --git a/src/api/auth/device.rs b/src/api/auth/device.rs new file mode 100644 index 0000000..2b05e12 --- /dev/null +++ b/src/api/auth/device.rs @@ -0,0 +1,455 @@ +use std::time::Duration; +use std::{collections::HashMap, sync::Arc}; + +use chrono::{DateTime, Utc}; +use futures::future::join_all; +use rocket::{serde::json::Json, State}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use crate::api::auth::{WithSession, WithVerifiedFxaLogin, WithVerifiedSession}; +use crate::api::{Empty, EMPTY}; +use crate::db::DbConn; +use crate::push::PushClient; +use crate::utils::DeferAction; +use crate::{ + api::{auth, serialize_dt_opt}, + auth::Authenticated, + db::Db, + types::{ + Device, DeviceCommand, DeviceCommands, DeviceID, DevicePush, DeviceUpdate, OauthTokenID, + SessionID, + }, +}; + +fn map_error(e: sqlx::Error) -> auth::Error { + match &e { + // not-null violations can presumably only be caused by bad parameters + sqlx::Error::Database(de) if de.code().as_deref() == Some("23502") => { + auth::Error::MissingParameter + }, + sqlx::Error::RowNotFound => auth::Error::UnknownDevice, + _ => auth::Error::Other(anyhow!(e)), + } +} + +#[derive(Debug, Serialize, Deserialize, PartialEq)] +#[allow(non_snake_case)] +#[serde(deny_unknown_fields)] +pub(crate) struct Info { + isCurrentDevice: bool, + id: DeviceID, + lastAccessTime: i64, + name: String, + r#type: String, + pushCallback: Option, + pushPublicKey: Option, + pushAuthKey: Option, + pushEndpointExpired: bool, + availableCommands: HashMap, + // NOTE location is optional per the spec, but fenix crashes if it isn't present + location: Value, + // MISSING lastAccessTimeFormatted + // MISSING approximateLastAccessTime + // MISSING approximateLastAccessTimeFormatted +} + +fn device_to_json(current: Option<&DeviceID>, dev: Device) -> Info { + let (pcb, ppk, pak) = match dev.push { + Some(p) => (Some(p.callback), Some(p.public_key), Some(p.auth_key)), + None => (None, None, None), + }; + Info { + isCurrentDevice: Some(&dev.device_id) == current, + id: dev.device_id, + lastAccessTime: dev.last_active.timestamp(), + name: dev.name, + r#type: dev.type_, + pushCallback: pcb, + pushPublicKey: ppk, + pushAuthKey: pak, + pushEndpointExpired: dev.push_expired, + availableCommands: dev.available_commands.into_map(), + location: dev.location, + } +} + +#[derive(Serialize, Deserialize, PartialEq)] +#[serde(transparent)] +pub(crate) struct ListResp(Vec); + +#[get("/account/devices")] +pub(crate) async fn devices( + db: &DbConn, + auth: Authenticated<(), WithVerifiedSession>, +) -> auth::Result { + let devs = db.get_devices(&auth.context.uid).await?; + Ok(Json(ListResp( + devs.into_iter().map(|dev| device_to_json(auth.context.device_id.as_ref(), dev)).collect(), + ))) +} + +#[derive(Debug, Deserialize)] +#[allow(non_snake_case)] +#[serde(deny_unknown_fields)] +pub(crate) struct DeviceReq { + id: Option, + name: Option, + r#type: Option, + pushCallback: Option, + pushPublicKey: Option, + pushAuthKey: Option, + availableCommands: Option>, + // present for legacy reasons, ignored + #[allow(dead_code)] + capabilities: Option>, + location: Option, +} + +#[post("/account/device", data = "")] +pub(crate) async fn device( + db: &DbConn, + db_pool: &Db, + defer: &DeferAction, + client: &State>, + // need to allow registrations to all sessions, otherwise the "now verified" + // notification can't be sent + data: Authenticated, +) -> auth::Result { + let dev = data.body; + if let (None, None, None) = (&dev.name, &dev.r#type, &dev.pushCallback) { + return Err(auth::Error::MissingParameter); + } + + let push = dev.pushCallback.map(|pcb| DevicePush { + callback: pcb, + public_key: dev.pushPublicKey.unwrap_or_default(), + auth_key: dev.pushAuthKey.unwrap_or_default(), + }); + + let (own_id, changed_id, notify) = match (dev.id, data.context.device_id) { + (None, None) => { + let new = DeviceID::random(); + (Some(new.clone()), new, true) + }, + (None, Some(own)) => (Some(own.clone()), own, false), + (Some(other), own) => (own, other, false), + }; + let result = db + .change_device( + &data.context.uid, + &changed_id, + DeviceUpdate { + name: dev.name.as_ref().map(AsRef::as_ref), + type_: dev.r#type.as_ref().map(AsRef::as_ref), + push, + available_commands: dev.availableCommands.map(DeviceCommands), + location: dev.location, + }, + ) + .await + .map_err(map_error)?; + if notify { + db.set_session_device(&data.session, Some(&changed_id)).await?; + match db.get_devices(&data.context.uid).await { + Err(e) => warn!("device_connected push failed: {e}"), + Ok(mut devs) => defer.spawn_after_success("api::auth/account/device(post)", { + devs.retain(|d| d.device_id != changed_id); + let (client, db) = (Arc::clone(client), db_pool.clone()); + let name = result.name.clone(); + async move { + let db = db.begin().await?; + client.device_connected(&db, &devs, &name).await; + db.commit().await?; + Ok(()) + } + }), + }; + } + Ok(Json(device_to_json(own_id.as_ref(), result))) +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] +pub(crate) struct Command { + target: DeviceID, + command: String, + payload: Value, + ttl: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +#[allow(non_snake_case)] +#[serde(deny_unknown_fields)] +pub(crate) struct InvokeResp { + enqueued: bool, + notified: bool, + notifyError: Option, +} + +// NOTE fenix doesn't register a push callback for some reason, so receiving tabs +// always requires opening the tab share menu or tab list first. +#[post("/account/devices/invoke_command", data = "")] +pub(crate) async fn invoke( + client: &State>, + db: &DbConn, + cmd: Authenticated, +) -> auth::Result { + let sender = cmd.context.device_id; + let dev = db.get_device(&cmd.context.uid, &cmd.body.target).await.map_err(map_error)?; + if dev.available_commands.get(&cmd.body.command).is_none() { + return Err(auth::Error::NoDeviceCommand); + } + let ttl = cmd.body.ttl.unwrap_or(30 * 86400).clamp(60, 30 * 86400); + let idx = db + .enqueue_command(&cmd.body.target, &sender, &cmd.body.command, &cmd.body.payload, ttl) + .await?; + let (notified, error) = client + .command_received(db, &dev, &cmd.body.command, idx, &sender) + .await + .map_or_else(|e| (false, Some(e.to_string())), |_| (true, None)); + Ok(Json(InvokeResp { enqueued: true, notified, notifyError: error })) +} + +#[derive(Debug, Serialize, Deserialize, PartialEq)] +#[serde(deny_unknown_fields)] +pub(crate) struct CommandData { + command: String, + payload: Value, + sender: Option, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq)] +#[serde(deny_unknown_fields)] +pub(crate) struct CommandsEntry { + index: i64, + data: CommandData, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq)] +#[serde(deny_unknown_fields)] +pub(crate) struct CommandsResp { + index: i64, + last: bool, + messages: Vec, +} + +fn map_command(c: DeviceCommand) -> CommandsEntry { + CommandsEntry { + index: c.index, + data: CommandData { command: c.command, payload: c.payload, sender: c.sender }, + } +} + +#[get("/account/device/commands?&")] +pub(crate) async fn commands( + db: &DbConn, + index: i64, + limit: Option, + auth: Authenticated<(), WithVerifiedSession>, +) -> auth::Result { + let dev = auth.context.device_id.as_ref().ok_or(auth::Error::UnknownDevice)?; + let (more, cmds) = + db.get_commands(&auth.context.uid, dev, index, limit.unwrap_or(100).clamp(0, 100)).await?; + Ok(Json(CommandsResp { + index: cmds.iter().map(|c| c.index).max().unwrap_or(0), + last: !more, + messages: cmds.into_iter().map(map_command).collect(), + })) +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +pub(crate) struct DestroyReq { + id: DeviceID, +} + +#[post("/account/device/destroy", data = "")] +pub(crate) async fn destroy( + db: &DbConn, + db_pool: &Db, + defer: &DeferAction, + client: &State>, + req: crate::auth::Authenticated, +) -> auth::Result { + db.delete_device(&req.context.uid, &req.body.id).await.map_err(map_error)?; + match db.get_devices(&req.context.uid).await { + Err(e) => warn!("device_disconnected push failed: {e}"), + Ok(devs) => defer.spawn_after_success("api::auth/account/device/destroy(post)", { + let (client, db) = (Arc::clone(client), db_pool.clone()); + async move { + let db = db.begin().await?; + client.device_disconnected(&db, &devs, &req.body.id).await; + db.commit().await?; + Ok(()) + } + }), + }; + Ok(EMPTY) +} + +#[derive(Debug, Deserialize)] +pub(crate) enum NotifyTarget { + #[serde(rename = "all")] + All, +} + +#[derive(Debug, Deserialize)] +pub(crate) enum NotifyEPAction { + #[serde(rename = "accountVerify")] + AccountVerify, +} + +#[derive(Debug, Deserialize)] +#[allow(non_snake_case)] +#[serde(untagged, deny_unknown_fields)] +pub(crate) enum NotifyReq { + // deny_unknown_fields and flatten don't work together + All { + #[allow(dead_code)] + to: NotifyTarget, + _endpointAction: Option, + excluded: Option>, + payload: Value, + TTL: Option, + }, + Some { + to: Vec, + _endpointAction: Option, + payload: Value, + TTL: Option, + }, +} + +#[post("/account/devices/notify", data = "")] +pub(crate) async fn notify( + db: &DbConn, + client: &State>, + req: Authenticated, +) -> auth::Result { + let (to, payload, ttl) = match req.body { + NotifyReq::All { excluded, payload, TTL: ttl, .. } => { + let excluded = excluded.unwrap_or_default(); + let mut devs = db.get_devices(&req.context.uid).await?; + devs.retain(|d| !excluded.contains(&d.device_id)); + (devs, payload, ttl) + }, + NotifyReq::Some { to, payload, TTL: ttl, .. } => { + let to = join_all(to.iter().map(|id| db.get_device(&req.context.uid, id))) + .await + .into_iter() + .collect::, _>>()?; + (to, payload, ttl) + }, + }; + client.push_any(db, &to, Duration::from_secs(ttl.unwrap_or(0).into()), payload).await; + Ok(EMPTY) +} + +#[derive(Debug, Serialize)] +#[allow(non_snake_case)] +pub(crate) struct AttachedClient { + clientId: Option, + deviceId: Option, + sessionTokenId: Option, + refreshTokenId: Option, + isCurrentSession: bool, + deviceType: Option, + name: Option, + #[serde(serialize_with = "serialize_dt_opt")] + createdTime: Option>, + // MISSING createdTimeFormatted + #[serde(serialize_with = "serialize_dt_opt")] + lastAccessTime: Option>, + // MISSING lastAccessTimeFormatted + // MISSING approximateLastAccessTime + // MISSING approximateLastAccessTimeFormatted + scope: Option, + // MISSING location + // MISSING userAgent + // MISSING os +} + +// MISSING filterIdleDevicesTimestamp +#[get("/account/attached_clients")] +pub(crate) async fn attached_clients( + db: &DbConn, + auth: Authenticated<(), WithVerifiedFxaLogin>, +) -> auth::Result> { + let clients = db.get_attached_clients(&auth.context.uid).await?; + Ok(Json( + clients + .into_iter() + .map(|dev| AttachedClient { + clientId: dev.client_id, + deviceId: dev.device_id, + refreshTokenId: dev.refresh_token_id, + isCurrentSession: dev.session_token_id.as_ref() == Some(&auth.session), + sessionTokenId: dev.session_token_id, + deviceType: dev.device_type, + name: dev.name, + createdTime: dev.created_time, + lastAccessTime: dev.last_access_time, + scope: dev.scope, + }) + .collect::>(), + )) +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +#[allow(non_snake_case)] +pub(crate) struct DestroyAttachedClientReq { + // NOTE should be used to verify token deletion, but since we allow only a fixed + // number of clients that makes little sense. + #[allow(dead_code)] + clientId: Option, + sessionTokenId: Option, + refreshTokenId: Option, + deviceId: Option, +} + +#[post("/account/attached_client/destroy", data = "")] +pub(crate) async fn destroy_attached_client( + db: &DbConn, + db_pool: &Db, + defer: &DeferAction, + client: &State>, + req: Authenticated, +) -> auth::Result { + // only one id may be given, otherwise deleting things properly is more work. + if (req.body.sessionTokenId.is_some() as u32) + + (req.body.refreshTokenId.is_some() as u32) + + (req.body.deviceId.is_some() as u32) + != 1 + { + return Err(auth::Error::InvalidParameter); + } + + if let Some(dev) = req.body.deviceId { + let devs = db.get_devices(&req.context.uid).await; + db.delete_device(&req.context.uid, &dev).await?; + match devs { + Err(e) => warn!("device_disconnected push failed: {e}"), + Ok(devs) => { + defer.spawn_after_success("api::auth/account/attached_client/destroy(post)", { + let (client, db) = (Arc::clone(client), db_pool.clone()); + async move { + let db = db.begin().await?; + client.device_disconnected(&db, &devs, &dev).await; + db.commit().await?; + Ok(()) + } + }) + }, + }; + } + if let Some(id) = req.body.sessionTokenId { + db.delete_session(&req.context.uid, &id).await?; + } + if let Some(id) = req.body.refreshTokenId { + db.delete_refresh_token(&id).await?; + } + + Ok(EMPTY) +} diff --git a/src/api/auth/email.rs b/src/api/auth/email.rs new file mode 100644 index 0000000..f206759 --- /dev/null +++ b/src/api/auth/email.rs @@ -0,0 +1,126 @@ +use std::sync::Arc; + +use rocket::{serde::json::Json, State}; +use serde::{Deserialize, Serialize}; + +use crate::{ + api::{ + auth::{self, WithFxaLogin}, + Empty, EMPTY, + }, + auth::Authenticated, + db::{Db, DbConn}, + mailer::Mailer, + push::PushClient, + types::UserID, + utils::DeferAction, +}; + +// MISSING get /recovery_emails +// MISSING post /recovery_email +// MISSING post /recovery_email/destroy +// MISSING post /recovery_email/resend_code +// MISSING post /recovery_email/set_primary +// MISSING post /emails/reminders/cad +// MISSING post /recovery_email/secondary/resend_code +// MISSING post /recovery_email/secondary/verify_code + +#[derive(Debug, Serialize)] +#[allow(non_snake_case)] +pub(crate) struct StatusResp { + email: String, + verified: bool, + sessionVerified: bool, + emailVerified: bool, +} + +// MISSING arg: reason +#[get("/recovery_email/status")] +pub(crate) async fn status( + db: &DbConn, + req: Authenticated<(), WithFxaLogin>, +) -> auth::Result { + let user = db.get_user_by_id(&req.context.uid).await?; + Ok(Json(StatusResp { + email: user.email, + verified: user.verified, + sessionVerified: req.context.verified, + emailVerified: user.verified, + })) +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +pub(crate) struct VerifyReq { + uid: UserID, + code: String, + // MISSING service + // MISSING reminder + // MISSING type + // MISSING style + // MISSING marketingOptIn + // MISSING newsletters +} + +#[post("/recovery_email/verify_code", data = "")] +pub(crate) async fn verify_code( + db: &DbConn, + db_pool: &Db, + defer: &DeferAction, + pc: &State>, + req: Json, +) -> auth::Result { + let code = match db.try_use_verify_code(&req.uid, &req.code).await? { + Some(code) => code, + None => return Err(auth::Error::InvalidVerificationCode), + }; + db.set_user_verified(&req.uid).await?; + if let Some(sid) = code.session_id { + db.set_session_verified(&sid).await?; + } + match db.get_devices(&req.uid).await { + Ok(devs) => defer.spawn_after_success("api::auth/recovery_email/verify_code(post)", { + let (pc, db) = (Arc::clone(pc), db_pool.clone()); + async move { + let db = db.begin().await?; + pc.account_verified(&db, &devs).await; + db.commit().await?; + Ok(()) + } + }), + Err(e) => warn!("account_verified push failed: {e}"), + } + Ok(EMPTY) +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +pub(crate) struct ResendReq { + // MISSING email + // MISSING service + // MISSING redirectTo + // MISSING resume + // MISSING style + // MISSING type +} + +// MISSING arg: service +// MISSING arg: type +#[post("/recovery_email/resend_code", data = "")] +pub(crate) async fn resend_code( + db: &DbConn, + mailer: &State>, + req: Authenticated, +) -> auth::Result { + let (email, code) = match db.get_verify_code(&req.context.uid).await { + Ok(v) => v, + Err(_) => return Err(auth::Error::InvalidVerificationCode), + }; + // NOTE we send the email in this context rather than a spawn to signal + // send errors to the client. + mailer.send_account_verify(&req.context.uid, &email, &code.code).await.map_err(|e| { + error!("failed to send email: {e}"); + auth::Error::EmailFailed + })?; + Ok(EMPTY) +} diff --git a/src/api/auth/invite.rs b/src/api/auth/invite.rs new file mode 100644 index 0000000..dd81540 --- /dev/null +++ b/src/api/auth/invite.rs @@ -0,0 +1,47 @@ +use base64::URL_SAFE_NO_PAD; +use chrono::{Duration, Utc}; +use rocket::{http::uri::Reference, serde::json::Json, State}; +use serde::{Deserialize, Serialize}; + +use crate::{api::auth, auth::Authenticated, crypto::SecretBytes, db::DbConn, Config}; + +use super::WithVerifiedFxaLogin; + +pub(crate) async fn generate_invite_link( + db: &DbConn, + cfg: &Config, + ttl: Duration, +) -> anyhow::Result> { + let code = base64::encode_config(&SecretBytes::<32>::generate().0, URL_SAFE_NO_PAD); + db.add_invite_code(&code, Utc::now() + ttl).await?; + Ok(Reference::parse_owned(format!("{}/#/register/{}", cfg.location, code)) + .map_err(|e| anyhow!("url building failed at {e}"))?) +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +pub(crate) struct GenerateReq { + ttl_hours: u32, +} + +#[derive(Debug, Serialize)] +pub(crate) struct GenerateResp { + url: Reference<'static>, +} + +#[post("/generate", data = "")] +pub(crate) async fn generate( + db: &DbConn, + cfg: &State, + req: Authenticated, +) -> auth::Result { + if !req.context.verified { + return Err(auth::Error::UnverifiedSession); + } + let user = db.get_user_by_id(&req.context.uid).await?; + if user.email != cfg.invite_admin_address { + return Err(auth::Error::InvalidAuthToken); + } + let url = generate_invite_link(&db, &cfg, Duration::hours(req.body.ttl_hours as i64)).await?; + Ok(Json(GenerateResp { url })) +} diff --git a/src/api/auth/mod.rs b/src/api/auth/mod.rs new file mode 100644 index 0000000..2c6d34d --- /dev/null +++ b/src/api/auth/mod.rs @@ -0,0 +1,238 @@ +use rocket::{ + http::Status, + response::{self, Responder}, + serde::json::Json, + Request, Response, +}; +use serde_json::json; + +use crate::{ + auth::Authenticated, + crypto::SecretBytes, + types::{OauthToken, SessionID, UserSession}, +}; + +pub(crate) mod account; +pub(crate) mod device; +pub(crate) mod email; +pub(crate) mod invite; +pub(crate) mod oauth; +pub(crate) mod password; +pub(crate) mod session; + +// we don't provide any additional fields. some we can't provide anyway (eg +// invalid parameter `validation`), others are implied by the request body (eg +// account exists `email`), and *our* client doesn't care about them anyway +#[derive(Debug)] +pub(crate) enum Error { + AccountExists, + UnknownAccount, + IncorrectPassword, + UnverifiedAccount, + InvalidVerificationCode, + InvalidBody, + InvalidParameter, + MissingParameter, + InvalidSignature, + InvalidAuthToken, + RequestTooLarge, + IncorrectEmailCase, + UnknownDevice, + UnverifiedSession, + EmailFailed, + NoDeviceCommand, + UnknownClientID, + ScopesNotAllowed, + + InviteOnly, + InviteNotFound, + + Other(anyhow::Error), + UnexpectedStatus(Status), +} + +#[rustfmt::skip] +impl<'r> Responder<'r, 'static> for Error { + fn respond_to(self, request: &'r Request<'_>) -> response::Result<'static> { + let (code, errno, msg) = match self { + Error::AccountExists => (Status::BadRequest, 101, "account already exists"), + Error::UnknownAccount => (Status::BadRequest, 102, "unknown account"), + Error::IncorrectPassword => (Status::BadRequest, 103, "incorrect password"), + Error::UnverifiedAccount => (Status::BadRequest, 104, "unverified account"), + Error::InvalidVerificationCode => (Status::BadRequest, 105, "invalid verification code"), + Error::InvalidBody => (Status::BadRequest, 106, "invalid json in request body"), + Error::InvalidParameter => (Status::BadRequest, 107, "invalid parameter in request body"), + Error::MissingParameter => (Status::BadRequest, 108, "missing parameter in request body"), + Error::InvalidSignature => (Status::Unauthorized, 109, "invalid request signature"), + Error::InvalidAuthToken => (Status::Unauthorized, 110, "invalid authentication token"), + Error::RequestTooLarge => (Status::PayloadTooLarge, 113, "request too large"), + Error::IncorrectEmailCase => (Status::BadRequest, 120, "incorrect email case"), + Error::UnknownDevice => (Status::BadRequest, 123, "unknown device"), + Error::UnverifiedSession => (Status::BadRequest, 138, "unverified session"), + Error::EmailFailed => (Status::UnprocessableEntity, 151, "failed to send email"), + Error::NoDeviceCommand => (Status::BadRequest, 157, "unavailable device command"), + Error::UnknownClientID => (Status::BadRequest, 162, "unknown client_id"), + Error::ScopesNotAllowed => (Status::BadRequest, 169, "requested scopes not allowed"), + Error::InviteOnly => (Status::BadRequest, -1, "invite code required"), + Error::InviteNotFound => (Status::BadRequest, -2, "invite code not found"), + Error::Other(e) => { + error!("non-api error during request: {:#?}", e); + (Status::InternalServerError, 999, "internal error") + }, + Error::UnexpectedStatus(s) => (s, 999, ""), + }; + let body = json!({ + "code": code.code, + "errno": errno, + "error": code.reason_lossy(), + "message": msg + }); + Response::build_from(Json(body).respond_to(request)?).status(code).ok() + } +} + +impl From for Error { + fn from(e: sqlx::Error) -> Self { + Error::Other(anyhow!(e)) + } +} + +impl From for Error { + fn from(e: anyhow::Error) -> Self { + Error::Other(e) + } +} + +pub(crate) type Result = std::result::Result, Error>; + +// hack marker type to convey that auth failed due to an unverified session. +// without this the catcher could convert the Unauthorized error we get from +// auth failures into just one thing, even though we have multiple causes. +#[derive(Clone, Copy, Debug)] +struct UsedUnverifiedSession; + +#[catch(default)] +pub(crate) fn catch_all(status: Status, req: &Request<'_>) -> Error { + match req.local_cache(|| None) { + Some(UsedUnverifiedSession) => Error::UnverifiedSession, + _ => { + match status.code { + 401 => Error::InvalidSignature, + // these three are caused by Json errors + 400 => Error::InvalidBody, + 413 => Error::RequestTooLarge, + 422 => Error::InvalidParameter, + // generic unauthorized instead of 404 for eg wrong method or nonexistant endpoints + 404 => Error::InvalidSignature, + _ => { + error!("caught unexpected error {status}"); + Error::UnexpectedStatus(status) + }, + } + }, + } +} + +#[derive(Debug)] +pub(crate) struct WithFxaLogin; + +#[async_trait] +impl crate::auth::AuthSource for WithFxaLogin { + type ID = SessionID; + type Context = UserSession; + async fn hawk( + r: &Request<'_>, + id: &SessionID, + ) -> anyhow::Result<(SecretBytes<32>, Self::Context)> { + let db = Authenticated::<(), Self>::get_conn(r).await?; + let k = db.use_session(id).await?; + Ok((k.req_hmac_key.0.clone(), k)) + } + async fn bearer_token( + _: &Request<'_>, + _: &OauthToken, + ) -> anyhow::Result<(SessionID, Self::Context)> { + bail!("refresh tokens not allowed here"); + } +} + +#[derive(Debug)] +pub(crate) struct WithVerifiedFxaLogin; + +#[async_trait] +impl crate::auth::AuthSource for WithVerifiedFxaLogin { + type ID = SessionID; + type Context = UserSession; + async fn hawk( + r: &Request<'_>, + id: &SessionID, + ) -> anyhow::Result<(SecretBytes<32>, Self::Context)> { + let res = WithFxaLogin::hawk(r, id).await?; + match res.1.verified { + true => Ok(res), + false => { + r.local_cache(|| Some(UsedUnverifiedSession)); + bail!("session not verified"); + }, + } + } + async fn bearer_token( + _: &Request<'_>, + _: &OauthToken, + ) -> anyhow::Result<(SessionID, Self::Context)> { + bail!("refresh tokens not allowed here"); + } +} + +#[derive(Debug)] +pub(crate) struct WithSession; + +#[rocket::async_trait] +impl crate::auth::AuthSource for WithSession { + type ID = SessionID; + type Context = UserSession; + async fn hawk( + r: &Request<'_>, + id: &SessionID, + ) -> anyhow::Result<(SecretBytes<32>, Self::Context)> { + WithFxaLogin::hawk(r, id).await + } + async fn bearer_token( + r: &Request<'_>, + token: &OauthToken, + ) -> anyhow::Result<(SessionID, Self::Context)> { + let db = Authenticated::<(), Self>::get_conn(r).await?; + Ok(db.use_session_from_refresh(&token.hash()).await?) + } +} + +#[derive(Debug)] +pub(crate) struct WithVerifiedSession; + +#[rocket::async_trait] +impl crate::auth::AuthSource for WithVerifiedSession { + type ID = SessionID; + type Context = UserSession; + async fn hawk( + r: &Request<'_>, + id: &SessionID, + ) -> anyhow::Result<(SecretBytes<32>, Self::Context)> { + WithVerifiedFxaLogin::hawk(r, id).await + } + async fn bearer_token( + r: &Request<'_>, + token: &OauthToken, + ) -> anyhow::Result<(SessionID, Self::Context)> { + let db = Authenticated::<(), Self>::get_conn(r).await?; + let res = db.use_session_from_refresh(&token.hash()).await?; + match res.1.verified { + true => Ok(res), + false => { + // technically unreachable because generating a refresh token requires a + // valid fxa session + r.local_cache(|| Some(UsedUnverifiedSession)); + bail!("session not verified"); + }, + } + } +} diff --git a/src/api/auth/oauth.rs b/src/api/auth/oauth.rs new file mode 100644 index 0000000..b0ed8ee --- /dev/null +++ b/src/api/auth/oauth.rs @@ -0,0 +1,433 @@ +use std::collections::HashMap; + +use chrono::{DateTime, Duration, Local, Utc}; +use rocket::serde::json::Json; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use sha2::Digest; +use subtle::ConstantTimeEq; + +use crate::api::auth::WithVerifiedFxaLogin; +use crate::db::DbConn; +use crate::types::oauth::{Scope, ScopeSet}; +use crate::{ + api::{auth, serialize_dt}, + auth::Authenticated, + crypto::{SecretBytes, SessionCredentials}, + types::{ + HawkKey, OauthAccessToken, OauthAccessType, OauthAuthorization, OauthAuthorizationID, + OauthRefreshToken, OauthToken, OauthTokenID, SessionID, UserID, + }, +}; + +// MISSING get /oauth/client/{client_id} + +pub(crate) struct OauthClient { + pub(crate) id: &'static str, + // NOTE not read so far, but good to have + #[allow(dead_code)] + pub(crate) name: &'static str, + pub(crate) scopes: &'static [Scope<'static>], +} + +const SESSION_SCOPE: Scope = Scope::borrowed("https://identity.mozilla.com/tokens/session"); + +// NOTE the telemetry scopes don't seem to be needed. since we'd have to give +// out keys for them (fxa does) we'll exclude them entirely. +// see fxa-auth-server/config/dev.json for lists of predefined clients and permissions. +pub(crate) const OAUTH_CLIENTS: [OauthClient; 2] = [ + OauthClient { + id: "5882386c6d801776", + name: "Firefox", + scopes: &[ + Scope::borrowed("profile:write"), + Scope::borrowed("https://identity.mozilla.com/apps/oldsync"), + Scope::borrowed("https://identity.mozilla.com/tokens/session"), + // "https://identity.mozilla.com/ids/ecosystem_telemetry", + ], + }, + OauthClient { + id: "a2270f727f45f648", + name: "Fenix", + scopes: &[ + Scope::borrowed("profile"), + Scope::borrowed("https://identity.mozilla.com/apps/oldsync"), + Scope::borrowed("https://identity.mozilla.com/tokens/session"), + // "https://identity.mozilla.com/ids/ecosystem_telemetry", + ], + }, +]; + +// NOTE fxa dev config allows scoped keys only for: +// - https://identity.mozilla.com/apps/notes +// - https://identity.mozilla.com/apps/oldsync +// - https://identity.mozilla.com/ids/ecosystem_telemetry +// - https://identity.mozilla.com/apps/send +// we only implement sync because notes and send are dead and +// telemetry is of no use to us +const SCOPES_WITH_KEYS: [Scope; 1] = [Scope::borrowed("https://identity.mozilla.com/apps/oldsync")]; + +fn check_client_and_scopes( + client_id: &str, + scope: &ScopeSet, +) -> Result<&'static OauthClient, auth::Error> { + let desc = match OAUTH_CLIENTS.iter().find(|&s| s.id == client_id) { + Some(d) => d, + None => return Err(auth::Error::UnknownClientID), + }; + if !scope.is_allowed_by(desc.scopes) { + return Err(auth::Error::ScopesNotAllowed); + } + Ok(desc) +} + +#[derive(Debug, Deserialize)] +pub(crate) enum PkceChallengeType { + S256, +} + +#[derive(Debug, Deserialize)] +pub(crate) enum AuthResponseType { + #[serde(rename = "code")] + Code, +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +pub(crate) struct OauthAuthReq { + client_id: String, + state: String, + keys_jwe: Option, + scope: ScopeSet, + access_type: OauthAccessType, + // NOTE we don't support confidential clients, so PKCE is mandatory + code_challenge: String, + + // MISSING redirect_uri + // MISSING acr_value + + // for validation during deserialization only + #[allow(dead_code)] + code_challenge_method: PkceChallengeType, + #[allow(dead_code)] + response_type: AuthResponseType, +} + +#[derive(Debug, Serialize)] +pub(crate) struct OauthAuthResp { + code: OauthAuthorizationID, + state: String, + // MISSING redirect +} + +#[post("/oauth/authorization", data = "")] +pub(crate) async fn authorization( + db: &DbConn, + req: Authenticated, +) -> auth::Result { + check_client_and_scopes(&req.body.client_id, &req.body.scope)?; + let id = OauthAuthorizationID::random(); + db.add_oauth_authorization( + &id, + OauthAuthorization { + user_id: req.context.uid, + client_id: req.body.client_id, + scope: req.body.scope, + access_type: req.body.access_type, + code_challenge: req.body.code_challenge, + keys_jwe: req.body.keys_jwe, + auth_at: req.context.created_at, + }, + ) + .await?; + Ok(Json(OauthAuthResp { code: id, state: req.body.state })) +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +pub(crate) struct ScopedKeysReq { + client_id: String, + scope: ScopeSet, +} + +#[derive(Debug, Serialize)] +#[allow(non_snake_case)] +pub(crate) struct ScopedKey { + identifier: String, + keyRotationSecret: &'static str, + keyRotationTimestamp: u64, +} + +#[post("/account/scoped-key-data", data = "")] +pub(crate) async fn scoped_key_data( + data: Authenticated, +) -> auth::Result> { + check_client_and_scopes(&data.body.client_id, &data.body.scope)?; + // like fxa we'll stub out key rotation handling entirely and return the same constants. + Ok(Json( + data.body + .scope + .split() + .filter(|s| SCOPES_WITH_KEYS.contains(s)) + .map(|scope| { + ( + scope.to_string(), + ScopedKey { + identifier: scope.to_string(), + keyRotationSecret: + "0000000000000000000000000000000000000000000000000000000000000000", + keyRotationTimestamp: 0, + }, + ) + }) + .collect(), + )) +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +pub(crate) struct OauthDestroy { + client_id: String, + token: OauthToken, +} + +#[post("/oauth/destroy", data = "")] +pub(crate) async fn destroy(db: &DbConn, data: Json) -> auth::Result<()> { + // MISSING api spec allows an optional basic auth header, but what for? + // TODO fxa also checks the authorization header if present, but firefox doesn't send it + let client_id = if let Ok(t) = db.get_refresh_token(&data.token.hash()).await { + t.client_id + } else if let Ok(t) = db.get_access_token(&data.token.hash()).await { + t.client_id + } else { + return Err(auth::Error::InvalidParameter); + }; + // fxa does constant-time checks for client_id, do that here too. + if client_id.as_bytes().ct_eq(data.client_id.as_bytes()).into() { + db.delete_oauth_token(&data.token.hash()).await?; + Ok(Json(())) + } else { + Err(auth::Error::InvalidParameter) + } +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "grant_type")] +enum TokenReqDetails { + // we can't use deny_unknown_fields when flatten is involved, and multiple + // flattens in the same struct cause problems if one of them is greedy (like map). + // flatten an extra map into every variant instead and check each of them. + #[serde(rename = "authorization_code")] + AuthCode { + code: OauthAuthorizationID, + code_verifier: String, + // NOTE only useful with redirect flows, which we kinda don't support at all + #[allow(dead_code)] + redirect_uri: Option, + #[serde(flatten)] + extra: HashMap, + }, + #[serde(rename = "refresh_token")] + RefreshToken { + refresh_token: OauthToken, + scope: ScopeSet, + #[serde(flatten)] + extra: HashMap, + }, + #[serde(rename = "fxa-credentials")] + FxaCreds { + scope: ScopeSet, + access_type: Option, + #[serde(flatten)] + extra: HashMap, + }, +} + +impl TokenReqDetails { + fn extra_is_empty(&self) -> bool { + match self { + TokenReqDetails::AuthCode { extra, .. } => extra.is_empty(), + TokenReqDetails::RefreshToken { extra, .. } => extra.is_empty(), + TokenReqDetails::FxaCreds { extra, .. } => extra.is_empty(), + } + } +} + +// TODO log errors in all the places + +#[derive(Debug, Deserialize)] +pub(crate) struct TokenReq { + client_id: String, + ttl: Option, + #[serde(flatten)] + details: TokenReqDetails, + // MISSING client_secret + // MISSING redirect_uri + // MISSING ttl + // MISSING ppid_seed + // MISSING resource +} + +#[derive(Debug, Serialize)] +pub(crate) enum TokenType { + #[serde(rename = "bearer")] + Bearer, +} + +#[derive(Debug, Serialize)] +pub(crate) struct TokenResp { + access_token: OauthToken, + #[serde(skip_serializing_if = "Option::is_none")] + refresh_token: Option, + // MISSING id_token + #[serde(skip_serializing_if = "Option::is_none")] + session_token: Option, + scope: ScopeSet, + token_type: TokenType, + expires_in: u32, + #[serde(serialize_with = "serialize_dt")] + auth_at: DateTime, + #[serde(skip_serializing_if = "Option::is_none")] + keys_jwe: Option, +} + +#[post("/oauth/token", data = "", rank = 1)] +pub(crate) async fn token_authenticated( + db: &DbConn, + req: Authenticated, +) -> auth::Result { + match &req.body.details { + TokenReqDetails::FxaCreds { .. } => (), + _ => return Err(auth::Error::InvalidParameter), + } + token_impl( + db, + Some(req.context.uid), + Some(req.context.created_at), + req.body, + None, + Some(req.session.clone()), + ) + .await +} + +#[post("/oauth/token", data = "", rank = 2)] +pub(crate) async fn token_unauthenticated( + db: &DbConn, + req: Json, +) -> auth::Result { + let (parent_refresh, auth_at) = match &req.details { + TokenReqDetails::RefreshToken { refresh_token, .. } => { + let session = db.use_session_from_refresh(&refresh_token.hash()).await?; + (Some(refresh_token.hash()), Some(session.1.created_at)) + }, + TokenReqDetails::AuthCode { .. } => (None, None), + _ => return Err(auth::Error::InvalidParameter), + }; + token_impl(db, None, auth_at, req.into_inner(), parent_refresh, None).await +} + +async fn token_impl( + db: &DbConn, + user_id: Option, + auth_at: Option>, + req: TokenReq, + parent_refresh: Option, + parent_session: Option, +) -> auth::Result { + if !req.details.extra_is_empty() { + return Err(auth::Error::InvalidParameter); + } + let ttl = req.ttl.unwrap_or(3600).clamp(0, 7 * 86400); + + let (auth_at, scope, keys_jwe, user_id, access_type) = match req.details { + TokenReqDetails::AuthCode { code, code_verifier, .. } => { + let auth = match db.take_oauth_authorization(&code).await { + Ok(a) => a, + Err(_) => return Err(auth::Error::InvalidAuthToken), + }; + if !bool::from(auth.client_id.as_bytes().ct_eq(req.client_id.as_bytes())) { + return Err(auth::Error::UnknownClientID); + } + let mut sha = sha2::Sha256::new(); + sha.update(code_verifier.as_bytes()); + let challenge = base64::encode_config(&sha.finalize(), base64::URL_SAFE_NO_PAD); + if !bool::from(challenge.as_bytes().ct_eq(auth.code_challenge.as_bytes())) { + return Err(auth::Error::InvalidParameter); + } + (auth.auth_at, auth.scope, auth.keys_jwe, auth.user_id, Some(auth.access_type)) + }, + TokenReqDetails::RefreshToken { refresh_token, scope, .. } => { + let auth_at = + auth_at.expect("oauth token requests with refresh token must set auth_at"); + let base = db.get_refresh_token(&refresh_token.hash()).await?; + if !bool::from(base.client_id.as_bytes().ct_eq(req.client_id.as_bytes())) { + return Err(auth::Error::UnknownClientID); + } + check_client_and_scopes(&req.client_id, &scope)?; + if !base.scope.implies_all(&scope) { + return Err(auth::Error::ScopesNotAllowed); + } + (auth_at, scope, None, base.user_id, None) + }, + TokenReqDetails::FxaCreds { scope, access_type, .. } => { + let user_id = user_id.expect("oauth token requests with fxa must set user_id"); + let auth_at = auth_at.expect("oauth token requests with fxa must set auth_at"); + check_client_and_scopes(&req.client_id, &scope)?; + (auth_at, scope, None, user_id, access_type) + }, + }; + + let access_token = OauthToken::random(); + db.add_access_token( + &access_token.hash(), + OauthAccessToken { + user_id: user_id.clone(), + client_id: req.client_id.clone(), + scope: scope.clone(), + parent_refresh, + parent_session, + expires_at: (Local::now() + Duration::seconds(ttl.into())).into(), + }, + ) + .await?; + + let (refresh_token, session_token) = if access_type == Some(OauthAccessType::Offline) { + let (session_token, session_id) = if scope.implies(&SESSION_SCOPE) { + let session_token = SecretBytes::generate(); + let session = SessionCredentials::derive(&session_token); + let session_id = SessionID(session.token_id.0); + db.add_session(session_id.clone(), &user_id, HawkKey(session.req_hmac_key), true, None) + .await?; + (Some(session_token.0), Some(SessionID(session.token_id.0))) + } else { + (None, None) + }; + + let refresh_token = OauthToken::random(); + db.add_refresh_token( + &refresh_token.hash(), + OauthRefreshToken { + user_id, + client_id: req.client_id, + scope: scope.remove(&SESSION_SCOPE), + session_id, + }, + ) + .await?; + (Some(refresh_token), session_token) + } else { + (None, None) + }; + + Ok(Json(TokenResp { + access_token, + refresh_token, + session_token: session_token.map(hex::encode), + scope: scope.remove(&SESSION_SCOPE), + token_type: TokenType::Bearer, + expires_in: ttl, + auth_at, + keys_jwe, + })) +} diff --git a/src/api/auth/password.rs b/src/api/auth/password.rs new file mode 100644 index 0000000..0eeab4f --- /dev/null +++ b/src/api/auth/password.rs @@ -0,0 +1,260 @@ +use std::sync::Arc; + +use anyhow::Result; +use password_hash::SaltString; +use rocket::{request::FromRequest, serde::json::Json, Request, State}; +use serde::{Deserialize, Serialize}; +use validator::Validate; + +use crate::{ + api::auth, + auth::{AuthSource, Authenticated}, + crypto::{AccountResetReq, AuthPW, KeyBundle, KeyFetchReq, PasswordChangeReq, SecretBytes}, + db::{Db, DbConn}, + mailer::Mailer, + types::{ + AccountResetID, HawkKey, KeyFetchID, OauthToken, PasswordChangeID, SecretKey, UserID, + VerifyHash, + }, +}; + +// MISSING get /password/forgot/status +// MISSING post /password/create +// MISSING post /password/forgot/resend_code + +#[derive(Debug, Deserialize, Validate)] +#[serde(deny_unknown_fields)] +#[allow(non_snake_case)] +pub(crate) struct ChangeStartReq { + #[validate(email, length(min = 3, max = 256))] + email: String, + oldAuthPW: AuthPW, +} + +#[derive(Debug, Serialize)] +#[allow(non_snake_case)] +pub(crate) struct ChangeStartResp { + keyFetchToken: SecretBytes<32>, + passwordChangeToken: SecretBytes<32>, +} + +#[post("/password/change/start", data = "")] +pub(crate) async fn change_start( + db: &DbConn, + data: Json, +) -> auth::Result { + let data = data.into_inner(); + data.validate().map_err(|_| auth::Error::InvalidParameter)?; + + let (uid, user) = db.get_user(&data.email).await.map_err(|_| auth::Error::UnknownAccount)?; + if user.email != data.email { + return Err(auth::Error::IncorrectEmailCase); + } + if !user.verified { + return Err(auth::Error::UnverifiedAccount); + } + + let stretched = data.oldAuthPW.stretch(user.auth_salt.as_salt())?; + if stretched.verify_hash() != user.verify_hash.0 { + return Err(auth::Error::IncorrectPassword); + } + + let change_token = SecretBytes::generate(); + let change_req = PasswordChangeReq::from_change_token(&change_token); + let key_fetch_token = SecretBytes::generate(); + let key_req = KeyFetchReq::from_token(&key_fetch_token); + let wrapped = key_req.derive_resp().wrap_keys(&KeyBundle { + ka: user.ka.0.clone(), + wrap_kb: stretched.decrypt_wwkb(&user.wrapwrap_kb.0), + }); + db.add_key_fetch(KeyFetchID(key_req.token_id.0), &HawkKey(key_req.req_hmac_key), &wrapped) + .await?; + db.add_password_change( + &uid, + &PasswordChangeID(change_req.token_id.0), + &HawkKey(change_req.req_hmac_key), + None, + ) + .await?; + + Ok(Json(ChangeStartResp { keyFetchToken: key_fetch_token, passwordChangeToken: change_token })) +} + +// NOTE we use a plain bool here and in the db instead of an enum because +// enums aren't usable in const generics in stable. +#[derive(Debug)] +pub(crate) struct WithChangeToken; + +#[async_trait] +impl AuthSource for WithChangeToken { + type ID = PasswordChangeID; + type Context = (UserID, Option); + async fn hawk( + r: &Request<'_>, + id: &PasswordChangeID, + ) -> Result<(SecretBytes<32>, Self::Context)> { + // unlike key fetch we'll use a separate transaction here since the body of the + // handler can fail. + let pool = <&Db as FromRequest>::from_request(r) + .await + .success_or_else(|| anyhow!("could not open db connection"))?; + let db = pool.begin().await?; + let result = db.finish_password_change(id, IS_FORGOT).await.map(|(h, ctx)| (h.0, ctx))?; + db.commit().await?; + Ok(result) + } + async fn bearer_token( + _: &Request<'_>, + _: &OauthToken, + ) -> Result<(PasswordChangeID, Self::Context)> { + bail!("invalid password change authentication") + } +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +#[allow(non_snake_case)] +pub(crate) struct ChangeFinishReq { + authPW: AuthPW, + wrapKb: SecretBytes<32>, + // MISSING sessionToken +} + +#[derive(Debug, Serialize)] +#[allow(non_snake_case)] +pub(crate) struct ChangeFinishResp { + // NOTE we intentionally deviate from mozilla here. mozilla creates a new + // session if sessionToken is set in the request, but we use the "legacy" + // password change mechanism that leaves the requesting session and its + // device and keys intact. as such this struct is intentionally empty. + // + // MISSING uid + // MISSING sessionToken + // MISSING verified + // MISSING authAt + // MISSING keyFetchToken +} + +#[post("/password/change/finish", data = "")] +pub(crate) async fn change_finish( + db: &DbConn, + mailer: &State>, + data: Authenticated>, +) -> auth::Result { + let user = db.get_user_by_id(&data.context.0).await?; + + let auth_salt = SaltString::generate(rand::rngs::OsRng); + let stretched = data.body.authPW.stretch(auth_salt.as_salt())?; + let verify_hash = stretched.verify_hash(); + let wrapwrap_kb = stretched.rewrap_wkb(&data.body.wrapKb); + + db.change_user_auth( + &data.context.0, + auth_salt, + SecretKey(wrapwrap_kb), + VerifyHash(verify_hash), + ) + .await?; + + // NOTE password_changed/password_reset pushes seem to have no effect, so skip them. + + mailer + .send_password_changed(&user.email) + .await + .map_err(|e| { + warn!("password change email send failed: {e}"); + }) + .ok(); + + Ok(Json(ChangeFinishResp {})) +} + +#[derive(Debug, Deserialize, Validate)] +#[serde(deny_unknown_fields)] +#[allow(non_snake_case)] +pub(crate) struct ForgotStartReq { + #[validate(email, length(min = 3, max = 256))] + email: String, +} + +#[derive(Debug, Serialize)] +#[allow(non_snake_case)] +pub(crate) struct ForgotStartResp { + passwordForgotToken: SecretBytes<32>, + ttl: u32, + codeLength: u32, + tries: u32, +} + +#[post("/password/forgot/send_code", data = "")] +pub(crate) async fn forgot_start( + db: &DbConn, + mailer: &State>, + data: Json, +) -> auth::Result { + let data = data.into_inner(); + data.validate().map_err(|_| auth::Error::InvalidParameter)?; + + let (uid, user) = db.get_user(&data.email).await.map_err(|_| auth::Error::UnknownAccount)?; + if user.email != data.email { + return Err(auth::Error::IncorrectEmailCase); + } + if !user.verified { + return Err(auth::Error::UnverifiedAccount); + } + + let forgot_code = hex::encode(SecretBytes::<16>::generate().0); + let forgot_token = SecretBytes::generate(); + let forgot_req = PasswordChangeReq::from_forgot_token(&forgot_token); + db.add_password_change( + &uid, + &PasswordChangeID(forgot_req.token_id.0), + &HawkKey(forgot_req.req_hmac_key), + Some(&forgot_code), + ) + .await?; + + mailer.send_password_forgot(&user.email, &forgot_code).await?; + + Ok(Json(ForgotStartResp { + passwordForgotToken: forgot_token, + ttl: 300, + codeLength: 16, + tries: 1, + })) +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +#[allow(non_snake_case)] +pub(crate) struct ForgotFinishReq { + code: String, + // MISSING accountResetWithRecoveryKey +} + +#[derive(Debug, Serialize)] +#[allow(non_snake_case)] +pub(crate) struct ForgotFinishResp { + accountResetToken: SecretBytes<32>, +} + +#[post("/password/forgot/verify_code", data = "")] +pub(crate) async fn forgot_finish( + db: &DbConn, + data: Authenticated>, +) -> auth::Result { + if Some(data.body.code) != data.context.1 { + return Err(auth::Error::InvalidVerificationCode); + } + + let reset_token = SecretBytes::generate(); + let reset_req = AccountResetReq::from_token(&reset_token); + db.add_account_reset( + &data.context.0, + &AccountResetID(reset_req.token_id.0), + &HawkKey(reset_req.req_hmac_key), + ) + .await?; + + Ok(Json(ForgotFinishResp { accountResetToken: reset_token })) +} diff --git a/src/api/auth/session.rs b/src/api/auth/session.rs new file mode 100644 index 0000000..5911b92 --- /dev/null +++ b/src/api/auth/session.rs @@ -0,0 +1,107 @@ +use std::sync::Arc; + +use rocket::serde::json::Json; +use rocket::State; +use serde::{Deserialize, Serialize}; + +use crate::api::auth::WithFxaLogin; +use crate::api::{auth, Empty, EMPTY}; +use crate::auth::Authenticated; +use crate::db::Db; +use crate::db::DbConn; +use crate::mailer::Mailer; +use crate::push::PushClient; +use crate::types::{SessionID, UserID}; +use crate::utils::DeferAction; + +// MISSING post /session/duplicate +// MISSING post /session/reauth +// MISSING post /session/verify/send_push + +#[derive(Debug, Serialize)] +pub(crate) struct StatusResp { + state: &'static str, // what does this *do*? + uid: UserID, +} + +#[get("/session/status")] +pub(crate) async fn status(req: Authenticated<(), WithFxaLogin>) -> auth::Result { + Ok(Json(StatusResp { state: "", uid: req.context.uid })) +} + +#[post("/session/resend_code", data = "")] +pub(crate) async fn resend_code( + db: &DbConn, + mailer: &State>, + req: Authenticated, +) -> auth::Result { + let code = match req.context.verify_code { + Some(code) => code, + _ => return Err(auth::Error::InvalidVerificationCode), + }; + + let user = db.get_user_by_id(&req.context.uid).await?; + mailer.send_session_verify(&user.email, &code).await.map_err(|e| { + error!("failed to send email: {e}"); + auth::Error::EmailFailed + })?; + Ok(EMPTY) +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +pub(crate) struct VerifyReq { + code: String, + // MISSING service + // MISSING scopes + // MISSING marketingOptIn + // MISSING newsletters +} + +#[post("/session/verify_code", data = "")] +pub(crate) async fn verify_code( + db: &DbConn, + req: Authenticated, +) -> auth::Result { + if req.context.verify_code.as_ref() != Some(&req.body.code) { + return Err(auth::Error::InvalidVerificationCode); + } + db.set_session_verified(&req.session).await?; + Ok(EMPTY) +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +pub(crate) struct DestroyReq { + custom_session_id: Option, +} + +#[post("/session/destroy", data = "")] +pub(crate) async fn destroy( + db: &DbConn, + db_pool: &Db, + defer: &DeferAction, + client: &State>, + data: Authenticated, +) -> auth::Result { + if data.body.custom_session_id.is_some() && !data.context.verified { + return Err(auth::Error::UnverifiedSession); + } + let id = data.body.custom_session_id.as_ref().unwrap_or(&data.session); + db.delete_session(&data.context.uid, id).await.map_err(|_| auth::Error::UnknownDevice)?; + if let Some(id) = data.context.device_id { + match db.get_devices(&data.context.uid).await { + Err(e) => warn!("device_disconnected push failed: {e}"), + Ok(devs) => defer.spawn_after_success("api::auth/session/destroy(post)", { + let (client, db) = (Arc::clone(client), db_pool.clone()); + async move { + let db = db.begin().await?; + client.device_disconnected(&db, &devs, &id).await; + db.commit().await?; + Ok(()) + } + }), + }; + } + Ok(EMPTY) +} diff --git a/src/api/mod.rs b/src/api/mod.rs new file mode 100644 index 0000000..1831659 --- /dev/null +++ b/src/api/mod.rs @@ -0,0 +1,32 @@ +use chrono::{DateTime, TimeZone}; +use rocket::serde::json::Json; +use serde::{Deserialize, Serialize, Serializer}; + +pub(crate) mod auth; +pub(crate) mod oauth; +pub(crate) mod profile; + +pub fn serialize_dt(dt: &DateTime, ser: S) -> Result +where + S: Serializer, + TZ: TimeZone, +{ + ser.serialize_i64(dt.timestamp()) +} + +pub fn serialize_dt_opt(dt: &Option>, ser: S) -> Result +where + S: Serializer, + TZ: TimeZone, +{ + match dt { + Some(dt) => serialize_dt(dt, ser), + None => ser.serialize_unit(), + } +} + +#[derive(Clone, Copy, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct Empty {} + +pub const EMPTY: Json = Json(Empty {}); diff --git a/src/api/oauth.rs b/src/api/oauth.rs new file mode 100644 index 0000000..0519125 --- /dev/null +++ b/src/api/oauth.rs @@ -0,0 +1,163 @@ +use rocket::{ + http::Status, + response::{self, Responder}, + serde::json::Json, + Request, Response, +}; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +use crate::{ + api::Empty, + types::{OauthToken, UserID}, +}; +use crate::{db::DbConn, types::oauth::Scope}; + +use super::EMPTY; + +// we don't provide any additional fields. some we can't provide anyway (eg +// invalid parameter `validation`), others are implied by the request body (eg +// account exists `email`), and *our* client doesn't care about them anyway +#[derive(Debug)] +pub(crate) enum Error { + InvalidParameter, + Unauthorized, + PayloadTooLarge, + + Other(anyhow::Error), + UnexpectedStatus(Status), +} + +#[rustfmt::skip] +impl<'r> Responder<'r, 'static> for Error { + fn respond_to(self, request: &'r Request<'_>) -> response::Result<'static> { + let (code, errno, msg) = match self { + Error::InvalidParameter => (Status::BadRequest, 109, "invalid request parameter"), + Error::Unauthorized => (Status::Forbidden, 111, "unauthorized"), + Error::PayloadTooLarge => (Status::PayloadTooLarge, 999, "payload too large"), + Error::Other(e) => { + error!("non-api error during request: {:?}", e); + (Status::InternalServerError, 999, "internal error") + }, + Error::UnexpectedStatus(s) => (s, 999, ""), + }; + let body = json!({ + "code": code.code, + "errno": errno, + "error": code.reason_lossy(), + "message": msg + }); + Response::build_from(Json(body).respond_to(request)?).status(code).ok() + } +} + +impl From for Error { + fn from(e: sqlx::Error) -> Self { + Error::Other(anyhow!(e)) + } +} + +impl From for Error { + fn from(e: anyhow::Error) -> Self { + Error::Other(e) + } +} + +pub(crate) type Result = std::result::Result, Error>; + +#[catch(default)] +pub(crate) fn catch_all(status: Status, _r: &Request<'_>) -> Error { + match status.code { + 401 => Error::Unauthorized, + // these three are caused by Json errors + 400 => Error::InvalidParameter, + 413 => Error::PayloadTooLarge, + 422 => Error::InvalidParameter, + // generic unauthorized instead of 404 for eg wrong method or nonexistant endpoints + 404 => Error::Unauthorized, + _ => { + error!("caught unexpected error {status}"); + Error::UnexpectedStatus(status) + }, + } +} + +fn map_error(e: sqlx::Error) -> Error { + match &e { + sqlx::Error::RowNotFound => Error::InvalidParameter, + _ => Error::Other(anyhow!(e)), + } +} + +// MISSING GET /v1/authorization +// MISSING POST /v1/authorization +// MISSING POST /v1/authorized-clients +// MISSING POST /v1/authorized-clients/destroy +// MISSING GET /v1/client/:id +// MISSING POST /v1/introspect +// MISSING GET /v1/jwks +// MISSING POST /v1/key-data +// MISSING POST /v1/token +// MISSING POST /v1/verify + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +pub(crate) struct DestroyReq { + access_token: Option, + refresh_token: Option, + // NOTE this field does not exist in the spec, but fenix sends it + token: Option, + // MISSING client_id + // MISSING client_secret + // MISSING refresh_token_id +} + +#[post("/destroy", data = "")] +pub(crate) async fn destroy( + db: &DbConn, + req: Json, +) -> std::result::Result, Error> { + // MISSING spec says basic auth is allowed, but nothing seems to use it + if let Some(t) = req.0.access_token { + db.delete_oauth_token(&t.hash()).await?; + } + if let Some(t) = req.0.refresh_token { + db.delete_oauth_token(&t.hash()).await?; + } + if let Some(t) = req.0.token { + db.delete_oauth_token(&t.hash()).await?; + } + Ok(EMPTY) +} + +#[get("/jwks")] +pub(crate) async fn jwks() -> Json { + // HACK we need to return *something* for /jwks, otherwise PyFxA fails. + // since syncstorage-rs uses PyFxA to check oauth tokens this is bad. + EMPTY +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +pub(crate) struct VerifyReq { + token: OauthToken, +} + +#[derive(Debug, Serialize)] +pub(crate) struct VerifyResp { + user: UserID, + client_id: String, + scope: Vec>, + // MISSING generation + // MISSING profile_changed_at +} + +#[post("/verify", data = "")] +pub(crate) async fn verify(db: &DbConn, req: Json) -> Result { + let token = db.get_access_token(&req.token.hash()).await.map_err(map_error)?; + Ok(Json(VerifyResp { + user: token.user_id, + client_id: token.client_id, + scope: token.scope.split().map(|s| s.into_owned()).collect::>(), + })) +} diff --git a/src/api/profile/mod.rs b/src/api/profile/mod.rs new file mode 100644 index 0000000..28d1e03 --- /dev/null +++ b/src/api/profile/mod.rs @@ -0,0 +1,324 @@ +use std::sync::Arc; + +use either::Either; +use rocket::{ + data::ToByteUnit, + http::{uri::Absolute, ContentType, Status}, + response::{self, Responder}, + serde::json::Json, + Request, Response, State, +}; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use sha2::{Digest, Sha256}; +use Either::{Left, Right}; + +use crate::{ + api::Empty, + auth::{Authenticated, WithBearer, AuthenticatedRequest}, + cache::Immutable, + db::Db, + types::{oauth::Scope, UserID}, + utils::DeferAction, +}; +use crate::{db::DbConn, types::AvatarID, Config}; +use crate::{push::PushClient, types::Avatar}; + +use super::EMPTY; + +// we don't provide any additional fields. some we can't provide anyway (eg +// invalid parameter `validation`), others are implied by the request body (eg +// account exists `email`), and *our* client doesn't care about them anyway +#[derive(Debug)] +pub(crate) enum Error { + Unauthorized, + InvalidParameter, + PayloadTooLarge, + NotFound, + + // this is actually a response from the auth api (not the profile api), + // but firefox needs the *exact response* of this auth error to refresh + // profile fetch oauth tokens for its ui. :( + InvalidAuthToken, + + Other(anyhow::Error), + UnexpectedStatus(Status), +} + +#[rustfmt::skip] +impl<'r> Responder<'r, 'static> for Error { + fn respond_to(self, request: &'r Request<'_>) -> response::Result<'static> { + let (code, errno, msg) = match self { + Error::Unauthorized => (Status::Forbidden, 100, "unauthorized"), + Error::InvalidParameter => (Status::BadRequest, 101, "invalid parameter in request body"), + Error::PayloadTooLarge => (Status::PayloadTooLarge, 999, "payload too large"), + Error::NotFound => (Status::NotFound, 999, "not found"), + + Error::InvalidAuthToken => (Status::Unauthorized, 110, "invalid authentication token"), + + Error::Other(e) => { + error!("non-api error during request: {:?}", e); + (Status::InternalServerError, 999, "internal error") + }, + Error::UnexpectedStatus(s) => (s, 999, ""), + }; + let body = json!({ + "code": code.code, + "errno": errno, + "error": code.reason_lossy(), + "message": msg + }); + Response::build_from(Json(body).respond_to(request)?).status(code).ok() + } +} + +impl From for Error { + fn from(e: sqlx::Error) -> Self { + Error::Other(anyhow!(e)) + } +} + +impl From for Error { + fn from(e: anyhow::Error) -> Self { + Error::Other(e) + } +} + +pub(crate) type Result = std::result::Result, Error>; + +#[catch(default)] +pub(crate) fn catch_all(status: Status, r: &Request<'_>) -> Error { + match status.code { + // these three are caused by Json errors + 400 | 422 => Error::InvalidParameter, + 413 => Error::PayloadTooLarge, + // translate forbidden-because-token to the auth api error for firefox + 401 if r.invalid_token_used() => Error::InvalidAuthToken, + // generic unauthorized instead of 404 for eg wrong method or nonexistant endpoints + 401 | 404 => Error::Unauthorized, + _ => { + error!("caught unexpected error {status}"); + Error::UnexpectedStatus(status) + }, + } +} + +// MISSING GET /v1/email +// MISSING GET /v1/subscriptions +// MISSING GET /v1/uid +// MISSING GET /v1/display_name +// MISSING DELETE /v1/cache/:uid +// MISSING send profile:change webchannel event an avatar/name changes + +#[derive(Debug, Serialize)] +#[allow(non_snake_case)] +pub(crate) struct ProfileResp { + uid: Option, + email: Option, + locale: Option, + amrValues: Option>, + twoFactorAuthentication: bool, + displayName: Option, + // NOTE spec does not exist, fxa-profile-server schema says this field is optional, + // but fenix exceptions if it's null. + // NOTE it also *must* be a valid url, or fenix crashes entirely. + avatar: Absolute<'static>, + avatarDefault: bool, + subscriptions: Option>, +} + +#[get("/profile")] +pub(crate) async fn profile( + db: &DbConn, + cfg: &State, + auth: Authenticated<(), WithBearer>, +) -> Result { + let has_scope = |s| auth.context.implies(&Scope::borrowed(s)); + + let user = db.get_user_by_id(&auth.session).await?; + let (avatar, avatar_default) = if has_scope("profile:avatar") { + match db.get_user_avatar_id(&auth.session).await? { + Some(id) => (uri!(cfg.avatars_prefix(), avatar_get_img(id = id.to_string())), false), + None => ( + uri!(cfg.avatars_prefix(), avatar_get_img("00000000000000000000000000000000")), + true, + ), + } + } else { + (uri!(cfg.avatars_prefix(), avatar_get_img("00000000000000000000000000000000")), true) + }; + Ok(Json(ProfileResp { + uid: if has_scope("profile:uid") { Some(auth.session) } else { None }, + email: if has_scope("profile:email") { Some(user.email) } else { None }, + locale: None, + amrValues: None, + twoFactorAuthentication: false, + displayName: if has_scope("profile:display_name") { user.display_name } else { None }, + avatar, + avatarDefault: avatar_default, + subscriptions: None, + })) +} + +#[derive(Debug, Deserialize)] +#[allow(non_snake_case)] +pub(crate) struct DisplayNameReq { + displayName: String, +} + +#[post("/display_name", data = "")] +pub(crate) async fn display_name_post( + db: &DbConn, + db_pool: &Db, + pc: &State>, + defer: &DeferAction, + req: Authenticated, +) -> Result { + if !req.context.implies(&Scope::borrowed("profile:display_name:write")) { + return Err(Error::Unauthorized); + } + + db.set_user_name(&req.session, &req.body.displayName).await?; + match db.get_devices(&req.session).await { + Ok(devs) => defer.spawn_after_success("api::profile/display_name(post)", { + let (pc, db) = (Arc::clone(pc), db_pool.clone()); + async move { + let db = db.begin().await?; + pc.profile_updated(&db, &devs).await; + db.commit().await?; + Ok(()) + } + }), + Err(e) => warn!("profile_updated push failed: {e}"), + } + Ok(EMPTY) +} + +#[derive(Serialize)] +#[allow(non_snake_case)] +pub(crate) struct AvatarResp { + id: AvatarID, + avatarDefault: bool, + avatar: Absolute<'static>, +} + +#[get("/avatar")] +pub(crate) async fn avatar_get( + db: &DbConn, + cfg: &State, + req: Authenticated<(), WithBearer>, +) -> Result { + if !req.context.implies(&Scope::borrowed("profile:avatar")) { + return Err(Error::Unauthorized); + } + + let resp = match db.get_user_avatar_id(&req.session).await? { + Some(id) => { + let url = uri!(cfg.avatars_prefix(), avatar_get_img(id = id.to_string())); + AvatarResp { id, avatarDefault: false, avatar: url } + }, + None => { + let url = + uri!(cfg.avatars_prefix(), avatar_get_img("00000000000000000000000000000000")); + AvatarResp { id: AvatarID([0; 16]), avatarDefault: true, avatar: url } + }, + }; + Ok(Json(resp)) +} + +#[get("/")] +pub(crate) async fn avatar_get_img( + db: &DbConn, + id: &str, +) -> std::result::Result<(ContentType, Immutable, &'static [u8]>>), Error> { + let id = id.parse().map_err(|_| Error::NotFound)?; + + if id == AvatarID([0; 16]) { + return Ok(( + ContentType::SVG, + Immutable(Right(include_bytes!("../../../Raven-Silhouette.svg"))), + )); + } + + match db.get_user_avatar(&id).await? { + Some(avatar) => { + let ct = avatar.content_type.parse().expect("invalid content type in db"); + Ok((ct, Immutable(Left(avatar.data)))) + }, + None => Err(Error::NotFound), + } +} + +#[derive(Serialize)] +#[allow(non_snake_case)] +pub(crate) struct AvatarUploadResp { + url: Absolute<'static>, +} + +#[post("/avatar/upload", data = "")] +pub(crate) async fn avatar_upload( + db: &DbConn, + db_pool: &Db, + pc: &State>, + defer: &DeferAction, + cfg: &State, + ct: &ContentType, + req: Authenticated<(), WithBearer>, + data: Vec, +) -> Result { + if !req.context.implies(&Scope::borrowed("profile:avatar:write")) { + return Err(Error::Unauthorized); + } + if data.len() >= 128.kibibytes() { + return Err(Error::PayloadTooLarge); + } + + if !ct.is_png() + && !ct.is_gif() + && !ct.is_bmp() + && !ct.is_jpeg() + && !ct.is_webp() + && !ct.is_avif() + && !ct.is_svg() + { + return Err(Error::InvalidParameter); + } + + let mut sha = Sha256::new(); + sha.update(&req.session.0); + sha.update(&data); + let id = AvatarID(sha.finalize()[0..16].try_into().unwrap()); + + db.set_user_avatar(&req.session, Avatar { id: id.clone(), data, content_type: ct.to_string() }) + .await?; + match db.get_devices(&req.session).await { + Ok(devs) => defer.spawn_after_success("api::profile/avatar/upload(post)", { + let (pc, db) = (Arc::clone(pc), db_pool.clone()); + async move { + let db = db.begin().await?; + pc.profile_updated(&db, &devs).await; + db.commit().await?; + Ok(()) + } + }), + Err(e) => warn!("profile_updated push failed: {e}"), + } + + let url = uri!(cfg.avatars_prefix(), avatar_get_img(id = id.to_string())); + Ok(Json(AvatarUploadResp { url })) +} + +#[delete("/avatar/")] +pub(crate) async fn avatar_delete( + db: &DbConn, + id: &str, + req: Authenticated<(), WithBearer>, +) -> Result { + if !req.context.implies(&Scope::borrowed("profile:avatar:write")) { + return Err(Error::Unauthorized); + } + let id = id.parse().map_err(|_| Error::NotFound)?; + + db.delete_user_avatar(&req.session, &id).await?; + Ok(EMPTY) +} diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 0000000..f56c5e2 --- /dev/null +++ b/src/auth.rs @@ -0,0 +1,241 @@ +use std::str::FromStr; +use std::time::Duration; + +use anyhow::Result; +use hawk::{DigestAlgorithm, Header, Key, PayloadHasher, RequestBuilder}; +use rocket::data::{self, FromData, ToByteUnit}; +use rocket::http::Status; +use rocket::outcome::{try_outcome, Outcome}; +use rocket::request::{local_cache, FromRequest, Request}; +use rocket::{request, Data, Ignite, Phase, Rocket, Sentinel}; +use serde::Deserialize; +use serde_json::error::Category; + +use crate::crypto::SecretBytes; +use crate::db::DbConn; +use crate::types::oauth::ScopeSet; +use crate::types::{OauthToken, UserID}; +use crate::Config; + +#[rocket::async_trait] +pub(crate) trait AuthSource { + type ID: FromStr + Send + Sync + Clone; + type Context: Send + Sync; + async fn hawk(r: &Request<'_>, id: &Self::ID) -> Result<(SecretBytes<32>, Self::Context)>; + async fn bearer_token(r: &Request<'_>, id: &OauthToken) -> Result<(Self::ID, Self::Context)>; +} + +// marker trait and type to communicate that authentication has failed with invalid +// tokens used. this is needed to properly translate these error for the profile api. +pub(crate) trait AuthenticatedRequest { + fn invalid_token_used(&self) -> bool; +} + +struct InvalidTokenUsed; + +impl<'r> AuthenticatedRequest for Request<'r> { + fn invalid_token_used(&self) -> bool { + self.local_cache(|| None as Option).is_some() + } +} + +#[derive(Debug)] +pub(crate) struct Authenticated { + pub body: T, + pub session: Src::ID, + pub context: Src::Context, +} + +enum AuthKind<'a> { + Hawk { header: Header }, + Token { token: &'a str }, +} + +fn drop_auth_prefix<'a>(s: &'a str, prefix: &str) -> Option<&'a str> { + if prefix.len() <= s.len() && s[..prefix.len()].eq_ignore_ascii_case(prefix) { + Some(&s[prefix.len()..]) + } else { + None + } +} + +impl Sentinel for Authenticated { + fn abort(rocket: &Rocket) -> bool { + // NOTE data sentinels are broken in rocket 0.5-rc2 + Self::try_get_state(rocket).is_none() || <&DbConn as Sentinel>::abort(rocket) + } +} + +impl Authenticated { + fn try_get_state(r: &Rocket) -> Option<&Config> { + r.state::() + } + + fn state(r: &Rocket) -> &Config { + Self::try_get_state(r).unwrap() + } + + async fn parse_auth<'a>( + request: &'a Request<'_>, + ) -> Outcome, (Status, anyhow::Error), ()> { + let auth = match request.headers().get("authorization").take(2).enumerate().last() { + Some((0, h)) => h, + Some((_, _)) => { + return Outcome::Failure(( + Status::BadRequest, + anyhow!("multiple authorization headers present"), + )) + }, + None => return Outcome::Forward(()), + }; + if let Some(hawk) = drop_auth_prefix(auth, "hawk ") { + match Header::from_str(hawk) { + Ok(header) => Outcome::Success(AuthKind::Hawk { header }), + Err(e) => Outcome::Failure(( + Status::Unauthorized, + anyhow!(e).context("malformed hawk header"), + )), + } + } else if let Some(token) = drop_auth_prefix(auth, "bearer ") { + Outcome::Success(AuthKind::Token { token }) + } else { + Outcome::Forward(()) + } + } + + pub async fn get_conn<'r>(req: &'r Request<'_>) -> Result<&'r DbConn> { + match <&'r DbConn as FromRequest<'r>>::from_request(req).await { + Outcome::Success(db) => Ok(db), + Outcome::Failure((_, e)) => Err(e.context("get db connection")), + _ => Err(anyhow!("could not get db connection")), + } + } + + async fn verify_hawk( + request: &Request<'_>, + hawk: Header, + data: Option<&str>, + ) -> Result<(Src::ID, Src::Context)> { + let cfg = Self::state(request.rocket()); + let url = format!("{}{}", cfg.location, request.uri()); + let url = url::Url::parse(&url).unwrap(); + let hash = data + .map(|d| PayloadHasher::hash("application/json", DigestAlgorithm::Sha256, d)) + .transpose()?; + let hawk_req = RequestBuilder::from_url(request.method().as_str(), &url)?; + let hawk_req = match hash.as_ref() { + Some(h) => hawk_req.hash(Some(h.as_ref())).request(), + _ => hawk_req.request(), + }; + let id: Src::ID = + match hawk.id.clone().ok_or_else(|| anyhow!("missing hawk key id"))?.parse() { + Ok(id) => id, + Err(_) => bail!("malformed hawk key id"), + }; + let (key, context) = Src::hawk(request, &id).await?; + let key = Key::new(&key.0, DigestAlgorithm::Sha256)?; + // large skew was taken from fxa-auth-server, large clock skews seem to happen + if !hawk_req.validate_header(&hawk, &key, Duration::from_secs(20 * 365 * 86400)) { + bail!("bad hawk signature"); + } + Ok((id, context)) + } + + async fn verify_bearer_token( + request: &Request<'_>, + token: &str, + ) -> Result<(Src::ID, Src::Context)> { + let token = match token.parse() { + Ok(token) => token, + Err(_) => bail!("malformed oauth token"), + }; + Src::bearer_token(request, &token).await + } +} + +#[rocket::async_trait] +impl<'r, Src: AuthSource> FromRequest<'r> for Authenticated<(), Src> { + type Error = anyhow::Error; + + async fn from_request(request: &'r Request<'_>) -> request::Outcome { + let auth = try_outcome!(Self::parse_auth(request).await); + let result = match auth { + AuthKind::Hawk { header } => Self::verify_hawk(request, header, None).await, + AuthKind::Token { token } => Self::verify_bearer_token(request, token).await, + }; + match result { + Ok((session, context)) => { + Outcome::Success(Authenticated { body: (), session, context }) + }, + Err(e) => { + request.local_cache(|| Some(InvalidTokenUsed)); + Outcome::Failure((Status::Unauthorized, anyhow!(e))) + }, + } + } +} + +#[rocket::async_trait] +impl<'r, T: Deserialize<'r>, Src: AuthSource> FromData<'r> for Authenticated { + type Error = anyhow::Error; + + async fn from_data(request: &'r Request<'_>, data: Data<'r>) -> data::Outcome<'r, Self> { + let auth = try_outcome_data!(data, Self::parse_auth(request).await); + let limit = + request.rocket().config().limits.get("json").unwrap_or_else(|| 1u32.mebibytes()); + let raw_json = match data.open(limit).into_string().await { + Ok(r) if r.is_complete() => local_cache!(request, r.into_inner()), + Ok(_) => { + return data::Outcome::Failure(( + Status::PayloadTooLarge, + anyhow!("request too large"), + )) + }, + Err(e) => return data::Outcome::Failure((Status::InternalServerError, e.into())), + }; + let verify_result = match auth { + AuthKind::Hawk { header } => Self::verify_hawk(request, header, Some(raw_json)).await, + AuthKind::Token { token } => Self::verify_bearer_token(request, token).await, + }; + let result = match verify_result { + Ok((session, context)) => { + serde_json::from_str(raw_json).map(|body| Authenticated { body, session, context }) + }, + Err(e) => { + request.local_cache(|| Some(InvalidTokenUsed)); + return Outcome::Failure((Status::Unauthorized, anyhow!(e))); + }, + }; + match result { + Ok(r) => Outcome::Success(r), + Err(e) => { + // match Json here to keep catchers generic + let status = match e.classify() { + Category::Data => Status::UnprocessableEntity, + _ => Status::BadRequest, + }; + Outcome::Failure((status, anyhow!(e))) + }, + } + } +} + +#[derive(Debug)] +pub(crate) struct WithBearer; + +#[rocket::async_trait] +impl crate::auth::AuthSource for WithBearer { + type ID = UserID; + type Context = ScopeSet; + async fn hawk(_r: &Request<'_>, _id: &Self::ID) -> Result<(SecretBytes<32>, Self::Context)> { + bail!("hawk signatures not allowed here") + } + async fn bearer_token( + r: &Request<'_>, + token: &OauthToken, + ) -> Result<(Self::ID, Self::Context)> { + let db = Authenticated::<(), Self>::get_conn(r).await?; + let t = db.get_access_token(&token.hash()).await?; + Ok((t.user_id, t.scope)) + } +} diff --git a/src/bin/minorskulk.rs b/src/bin/minorskulk.rs new file mode 100644 index 0000000..1609edb --- /dev/null +++ b/src/bin/minorskulk.rs @@ -0,0 +1,9 @@ +use minor_skulk::build; + +#[rocket::main] +async fn main() -> anyhow::Result<()> { + dotenv::dotenv().ok(); + + let _ = build().await?.launch().await?; + Ok(()) +} diff --git a/src/cache.rs b/src/cache.rs new file mode 100644 index 0000000..680d9da --- /dev/null +++ b/src/cache.rs @@ -0,0 +1,42 @@ +use std::borrow::Cow; + +use rocket::{ + http::Header, + request::{self, FromRequest}, + response::{self, Responder}, + Request, +}; + +pub(crate) struct Etagged<'r, T>(pub T, pub Cow<'r, str>); + +impl<'r, 'o: 'r, T: Responder<'r, 'o>> Responder<'r, 'o> for Etagged<'o, T> { + fn respond_to(self, r: &'r Request<'_>) -> response::Result<'o> { + let mut resp = self.0.respond_to(r)?; + resp.set_header(Header::new("etag", self.1)); + Ok(resp) + } +} + +pub(crate) struct Immutable(pub T); + +impl<'r, 'o: 'r, T: Responder<'r, 'o>> Responder<'r, 'o> for Immutable { + fn respond_to(self, r: &'r Request<'_>) -> response::Result<'o> { + let mut resp = self.0.respond_to(r)?; + resp.set_header(Header::new("cache-control", "public, max-age=604800, immutable")); + Ok(resp) + } +} + +pub(crate) struct IfNoneMatch<'r>(pub &'r str); + +#[async_trait] +impl<'r> FromRequest<'r> for IfNoneMatch<'r> { + type Error = (); + + async fn from_request(req: &'r Request<'_>) -> request::Outcome { + match req.headers().get_one("if-none-match") { + Some(h) => request::Outcome::Success(Self(h)), + None => request::Outcome::Forward(()), + } + } +} diff --git a/src/crypto.rs b/src/crypto.rs new file mode 100644 index 0000000..cf1044e --- /dev/null +++ b/src/crypto.rs @@ -0,0 +1,408 @@ +#![deny(clippy::pedantic)] +#![deny(clippy::restriction)] +#![allow(clippy::blanket_clippy_restriction_lints)] +#![allow(clippy::implicit_return)] +#![allow(clippy::missing_docs_in_private_items)] +#![allow(clippy::shadow_reuse)] + +use std::fmt::Debug; + +use hmac::{Hmac, Mac}; +use password_hash::{Output, Salt}; +use rand::RngCore; +use scrypt::scrypt; +use serde::{Deserialize, Serialize}; +use sha2::Sha256; + +const NAMESPACE: &[u8] = b"identity.mozilla.com/picl/v1/"; + +#[derive(Clone, PartialEq, Eq, Zeroize, Serialize, Deserialize)] +#[serde(try_from = "String", into = "String")] +pub struct SecretBytes(pub [u8; N]); + +impl Drop for SecretBytes { + fn drop(&mut self) { + self.zeroize(); + } +} + +#[derive(Clone, PartialEq, Eq)] +pub struct TokenID(pub [u8; 32]); + +impl Debug for SecretBytes { + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + fmt.write_fmt(format_args!("SecretBytes {{ raw: {} }}", hex::encode(&self.0))) + } +} + +impl SecretBytes { + fn xor(&self, other: &Self) -> Self { + let mut result = self.clone(); + for (a, b) in result.0.iter_mut().zip(other.0.iter()) { + *a ^= b; + } + result + } +} + +impl From> for String { + fn from(sb: SecretBytes) -> Self { + hex::encode(&sb.0) + } +} + +impl TryFrom for SecretBytes { + type Error = hex::FromHexError; + + fn try_from(value: String) -> Result { + let mut result = Self([0; N]); + hex::decode_to_slice(value, &mut result.0)?; + Ok(result) + } +} + +impl From> for Output { + fn from(s: SecretBytes<32>) -> Output { + #[allow(clippy::unwrap_used)] + Output::new(&s.0).unwrap() + } +} + +mod from_hkdf { + use hkdf::Hkdf; + use sha2::Sha256; + + // sealing lets us guarantee that SIZE is always correct, + // which means that from_hkdf always receives correctly sized slices + // and copies never fail + mod private { + pub trait Seal {} + impl Seal for super::super::SecretBytes {} + impl Seal for super::super::TokenID {} + impl Seal for (L, R) {} + } + + pub trait FromHkdf: private::Seal { + const SIZE: usize; + fn from_hkdf(bytes: &[u8]) -> Self; + } + + impl FromHkdf for super::SecretBytes { + const SIZE: usize = N; + fn from_hkdf(bytes: &[u8]) -> Self { + #[allow(clippy::unwrap_used)] + Self(bytes.try_into().unwrap()) + } + } + + impl FromHkdf for super::TokenID { + const SIZE: usize = 32; + fn from_hkdf(bytes: &[u8]) -> Self { + #[allow(clippy::expect_used)] + Self(bytes.try_into().expect("hkdf failed")) + } + } + + impl FromHkdf for (L, R) { + const SIZE: usize = L::SIZE + R::SIZE; + #[allow(clippy::indexing_slicing)] + fn from_hkdf(bytes: &[u8]) -> Self { + (L::from_hkdf(&bytes[0..L::SIZE]), R::from_hkdf(&bytes[L::SIZE..])) + } + } + + pub fn from_hkdf(key: &[u8], info: &[&[u8]]) -> O { + let hk = Hkdf::::new(None, key); + let mut buf = vec![0; O::SIZE]; + #[allow(clippy::expect_used)] + // worth keeping an eye out for very large results (>255*32 bytes) + hk.expand_multi_info(info, buf.as_mut_slice()).expect("hkdf failed"); + O::from_hkdf(&buf) + } +} + +use from_hkdf::from_hkdf; +use zeroize::Zeroize; + +impl SecretBytes { + pub fn generate() -> Self { + let mut result = Self([0; N]); + rand::rngs::OsRng.fill_bytes(&mut result.0); + result + } +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(transparent)] +pub struct AuthPW { + pub pw: SecretBytes<32>, +} + +pub struct StretchedPW { + pub pw: Output, +} + +impl AuthPW { + pub fn stretch(&self, salt: Salt) -> anyhow::Result { + let mut result = [0; 32]; + let mut buf = [0; Salt::MAX_LENGTH]; + let params = scrypt::Params::new(16, 8, 1)?; + let salt = salt.b64_decode(&mut buf)?; + scrypt(&self.pw.0, salt, ¶ms, &mut result)?; + Ok(StretchedPW { pw: Output::new(&result)? }) + } +} + +impl StretchedPW { + pub fn verify_hash(&self) -> Output { + let raw: SecretBytes<32> = from_hkdf(self.pw.as_bytes(), &[NAMESPACE, b"verifyHash"]); + raw.into() + } + + fn wrap_wrap_key(&self) -> SecretBytes<32> { + from_hkdf(self.pw.as_bytes(), &[NAMESPACE, b"wrapwrapKey"]) + } + + pub fn decrypt_wwkb(&self, wwkb: &SecretBytes<32>) -> SecretBytes<32> { + wwkb.xor(&self.wrap_wrap_key()) + } + + pub fn rewrap_wkb(&self, wkb: &SecretBytes<32>) -> SecretBytes<32> { + wkb.xor(&self.wrap_wrap_key()) + } +} + +pub struct SessionCredentials { + pub token_id: TokenID, + pub req_hmac_key: SecretBytes<32>, +} + +impl SessionCredentials { + pub fn derive(seed: &SecretBytes<32>) -> Self { + let (token_id, req_hmac_key) = from_hkdf(&seed.0, &[NAMESPACE, b"sessionToken"]); + Self { token_id, req_hmac_key } + } +} + +pub struct KeyFetchReq { + pub token_id: TokenID, + pub req_hmac_key: SecretBytes<32>, + pub key_request_key: SecretBytes<32>, +} + +impl KeyFetchReq { + pub fn from_token(key_fetch_token: &SecretBytes<32>) -> Self { + let (token_id, (req_hmac_key, key_request_key)) = + from_hkdf(&key_fetch_token.0, &[NAMESPACE, b"keyFetchToken"]); + Self { token_id, req_hmac_key, key_request_key } + } + + pub fn derive_resp(&self) -> KeyFetchResp { + let (resp_hmac_key, resp_xor_key) = + from_hkdf(&self.key_request_key.0, &[NAMESPACE, b"account/keys"]); + KeyFetchResp { resp_hmac_key, resp_xor_key } + } +} + +pub struct KeyFetchResp { + pub resp_hmac_key: SecretBytes<32>, + pub resp_xor_key: SecretBytes<64>, +} + +impl KeyFetchResp { + pub fn wrap_keys(&self, keys: &KeyBundle) -> WrappedKeyBundle { + let ciphertext = self.resp_xor_key.xor(&keys.to_bytes()); + #[allow(clippy::unwrap_used)] + let mut hmac = Hmac::::new_from_slice(&self.resp_hmac_key.0).unwrap(); + hmac.update(&ciphertext.0); + let hmac = *hmac.finalize().into_bytes().as_ref(); + WrappedKeyBundle { ciphertext, hmac } + } +} + +pub struct KeyBundle { + pub ka: SecretBytes<32>, + pub wrap_kb: SecretBytes<32>, +} + +impl KeyBundle { + pub fn to_bytes(&self) -> SecretBytes<64> { + let mut result = SecretBytes([0; 64]); + result.0[0..32].copy_from_slice(&self.ka.0); + result.0[32..].copy_from_slice(&self.wrap_kb.0); + result + } +} + +#[derive(Debug)] +pub struct WrappedKeyBundle { + pub ciphertext: SecretBytes<64>, + pub hmac: [u8; 32], +} + +impl WrappedKeyBundle { + pub fn to_bytes(&self) -> [u8; 96] { + let mut result = [0; 96]; + result[0..64].copy_from_slice(&self.ciphertext.0); + result[64..].copy_from_slice(&self.hmac); + result + } +} + +pub struct PasswordChangeReq { + pub token_id: TokenID, + pub req_hmac_key: SecretBytes<32>, +} + +impl PasswordChangeReq { + pub fn from_change_token(token: &SecretBytes<32>) -> Self { + let (token_id, req_hmac_key) = from_hkdf(&token.0, &[NAMESPACE, b"passwordChangeToken"]); + Self { token_id, req_hmac_key } + } + + pub fn from_forgot_token(token: &SecretBytes<32>) -> Self { + let (token_id, req_hmac_key) = from_hkdf(&token.0, &[NAMESPACE, b"passwordForgotToken"]); + Self { token_id, req_hmac_key } + } +} + +pub struct AccountResetReq { + pub token_id: TokenID, + pub req_hmac_key: SecretBytes<32>, +} + +impl AccountResetReq { + pub fn from_token(token: &SecretBytes<32>) -> Self { + let (token_id, req_hmac_key) = from_hkdf(&token.0, &[NAMESPACE, b"accountResetToken"]); + Self { token_id, req_hmac_key } + } +} + +#[cfg(test)] +mod test { + use hex_literal::hex; + use password_hash::{Output, SaltString}; + + use crate::crypto::{KeyBundle, KeyFetchReq, SessionCredentials}; + + use super::{AuthPW, SecretBytes}; + + macro_rules! shex { + ( $s: literal ) => { + SecretBytes(hex!($s)) + }; + } + + #[test] + fn test_derive_session() { + let creds = SessionCredentials::derive(&SecretBytes(hex!( + "a0a1a2a3a4a5a6a7 a8a9aaabacadaeaf b0b1b2b3b4b5b6b7 b8b9babbbcbdbebf" + ))); + assert_eq!( + creds.token_id.0, + hex!("c0a29dcf46174973da1378696e4c82ae10f723cf4f4d9f75e39f4ae3851595ab") + ); + assert_eq!( + creds.req_hmac_key.0, + hex!("9d8f22998ee7f579 8b887042466b72d5 3e56ab0c094388bf 65831f702d2febc0") + ); + } + + #[test] + fn test_key_fetch() { + let key_fetch = KeyFetchReq::from_token(&shex!( + "8081828384858687 88898a8b8c8d8e8f 9091929394959697 98999a9b9c9d9e9f" + )); + assert_eq!( + key_fetch.token_id.0, + hex!("3d0a7c02a15a62a2882f76e39b6494b500c022a8816e048625a495718998ba60") + ); + assert_eq!( + key_fetch.req_hmac_key.0, + hex!("87b8937f61d38d0e 29cd2d5600b3f4da 0aa48ac41de36a0e fe84bb4a9872ceb7") + ); + assert_eq!( + key_fetch.key_request_key.0, + hex!("14f338a9e8c6324d 9e102d4e6ee83b20 9796d5c74bb734a4 10e729e014a4a546") + ); + + let resp = key_fetch.derive_resp(); + assert_eq!( + resp.resp_hmac_key.0, + hex!("f824d2953aab9faf 51a1cb65ba9e7f9e 5bf91c8d8fd1ac1c 8c2d31853a8a1210") + ); + assert_eq!( + resp.resp_xor_key.0, + hex!( + "ce7d7aa77859b235 9932970bbe2101f2 e80d01faf9191bd5 ee52181d2f0b7809 + 8281ba8cff392543 3a89f7c3095e0c89 900a469d60790c83 3281c4df1a11c763" + ) + ); + + let bundle = KeyBundle { + ka: shex!("2021222324252627 28292a2b2c2d2e2f 3031323334353637 38393a3b3c3d3e3f"), + wrap_kb: shex!("7effe354abecbcb2 34a8dfc2d7644b4a d339b525589738f2 d27341bb8622ecd8"), + }; + assert_eq!( + bundle.to_bytes().0, + hex!( + "2021222324252627 28292a2b2c2d2e2f 3031323334353637 38393a3b3c3d3e3f + 7effe354abecbcb2 34a8dfc2d7644b4a d339b525589738f2 d27341bb8622ecd8" + ) + ); + + let wrapped = resp.wrap_keys(&bundle); + assert_eq!( + wrapped.ciphertext.0, + hex!( + "ee5c58845c7c9412 b11bbd20920c2fdd d83c33c9cd2c2de2 d66b222613364636 + fc7e59d854d599f1 0e212801de3a47c3 4333f3b838ee3471 e0f285649c332bbb" + ) + ); + assert_eq!( + wrapped.hmac, + hex!("4c17f42a0b319bbb a327d2b326ad23e9 37219b4de32e3ec7 b3e3f740522ad6ef") + ); + assert_eq!( + wrapped.to_bytes(), + hex!( + "ee5c58845c7c9412 b11bbd20920c2fdd d83c33c9cd2c2de2 d66b222613364636 + fc7e59d854d599f1 0e212801de3a47c3 4333f3b838ee3471 e0f285649c332bbb + 4c17f42a0b319bbb a327d2b326ad23e9 37219b4de32e3ec7 b3e3f740522ad6ef" + ) + ); + } + + #[test] + fn test_stretch() -> anyhow::Result<()> { + let auth_pw = AuthPW { + pw: SecretBytes(hex!( + "247b675ffb4c4631 0bc87e26d712153a be5e1c90ef00a478 4594f97ef54f2375" + )), + }; + + let stretched = auth_pw.stretch( + SaltString::b64_encode(&hex!( + "00f0000000000000 0000000000000000 0000000000000000 0000000000000000" + ))? + .as_salt(), + )?; + assert_eq!( + stretched.pw, + Output::new(&hex!( + "441509e25c92ee10 3d5a1a874e6f155d f25a44d06e61c894 616c9e85181dba97" + ))? + ); + + assert_eq!( + stretched.verify_hash().as_bytes(), + hex!("a4765bf103dc057f 4cf4bc2c131ddb67 16e8a4333cc55e1d 3c449f31f0eec4f1") + ); + + assert_eq!( + stretched.wrap_wrap_key().0, + hex!("3ebea117efa9faf5 7ce195899b290505 8368e7760cc26ea5 8a2a1be0da7fb287") + ); + Ok(()) + } +} diff --git a/src/db/mod.rs b/src/db/mod.rs new file mode 100644 index 0000000..040507f --- /dev/null +++ b/src/db/mod.rs @@ -0,0 +1,1026 @@ +use std::{error::Error, mem::replace, sync::Arc}; + +use anyhow::Result; +use chrono::{DateTime, Duration, Utc}; +use password_hash::SaltString; +use rocket::{ + fairing::{self, Fairing}, + futures::lock::{MappedMutexGuard, Mutex, MutexGuard}, + http::Status, + request::{FromRequest, Outcome}, + Request, Response, Sentinel, +}; +use serde_json::Value; +use sqlx::{query, query_as, query_scalar, PgPool, Postgres, Transaction}; + +use crate::{ + crypto::WrappedKeyBundle, + types::{ + oauth::ScopeSet, AccountResetID, AttachedClient, Avatar, AvatarID, Device, DeviceCommand, + DeviceCommands, DeviceID, DevicePush, DeviceUpdate, HawkKey, KeyFetchID, OauthAccessToken, + OauthAccessType, OauthAuthorization, OauthAuthorizationID, OauthRefreshToken, OauthTokenID, + PasswordChangeID, SecretKey, SessionID, User, UserID, UserSession, VerifyCode, VerifyHash, + }, +}; + +// we implement a completely custom db type set instead of using rocket_db_pools for two reasons: +// 1. rocket_db_pools doesn't support sqlx 0.6 +// 2. we want one transaction per *request*, including all guards + +#[derive(Clone)] +pub struct Db { + db: Arc, +} + +impl Db { + pub async fn connect(db: &str) -> Result { + Ok(Self { db: Arc::new(PgPool::connect(db).await?) }) + } + + pub async fn begin(&self) -> Result { + Ok(DbConn(Mutex::new(ConnState::Capable(self.clone())))) + } + + pub async fn migrate(&self) -> Result<()> { + sqlx::migrate!().run(&*self.db).await?; + Ok(()) + } +} + +struct ActiveConn { + tx: Transaction<'static, Postgres>, + always_commit: bool, +} + +#[allow(clippy::large_enum_variant)] +enum ConnState { + None, + Capable(Db), + Active(ActiveConn), + Done, +} + +pub struct DbConn(Mutex); + +struct DbWrap(Db); +struct DbConnWrap(DbConn); +type DbLock<'c> = MappedMutexGuard<'c, ConnState, ActiveConn>; + +#[rocket::async_trait] +impl Fairing for Db { + fn info(&self) -> fairing::Info { + fairing::Info { + name: "db access", + kind: fairing::Kind::Ignite | fairing::Kind::Response | fairing::Kind::Singleton, + } + } + + async fn on_ignite(&self, rocket: rocket::Rocket) -> fairing::Result { + Ok(rocket.manage(DbWrap(self.clone())).manage(self.clone())) + } + + async fn on_response<'r>(&self, req: &'r Request<'_>, resp: &mut Response<'r>) { + let conn = req.local_cache(|| DbConnWrap(DbConn(Mutex::new(ConnState::None)))); + let s = replace(&mut *conn.0 .0.lock().await, ConnState::Done); + if let ConnState::Active(ActiveConn { tx, always_commit }) = s { + // don't commit if the request failed, unless explicitly asked to. + // this is used by key fetch to invalidate tokens even if the header + // signature of the request is incorrect. + if !always_commit && resp.status() != Status::Ok { + return; + } + if let Err(e) = tx.commit().await { + resp.set_status(Status::InternalServerError); + error!("commit failed: {}", e); + } + } + } +} + +impl DbConn { + pub async fn commit(self) -> Result<(), sqlx::Error> { + match self.0.into_inner() { + ConnState::None => Ok(()), + ConnState::Capable(_) => Ok(()), + ConnState::Active(ac) => ac.tx.commit().await, + ConnState::Done => Ok(()), + } + } + + fn register<'r>(req: &'r Request<'_>) -> &'r DbConn { + &req.local_cache(|| { + let db = req.rocket().state::().unwrap().0.clone(); + DbConnWrap(DbConn(Mutex::new(ConnState::Capable(db)))) + }) + .0 + } + + // defer opening a transaction and thus locking the mutex holding it. + // without deferral the order of arguments in route signatures is important, + // which may be surprising: placing a DbConn before a guard that also uses + // the database causes a concurrent transaction error. + // HACK maybe we should return errors instead of panicking, but there's no + // way an error here is not a severe bug + async fn get(&self) -> sqlx::Result> { + let mut m = match self.0.try_lock() { + Some(m) => m, + None => panic!("attempted to open concurrent transactions"), + }; + match &*m { + ConnState::Capable(db) => { + *m = ConnState::Active(ActiveConn { + tx: db.db.begin().await?, + always_commit: false, + }); + }, + ConnState::None | ConnState::Done => panic!("db connection requested after teardown"), + _ => (), + } + Ok(MutexGuard::map(m, |g| match g { + ConnState::Active(ref mut tx) => tx, + _ => unreachable!(), + })) + } +} + +impl<'r> Sentinel for &'r Db { + fn abort(rocket: &rocket::Rocket) -> bool { + rocket.state::().is_none() + } +} + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for &'r Db { + type Error = anyhow::Error; + + async fn from_request(req: &'r Request<'_>) -> Outcome { + Outcome::Success(&req.rocket().state::().unwrap().0) + } +} + +impl<'r> Sentinel for &'r DbConn { + fn abort(rocket: &rocket::Rocket) -> bool { + rocket.state::().is_none() + } +} + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for &'r DbConn { + type Error = anyhow::Error; + + async fn from_request(req: &'r Request<'_>) -> Outcome { + Outcome::Success(DbConn::register(req)) + } +} + +// +// +// + +// +// +// + +impl DbConn { + pub(crate) async fn always_commit(&self) -> Result<()> { + self.get().await?.always_commit = true; + Ok(()) + } + + // + // + // + + pub(crate) async fn add_session( + &self, + id: SessionID, + user_id: &UserID, + key: HawkKey, + verified: bool, + verify_code: Option<&str>, + ) -> sqlx::Result> { + query_scalar!( + r#"insert into user_session (session_id, user_id, req_hmac_key, device_id, verified, + verify_code) + values ($1, $2, $3, null, $4, $5) + returning created_at"#, + id as _, + user_id as _, + key as _, + verified, + verify_code, + ) + .fetch_one(&mut self.get().await?.tx) + .await + } + + pub(crate) async fn use_session(&self, id: &SessionID) -> sqlx::Result { + query_as!( + UserSession, + r#"update user_session + set last_active = now() + where session_id = $1 + returning user_id as "uid: UserID", req_hmac_key as "req_hmac_key: HawkKey", + device_id as "device_id: DeviceID", created_at, verified, verify_code"#, + id as _ + ) + .fetch_one(&mut self.get().await?.tx) + .await + } + + pub(crate) async fn use_session_from_refresh( + &self, + id: &OauthTokenID, + ) -> sqlx::Result<(SessionID, UserSession)> { + query!( + r#"update user_session + set last_active = now() + where session_id = ( + select session_id from oauth_token where kind = 'refresh' and id = $1 + ) + returning user_id as "uid: UserID", req_hmac_key as "req_hmac_key: HawkKey", + device_id as "device_id: DeviceID", session_id as "session_id: SessionID", + created_at, verified, verify_code"#, + id as _ + ) + .map(|r| { + ( + r.session_id.clone(), + UserSession { + uid: r.uid, + req_hmac_key: r.req_hmac_key, + device_id: r.device_id, + created_at: r.created_at, + verified: r.verified, + verify_code: r.verify_code, + }, + ) + }) + .fetch_one(&mut self.get().await?.tx) + .await + } + + pub(crate) async fn delete_session(&self, user: &UserID, id: &SessionID) -> sqlx::Result<()> { + query_scalar!( + r#"delete from user_session + where user_id = $1 and session_id = $2 + returning 1"#, + user as _, + id as _ + ) + .fetch_one(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn set_session_device<'d>( + &self, + id: &SessionID, + dev: Option<&'d DeviceID>, + ) -> sqlx::Result<()> { + query!( + r#"update user_session set device_id = $1 where session_id = $2"#, + dev as _, + id as _ + ) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn set_session_verified(&self, id: &SessionID) -> sqlx::Result<()> { + query_scalar!( + r#"update user_session + set verified = true, verify_code = null + where session_id = $1 + returning 1"#, + id as _ + ) + .fetch_one(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + // + // + // + + pub(crate) async fn enqueue_command( + &self, + target: &DeviceID, + sender: &Option, + command: &str, + payload: &Value, + ttl: u32, + ) -> sqlx::Result { + let expires = Utc::now() + Duration::seconds(ttl as i64); + query!( + r#"insert into device_commands (device_id, command, payload, expires, sender) + values ($1, $2, $3, $4, $5) + returning index"#, + target as _, + command, + payload, + expires, + sender.as_ref().map(ToString::to_string) + ) + .map(|x| x.index) + .fetch_one(&mut self.get().await?.tx) + .await + } + + pub(crate) async fn get_commands( + &self, + user: &UserID, + dev: &DeviceID, + min_index: i64, + limit: i64, + ) -> sqlx::Result<(bool, Vec)> { + // NOTE while fxa api docs state that command queries return only commands enqueued + // *after* index, what pushbox actually does is return commands *starting at* index! + let mut results = query_as!( + DeviceCommand, + r#"select index, command, payload, expires, sender + from device_commands join device using (device_id) + where index >= $1 and device_id = $2 and user_id = $3 + order by index + limit $4"#, + min_index, + dev as _, + user as _, + limit + 1 + ) + .fetch_all(&mut self.get().await?.tx) + .await?; + let more = results.len() > limit as usize; + results.truncate(limit as usize); + Ok((more, results)) + } + + // + // + // + + pub(crate) async fn get_devices(&self, user: &UserID) -> sqlx::Result> { + query_as!( + Device, + r#"select d.device_id as "device_id: DeviceID", d.name, d.type as type_, + d.push as "push: DevicePush", + d.available_commands as "available_commands: DeviceCommands", + d.push_expired, d.location, coalesce(us.last_active, to_timestamp(0)) as "last_active!" + from device d left join user_session us using (device_id) + where d.user_id = $1"#, + user as _ + ) + .fetch_all(&mut self.get().await?.tx) + .await + } + + pub(crate) async fn get_device(&self, user: &UserID, dev: &DeviceID) -> sqlx::Result { + query_as!( + Device, + r#"select d.device_id as "device_id: DeviceID", d.name, d.type as type_, + d.push as "push: DevicePush", + d.available_commands as "available_commands: DeviceCommands", + d.push_expired, d.location, coalesce(us.last_active, to_timestamp(0)) as "last_active!" + from device d left join user_session us using (device_id) + where d.user_id = $1 and d.device_id = $2"#, + user as _, + dev as _ + ) + .fetch_one(&mut self.get().await?.tx) + .await + } + + pub(crate) async fn change_device<'d>( + &self, + user: &UserID, + id: &DeviceID, + dev: DeviceUpdate<'d>, + ) -> sqlx::Result { + query_as!( + Device, + r#"select device_id as "device_id!: DeviceID", name as "name!", type as "type_!", + push as "push: DevicePush", + available_commands as "available_commands!: DeviceCommands", + push_expired as "push_expired!", location as "location!", + coalesce(last_active, to_timestamp(0)) as "last_active!" + from insert_or_update_device($1, $2, $3, $4, $5, $6, $7) as iud + left join user_session using (device_id)"#, + id as _, + user as _, + // these two are not optional but are Option anyway. the db will + // refuse insertions that don't have them set. + dev.name, + dev.type_, + dev.push as _, + dev.available_commands as _, + dev.location + ) + .fetch_one(&mut self.get().await?.tx) + .await + } + + pub(crate) async fn set_push_expired(&self, dev: &DeviceID) -> sqlx::Result<()> { + query!( + r#"update device + set push_expired = true + where device_id = $1"#, + dev as _ + ) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn delete_device(&self, user: &UserID, dev: &DeviceID) -> sqlx::Result<()> { + query_scalar!( + r#"delete from device where user_id = $1 and device_id = $2 returning 1"#, + user as _, + dev as _ + ) + .fetch_one(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + // + // + // + + pub(crate) async fn add_key_fetch( + &self, + id: KeyFetchID, + hmac_key: &HawkKey, + keys: &WrappedKeyBundle, + ) -> sqlx::Result<()> { + query!( + r#"insert into key_fetch (id, hmac_key, keys) values ($1, $2, $3)"#, + id as _, + hmac_key as _, + &keys.to_bytes()[..] + ) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn finish_key_fetch( + &self, + id: &KeyFetchID, + ) -> sqlx::Result<(HawkKey, Vec)> { + query!( + r#"delete from key_fetch + where id = $1 and expires_at > now() + returning hmac_key as "hmac_key: HawkKey", keys"#, + id as _ + ) + .map(|r| (r.hmac_key, r.keys)) + .fetch_one(&mut self.get().await?.tx) + .await + } + + // + // + // + + pub(crate) async fn get_refresh_token( + &self, + id: &OauthTokenID, + ) -> sqlx::Result { + query_as!( + OauthRefreshToken, + r#"select user_id as "user_id: UserID", client_id, scope as "scope: ScopeSet", + session_id as "session_id: SessionID" + from oauth_token + where id = $1 and kind = 'refresh'"#, + id as _ + ) + .fetch_one(&mut self.get().await?.tx) + .await + } + + pub(crate) async fn get_access_token( + &self, + id: &OauthTokenID, + ) -> sqlx::Result { + query_as!( + OauthAccessToken, + r#"select user_id as "user_id: UserID", client_id, scope as "scope: ScopeSet", + parent_refresh as "parent_refresh: OauthTokenID", + parent_session as "parent_session: SessionID", + expires_at as "expires_at!" + from oauth_token + where id = $1 and kind = 'access' and expires_at > now()"#, + id as _ + ) + .fetch_one(&mut self.get().await?.tx) + .await + } + + pub(crate) async fn add_refresh_token( + &self, + id: &OauthTokenID, + token: OauthRefreshToken, + ) -> sqlx::Result<()> { + query!( + r#"insert into oauth_token (id, kind, user_id, client_id, scope, session_id) + values ($1, 'refresh', $2, $3, $4, $5)"#, + id as _, + token.user_id as _, + token.client_id, + token.scope as _, + token.session_id as _ + ) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn add_access_token( + &self, + id: &OauthTokenID, + token: OauthAccessToken, + ) -> sqlx::Result<()> { + query!( + r#"insert into oauth_token (id, kind, user_id, client_id, scope, session_id, + parent_refresh, parent_session, expires_at) + values ($1, 'access', $2, $3, $4, null, $5, $6, $7)"#, + id as _, + token.user_id as _, + token.client_id, + token.scope as _, + token.parent_refresh as _, + token.parent_session as _, + token.expires_at, + ) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn delete_oauth_token(&self, id: &OauthTokenID) -> sqlx::Result<()> { + query!(r#"delete from oauth_token where id = $1"#, id as _) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn delete_refresh_token(&self, id: &OauthTokenID) -> sqlx::Result<()> { + query!(r#"delete from oauth_token where id = $1 and kind = 'refresh'"#, id as _) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + // + // + // + + pub(crate) async fn add_oauth_authorization( + &self, + id: &OauthAuthorizationID, + auth: OauthAuthorization, + ) -> sqlx::Result<()> { + query!( + r#"insert into oauth_authorization (id, user_id, client_id, scope, access_type, + code_challenge, keys_jwe, auth_at) + values ($1, $2, $3, $4, $5, $6, $7, $8)"#, + id as _, + auth.user_id as _, + auth.client_id, + auth.scope as _, + auth.access_type as _, + auth.code_challenge, + auth.keys_jwe, + auth.auth_at, + ) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn take_oauth_authorization( + &self, + id: &OauthAuthorizationID, + ) -> sqlx::Result { + query_as!( + OauthAuthorization, + r#"delete from oauth_authorization + where id = $1 and expires_at > now() + returning user_id as "user_id: UserID", client_id, scope as "scope: ScopeSet", + access_type as "access_type: OauthAccessType", + code_challenge, keys_jwe, auth_at"#, + id as _ + ) + .fetch_one(&mut self.get().await?.tx) + .await + } + + // + // + // + + pub(crate) async fn user_email_exists(&self, email: &str) -> sqlx::Result { + Ok(query_scalar!(r#"select 1 from users where email = lower($1)"#, email) + .fetch_optional(&mut self.get().await?.tx) + .await? + .is_some()) + } + + pub(crate) async fn add_user(&self, user: User) -> sqlx::Result { + let id = UserID::random(); + query_scalar!( + r#"insert into users (user_id, auth_salt, email, ka, wrapwrap_kb, verify_hash, + display_name) + values ($1, $2, $3, $4, $5, $6, $7)"#, + id as _, + user.auth_salt.as_str(), + user.email, + user.ka as _, + user.wrapwrap_kb as _, + user.verify_hash as _, + user.display_name, + ) + .execute(&mut self.get().await?.tx) + .await?; + Ok(id) + } + + pub(crate) async fn get_user(&self, email: &str) -> sqlx::Result<(UserID, User)> { + query!( + r#"select user_id as "id: UserID", auth_salt as "auth_salt: String", email, + ka as "ka: SecretKey", wrapwrap_kb as "wrapwrap_kb: SecretKey", + verify_hash as "verify_hash: VerifyHash", display_name, verified + from users + where email = lower($1)"#, + email + ) + .try_map(|r| { + Ok(( + r.id, + User { + auth_salt: SaltString::new(&r.auth_salt).map_err(decode_err("auth_salt"))?, + email: r.email, + ka: r.ka, + wrapwrap_kb: r.wrapwrap_kb, + verify_hash: r.verify_hash, + display_name: r.display_name, + verified: r.verified, + }, + )) + }) + .fetch_one(&mut self.get().await?.tx) + .await + } + + pub(crate) async fn get_user_by_id(&self, id: &UserID) -> sqlx::Result { + query!( + r#"select auth_salt as "auth_salt: String", email, + ka as "ka: SecretKey", wrapwrap_kb as "wrapwrap_kb: SecretKey", + verify_hash as "verify_hash: VerifyHash", display_name, verified + from users + where user_id = $1"#, + id as _ + ) + .try_map(|r| { + Ok(User { + auth_salt: SaltString::new(&r.auth_salt).map_err(decode_err("auth_salt"))?, + email: r.email, + ka: r.ka, + wrapwrap_kb: r.wrapwrap_kb, + verify_hash: r.verify_hash, + display_name: r.display_name, + verified: r.verified, + }) + }) + .fetch_one(&mut self.get().await?.tx) + .await + } + + pub(crate) async fn set_user_name(&self, id: &UserID, name: &str) -> sqlx::Result<()> { + query!( + "update users + set display_name = $2 + where user_id = $1", + id as _, + name, + ) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn set_user_verified(&self, id: &UserID) -> sqlx::Result<()> { + query_scalar!("update users set verified = true where user_id = $1 returning 1", id as _) + .fetch_one(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn delete_user(&self, email: &str) -> sqlx::Result<()> { + query_scalar!(r#"delete from users where email = lower($1)"#, email) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn change_user_auth( + &self, + uid: &UserID, + salt: SaltString, + wwkb: SecretKey, + verify_hash: VerifyHash, + ) -> sqlx::Result<()> { + query!( + r#"update users + set auth_salt = $2, wrapwrap_kb = $3, verify_hash = $4 + where user_id = $1"#, + uid as _, + salt.to_string(), + wwkb as _, + verify_hash as _, + ) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn reset_user_auth( + &self, + uid: &UserID, + salt: SaltString, + wwkb: SecretKey, + verify_hash: VerifyHash, + ) -> sqlx::Result<()> { + query!( + r#"call reset_user_auth($1, $2, $3, $4)"#, + uid as _, + salt.to_string(), + wwkb as _, + verify_hash as _, + ) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + // + // + // + + pub(crate) async fn get_attached_clients( + &self, + id: &UserID, + ) -> sqlx::Result> { + query_as!( + AttachedClient, + r#"select + ot.client_id as "client_id?", + d.device_id as "device_id?: DeviceID", + us.session_id as "session_token_id?: SessionID", + ot.id as "refresh_token_id?: OauthTokenID", + d.type as "device_type?", + d.name as "name?", + coalesce(d.created_at, us.created_at, ot.created_at) as "created_time?", + us.last_active as "last_access_time?", + ot.scope as "scope?" + from device d + full outer join user_session us on (d.device_id = us.device_id) + full outer join oauth_token ot on (us.session_id = ot.session_id) + where + (ot.kind is null or ot.kind = 'refresh') + and $1 in (d.user_id, us.user_id, ot.user_id) + order by d.device_id"#, + id as _, + ) + .fetch_all(&mut self.get().await?.tx) + .await + } + + // + // + // + + pub(crate) async fn get_user_avatar_id(&self, id: &UserID) -> sqlx::Result> { + query!(r#"select id as "id: AvatarID" from user_avatars where user_id = $1"#, id as _,) + .map(|r| r.id) + .fetch_optional(&mut *self.get().await?.tx) + .await + } + + pub(crate) async fn get_user_avatar(&self, id: &AvatarID) -> sqlx::Result> { + query_as!( + Avatar, + r#"select id as "id: AvatarID", data, content_type + from user_avatars + where id = $1"#, + id as _, + ) + .fetch_optional(&mut *self.get().await?.tx) + .await + } + + pub(crate) async fn set_user_avatar(&self, id: &UserID, avatar: Avatar) -> sqlx::Result<()> { + query!( + r#"insert into user_avatars (user_id, id, data, content_type) + values ($1, $2, $3, $4) + on conflict (user_id) do update set + id = $2, data = $3, content_type = $4"#, + id as _, + avatar.id as _, + avatar.data, + avatar.content_type, + ) + .execute(&mut *self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn delete_user_avatar( + &self, + user: &UserID, + id: &AvatarID, + ) -> sqlx::Result<()> { + query!(r#"delete from user_avatars where user_id = $1 and id = $2"#, user as _, id as _,) + .execute(&mut *self.get().await?.tx) + .await?; + Ok(()) + } + + // + // + // + + pub(crate) async fn add_verify_code( + &self, + user: &UserID, + session: &SessionID, + code: &str, + ) -> sqlx::Result<()> { + query!( + r#"insert into verify_codes (user_id, session_id, code) + values ($1, $2, $3)"#, + user as _, + session as _, + code, + ) + .execute(&mut *self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn get_verify_code( + &self, + user: &UserID, + ) -> sqlx::Result<(String, VerifyCode)> { + query!( + r#"select user_id as "user_id: UserID", session_id as "session_id: SessionID", code, + email + from verify_codes join users using (user_id) + where user_id = $1"#, + user as _, + ) + .map(|r| { + (r.email, VerifyCode { user_id: r.user_id, session_id: r.session_id, code: r.code }) + }) + .fetch_one(&mut *self.get().await?.tx) + .await + } + + pub(crate) async fn try_use_verify_code( + &self, + user: &UserID, + code: &str, + ) -> sqlx::Result> { + query_as!( + VerifyCode, + r#"delete from verify_codes + where user_id = $1 and code = $2 + returning user_id as "user_id: UserID", session_id as "session_id: SessionID", + code"#, + user as _, + code, + ) + .fetch_optional(&mut *self.get().await?.tx) + .await + } + + // + // + // + + pub(crate) async fn add_password_change( + &self, + user: &UserID, + id: &PasswordChangeID, + key: &HawkKey, + forgot_code: Option<&str>, + ) -> sqlx::Result<()> { + query!( + r#"insert into password_change_tokens (id, user_id, hmac_key, forgot_code) + values ($1, $2, $3, $4) + on conflict (user_id) do update set id = $1, hmac_key = $3, forgot_code = $4, + expires_at = default"#, + id as _, + user as _, + key as _, + forgot_code, + ) + .execute(&mut *self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn finish_password_change( + &self, + id: &PasswordChangeID, + is_forgot: bool, + ) -> sqlx::Result<(HawkKey, (UserID, Option))> { + query!( + r#"delete from password_change_tokens + where id = $1 and expires_at > now() and (forgot_code is not null) = $2 + returning hmac_key as "hmac_key: HawkKey", user_id as "user_id: UserID", + forgot_code"#, + id as _, + is_forgot, + ) + .map(|r| (r.hmac_key, (r.user_id, r.forgot_code))) + .fetch_one(&mut self.get().await?.tx) + .await + } + + pub(crate) async fn add_account_reset( + &self, + user: &UserID, + id: &AccountResetID, + key: &HawkKey, + ) -> sqlx::Result<()> { + query!( + r#"insert into account_reset_tokens (id, user_id, hmac_key) + values ($1, $2, $3) + on conflict (user_id) do update set id = $1, hmac_key = $3, expires_at = default"#, + id as _, + user as _, + key as _, + ) + .execute(&mut *self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn finish_account_reset( + &self, + id: &AccountResetID, + ) -> sqlx::Result<(HawkKey, UserID)> { + query!( + r#"delete from account_reset_tokens + where id = $1 and expires_at > now() + returning hmac_key as "hmac_key: HawkKey", user_id as "user_id: UserID""#, + id as _, + ) + .map(|r| (r.hmac_key, r.user_id)) + .fetch_one(&mut self.get().await?.tx) + .await + } + + // + // + // + + pub async fn add_invite_code(&self, code: &str, expires: DateTime) -> sqlx::Result<()> { + query!(r#"insert into invite_codes (code, expires_at) values ($1, $2)"#, code, expires,) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn use_invite_code(&self, code: &str) -> sqlx::Result<()> { + query_scalar!( + r#"delete from invite_codes where code = $1 and expires_at > now() returning 1"#, + code, + ) + .fetch_one(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + // + // + // + + pub(crate) async fn prune_expired_tokens(&self) -> sqlx::Result<()> { + query!("call prune_expired_tokens()").execute(&mut self.get().await?.tx).await?; + Ok(()) + } + + pub(crate) async fn prune_expired_verify_codes(&self) -> sqlx::Result<()> { + query!("call prune_expired_verify_codes()").execute(&mut self.get().await?.tx).await?; + Ok(()) + } +} + +fn decode_err(c: &str) -> impl FnOnce(E) -> sqlx::Error { + let index = c.to_string(); + move |e| sqlx::Error::ColumnDecode { index, source: Box::new(e) } +} diff --git a/src/js.rs b/src/js.rs new file mode 100644 index 0000000..f3d662c --- /dev/null +++ b/src/js.rs @@ -0,0 +1,53 @@ +use std::{collections::HashMap, path::PathBuf}; + +use rocket::http::{ContentType, Status}; +use sha2::{Digest, Sha256}; + +use crate::cache::{Etagged, IfNoneMatch}; + +struct Entry { + data: &'static str, + hash: String, +} + +fn enter(data: &'static str) -> Entry { + let mut sha = Sha256::new(); + sha.update(data.as_bytes()); + let hash = base64::encode(sha.finalize()); + Entry { data, hash } +} + +lazy_static! { + static ref JS: HashMap<&'static str, Entry> = { + let mut m = HashMap::new(); + m.insert("main", enter(include_str!("../web/js//main.js"))); + m.insert("crypto", enter(include_str!("../web/js//crypto.js"))); + m.insert("auth-client/browser", enter(include_str!("../web/js//browser/browser.js"))); + m.insert("auth-client/lib/client", enter(include_str!("../web/js//browser/lib/client.js"))); + m.insert("auth-client/lib/crypto", enter(include_str!("../web/js//browser/lib/crypto.js"))); + m.insert("auth-client/lib/hawk", enter(include_str!("../web/js//browser/lib/hawk.js"))); + m.insert( + "auth-client/lib/recoveryKey", + enter(include_str!("../web/js//browser/lib/recoveryKey.js")), + ); + m.insert("auth-client/lib/utils", enter(include_str!("../web/js//browser/lib/utils.js"))); + m + }; +} + +#[get("/")] +pub(crate) async fn static_js( + name: PathBuf, + inm: Option>, +) -> (Status, Result<(ContentType, Etagged<'_, &'static str>), ()>) { + let entry = JS.get(name.to_string_lossy().as_ref()); + match entry { + Some(e) => match inm { + Some(h) if h.0 == e.hash => (Status::NotModified, Err(())), + _ => { + (Status::Ok, Ok((ContentType::JavaScript, Etagged(e.data, e.hash.as_str().into())))) + }, + }, + _ => (Status::NotFound, Err(())), + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..1e6fa31 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,319 @@ +use std::{ + path::PathBuf, + sync::Arc, + time::{Duration as StdDuration, SystemTime, UNIX_EPOCH}, +}; + +use anyhow::Context; +use chrono::Duration; +use db::Db; +use futures::Future; +use lettre::message::Mailbox; +use mailer::Mailer; +use push::PushClient; +use rocket::{ + fairing::AdHoc, + http::{uri::Absolute, ContentType, Header}, + request::{self, FromRequest}, + response::Redirect, + tokio::{ + spawn, + time::{interval_at, Instant, MissedTickBehavior}, + }, + Request, State, +}; +use serde_json::{json, Value}; +use utils::DeferredActions; + +use crate::api::auth::invite::generate_invite_link; + +#[macro_use] +extern crate rocket; +#[macro_use] +extern crate anyhow; +#[macro_use] +extern crate lazy_static; + +#[macro_use] +pub(crate) mod utils; +pub(crate) mod api; +mod auth; +mod cache; +mod crypto; +pub mod db; +mod js; +mod mailer; +mod push; +mod types; + +fn default_push_ttl() -> std::time::Duration { + std::time::Duration::from_secs(2 * 86400) +} + +fn default_task_interval() -> std::time::Duration { + std::time::Duration::from_secs(5 * 60) +} + +#[derive(serde::Deserialize)] +struct Config { + database_url: String, + location: Absolute<'static>, + token_server_location: Absolute<'static>, + vapid_key: PathBuf, + vapid_subject: String, + #[serde(default = "default_push_ttl", with = "humantime_serde")] + default_push_ttl: std::time::Duration, + #[serde(default = "default_task_interval", with = "humantime_serde")] + prune_expired_interval: std::time::Duration, + + mail_from: Mailbox, + mail_host: Option, + mail_port: Option, + + #[serde(default)] + invite_only: bool, + #[serde(default)] + invite_admin_address: String, +} + +impl Config { + pub fn avatars_prefix(&self) -> Absolute<'static> { + Absolute::parse_owned(format!("{}/avatars", self.location)).unwrap() + } +} + +#[get("/")] +async fn root() -> (ContentType, &'static str) { + (ContentType::HTML, include_str!("../web/index.html")) +} + +#[get("/settings/<_..>")] +async fn settings() -> Redirect { + Redirect::to(uri!("/#/settings")) +} + +#[get("/auth/v1/authorization")] +async fn auth_auth() -> (ContentType, &'static str) { + root().await +} + +#[get("/force_auth")] +async fn force_auth() -> Redirect { + Redirect::to(uri!("/#/force_auth")) +} + +#[derive(Debug)] +struct IsFenix(bool); + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for IsFenix { + type Error = std::convert::Infallible; + + async fn from_request(request: &'r Request<'_>) -> request::Outcome { + let ua = request.headers().get_one("user-agent"); + request::Outcome::Success(IsFenix( + ua.map(|ua| ua.contains("Firefox") && ua.contains("Android")).unwrap_or(false), + )) + } +} + +#[get("/.well-known/fxa-client-configuration")] +async fn fxa_client_configuration(cfg: &State, is_fenix: IsFenix) -> Value { + let base = &cfg.location; + json!({ + "auth_server_base_url": format!("{base}/auth"), + "oauth_server_base_url": format!("{base}/oauth"), + "pairing_server_base_uri": format!("{base}/pairing"), + "profile_server_base_url": format!("{base}/profile"), + // NOTE trailing slash is *essential*, otherwise fenix will refuse to sync. + // likewise firefox desktop seems to misbehave if there *is* a trailing slash. + "sync_tokenserver_base_url": format!("{}{}", cfg.token_server_location, if is_fenix.0 { "/" } else { "" }) + }) +} + +// NOTE it looks like firefox does not implement refresh token rotation. +// since it also looks like it doesn't implement MTLS we can't secure +// refresh tokens against being stolen as advised by +// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-security-topics +// section 2.2.2 + +// NOTE firefox "oldsync" scope is the current version? +// https://github.com/mozilla/fxa/blob/main/packages/fxa-auth-server/docs/oauth/scopes.md +// https://mozilla.github.io/ecosystem-platform/explanation/onepw-protocol +// https://mozilla.github.io/ecosystem-platform/api +// https://github.com/mozilla/fxa/blob/main/packages/fxa-auth-server/docs/device_registration.md +// -> push for everything +// https://mozilla.github.io/ecosystem-platform/explanation/scoped-keys + +#[get("/.well-known/openid-configuration")] +fn oid(cfg: &State) -> Value { + let base = &cfg.location; + json!({ + "authorization_endpoint": format!("{base}/auth/v1/authorization"), + "introspection_endpoint": format!("{base}/oauth/v1/introspect"), + "issuer": base.to_string(), + "jwks_uri": format!("{base}/oauth/v1/jwks"), + "revocation_endpoint": format!("{base}/oauth/v1/destroy"), + "token_endpoint": format!("{base}/auth/v1/oauth/token"), + "userinfo_endpoint": format!("{base}/profile/v1/profile"), + "claims_supported": ["aud","exp","iat","iss","sub"], + "id_token_signing_alg_values_supported": ["RS256"], + "response_types_supported": ["code","token"], + "scopes_supported": ["openid","profile","email"], + "subject_types_supported": ["public"], + "token_endpoint_auth_methods_supported": ["client_secret_post"], + }) +} + +fn spawn_periodic(context: &'static str, t: StdDuration, p: P, f: A) +where + A: Fn(P) -> F + Send + Sync + Sized + 'static, + P: Clone + Send + Sync + 'static, + F: Future> + Send + Sized, +{ + let mut interval = interval_at(Instant::now() + t, t); + interval.set_missed_tick_behavior(MissedTickBehavior::Skip); + + spawn(async move { + loop { + interval.tick().await; + info!("starting periodic {context}"); + if let Err(e) = f(p.clone()).await { + error!("periodic {context} failed: {e}"); + } + } + }); +} + +async fn ensure_invite_admin(db: &Db, cfg: &Config) -> anyhow::Result<()> { + if !cfg.invite_only { + return Ok(()); + } + + let tx = db.begin().await?; + match tx.get_user(&cfg.invite_admin_address).await { + Err(sqlx::Error::RowNotFound) => { + let url = generate_invite_link(&tx, cfg, Duration::hours(1)).await?; + tx.commit().await?; + warn!("admin user {} does not exist, register at {url}", cfg.invite_admin_address); + Ok(()) + }, + Err(e) => Err(anyhow!(e)), + Ok(_) => Ok(()), + } +} + +pub async fn build() -> anyhow::Result> { + let rocket = rocket::build(); + let config = rocket.figment().extract::().context("reading config")?; + let db = Arc::new(Db::connect(&config.database_url).await.unwrap()); + + db.migrate().await.context("running db migrations")?; + + ensure_invite_admin(&db, &config).await?; + let push = Arc::new( + PushClient::new( + &config.vapid_key, + &config.vapid_subject, + config.location.clone(), + config.default_push_ttl, + ) + .context("setting up push notifications")?, + ); + let mailer = Arc::new( + Mailer::new( + config.mail_from.clone(), + config.mail_host.as_deref().unwrap_or("localhost"), + config.mail_port.unwrap_or(25), + config.location.clone(), + ) + .context("setting up mail notifications")?, + ); + spawn_periodic("verify code prune", StdDuration::from_secs(5 * 60), Arc::clone(&db), { + |db| async move { + let tx = db.begin().await?; + tx.prune_expired_verify_codes().await?; + tx.commit().await?; + Ok(()) + } + }); + spawn_periodic("expired token prune", config.prune_expired_interval, Arc::clone(&db), { + |db| async move { + let tx = db.begin().await?; + tx.prune_expired_tokens().await?; + tx.commit().await?; + Ok(()) + } + }); + let rocket = rocket + .manage(config) + .manage(push) + .manage(mailer) + .attach(db) + .attach(DeferredActions) + .mount("/", routes![root, settings, oid, auth_auth, force_auth, fxa_client_configuration,]) + .register("/auth/v1", catchers![api::auth::catch_all,]) + .mount( + "/auth/v1", + routes![ + api::auth::account::create, + api::auth::account::login, + api::auth::account::destroy, + api::auth::account::keys, + api::auth::account::reset, + api::auth::oauth::token_authenticated, + api::auth::oauth::token_unauthenticated, + api::auth::oauth::destroy, + api::auth::oauth::scoped_key_data, + api::auth::device::devices, + api::auth::device::device, + api::auth::device::invoke, + api::auth::device::commands, + api::auth::session::status, + api::auth::session::resend_code, + api::auth::session::verify_code, + api::auth::session::destroy, + api::auth::oauth::authorization, + api::auth::device::destroy, + api::auth::device::notify, + api::auth::device::attached_clients, + api::auth::device::destroy_attached_client, + api::auth::email::status, + api::auth::email::verify_code, + api::auth::email::resend_code, + api::auth::password::change_start, + api::auth::password::change_finish, + api::auth::password::forgot_start, + api::auth::password::forgot_finish, + ], + ) + // slight hack to allow the js auth client to "just work" + .register("/_invite/v1", catchers![api::auth::catch_all,]) + .mount("/_invite/v1", routes![api::auth::invite::generate,]) + .attach(AdHoc::on_response("/auth Timestamp", |req, resp| { + Box::pin(async move { + if req.uri().path().as_str().starts_with("/auth/v1/") { + if let Ok(ts) = SystemTime::now().duration_since(UNIX_EPOCH) { + resp.set_header(Header::new("timestamp", ts.as_secs().to_string())); + } + } + }) + })) + .register("/profile", catchers![api::profile::catch_all,]) + .mount( + "/profile/v1", + routes![ + api::profile::profile, + api::profile::display_name_post, + api::profile::avatar_get, + api::profile::avatar_upload, + api::profile::avatar_delete, + ], + ) + .register("/avatars", catchers![api::profile::catch_all,]) + .mount("/avatars", routes![api::profile::avatar_get_img]) + .register("/oauth/v1", catchers![api::oauth::catch_all,]) + .mount("/oauth/v1", routes![api::oauth::destroy, api::oauth::jwks, api::oauth::verify,]) + .mount("/js", routes![js::static_js]); + Ok(rocket) +} diff --git a/src/mailer.rs b/src/mailer.rs new file mode 100644 index 0000000..1ea1a8b --- /dev/null +++ b/src/mailer.rs @@ -0,0 +1,105 @@ +use std::time::Duration; + +use lettre::{ + message::Mailbox, + transport::smtp::client::{Tls, TlsParameters}, + AsyncSmtpTransport, Message, Tokio1Executor, +}; +use rocket::http::uri::Absolute; +use serde_json::json; + +use crate::types::UserID; + +pub struct Mailer { + from: Mailbox, + verify_base: Absolute<'static>, + transport: AsyncSmtpTransport, +} + +impl Mailer { + pub fn new( + from: Mailbox, + host: &str, + port: u16, + verify_base: Absolute<'static>, + ) -> anyhow::Result { + Ok(Mailer { + from, + verify_base, + transport: AsyncSmtpTransport::::builder_dangerous(host) + .port(port) + .tls(Tls::Opportunistic(TlsParameters::new(host.to_string())?)) + .timeout(Some(Duration::from_secs(5))) + .build(), + }) + } + + pub(crate) async fn send_account_verify( + &self, + uid: &UserID, + to: &str, + code: &str, + ) -> anyhow::Result<()> { + let fragment = base64::encode_config( + serde_json::to_string(&json!({ + "uid": uid, + "email": to, + "code": code, + }))?, + base64::URL_SAFE, + ); + let email = Message::builder() + .from(self.from.clone()) + .to(to.parse()?) + .subject("account verify code") + .body(format!("{}/#/verify/{fragment}", self.verify_base))?; + lettre::AsyncTransport::send(&self.transport, email).await?; + Ok(()) + } + + pub(crate) async fn send_session_verify(&self, to: &str, code: &str) -> anyhow::Result<()> { + let email = Message::builder() + .from(self.from.clone()) + .to(to.parse()?) + .subject("session verify code") + .body(format!("{code}"))?; + lettre::AsyncTransport::send(&self.transport, email).await?; + Ok(()) + } + + pub(crate) async fn send_password_changed(&self, to: &str) -> anyhow::Result<()> { + let email = Message::builder() + .from(self.from.clone()) + .to(to.parse()?) + .subject("account password has been changed") + .body(String::from( + "your account password has been changed. if you haven't done this, \ + you're probably in trouble now.", + ))?; + lettre::AsyncTransport::send(&self.transport, email).await?; + Ok(()) + } + + pub(crate) async fn send_password_forgot(&self, to: &str, code: &str) -> anyhow::Result<()> { + let email = Message::builder() + .from(self.from.clone()) + .to(to.parse()?) + .subject("account reset code") + .body(code.to_string())?; + lettre::AsyncTransport::send(&self.transport, email).await?; + Ok(()) + } + + pub(crate) async fn send_account_reset(&self, to: &str) -> anyhow::Result<()> { + let email = Message::builder() + .from(self.from.clone()) + .to(to.parse()?) + .subject("account has been reset") + .body(String::from( + "your account has been reset. if you haven't done this, \ + you're probably in trouble now.", + ))?; + lettre::AsyncTransport::send(&self.transport, email).await?; + Ok(()) + } +} diff --git a/src/push.rs b/src/push.rs new file mode 100644 index 0000000..6ee1afb --- /dev/null +++ b/src/push.rs @@ -0,0 +1,198 @@ +use anyhow::Result; +use rocket::http::uri::Absolute; +use serde_json::{json, Value}; +use std::time::Duration; +use std::{fs::File, io::Read, path::Path}; + +use serde::Serialize; +use web_push::{ + ContentEncoding, SubscriptionInfo, VapidSignatureBuilder, WebPushClient, WebPushMessageBuilder, +}; + +use crate::db::DbConn; +use crate::types::{Device, DeviceID, DevicePush, UserID}; + +pub(crate) struct PushClient { + key: Box<[u8]>, + client: WebPushClient, + subject: String, + base_uri: Absolute<'static>, + default_ttl: Duration, +} + +impl PushClient { + pub(crate) fn new>( + key: P, + subject: &str, + base_uri: Absolute<'static>, + default_ttl: Duration, + ) -> Result { + let mut key_bytes = vec![]; + File::open(key).and_then(|mut f| f.read_to_end(&mut key_bytes))?; + Ok(PushClient { + key: key_bytes.into_boxed_slice(), + client: WebPushClient::new()?, + subject: subject.to_string(), + base_uri, + default_ttl, + }) + } + + async fn push_raw(&self, to: &DevicePush, ttl: Duration, data: Option<&[u8]>) -> Result<()> { + let sub = SubscriptionInfo::new(&to.callback, &to.public_key, &to.auth_key); + let mut sig = VapidSignatureBuilder::from_pem(&*self.key, &sub)?; + // mozilla requires {aud,exp,sub} or message will get a 401 unauthorized. + // {aud,exp} are added automatically + sig.add_claim("sub", self.subject.as_str()); + let mut builder = WebPushMessageBuilder::new(&sub)?; + if let Some(data) = data { + builder.set_payload(ContentEncoding::Aes128Gcm, data); + } + builder.set_vapid_signature(sig.build()?); + builder.set_ttl(ttl.as_secs().min(u32::MAX as u64) as u32); + Ok(self.client.send(builder.build()?).await?) + } + + async fn push_one( + &self, + context: &str, + db: Option<&DbConn>, + to: &Device, + ttl: Duration, + data: Option<&[u8]>, + ) -> Result<()> { + match (to.push_expired, to.push.as_ref()) { + (false, Some(ep)) => match self.push_raw(ep, ttl, data).await { + Ok(()) => Ok(()), + Err(e) => { + warn!("{} push to {} failed: {}", context, &to.device_id, e); + if let Some(db) = db { + if let Err(e) = db.set_push_expired(&to.device_id).await { + warn!("failed to set {} push_endpoint_expired: {}", &to.device_id, e); + } + } + Err(e) + }, + }, + (_, None) => Err(anyhow!("no push callback")), + (true, _) => Err(anyhow!("push endpoint expired")), + } + } + + async fn push_all( + &self, + context: &str, + db: Option<&DbConn>, + to: &[Device], + ttl: Duration, + msg: impl Serialize, + ) { + let msg = serde_json::to_vec(&msg).expect("push message serialization failed"); + for dev in to { + // ignore errors here, except by logging them. we can't notify the client + // about anything and failing isn't an option either. + let _ = self.push_one(context, db, dev, ttl, Some(&msg)).await; + } + } + + pub(crate) async fn command_received( + &self, + db: &DbConn, + to: &Device, + command: &str, + index: i64, + sender: &Option, + ) -> Result<()> { + let url = + format!("{}/auth/v1/account/device/commands?index={}&limit=1", self.base_uri, index); + let msg = json!({ + "version": 1, + "command": "fxaccounts:command_received", + "data": { + "command": command, + "index": index, + "sender": sender, + "url": url, + }, + }); + let msg = serde_json::to_vec(&msg)?; + self.push_one("command_received", Some(db), to, self.default_ttl, Some(&msg)).await + } + + pub(crate) async fn device_connected(&self, db: &DbConn, to: &[Device], name: &str) { + let msg = json!({ + "version": 1, + "command": "fxaccounts:device_connected", + "data": { + "deviceName": name, + }, + }); + self.push_all("device_connected", Some(db), to, self.default_ttl, &msg).await; + } + + pub(crate) async fn device_disconnected(&self, db: &DbConn, to: &[Device], id: &DeviceID) { + let msg = json!({ + "version": 1, + "command": "fxaccounts:device_disconnected", + "data": { + "id": id, + }, + }); + self.push_all("device_disconnected", Some(db), to, self.default_ttl, &msg).await; + } + + pub(crate) async fn profile_updated(&self, db: &DbConn, to: &[Device]) { + let msg = json!({ + "version": 1, + "command": "fxaccounts:profile_updated", + }); + self.push_all("profile_updated", Some(db), to, self.default_ttl, &msg).await; + } + + pub(crate) async fn account_verified(&self, db: &DbConn, to: &[Device]) { + for dev in to { + // ignore errors here, except by logging them. we can't notify the client + // about anything and failing isn't an option either. + let _ = self.push_one("account_verified", Some(db), dev, Duration::ZERO, None).await; + } + } + + pub(crate) async fn account_destroyed(&self, to: &[Device], uid: &UserID) { + let msg = json!({ + "version": 1, + "command": "fxaccounts:account_destroyed", + "data": { + "uid": uid, + }, + }); + self.push_all("account_destroyed", None, to, self.default_ttl, &msg).await; + } + + pub(crate) async fn password_reset(&self, to: &[Device]) { + let msg = serde_json::to_vec(&json!({ + "version": 1u32, + "command": "fxaccounts:password_reset", + })) + .expect("serde failed"); + for dev in to { + // ignore errors here, except by logging them. we can't notify the client + // about anything and failing isn't an option either. + let _ = self.push_one("password_reset", None, dev, self.default_ttl, Some(&msg)).await; + // NOTE password_reset alone doesn't seem to do much, se we also disconnect + // each device explicitly. + let msg = serde_json::to_vec(&json!({ + "version": 1, + "command": "fxaccounts:device_disconnected", + "data": { + "id": dev.device_id, + }, + })) + .expect("serde failed"); + let _ = self.push_one("password_reset", None, dev, self.default_ttl, Some(&msg)).await; + } + } + + pub(crate) async fn push_any(&self, db: &DbConn, to: &[Device], ttl: Duration, payload: Value) { + self.push_all("push_any", Some(db), to, ttl, &payload).await; + } +} diff --git a/src/types.rs b/src/types.rs new file mode 100644 index 0000000..c0c5dfe --- /dev/null +++ b/src/types.rs @@ -0,0 +1,436 @@ +use crate::crypto::SecretBytes; +use chrono::{DateTime, Utc}; +use password_hash::{rand_core::OsRng, Output, SaltString}; +use rand::RngCore; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use sha2::{Digest, Sha256}; +use sqlx::{ + postgres::{PgArgumentBuffer, PgTypeInfo, PgValueRef}, + Decode, Encode, Postgres, Type, +}; +use std::{ + collections::HashMap, + fmt::{Debug, Display}, + ops::Deref, + str::FromStr, +}; + +use self::oauth::ScopeSet; + +pub(crate) mod oauth; + +macro_rules! array_type { + ( + $( #[ $attr:meta ] )* + $name:ident($inner:ty) as $sql_name:ident { + $( $body:tt )* + } + ) => { + $( #[ $attr ] )* + pub(crate) struct $name(pub(crate) $inner); + + impl $name { + $( $body )* + } + + impl Type for $name { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name(stringify!($sql_name)) + } + } + + impl Encode<'_, Postgres> for $name { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> sqlx::encode::IsNull { + let raw = self.0.iter().map(Self::encode_elem).collect::>(); + Encode::<'_, Postgres>::encode_by_ref(&raw, buf) + } + } + + impl Decode<'_, Postgres> for $name { + fn decode(value: PgValueRef) -> Result { + Ok(Self::decode_elems(Decode::<'_, Postgres>::decode(value)?)?) + } + } + } +} + +macro_rules! bytea_types { + () => {}; + ( + #[simple_array] + struct $name:ident($inner:ty) as $sql_name:ident; + + $( $rest:tt )* + ) => { + bytea_types!{ + #[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] + #[serde(try_from = "String", into = "String")] + struct $name($inner) as $sql_name { + fn decode(v) -> _ { &v.0[..] } + fn encode(v) -> _ { v } + } + + impl FromStr for $name {} + impl ToString for $name {} + impl Debug for $name {} + + $( $rest )* + } + }; + ( + $( #[ $attr:meta ] )* + struct $name:ident($inner:ty) as $sql_name:ident { + $( fn arbitrary($a:ident) -> _ { $ae:expr } )? + fn decode($d:ident) -> _ { $de:expr } + fn encode($e:ident) -> _ { $ee:expr } + + $( $impls:tt )* + } + + $( $rest:tt )* + ) => { + $( #[ $attr ] )* + pub(crate) struct $name(pub(crate) $inner); + + impl $name { + $( $impls )* + } + + impl Type for $name { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name(stringify!($sql_name)) + } + } + + impl Encode<'_, Postgres> for $name { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> sqlx::encode::IsNull { + let $d = self; + <&[u8] as Encode<'_, Postgres>>::encode_by_ref(&$de, buf) + } + } + + impl Decode<'_, Postgres> for $name { + fn decode(value: PgValueRef) -> Result { + let $e = <&[u8] as Decode<'_, Postgres>>::decode(value)?.try_into()?; + Ok($name($ee)) + } + } + + bytea_types!{ $( $rest )* } + }; + ( impl ToString for $name:ident {} $( $rest:tt )* ) => { + impl Display for $name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + f.write_str(&hex::encode(&self.0)) + } + } + impl From<$name> for String { + fn from(s: $name) -> String { + format!("{}", s) + } + } + bytea_types!{ $( $rest )* } + }; + ( impl Debug for $name:ident {} $( $rest:tt )* ) => { + impl Debug for $name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple(stringify!($name)).field(&self.to_string()).finish() + } + } + bytea_types!{ $( $rest )* } + }; + ( impl FromStr for $name:ident {} $( $rest:tt )* ) => { + impl FromStr for $name { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + Ok(Self(hex::decode(s)?.as_slice().try_into()?)) + } + } + impl TryFrom for $name { + type Error = anyhow::Error; + + fn try_from(s: String) -> Result { + s.parse() + } + } + bytea_types!{ $( $rest )* } + } +} + +// +// +// + +bytea_types! { + #[derive(Clone, Debug, PartialEq, Eq)] + struct HawkKey(SecretBytes<32>) as hawk_key { + fn arbitrary(a) -> _ { SecretBytes(a) } + fn decode(v) -> _ { v.0.0.as_ref() } + fn encode(v) -> _ { SecretBytes(v) } + } + + #[simple_array] + struct SessionID([u8; 32]) as session_id; + + #[simple_array] + struct DeviceID([u8; 16]) as device_id; + + #[simple_array] + struct UserID([u8; 16]) as user_id; + + #[simple_array] + struct KeyFetchID([u8; 32]) as key_fetch_id; + + #[simple_array] + struct OauthTokenID([u8; 32]) as oauth_token_id; + + #[simple_array] + struct OauthAuthorizationID([u8; 32]) as oauth_auth_id; + + #[simple_array] + struct PasswordChangeID([u8; 32]) as password_change_id; + + #[simple_array] + struct AccountResetID([u8; 32]) as account_reset_id; + + #[simple_array] + struct AvatarID([u8; 16]) as avatar_id; + + #[derive(Clone, Debug, PartialEq, Eq)] + struct SecretKey(SecretBytes<32>) as secret_key { + fn arbitrary(a) -> _ { SecretBytes(a) } + fn decode(v) -> _ { v.0.0.as_ref() } + fn encode(v) -> _ { SecretBytes(v) } + } + + #[derive(Clone, Debug, PartialEq, Eq)] + struct VerifyHash(Output) as verify_hash { + fn arbitrary(a) -> _ { Output::new(<[u8; 32]>::as_ref(&a)).unwrap() } + fn decode(v) -> _ { v.0.as_ref() } + fn encode(v) -> _ { v } + } +} + +impl DeviceID { + pub fn random() -> Self { + let mut result = Self([0; 16]); + OsRng.fill_bytes(&mut result.0); + result + } +} + +impl UserID { + pub fn random() -> Self { + let mut result = Self([0; 16]); + OsRng.fill_bytes(&mut result.0); + result + } +} + +impl OauthAuthorizationID { + pub fn random() -> Self { + let mut result = Self([0; 32]); + OsRng.fill_bytes(&mut result.0); + result + } +} + +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(try_from = "String", into = "String")] +pub(crate) struct OauthToken([u8; 32]); + +impl OauthToken { + pub fn random() -> Self { + let mut result = Self([0; 32]); + OsRng.fill_bytes(&mut result.0); + result + } + + pub fn hash(&self) -> OauthTokenID { + let mut sha = Sha256::new(); + sha.update(&self.0); + OauthTokenID(*sha.finalize().as_ref()) + } +} + +bytea_types! { + impl Debug for OauthToken {} + impl FromStr for OauthToken {} + impl ToString for OauthToken {} +} + +#[derive(Debug, Deserialize, PartialEq, Eq, Type)] +#[sqlx(type_name = "oauth_access_type", rename_all = "lowercase")] +#[serde(rename_all = "lowercase")] +pub enum OauthAccessType { + Online, + Offline, +} + +#[derive(Debug)] +pub(crate) struct UserSession { + pub(crate) uid: UserID, + pub(crate) req_hmac_key: HawkKey, + pub(crate) device_id: Option, + pub(crate) created_at: DateTime, + pub(crate) verified: bool, + pub(crate) verify_code: Option, +} + +#[derive(Clone, Debug)] +pub(crate) struct DeviceCommand { + pub(crate) index: i64, + pub(crate) command: String, + pub(crate) payload: Value, + #[allow(dead_code)] + pub(crate) expires: DateTime, + // NOTE this is a device ID, but we don't link it to the actual sender device + // because removing a device would also remove its queued commands. this mirrors + // what fxa does. + pub(crate) sender: Option, +} + +#[derive(Clone, Debug, PartialEq, sqlx::Type)] +#[sqlx(type_name = "device_push_info")] +pub(crate) struct DevicePush { + pub(crate) callback: String, + pub(crate) public_key: String, + pub(crate) auth_key: String, +} + +#[derive(Clone, Debug, PartialEq, sqlx::Type)] +#[sqlx(type_name = "device_command")] +struct DeviceCommandsEntry { + name: String, + body: String, +} + +array_type! { + #[derive(Clone, Debug, PartialEq)] + DeviceCommands(HashMap) as _device_command { + fn encode_elem(e: (&String, &String)) -> DeviceCommandsEntry { + DeviceCommandsEntry { name: e.0.clone(), body: e.1.clone() } + } + fn decode_elems(e: Vec) -> anyhow::Result { + Ok(Self(e.into_iter().map(|e| (e.name, e.body)).collect())) + } + + pub(crate) fn into_map(self) -> HashMap { + self.0 + } + } +} + +impl Deref for DeviceCommands { + type Target = HashMap; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[derive(Clone, Debug)] +pub(crate) struct Device { + pub(crate) device_id: DeviceID, + // taken from session, otherwise UNIX_EPOCH + pub(crate) last_active: DateTime, + pub(crate) name: String, + pub(crate) type_: String, + pub(crate) push: Option, + pub(crate) available_commands: DeviceCommands, + pub(crate) push_expired: bool, + // actually a str->str map, but we treat it as opaque for simplicity. + // writing a HashMap to the db through sqlx is an immense pain, + // and we don't care about the value anyway—it only has to exist for fenix. + pub(crate) location: Value, +} + +#[derive(Clone, Debug)] +pub(crate) struct DeviceUpdate<'a> { + pub(crate) name: Option<&'a str>, + pub(crate) type_: Option<&'a str>, + pub(crate) push: Option, + pub(crate) available_commands: Option, + pub(crate) location: Option, +} + +#[derive(Debug, sqlx::Type)] +#[sqlx(type_name = "oauth_token_kind", rename_all = "lowercase")] +pub(crate) enum OauthTokenKind { + Access, + Refresh, +} + +#[derive(Debug)] +pub(crate) struct OauthAccessToken { + pub(crate) user_id: UserID, + pub(crate) client_id: String, + pub(crate) scope: ScopeSet, + pub(crate) parent_refresh: Option, + pub(crate) parent_session: Option, + pub(crate) expires_at: DateTime, +} + +#[derive(Debug)] +pub(crate) struct OauthRefreshToken { + pub(crate) user_id: UserID, + pub(crate) client_id: String, + pub(crate) scope: ScopeSet, + pub(crate) session_id: Option, +} + +#[derive(Debug)] +pub(crate) struct OauthAuthorization { + pub(crate) user_id: UserID, + pub(crate) client_id: String, + pub(crate) scope: ScopeSet, + pub(crate) access_type: OauthAccessType, + pub(crate) code_challenge: String, + pub(crate) keys_jwe: Option, + pub(crate) auth_at: DateTime, +} + +#[derive(Debug)] +#[cfg_attr(test, derive(Clone))] +pub(crate) struct User { + pub(crate) auth_salt: SaltString, + pub(crate) email: String, + pub(crate) display_name: Option, + pub(crate) ka: SecretKey, + pub(crate) wrapwrap_kb: SecretKey, + pub(crate) verify_hash: VerifyHash, + pub(crate) verified: bool, +} + +// MISSING user secondary email addresses + +#[derive(Debug)] +pub(crate) struct Avatar { + pub(crate) id: AvatarID, + pub(crate) data: Vec, + pub(crate) content_type: String, +} + +#[derive(Debug)] +pub(crate) struct AttachedClient { + pub(crate) client_id: Option, + pub(crate) device_id: Option, + pub(crate) session_token_id: Option, + pub(crate) refresh_token_id: Option, + pub(crate) device_type: Option, + pub(crate) name: Option, + pub(crate) created_time: Option>, + pub(crate) last_access_time: Option>, + pub(crate) scope: Option, +} + +#[derive(Debug)] +pub(crate) struct VerifyCode { + #[allow(dead_code)] + pub(crate) user_id: UserID, + pub(crate) session_id: Option, + #[allow(dead_code)] + pub(crate) code: String, +} diff --git a/src/types/oauth.rs b/src/types/oauth.rs new file mode 100644 index 0000000..222c567 --- /dev/null +++ b/src/types/oauth.rs @@ -0,0 +1,267 @@ +use std::{borrow::Cow, fmt::Display}; + +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(transparent)] +pub(crate) struct Scope<'a>(pub Cow<'a, str>); + +impl<'a> Scope<'a> { + pub const fn borrowed(s: &'a str) -> Self { + Self(Cow::Borrowed(s)) + } + + pub fn into_owned(self) -> Scope<'static> { + Scope(Cow::Owned(self.0.into_owned())) + } + + pub fn implies(&self, other: &Scope) -> bool { + let (a, b) = (&*self.0, &*other.0); + match (a.strip_prefix("https://"), b.strip_prefix("https://")) { + (Some(a), Some(b)) => { + let (a_origin, a_path) = a.split_once('/').unwrap_or((a, "")); + let (b_origin, b_path) = b.split_once('/').unwrap_or((b, "")); + if a_origin != b_origin { + false + } else { + let (a_path, a_frag) = match a_path.split_once('#') { + Some((p, f)) => (p, Some(f)), + None => (a_path, None), + }; + let (b_path, b_frag) = match b_path.split_once('#') { + Some((p, f)) => (p, Some(f)), + None => (b_path, None), + }; + if b_path + .strip_prefix(a_path) + .map_or(false, |br| br.is_empty() || br.starts_with('/')) + { + match (a_frag, b_frag) { + (Some(af), Some(bf)) => af == bf, + (Some(_), None) => false, + _ => true, + } + } else { + false + } + } + }, + (None, None) => { + let (a, a_write) = + a.strip_suffix(":write").map(|s| (s, true)).unwrap_or((a, false)); + let (b, b_write) = + b.strip_suffix(":write").map(|s| (s, true)).unwrap_or((b, false)); + if b_write && !a_write { + false + } else { + b.strip_prefix(a).map_or(false, |br| br.is_empty() || br.starts_with(':')) + } + }, + _ => false, + } + } +} + +impl<'a> Display for Scope<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +#[derive(Clone, Debug, Serialize, Deserialize, sqlx::Type)] +#[serde(transparent)] +#[sqlx(transparent)] +pub(crate) struct ScopeSet(String); + +impl ScopeSet { + pub fn split(&self) -> impl Iterator { + // not using split_whitespace because the oauth spec explicitly says to split on SP + self.0.split(' ').filter(|s| !s.is_empty()).map(Scope::borrowed) + } + + pub fn implies(&self, scope: &Scope) -> bool { + self.split().any(|a| a.implies(scope)) + } + + pub fn implies_all(&self, scopes: &ScopeSet) -> bool { + scopes.split().all(|b| self.implies(&b)) + } + + pub fn is_allowed_by(&self, allowed: &[Scope]) -> bool { + self.split().all(|scope| allowed.iter().any(|perm| perm.implies(&scope))) + } + + pub fn remove(&self, remove: &Scope) -> ScopeSet { + let remaining = self.split().filter(|s| !remove.implies(s)); + ScopeSet(remaining.map(|s| s.0).collect::>().join(" ")) + } +} + +impl PartialEq for ScopeSet { + fn eq(&self, other: &Self) -> bool { + let (mut a, mut b) = (self.split().collect::>(), other.split().collect::>()); + a.sort_by(|a, b| a.0.cmp(&b.0)); + b.sort_by(|a, b| a.0.cmp(&b.0)); + a.eq(&b) + } +} + +impl Eq for ScopeSet {} + +impl Display for ScopeSet { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +#[cfg(test)] +mod test { + use super::{Scope, ScopeSet}; + + #[test] + fn test_scope_implies() { + assert!(ScopeSet("profile:write".to_string()).implies(&Scope::borrowed("profile"))); + assert!(ScopeSet("profile".to_string()).implies(&Scope::borrowed("profile:email"))); + assert!(ScopeSet("profile:write".to_string()).implies(&Scope::borrowed("profile:email"))); + assert!( + ScopeSet("profile:write".to_string()).implies(&Scope::borrowed("profile:email:write")) + ); + assert!( + ScopeSet("profile:email:write".to_string()).implies(&Scope::borrowed("profile:email")) + ); + assert!(ScopeSet("profile profile:email:write".to_string()) + .implies(&Scope::borrowed("profile:email"))); + assert!(ScopeSet("profile profile:email:write".to_string()) + .implies(&Scope::borrowed("profile:display_name"))); + assert!(ScopeSet("profile https://identity.mozilla.com/apps/oldsync".to_string()) + .implies(&Scope::borrowed("profile"))); + assert!(ScopeSet("profile https://identity.mozilla.com/apps/oldsync".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync"))); + assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync#read"))); + assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/bookmarks"))); + assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/bookmarks#read"))); + assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync#read".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/bookmarks#read"))); + assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync#read profile".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/bookmarks#read"))); + + assert!(!ScopeSet("profile:email:write".to_string()).implies(&Scope::borrowed("profile"))); + assert!( + !ScopeSet("profile:email:write".to_string()).implies(&Scope::borrowed("profile:write")) + ); + assert!(!ScopeSet("profile:email".to_string()) + .implies(&Scope::borrowed("profile:display_name"))); + assert!(!ScopeSet("profilebogey".to_string()).implies(&Scope::borrowed("profile"))); + assert!(!ScopeSet("profile:write".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync"))); + assert!(!ScopeSet("profile profile:email:write".to_string()) + .implies(&Scope::borrowed("profile:write"))); + assert!(!ScopeSet("https".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync"))); + assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsync".to_string()) + .implies(&Scope::borrowed("profile"))); + assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsync#read".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/bookmarks"))); + assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsync#write".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/bookmarks#read"))); + assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsync/bookmarks".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync"))); + assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsync/bookmarks".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/passwords"))); + assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsyncer".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync"))); + assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsync".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsyncer"))); + assert!(!ScopeSet("https://identity.mozilla.org/apps/oldsync".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync"))); + } + + #[test] + fn test_scopes_allowed_by() { + const ALLOWED: [Scope; 2] = [ + Scope::borrowed("profile:write"), + Scope::borrowed("https://identity.mozilla.com/apps/oldsync"), + ]; + + assert!(ScopeSet("profile".to_string()).is_allowed_by(&ALLOWED)); + assert!(ScopeSet("profile:write".to_string()).is_allowed_by(&ALLOWED)); + assert!(ScopeSet("profile:email".to_string()).is_allowed_by(&ALLOWED)); + assert!(ScopeSet("profile:email:write".to_string()).is_allowed_by(&ALLOWED)); + assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync".to_string()) + .is_allowed_by(&ALLOWED)); + assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync#read".to_string()) + .is_allowed_by(&ALLOWED)); + assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync/bookmarks".to_string()) + .is_allowed_by(&ALLOWED)); + assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync/bookmarks#read".to_string()) + .is_allowed_by(&ALLOWED)); + assert!(ScopeSet("profile https://identity.mozilla.com/apps/oldsync".to_string()) + .is_allowed_by(&ALLOWED)); + + assert!(!ScopeSet("storage".to_string()).is_allowed_by(&ALLOWED)); + assert!(!ScopeSet("storage:write".to_string()).is_allowed_by(&ALLOWED)); + assert!(!ScopeSet("storage:email".to_string()).is_allowed_by(&ALLOWED)); + assert!(!ScopeSet("storage:email:write".to_string()).is_allowed_by(&ALLOWED)); + assert!(!ScopeSet("https://identity.mozilla.com/apps/newsync".to_string()) + .is_allowed_by(&ALLOWED)); + assert!(!ScopeSet("https://identity.mozilla.com/apps/newsync#read".to_string()) + .is_allowed_by(&ALLOWED)); + assert!(!ScopeSet("https://identity.mozilla.com/apps/newsync/bookmarks".to_string()) + .is_allowed_by(&ALLOWED)); + assert!(!ScopeSet("https://identity.mozilla.com/apps/newsync/bookmarks#read".to_string()) + .is_allowed_by(&ALLOWED)); + assert!(!ScopeSet("storage https://identity.mozilla.com/apps/newsync".to_string()) + .is_allowed_by(&ALLOWED)); + } + + #[test] + fn test_scopes_remove() { + assert_eq!( + ScopeSet("profile foo".to_string()).remove(&Scope::borrowed("profile")), + ScopeSet("foo".to_string()) + ); + assert_ne!( + ScopeSet("profile:write foo".to_string()).remove(&Scope::borrowed("profile")), + ScopeSet("foo".to_string()) + ); + assert_eq!( + ScopeSet("profile:write foo".to_string()).remove(&Scope::borrowed("profile:write")), + ScopeSet("foo".to_string()) + ); + assert_eq!( + ScopeSet("profile:x foo".to_string()).remove(&Scope::borrowed("profile")), + ScopeSet("foo".to_string()) + ); + assert_ne!( + ScopeSet("profile:x:write foo".to_string()).remove(&Scope::borrowed("profile")), + ScopeSet("foo".to_string()) + ); + assert_eq!( + ScopeSet("profile:x:write foo".to_string()).remove(&Scope::borrowed("profile:write")), + ScopeSet("foo".to_string()) + ); + + assert_eq!( + ScopeSet("https://foo/bar foo".to_string()).remove(&Scope::borrowed("https://foo/bar")), + ScopeSet("foo".to_string()) + ); + assert_eq!( + ScopeSet("https://foo/bar/baz foo".to_string()) + .remove(&Scope::borrowed("https://foo/bar")), + ScopeSet("foo".to_string()) + ); + assert_eq!( + ScopeSet("https://foo/bar#read foo".to_string()) + .remove(&Scope::borrowed("https://foo/bar")), + ScopeSet("foo".to_string()) + ); + assert_eq!( + ScopeSet("https://foo/bar/baz#read foo".to_string()) + .remove(&Scope::borrowed("https://foo/bar")), + ScopeSet("foo".to_string()) + ); + } +} diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..5a2407a --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,124 @@ +use std::convert::Infallible; + +use futures::Future; +use rocket::{ + fairing::{self, Fairing, Info, Kind}, + http::Status, + request::{FromRequest, Outcome}, + tokio::{ + self, + sync::broadcast::{channel, Sender}, + task::JoinHandle, + }, + Request, Response, Sentinel, +}; + +// does the same as try_outcome!(...).map_forward(|_| data), +// but without moving data in non-failure cases. +macro_rules! try_outcome_data { + ($data:expr, $e:expr) => { + match $e { + ::rocket::outcome::Outcome::Success(val) => val, + ::rocket::outcome::Outcome::Failure(e) => { + return ::rocket::outcome::Outcome::Failure(::std::convert::From::from(e)) + }, + ::rocket::outcome::Outcome::Forward(()) => { + return ::rocket::outcome::Outcome::Forward($data) + }, + } + }; +} + +pub trait CanBeLogged: Send + 'static { + fn into_message(self) -> Option; + fn not_run() -> Self; +} + +impl CanBeLogged for Result<(), anyhow::Error> { + fn into_message(self) -> Option { + self.err().map(|e| e.to_string()) + } + fn not_run() -> Self { + Ok(()) + } +} + +pub fn spawn_logged(context: &'static str, future: T) -> JoinHandle<()> +where + T: Future + Send + 'static, + T::Output: CanBeLogged + Send + 'static, +{ + tokio::spawn(async move { + if let Some(msg) = future.await.into_message() { + warn!("{context}: {msg}"); + } + }) +} + +pub struct DeferredActions; +struct HasDeferredActions; + +pub struct DeferAction(Sender<()>); + +#[async_trait] +impl Fairing for DeferredActions { + fn info(&self) -> Info { + Info { name: "deferred actions", kind: Kind::Ignite | Kind::Response } + } + + async fn on_ignite(&self, rocket: rocket::Rocket) -> fairing::Result { + Ok(rocket.manage(HasDeferredActions)) + } + + async fn on_response<'r>(&self, req: &'r Request<'_>, res: &mut Response<'r>) { + if res.status() == Status::Ok { + if let Some(DeferAction(tx)) = req.local_cache(|| None) { + // could have no receivers, that's not an error here + tx.send(()).ok(); + } + } + } +} + +impl<'r> Sentinel for &'r DeferAction { + fn abort(rocket: &rocket::Rocket) -> bool { + rocket.state::().is_none() + } +} + +#[async_trait] +impl<'r> FromRequest<'r> for &'r DeferAction { + type Error = Infallible; + + async fn from_request(req: &'r Request<'_>) -> Outcome { + Outcome::Success( + req.local_cache(|| { + let (tx, _) = channel(1); + Some(DeferAction(tx)) + }) + .as_ref() + .unwrap(), + ) + } +} + +impl DeferAction { + pub fn spawn_after_success(&self, context: &'static str, future: T) + where + T: Future + Send + 'static, + T::Output: CanBeLogged + Send + 'static, + { + let mut r = self.0.subscribe(); + spawn_logged(context, async move { + match r.recv().await { + Ok(_) => { + // the request finished with success, now wait for it to be dropped + // to ensure that all other fairings have run to completion. + r.recv().await.ok(); + future.await + }, + Err(_) => CanBeLogged::not_run(), + } + }); + } +} -- cgit v1.2.3