diff options
-rw-r--r-- | src/api/auth/account.rs | 41 | ||||
-rw-r--r-- | src/api/auth/oauth.rs | 16 | ||||
-rw-r--r-- | src/api/auth/password.rs | 39 | ||||
-rw-r--r-- | src/crypto.rs | 139 | ||||
-rw-r--r-- | src/lib.rs | 3 | ||||
-rw-r--r-- | src/serde.rs | 21 |
6 files changed, 175 insertions, 84 deletions
diff --git a/src/api/auth/account.rs b/src/api/auth/account.rs index 0f12d49..c0c2099 100644 --- a/src/api/auth/account.rs +++ b/src/api/auth/account.rs @@ -11,6 +11,7 @@ use serde::{Deserialize, Serialize}; use validator::Validate; use crate::api::{Empty, EMPTY}; +use crate::crypto::{KeyFetchToken, SessionToken}; use crate::db::{Db, DbConn}; use crate::mailer::Mailer; use crate::push::PushClient; @@ -21,7 +22,7 @@ use crate::{ api::{auth, serialize_dt}, auth::{AuthSource, Authenticated}, crypto::{AuthPW, KeyBundle, KeyFetchReq, SecretBytes, SessionCredentials}, - types::{HawkKey, KeyFetchID, OauthToken, SecretKey, SessionID, User, UserID, VerifyHash}, + types::{HawkKey, KeyFetchID, OauthToken, SecretKey, User, UserID, VerifyHash}, }; // TODO better error handling @@ -52,9 +53,9 @@ pub(crate) struct Create { #[serde(deny_unknown_fields)] pub(crate) struct CreateResp { uid: UserID, - sessionToken: SecretBytes<32>, + sessionToken: SessionToken, #[serde(skip_serializing_if = "Option::is_none")] - keyFetchToken: Option<SecretBytes<32>>, + keyFetchToken: Option<KeyFetchToken>, #[serde(serialize_with = "serialize_dt")] authAt: DateTime<Utc>, // MISSING verificationMethod @@ -94,17 +95,16 @@ pub(crate) async fn create( let auth_salt = SaltString::generate(rand::rngs::OsRng); let stretched = data.authPW.stretch(auth_salt.as_salt())?; let verify_hash = stretched.verify_hash(); - let session_token = SecretBytes::generate(); - let session = SessionCredentials::derive(&session_token); + let session_token = SessionToken::generate(); + let session = SessionCredentials::derive_from(&session_token); let key_fetch_token = if keys { - let key_fetch_token = SecretBytes::generate(); - let req = KeyFetchReq::from_token(&key_fetch_token); + let key_fetch_token = KeyFetchToken::generate(); + let req = KeyFetchReq::derive_from(&key_fetch_token); let wrapped = req.derive_resp().wrap_keys(&KeyBundle { ka: ka.clone(), wrap_kb: stretched.decrypt_wwkb(&wrapwrap_kb), }); - db.add_key_fetch(KeyFetchID(req.token_id.0), &HawkKey(req.req_hmac_key.0), &wrapped) - .await?; + db.add_key_fetch(req.token_id, &HawkKey(req.req_hmac_key.0), &wrapped).await?; Some(key_fetch_token) } else { None @@ -120,12 +120,11 @@ pub(crate) async fn create( verified: false, }) .await?; - let session_id = SessionID(session.token_id.0); let auth_at = db - .add_session(session_id.clone(), &uid, HawkKey(session.req_hmac_key.0), false, None) + .add_session(session.token_id.clone(), &uid, HawkKey(session.req_hmac_key.0), false, None) .await?; let verify_code = hex::encode(&SecretBytes::<16>::generate().0); - db.add_verify_code(&uid, &session_id, &verify_code).await?; + db.add_verify_code(&uid, &session.token_id, &verify_code).await?; // NOTE we send the email in this context rather than a spawn to signal // send errors to the client. mailer.send_account_verify(&uid, &data.email, &verify_code).await.map_err(|e| { @@ -161,9 +160,9 @@ pub(crate) struct Login { #[serde(deny_unknown_fields)] pub(crate) struct LoginResp { uid: UserID, - sessionToken: SecretBytes<32>, + sessionToken: SessionToken, #[serde(skip_serializing_if = "Option::is_none")] - keyFetchToken: Option<SecretBytes<32>>, + keyFetchToken: Option<KeyFetchToken>, // MISSING verificationMethod // MISSING verificationReason // NOTE this is the *account* verified status, not the session status. @@ -200,27 +199,25 @@ pub(crate) async fn login( return Err(auth::Error::IncorrectPassword); } - let session_token = SecretBytes::generate(); - let session = SessionCredentials::derive(&session_token); + let session_token = SessionToken::generate(); + let session = SessionCredentials::derive_from(&session_token); let key_fetch_token = if keys { - let key_fetch_token = SecretBytes::generate(); - let req = KeyFetchReq::from_token(&key_fetch_token); + let key_fetch_token = KeyFetchToken::generate(); + let req = KeyFetchReq::derive_from(&key_fetch_token); let wrapped = req.derive_resp().wrap_keys(&KeyBundle { ka: SecretBytes(user.ka.0), wrap_kb: stretched.decrypt_wwkb(&SecretBytes(user.wrapwrap_kb.0)), }); - db.add_key_fetch(KeyFetchID(req.token_id.0), &HawkKey(req.req_hmac_key.0), &wrapped) - .await?; + db.add_key_fetch(req.token_id, &HawkKey(req.req_hmac_key.0), &wrapped).await?; Some(key_fetch_token) } else { None }; - let session_id = SessionID(session.token_id.0); let verify_code = format!("{:06}", thread_rng().gen_range(0..=999999)); let auth_at = db .add_session( - session_id.clone(), + session.token_id.clone(), &uid, HawkKey(session.req_hmac_key.0), false, diff --git a/src/api/auth/oauth.rs b/src/api/auth/oauth.rs index 6d2f700..c159352 100644 --- a/src/api/auth/oauth.rs +++ b/src/api/auth/oauth.rs @@ -9,12 +9,13 @@ use subtle::ConstantTimeEq; use crate::api::auth::WithVerifiedFxaLogin; use crate::api::{Empty, EMPTY}; +use crate::crypto::SessionToken; use crate::db::DbConn; use crate::types::oauth::{Scope, ScopeSet}; use crate::{ api::{auth, serialize_dt}, auth::Authenticated, - crypto::{SecretBytes, SessionCredentials}, + crypto::SessionCredentials, types::{ HawkKey, OauthAccessToken, OauthAccessType, OauthAuthorization, OauthAuthorizationID, OauthRefreshToken, OauthToken, OauthTokenID, SessionID, UserID, @@ -278,7 +279,7 @@ pub(crate) struct TokenResp { refresh_token: Option<OauthToken>, // MISSING id_token #[serde(skip_serializing_if = "Option::is_none")] - session_token: Option<String>, + session_token: Option<SessionToken>, scope: ScopeSet, token_type: TokenType, expires_in: u32, @@ -391,18 +392,17 @@ async fn token_impl( let (refresh_token, session_token) = if access_type == Some(OauthAccessType::Offline) { let (session_token, session_id) = if scope.implies(&SESSION_SCOPE) { - let session_token = SecretBytes::generate(); - let session = SessionCredentials::derive(&session_token); - let session_id = SessionID(session.token_id.0); + let session_token = SessionToken::generate(); + let session = SessionCredentials::derive_from(&session_token); db.add_session( - session_id.clone(), + session.token_id.clone(), &user_id, HawkKey(session.req_hmac_key.0), true, None, ) .await?; - (Some(session_token.0), Some(SessionID(session.token_id.0))) + (Some(session_token), Some(session.token_id)) } else { (None, None) }; @@ -426,7 +426,7 @@ async fn token_impl( Ok(Json(TokenResp { access_token, refresh_token, - session_token: session_token.map(hex::encode), + session_token, scope: scope.remove(&SESSION_SCOPE), token_type: TokenType::Bearer, expires_in: ttl, diff --git a/src/api/auth/password.rs b/src/api/auth/password.rs index 56ad2a2..ae5bd6d 100644 --- a/src/api/auth/password.rs +++ b/src/api/auth/password.rs @@ -9,11 +9,14 @@ use validator::Validate; use crate::{ api::auth, auth::{AuthSource, Authenticated}, - crypto::{AccountResetReq, AuthPW, KeyBundle, KeyFetchReq, PasswordChangeReq, SecretBytes}, + crypto::{ + AccountResetReq, AccountResetToken, AuthPW, KeyBundle, KeyFetchReq, KeyFetchToken, + PasswordChangeReq, PasswordChangeToken, SecretBytes, + }, db::{Db, DbConn}, mailer::Mailer, types::{ - AccountResetID, HawkKey, KeyFetchID, OauthToken, PasswordChangeID, SecretKey, UserID, + HawkKey, OauthToken, PasswordChangeID, SecretKey, UserID, VerifyHash, }, }; @@ -34,8 +37,8 @@ pub(crate) struct ChangeStartReq { #[derive(Debug, Serialize)] #[allow(non_snake_case)] pub(crate) struct ChangeStartResp { - keyFetchToken: SecretBytes<32>, - passwordChangeToken: SecretBytes<32>, + keyFetchToken: KeyFetchToken, + passwordChangeToken: PasswordChangeToken, } #[post("/password/change/start", data = "<data>")] @@ -59,19 +62,19 @@ pub(crate) async fn change_start( return Err(auth::Error::IncorrectPassword); } - let change_token = SecretBytes::generate(); - let change_req = PasswordChangeReq::from_change_token(&change_token); - let key_fetch_token = SecretBytes::generate(); - let key_req = KeyFetchReq::from_token(&key_fetch_token); + let change_token = PasswordChangeToken::generate(); + let change_req = PasswordChangeReq::derive_from_change_token(&change_token); + let key_fetch_token = KeyFetchToken::generate(); + let key_req = KeyFetchReq::derive_from(&key_fetch_token); let wrapped = key_req.derive_resp().wrap_keys(&KeyBundle { ka: SecretBytes(user.ka.0), wrap_kb: stretched.decrypt_wwkb(&SecretBytes(user.wrapwrap_kb.0)), }); - db.add_key_fetch(KeyFetchID(key_req.token_id.0), &HawkKey(key_req.req_hmac_key.0), &wrapped) + db.add_key_fetch(key_req.token_id, &HawkKey(key_req.req_hmac_key.0), &wrapped) .await?; db.add_password_change( &uid, - &PasswordChangeID(change_req.token_id.0), + &change_req.token_id, &HawkKey(change_req.req_hmac_key.0), None, ) @@ -183,7 +186,7 @@ pub(crate) struct ForgotStartReq { #[derive(Debug, Serialize)] #[allow(non_snake_case)] pub(crate) struct ForgotStartResp { - passwordForgotToken: SecretBytes<32>, + passwordForgotToken: PasswordChangeToken, ttl: u32, codeLength: u32, tries: u32, @@ -207,11 +210,11 @@ pub(crate) async fn forgot_start( } let forgot_code = hex::encode(SecretBytes::<16>::generate().0); - let forgot_token = SecretBytes::generate(); - let forgot_req = PasswordChangeReq::from_forgot_token(&forgot_token); + let forgot_token = PasswordChangeToken::generate(); + let forgot_req = PasswordChangeReq::derive_from_forgot_token(&forgot_token); db.add_password_change( &uid, - &PasswordChangeID(forgot_req.token_id.0), + &forgot_req.token_id, &HawkKey(forgot_req.req_hmac_key.0), Some(&forgot_code), ) @@ -238,7 +241,7 @@ pub(crate) struct ForgotFinishReq { #[derive(Debug, Serialize)] #[allow(non_snake_case)] pub(crate) struct ForgotFinishResp { - accountResetToken: SecretBytes<32>, + accountResetToken: AccountResetToken, } #[post("/password/forgot/verify_code", data = "<data>")] @@ -250,11 +253,11 @@ pub(crate) async fn forgot_finish( return Err(auth::Error::InvalidVerificationCode); } - let reset_token = SecretBytes::generate(); - let reset_req = AccountResetReq::from_token(&reset_token); + let reset_token = AccountResetToken::generate(); + let reset_req = AccountResetReq::derive_from(&reset_token); db.add_account_reset( &data.context.0, - &AccountResetID(reset_req.token_id.0), + &reset_req.token_id, &HawkKey(reset_req.req_hmac_key.0), ) .await?; diff --git a/src/crypto.rs b/src/crypto.rs index 617bd7a..7fba9cd 100644 --- a/src/crypto.rs +++ b/src/crypto.rs @@ -15,15 +15,23 @@ use scrypt::scrypt; use serde::{Deserialize, Serialize}; use sha2::Sha256; +use crate::{ + serde::as_hex, + types::{AccountResetID, KeyFetchID, PasswordChangeID, SessionID}, +}; + const NAMESPACE: &[u8] = b"identity.mozilla.com/picl/v1/"; +pub fn random_bytes<const N: usize>() -> [u8; N] { + let mut result = [0; N]; + rand::rngs::OsRng.fill_bytes(&mut result); + result +} + #[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(try_from = "String", into = "String")] pub struct SecretBytes<const N: usize>(pub [u8; N]); -#[derive(Clone, PartialEq, Eq)] -pub struct TokenID(pub [u8; 32]); - impl<const N: usize> Debug for SecretBytes<N> { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { fmt.write_fmt(format_args!("SecretBytes {{ raw: {} }}", hex::encode(&self.0))) @@ -73,7 +81,7 @@ mod from_hkdf { mod private { pub trait Seal {} impl<const N: usize> Seal for super::super::SecretBytes<N> {} - impl Seal for super::super::TokenID {} + impl<const N: usize> Seal for [u8; N] {} impl<L: Seal, R: Seal> Seal for (L, R) {} } @@ -90,11 +98,11 @@ mod from_hkdf { } } - impl FromHkdf for super::TokenID { - const SIZE: usize = 32; + impl<const N: usize> FromHkdf for [u8; N] { + const SIZE: usize = N; fn from_hkdf(bytes: &[u8]) -> Self { #[allow(clippy::expect_used)] - Self(bytes.try_into().expect("hkdf failed")) + bytes.try_into().expect("hkdf failed") } } @@ -166,29 +174,59 @@ impl StretchedPW { } } -pub struct SessionCredentials { - pub token_id: TokenID, +#[derive(Clone, Serialize, Deserialize, PartialEq, Eq)] +pub(crate) struct SessionToken(#[serde(with = "as_hex")] [u8; 32]); + +impl Debug for SessionToken { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("SessionToken").field(&hex::encode(self.0)).finish() + } +} + +impl SessionToken { + pub fn generate() -> Self { + Self(random_bytes()) + } +} + +pub(crate) struct SessionCredentials { + pub token_id: SessionID, pub req_hmac_key: SecretBytes<32>, } impl SessionCredentials { - pub fn derive(seed: &SecretBytes<32>) -> Self { + pub fn derive_from(seed: &SessionToken) -> Self { let (token_id, req_hmac_key) = from_hkdf(&seed.0, &[NAMESPACE, b"sessionToken"]); - Self { token_id, req_hmac_key } + Self { token_id: SessionID(token_id), req_hmac_key } + } +} + +#[derive(Clone, Serialize, Deserialize, PartialEq, Eq)] +pub(crate) struct KeyFetchToken(#[serde(with = "as_hex")] [u8; 32]); + +impl Debug for KeyFetchToken { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("KeyFetchToken").field(&hex::encode(self.0)).finish() + } +} + +impl KeyFetchToken { + pub fn generate() -> Self { + Self(random_bytes()) } } -pub struct KeyFetchReq { - pub token_id: TokenID, +pub(crate) struct KeyFetchReq { + pub token_id: KeyFetchID, pub req_hmac_key: SecretBytes<32>, key_request_key: SecretBytes<32>, } impl KeyFetchReq { - pub fn from_token(key_fetch_token: &SecretBytes<32>) -> Self { + pub fn derive_from(key_fetch_token: &KeyFetchToken) -> Self { let (token_id, (req_hmac_key, key_request_key)) = from_hkdf(&key_fetch_token.0, &[NAMESPACE, b"keyFetchToken"]); - Self { token_id, req_hmac_key, key_request_key } + Self { token_id: KeyFetchID(token_id), req_hmac_key, key_request_key } } pub fn derive_resp(&self) -> KeyFetchResp { @@ -243,32 +281,62 @@ impl WrappedKeyBundle { } } -pub struct PasswordChangeReq { - pub token_id: TokenID, +#[derive(Clone, Serialize, Deserialize, PartialEq, Eq)] +pub(crate) struct PasswordChangeToken(#[serde(with = "as_hex")] [u8; 32]); + +impl Debug for PasswordChangeToken { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("PasswordChangeToken").field(&hex::encode(self.0)).finish() + } +} + +impl PasswordChangeToken { + pub fn generate() -> Self { + Self(random_bytes()) + } +} + +pub(crate) struct PasswordChangeReq { + pub token_id: PasswordChangeID, pub req_hmac_key: SecretBytes<32>, } impl PasswordChangeReq { - pub fn from_change_token(token: &SecretBytes<32>) -> Self { + pub fn derive_from_change_token(token: &PasswordChangeToken) -> Self { let (token_id, req_hmac_key) = from_hkdf(&token.0, &[NAMESPACE, b"passwordChangeToken"]); - Self { token_id, req_hmac_key } + Self { token_id: PasswordChangeID(token_id), req_hmac_key } } - pub fn from_forgot_token(token: &SecretBytes<32>) -> Self { + pub fn derive_from_forgot_token(token: &PasswordChangeToken) -> Self { let (token_id, req_hmac_key) = from_hkdf(&token.0, &[NAMESPACE, b"passwordForgotToken"]); - Self { token_id, req_hmac_key } + Self { token_id: PasswordChangeID(token_id), req_hmac_key } + } +} + +#[derive(Clone, Serialize, Deserialize, PartialEq, Eq)] +pub(crate) struct AccountResetToken(#[serde(with = "as_hex")] [u8; 32]); + +impl Debug for AccountResetToken { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("AccountResetToken").field(&hex::encode(self.0)).finish() + } +} + +impl AccountResetToken { + pub fn generate() -> Self { + Self(random_bytes()) } } -pub struct AccountResetReq { - pub token_id: TokenID, +pub(crate) struct AccountResetReq { + pub token_id: AccountResetID, pub req_hmac_key: SecretBytes<32>, } impl AccountResetReq { - pub fn from_token(token: &SecretBytes<32>) -> Self { + pub fn derive_from(token: &AccountResetToken) -> Self { let (token_id, req_hmac_key) = from_hkdf(&token.0, &[NAMESPACE, b"accountResetToken"]); - Self { token_id, req_hmac_key } + Self { token_id: AccountResetID(token_id), req_hmac_key } } } @@ -278,7 +346,8 @@ mod test { use password_hash::{Output, SaltString}; use crate::crypto::{ - AccountResetReq, KeyBundle, KeyFetchReq, PasswordChangeReq, SessionCredentials, + AccountResetReq, AccountResetToken, KeyBundle, KeyFetchReq, KeyFetchToken, + PasswordChangeReq, PasswordChangeToken, SessionCredentials, SessionToken, }; use super::{AuthPW, SecretBytes}; @@ -291,7 +360,7 @@ mod test { #[test] fn test_derive_session() { - let creds = SessionCredentials::derive(&SecretBytes(hex!( + let creds = SessionCredentials::derive_from(&SessionToken(hex!( "a0a1a2a3a4a5a6a7 a8a9aaabacadaeaf b0b1b2b3b4b5b6b7 b8b9babbbcbdbebf" ))); assert_eq!( @@ -306,9 +375,9 @@ mod test { #[test] fn test_key_fetch() { - let key_fetch = KeyFetchReq::from_token(&shex!( + let key_fetch = KeyFetchReq::derive_from(&KeyFetchToken(hex!( "8081828384858687 88898a8b8c8d8e8f 9091929394959697 98999a9b9c9d9e9f" - )); + ))); assert_eq!( key_fetch.token_id.0, hex!("3d0a7c02a15a62a2882f76e39b6494b500c022a8816e048625a495718998ba60") @@ -404,9 +473,9 @@ mod test { #[test] fn test_password_change() { - let req = PasswordChangeReq::from_change_token(&shex!( + let req = PasswordChangeReq::derive_from_change_token(&PasswordChangeToken(hex!( "0000000000000000 0000000000000000 0000000000000000 0000000000000000" - )); + ))); assert_eq!( req.token_id.0, hex!("5a9f93f66c26fd1c 1ea9826fafc422e9 4b9c9f833cd2bfa5 da18c8d3317224aa") @@ -419,9 +488,9 @@ mod test { #[test] fn test_password_forgot() { - let req = PasswordChangeReq::from_forgot_token(&shex!( + let req = PasswordChangeReq::derive_from_forgot_token(&PasswordChangeToken(hex!( "0000000000000000 0000000000000000 0000000000000000 0000000000000000" - )); + ))); assert_eq!( req.token_id.0, hex!("570e79050fd157a9 b8e7d7d6f88a3f67 e36207c5dfabe7d8 a80994502a624e07") @@ -434,9 +503,9 @@ mod test { #[test] fn test_account_reset() { - let req = AccountResetReq::from_token(&shex!( + let req = AccountResetReq::derive_from(&AccountResetToken(hex!( "0000000000000000 0000000000000000 0000000000000000 0000000000000000" - )); + ))); assert_eq!( req.token_id.0, hex!("8ade842449ab0285 e7b22de9d428cd5b 3c38ea0aa78e2956 a6a69ec66818d864") @@ -44,6 +44,7 @@ pub mod db; mod js; mod mailer; mod push; +pub(crate) mod serde; mod types; fn default_push_ttl() -> std::time::Duration { @@ -54,7 +55,7 @@ fn default_task_interval() -> std::time::Duration { std::time::Duration::from_secs(5 * 60) } -#[derive(Debug, serde::Deserialize)] +#[derive(Debug, ::serde::Deserialize)] struct Config { database_url: String, location: Absolute<'static>, diff --git a/src/serde.rs b/src/serde.rs new file mode 100644 index 0000000..76676d7 --- /dev/null +++ b/src/serde.rs @@ -0,0 +1,21 @@ +pub mod as_hex { + use serde::{de, Deserialize, Deserializer, Serializer}; + + pub fn serialize<const N: usize, S: Serializer>( + b: &[u8; N], + ser: S, + ) -> Result<S::Ok, S::Error> { + ser.serialize_str(&hex::encode(b)) + } + + pub fn deserialize<'de, const N: usize, D: Deserializer<'de>>( + des: D, + ) -> Result<[u8; N], D::Error> { + let raw = <String as Deserialize>::deserialize(des)?; + let mut result = [0; N]; + hex::decode_to_slice(&raw, &mut result).map_err(|_| { + de::Error::invalid_value(de::Unexpected::Other("non-hex string"), &"a hex string") + })?; + Ok(result) + } +} |