summaryrefslogtreecommitdiff
path: root/src/types.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/types.rs')
-rw-r--r--src/types.rs436
1 files changed, 436 insertions, 0 deletions
diff --git a/src/types.rs b/src/types.rs
new file mode 100644
index 0000000..c0c5dfe
--- /dev/null
+++ b/src/types.rs
@@ -0,0 +1,436 @@
+use crate::crypto::SecretBytes;
+use chrono::{DateTime, Utc};
+use password_hash::{rand_core::OsRng, Output, SaltString};
+use rand::RngCore;
+use serde::{Deserialize, Serialize};
+use serde_json::Value;
+use sha2::{Digest, Sha256};
+use sqlx::{
+ postgres::{PgArgumentBuffer, PgTypeInfo, PgValueRef},
+ Decode, Encode, Postgres, Type,
+};
+use std::{
+ collections::HashMap,
+ fmt::{Debug, Display},
+ ops::Deref,
+ str::FromStr,
+};
+
+use self::oauth::ScopeSet;
+
+pub(crate) mod oauth;
+
+macro_rules! array_type {
+ (
+ $( #[ $attr:meta ] )*
+ $name:ident($inner:ty) as $sql_name:ident {
+ $( $body:tt )*
+ }
+ ) => {
+ $( #[ $attr ] )*
+ pub(crate) struct $name(pub(crate) $inner);
+
+ impl $name {
+ $( $body )*
+ }
+
+ impl Type<Postgres> for $name {
+ fn type_info() -> PgTypeInfo {
+ PgTypeInfo::with_name(stringify!($sql_name))
+ }
+ }
+
+ impl Encode<'_, Postgres> for $name {
+ fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> sqlx::encode::IsNull {
+ let raw = self.0.iter().map(Self::encode_elem).collect::<Vec<_>>();
+ Encode::<'_, Postgres>::encode_by_ref(&raw, buf)
+ }
+ }
+
+ impl Decode<'_, Postgres> for $name {
+ fn decode(value: PgValueRef) -> Result<Self, sqlx::error::BoxDynError> {
+ Ok(Self::decode_elems(Decode::<'_, Postgres>::decode(value)?)?)
+ }
+ }
+ }
+}
+
+macro_rules! bytea_types {
+ () => {};
+ (
+ #[simple_array]
+ struct $name:ident($inner:ty) as $sql_name:ident;
+
+ $( $rest:tt )*
+ ) => {
+ bytea_types!{
+ #[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
+ #[serde(try_from = "String", into = "String")]
+ struct $name($inner) as $sql_name {
+ fn decode(v) -> _ { &v.0[..] }
+ fn encode(v) -> _ { v }
+ }
+
+ impl FromStr for $name {}
+ impl ToString for $name {}
+ impl Debug for $name {}
+
+ $( $rest )*
+ }
+ };
+ (
+ $( #[ $attr:meta ] )*
+ struct $name:ident($inner:ty) as $sql_name:ident {
+ $( fn arbitrary($a:ident) -> _ { $ae:expr } )?
+ fn decode($d:ident) -> _ { $de:expr }
+ fn encode($e:ident) -> _ { $ee:expr }
+
+ $( $impls:tt )*
+ }
+
+ $( $rest:tt )*
+ ) => {
+ $( #[ $attr ] )*
+ pub(crate) struct $name(pub(crate) $inner);
+
+ impl $name {
+ $( $impls )*
+ }
+
+ impl Type<Postgres> for $name {
+ fn type_info() -> PgTypeInfo {
+ PgTypeInfo::with_name(stringify!($sql_name))
+ }
+ }
+
+ impl Encode<'_, Postgres> for $name {
+ fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> sqlx::encode::IsNull {
+ let $d = self;
+ <&[u8] as Encode<'_, Postgres>>::encode_by_ref(&$de, buf)
+ }
+ }
+
+ impl Decode<'_, Postgres> for $name {
+ fn decode(value: PgValueRef) -> Result<Self, sqlx::error::BoxDynError> {
+ let $e = <&[u8] as Decode<'_, Postgres>>::decode(value)?.try_into()?;
+ Ok($name($ee))
+ }
+ }
+
+ bytea_types!{ $( $rest )* }
+ };
+ ( impl ToString for $name:ident {} $( $rest:tt )* ) => {
+ impl Display for $name {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
+ f.write_str(&hex::encode(&self.0))
+ }
+ }
+ impl From<$name> for String {
+ fn from(s: $name) -> String {
+ format!("{}", s)
+ }
+ }
+ bytea_types!{ $( $rest )* }
+ };
+ ( impl Debug for $name:ident {} $( $rest:tt )* ) => {
+ impl Debug for $name {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_tuple(stringify!($name)).field(&self.to_string()).finish()
+ }
+ }
+ bytea_types!{ $( $rest )* }
+ };
+ ( impl FromStr for $name:ident {} $( $rest:tt )* ) => {
+ impl FromStr for $name {
+ type Err = anyhow::Error;
+
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ Ok(Self(hex::decode(s)?.as_slice().try_into()?))
+ }
+ }
+ impl TryFrom<String> for $name {
+ type Error = anyhow::Error;
+
+ fn try_from(s: String) -> Result<Self, Self::Error> {
+ s.parse()
+ }
+ }
+ bytea_types!{ $( $rest )* }
+ }
+}
+
+//
+//
+//
+
+bytea_types! {
+ #[derive(Clone, Debug, PartialEq, Eq)]
+ struct HawkKey(SecretBytes<32>) as hawk_key {
+ fn arbitrary(a) -> _ { SecretBytes(a) }
+ fn decode(v) -> _ { v.0.0.as_ref() }
+ fn encode(v) -> _ { SecretBytes(v) }
+ }
+
+ #[simple_array]
+ struct SessionID([u8; 32]) as session_id;
+
+ #[simple_array]
+ struct DeviceID([u8; 16]) as device_id;
+
+ #[simple_array]
+ struct UserID([u8; 16]) as user_id;
+
+ #[simple_array]
+ struct KeyFetchID([u8; 32]) as key_fetch_id;
+
+ #[simple_array]
+ struct OauthTokenID([u8; 32]) as oauth_token_id;
+
+ #[simple_array]
+ struct OauthAuthorizationID([u8; 32]) as oauth_auth_id;
+
+ #[simple_array]
+ struct PasswordChangeID([u8; 32]) as password_change_id;
+
+ #[simple_array]
+ struct AccountResetID([u8; 32]) as account_reset_id;
+
+ #[simple_array]
+ struct AvatarID([u8; 16]) as avatar_id;
+
+ #[derive(Clone, Debug, PartialEq, Eq)]
+ struct SecretKey(SecretBytes<32>) as secret_key {
+ fn arbitrary(a) -> _ { SecretBytes(a) }
+ fn decode(v) -> _ { v.0.0.as_ref() }
+ fn encode(v) -> _ { SecretBytes(v) }
+ }
+
+ #[derive(Clone, Debug, PartialEq, Eq)]
+ struct VerifyHash(Output) as verify_hash {
+ fn arbitrary(a) -> _ { Output::new(<[u8; 32]>::as_ref(&a)).unwrap() }
+ fn decode(v) -> _ { v.0.as_ref() }
+ fn encode(v) -> _ { v }
+ }
+}
+
+impl DeviceID {
+ pub fn random() -> Self {
+ let mut result = Self([0; 16]);
+ OsRng.fill_bytes(&mut result.0);
+ result
+ }
+}
+
+impl UserID {
+ pub fn random() -> Self {
+ let mut result = Self([0; 16]);
+ OsRng.fill_bytes(&mut result.0);
+ result
+ }
+}
+
+impl OauthAuthorizationID {
+ pub fn random() -> Self {
+ let mut result = Self([0; 32]);
+ OsRng.fill_bytes(&mut result.0);
+ result
+ }
+}
+
+#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
+#[serde(try_from = "String", into = "String")]
+pub(crate) struct OauthToken([u8; 32]);
+
+impl OauthToken {
+ pub fn random() -> Self {
+ let mut result = Self([0; 32]);
+ OsRng.fill_bytes(&mut result.0);
+ result
+ }
+
+ pub fn hash(&self) -> OauthTokenID {
+ let mut sha = Sha256::new();
+ sha.update(&self.0);
+ OauthTokenID(*sha.finalize().as_ref())
+ }
+}
+
+bytea_types! {
+ impl Debug for OauthToken {}
+ impl FromStr for OauthToken {}
+ impl ToString for OauthToken {}
+}
+
+#[derive(Debug, Deserialize, PartialEq, Eq, Type)]
+#[sqlx(type_name = "oauth_access_type", rename_all = "lowercase")]
+#[serde(rename_all = "lowercase")]
+pub enum OauthAccessType {
+ Online,
+ Offline,
+}
+
+#[derive(Debug)]
+pub(crate) struct UserSession {
+ pub(crate) uid: UserID,
+ pub(crate) req_hmac_key: HawkKey,
+ pub(crate) device_id: Option<DeviceID>,
+ pub(crate) created_at: DateTime<Utc>,
+ pub(crate) verified: bool,
+ pub(crate) verify_code: Option<String>,
+}
+
+#[derive(Clone, Debug)]
+pub(crate) struct DeviceCommand {
+ pub(crate) index: i64,
+ pub(crate) command: String,
+ pub(crate) payload: Value,
+ #[allow(dead_code)]
+ pub(crate) expires: DateTime<Utc>,
+ // NOTE this is a device ID, but we don't link it to the actual sender device
+ // because removing a device would also remove its queued commands. this mirrors
+ // what fxa does.
+ pub(crate) sender: Option<String>,
+}
+
+#[derive(Clone, Debug, PartialEq, sqlx::Type)]
+#[sqlx(type_name = "device_push_info")]
+pub(crate) struct DevicePush {
+ pub(crate) callback: String,
+ pub(crate) public_key: String,
+ pub(crate) auth_key: String,
+}
+
+#[derive(Clone, Debug, PartialEq, sqlx::Type)]
+#[sqlx(type_name = "device_command")]
+struct DeviceCommandsEntry {
+ name: String,
+ body: String,
+}
+
+array_type! {
+ #[derive(Clone, Debug, PartialEq)]
+ DeviceCommands(HashMap<String, String>) as _device_command {
+ fn encode_elem(e: (&String, &String)) -> DeviceCommandsEntry {
+ DeviceCommandsEntry { name: e.0.clone(), body: e.1.clone() }
+ }
+ fn decode_elems(e: Vec<DeviceCommandsEntry>) -> anyhow::Result<Self> {
+ Ok(Self(e.into_iter().map(|e| (e.name, e.body)).collect()))
+ }
+
+ pub(crate) fn into_map(self) -> HashMap<String, String> {
+ self.0
+ }
+ }
+}
+
+impl Deref for DeviceCommands {
+ type Target = HashMap<String, String>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+#[derive(Clone, Debug)]
+pub(crate) struct Device {
+ pub(crate) device_id: DeviceID,
+ // taken from session, otherwise UNIX_EPOCH
+ pub(crate) last_active: DateTime<Utc>,
+ pub(crate) name: String,
+ pub(crate) type_: String,
+ pub(crate) push: Option<DevicePush>,
+ pub(crate) available_commands: DeviceCommands,
+ pub(crate) push_expired: bool,
+ // actually a str->str map, but we treat it as opaque for simplicity.
+ // writing a HashMap<String, String> to the db through sqlx is an immense pain,
+ // and we don't care about the value anyway—it only has to exist for fenix.
+ pub(crate) location: Value,
+}
+
+#[derive(Clone, Debug)]
+pub(crate) struct DeviceUpdate<'a> {
+ pub(crate) name: Option<&'a str>,
+ pub(crate) type_: Option<&'a str>,
+ pub(crate) push: Option<DevicePush>,
+ pub(crate) available_commands: Option<DeviceCommands>,
+ pub(crate) location: Option<Value>,
+}
+
+#[derive(Debug, sqlx::Type)]
+#[sqlx(type_name = "oauth_token_kind", rename_all = "lowercase")]
+pub(crate) enum OauthTokenKind {
+ Access,
+ Refresh,
+}
+
+#[derive(Debug)]
+pub(crate) struct OauthAccessToken {
+ pub(crate) user_id: UserID,
+ pub(crate) client_id: String,
+ pub(crate) scope: ScopeSet,
+ pub(crate) parent_refresh: Option<OauthTokenID>,
+ pub(crate) parent_session: Option<SessionID>,
+ pub(crate) expires_at: DateTime<Utc>,
+}
+
+#[derive(Debug)]
+pub(crate) struct OauthRefreshToken {
+ pub(crate) user_id: UserID,
+ pub(crate) client_id: String,
+ pub(crate) scope: ScopeSet,
+ pub(crate) session_id: Option<SessionID>,
+}
+
+#[derive(Debug)]
+pub(crate) struct OauthAuthorization {
+ pub(crate) user_id: UserID,
+ pub(crate) client_id: String,
+ pub(crate) scope: ScopeSet,
+ pub(crate) access_type: OauthAccessType,
+ pub(crate) code_challenge: String,
+ pub(crate) keys_jwe: Option<String>,
+ pub(crate) auth_at: DateTime<Utc>,
+}
+
+#[derive(Debug)]
+#[cfg_attr(test, derive(Clone))]
+pub(crate) struct User {
+ pub(crate) auth_salt: SaltString,
+ pub(crate) email: String,
+ pub(crate) display_name: Option<String>,
+ pub(crate) ka: SecretKey,
+ pub(crate) wrapwrap_kb: SecretKey,
+ pub(crate) verify_hash: VerifyHash,
+ pub(crate) verified: bool,
+}
+
+// MISSING user secondary email addresses
+
+#[derive(Debug)]
+pub(crate) struct Avatar {
+ pub(crate) id: AvatarID,
+ pub(crate) data: Vec<u8>,
+ pub(crate) content_type: String,
+}
+
+#[derive(Debug)]
+pub(crate) struct AttachedClient {
+ pub(crate) client_id: Option<String>,
+ pub(crate) device_id: Option<DeviceID>,
+ pub(crate) session_token_id: Option<SessionID>,
+ pub(crate) refresh_token_id: Option<OauthTokenID>,
+ pub(crate) device_type: Option<String>,
+ pub(crate) name: Option<String>,
+ pub(crate) created_time: Option<DateTime<Utc>>,
+ pub(crate) last_access_time: Option<DateTime<Utc>>,
+ pub(crate) scope: Option<String>,
+}
+
+#[derive(Debug)]
+pub(crate) struct VerifyCode {
+ #[allow(dead_code)]
+ pub(crate) user_id: UserID,
+ pub(crate) session_id: Option<SessionID>,
+ #[allow(dead_code)]
+ pub(crate) code: String,
+}