diff options
Diffstat (limited to 'src/types.rs')
-rw-r--r-- | src/types.rs | 436 |
1 files changed, 436 insertions, 0 deletions
diff --git a/src/types.rs b/src/types.rs new file mode 100644 index 0000000..c0c5dfe --- /dev/null +++ b/src/types.rs @@ -0,0 +1,436 @@ +use crate::crypto::SecretBytes; +use chrono::{DateTime, Utc}; +use password_hash::{rand_core::OsRng, Output, SaltString}; +use rand::RngCore; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use sha2::{Digest, Sha256}; +use sqlx::{ + postgres::{PgArgumentBuffer, PgTypeInfo, PgValueRef}, + Decode, Encode, Postgres, Type, +}; +use std::{ + collections::HashMap, + fmt::{Debug, Display}, + ops::Deref, + str::FromStr, +}; + +use self::oauth::ScopeSet; + +pub(crate) mod oauth; + +macro_rules! array_type { + ( + $( #[ $attr:meta ] )* + $name:ident($inner:ty) as $sql_name:ident { + $( $body:tt )* + } + ) => { + $( #[ $attr ] )* + pub(crate) struct $name(pub(crate) $inner); + + impl $name { + $( $body )* + } + + impl Type<Postgres> for $name { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name(stringify!($sql_name)) + } + } + + impl Encode<'_, Postgres> for $name { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> sqlx::encode::IsNull { + let raw = self.0.iter().map(Self::encode_elem).collect::<Vec<_>>(); + Encode::<'_, Postgres>::encode_by_ref(&raw, buf) + } + } + + impl Decode<'_, Postgres> for $name { + fn decode(value: PgValueRef) -> Result<Self, sqlx::error::BoxDynError> { + Ok(Self::decode_elems(Decode::<'_, Postgres>::decode(value)?)?) + } + } + } +} + +macro_rules! bytea_types { + () => {}; + ( + #[simple_array] + struct $name:ident($inner:ty) as $sql_name:ident; + + $( $rest:tt )* + ) => { + bytea_types!{ + #[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] + #[serde(try_from = "String", into = "String")] + struct $name($inner) as $sql_name { + fn decode(v) -> _ { &v.0[..] } + fn encode(v) -> _ { v } + } + + impl FromStr for $name {} + impl ToString for $name {} + impl Debug for $name {} + + $( $rest )* + } + }; + ( + $( #[ $attr:meta ] )* + struct $name:ident($inner:ty) as $sql_name:ident { + $( fn arbitrary($a:ident) -> _ { $ae:expr } )? + fn decode($d:ident) -> _ { $de:expr } + fn encode($e:ident) -> _ { $ee:expr } + + $( $impls:tt )* + } + + $( $rest:tt )* + ) => { + $( #[ $attr ] )* + pub(crate) struct $name(pub(crate) $inner); + + impl $name { + $( $impls )* + } + + impl Type<Postgres> for $name { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name(stringify!($sql_name)) + } + } + + impl Encode<'_, Postgres> for $name { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> sqlx::encode::IsNull { + let $d = self; + <&[u8] as Encode<'_, Postgres>>::encode_by_ref(&$de, buf) + } + } + + impl Decode<'_, Postgres> for $name { + fn decode(value: PgValueRef) -> Result<Self, sqlx::error::BoxDynError> { + let $e = <&[u8] as Decode<'_, Postgres>>::decode(value)?.try_into()?; + Ok($name($ee)) + } + } + + bytea_types!{ $( $rest )* } + }; + ( impl ToString for $name:ident {} $( $rest:tt )* ) => { + impl Display for $name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + f.write_str(&hex::encode(&self.0)) + } + } + impl From<$name> for String { + fn from(s: $name) -> String { + format!("{}", s) + } + } + bytea_types!{ $( $rest )* } + }; + ( impl Debug for $name:ident {} $( $rest:tt )* ) => { + impl Debug for $name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple(stringify!($name)).field(&self.to_string()).finish() + } + } + bytea_types!{ $( $rest )* } + }; + ( impl FromStr for $name:ident {} $( $rest:tt )* ) => { + impl FromStr for $name { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + Ok(Self(hex::decode(s)?.as_slice().try_into()?)) + } + } + impl TryFrom<String> for $name { + type Error = anyhow::Error; + + fn try_from(s: String) -> Result<Self, Self::Error> { + s.parse() + } + } + bytea_types!{ $( $rest )* } + } +} + +// +// +// + +bytea_types! { + #[derive(Clone, Debug, PartialEq, Eq)] + struct HawkKey(SecretBytes<32>) as hawk_key { + fn arbitrary(a) -> _ { SecretBytes(a) } + fn decode(v) -> _ { v.0.0.as_ref() } + fn encode(v) -> _ { SecretBytes(v) } + } + + #[simple_array] + struct SessionID([u8; 32]) as session_id; + + #[simple_array] + struct DeviceID([u8; 16]) as device_id; + + #[simple_array] + struct UserID([u8; 16]) as user_id; + + #[simple_array] + struct KeyFetchID([u8; 32]) as key_fetch_id; + + #[simple_array] + struct OauthTokenID([u8; 32]) as oauth_token_id; + + #[simple_array] + struct OauthAuthorizationID([u8; 32]) as oauth_auth_id; + + #[simple_array] + struct PasswordChangeID([u8; 32]) as password_change_id; + + #[simple_array] + struct AccountResetID([u8; 32]) as account_reset_id; + + #[simple_array] + struct AvatarID([u8; 16]) as avatar_id; + + #[derive(Clone, Debug, PartialEq, Eq)] + struct SecretKey(SecretBytes<32>) as secret_key { + fn arbitrary(a) -> _ { SecretBytes(a) } + fn decode(v) -> _ { v.0.0.as_ref() } + fn encode(v) -> _ { SecretBytes(v) } + } + + #[derive(Clone, Debug, PartialEq, Eq)] + struct VerifyHash(Output) as verify_hash { + fn arbitrary(a) -> _ { Output::new(<[u8; 32]>::as_ref(&a)).unwrap() } + fn decode(v) -> _ { v.0.as_ref() } + fn encode(v) -> _ { v } + } +} + +impl DeviceID { + pub fn random() -> Self { + let mut result = Self([0; 16]); + OsRng.fill_bytes(&mut result.0); + result + } +} + +impl UserID { + pub fn random() -> Self { + let mut result = Self([0; 16]); + OsRng.fill_bytes(&mut result.0); + result + } +} + +impl OauthAuthorizationID { + pub fn random() -> Self { + let mut result = Self([0; 32]); + OsRng.fill_bytes(&mut result.0); + result + } +} + +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(try_from = "String", into = "String")] +pub(crate) struct OauthToken([u8; 32]); + +impl OauthToken { + pub fn random() -> Self { + let mut result = Self([0; 32]); + OsRng.fill_bytes(&mut result.0); + result + } + + pub fn hash(&self) -> OauthTokenID { + let mut sha = Sha256::new(); + sha.update(&self.0); + OauthTokenID(*sha.finalize().as_ref()) + } +} + +bytea_types! { + impl Debug for OauthToken {} + impl FromStr for OauthToken {} + impl ToString for OauthToken {} +} + +#[derive(Debug, Deserialize, PartialEq, Eq, Type)] +#[sqlx(type_name = "oauth_access_type", rename_all = "lowercase")] +#[serde(rename_all = "lowercase")] +pub enum OauthAccessType { + Online, + Offline, +} + +#[derive(Debug)] +pub(crate) struct UserSession { + pub(crate) uid: UserID, + pub(crate) req_hmac_key: HawkKey, + pub(crate) device_id: Option<DeviceID>, + pub(crate) created_at: DateTime<Utc>, + pub(crate) verified: bool, + pub(crate) verify_code: Option<String>, +} + +#[derive(Clone, Debug)] +pub(crate) struct DeviceCommand { + pub(crate) index: i64, + pub(crate) command: String, + pub(crate) payload: Value, + #[allow(dead_code)] + pub(crate) expires: DateTime<Utc>, + // NOTE this is a device ID, but we don't link it to the actual sender device + // because removing a device would also remove its queued commands. this mirrors + // what fxa does. + pub(crate) sender: Option<String>, +} + +#[derive(Clone, Debug, PartialEq, sqlx::Type)] +#[sqlx(type_name = "device_push_info")] +pub(crate) struct DevicePush { + pub(crate) callback: String, + pub(crate) public_key: String, + pub(crate) auth_key: String, +} + +#[derive(Clone, Debug, PartialEq, sqlx::Type)] +#[sqlx(type_name = "device_command")] +struct DeviceCommandsEntry { + name: String, + body: String, +} + +array_type! { + #[derive(Clone, Debug, PartialEq)] + DeviceCommands(HashMap<String, String>) as _device_command { + fn encode_elem(e: (&String, &String)) -> DeviceCommandsEntry { + DeviceCommandsEntry { name: e.0.clone(), body: e.1.clone() } + } + fn decode_elems(e: Vec<DeviceCommandsEntry>) -> anyhow::Result<Self> { + Ok(Self(e.into_iter().map(|e| (e.name, e.body)).collect())) + } + + pub(crate) fn into_map(self) -> HashMap<String, String> { + self.0 + } + } +} + +impl Deref for DeviceCommands { + type Target = HashMap<String, String>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[derive(Clone, Debug)] +pub(crate) struct Device { + pub(crate) device_id: DeviceID, + // taken from session, otherwise UNIX_EPOCH + pub(crate) last_active: DateTime<Utc>, + pub(crate) name: String, + pub(crate) type_: String, + pub(crate) push: Option<DevicePush>, + pub(crate) available_commands: DeviceCommands, + pub(crate) push_expired: bool, + // actually a str->str map, but we treat it as opaque for simplicity. + // writing a HashMap<String, String> to the db through sqlx is an immense pain, + // and we don't care about the value anyway—it only has to exist for fenix. + pub(crate) location: Value, +} + +#[derive(Clone, Debug)] +pub(crate) struct DeviceUpdate<'a> { + pub(crate) name: Option<&'a str>, + pub(crate) type_: Option<&'a str>, + pub(crate) push: Option<DevicePush>, + pub(crate) available_commands: Option<DeviceCommands>, + pub(crate) location: Option<Value>, +} + +#[derive(Debug, sqlx::Type)] +#[sqlx(type_name = "oauth_token_kind", rename_all = "lowercase")] +pub(crate) enum OauthTokenKind { + Access, + Refresh, +} + +#[derive(Debug)] +pub(crate) struct OauthAccessToken { + pub(crate) user_id: UserID, + pub(crate) client_id: String, + pub(crate) scope: ScopeSet, + pub(crate) parent_refresh: Option<OauthTokenID>, + pub(crate) parent_session: Option<SessionID>, + pub(crate) expires_at: DateTime<Utc>, +} + +#[derive(Debug)] +pub(crate) struct OauthRefreshToken { + pub(crate) user_id: UserID, + pub(crate) client_id: String, + pub(crate) scope: ScopeSet, + pub(crate) session_id: Option<SessionID>, +} + +#[derive(Debug)] +pub(crate) struct OauthAuthorization { + pub(crate) user_id: UserID, + pub(crate) client_id: String, + pub(crate) scope: ScopeSet, + pub(crate) access_type: OauthAccessType, + pub(crate) code_challenge: String, + pub(crate) keys_jwe: Option<String>, + pub(crate) auth_at: DateTime<Utc>, +} + +#[derive(Debug)] +#[cfg_attr(test, derive(Clone))] +pub(crate) struct User { + pub(crate) auth_salt: SaltString, + pub(crate) email: String, + pub(crate) display_name: Option<String>, + pub(crate) ka: SecretKey, + pub(crate) wrapwrap_kb: SecretKey, + pub(crate) verify_hash: VerifyHash, + pub(crate) verified: bool, +} + +// MISSING user secondary email addresses + +#[derive(Debug)] +pub(crate) struct Avatar { + pub(crate) id: AvatarID, + pub(crate) data: Vec<u8>, + pub(crate) content_type: String, +} + +#[derive(Debug)] +pub(crate) struct AttachedClient { + pub(crate) client_id: Option<String>, + pub(crate) device_id: Option<DeviceID>, + pub(crate) session_token_id: Option<SessionID>, + pub(crate) refresh_token_id: Option<OauthTokenID>, + pub(crate) device_type: Option<String>, + pub(crate) name: Option<String>, + pub(crate) created_time: Option<DateTime<Utc>>, + pub(crate) last_access_time: Option<DateTime<Utc>>, + pub(crate) scope: Option<String>, +} + +#[derive(Debug)] +pub(crate) struct VerifyCode { + #[allow(dead_code)] + pub(crate) user_id: UserID, + pub(crate) session_id: Option<SessionID>, + #[allow(dead_code)] + pub(crate) code: String, +} |