summaryrefslogtreecommitdiff
path: root/src/db
diff options
context:
space:
mode:
Diffstat (limited to 'src/db')
-rw-r--r--src/db/mod.rs1026
1 files changed, 1026 insertions, 0 deletions
diff --git a/src/db/mod.rs b/src/db/mod.rs
new file mode 100644
index 0000000..040507f
--- /dev/null
+++ b/src/db/mod.rs
@@ -0,0 +1,1026 @@
+use std::{error::Error, mem::replace, sync::Arc};
+
+use anyhow::Result;
+use chrono::{DateTime, Duration, Utc};
+use password_hash::SaltString;
+use rocket::{
+ fairing::{self, Fairing},
+ futures::lock::{MappedMutexGuard, Mutex, MutexGuard},
+ http::Status,
+ request::{FromRequest, Outcome},
+ Request, Response, Sentinel,
+};
+use serde_json::Value;
+use sqlx::{query, query_as, query_scalar, PgPool, Postgres, Transaction};
+
+use crate::{
+ crypto::WrappedKeyBundle,
+ types::{
+ oauth::ScopeSet, AccountResetID, AttachedClient, Avatar, AvatarID, Device, DeviceCommand,
+ DeviceCommands, DeviceID, DevicePush, DeviceUpdate, HawkKey, KeyFetchID, OauthAccessToken,
+ OauthAccessType, OauthAuthorization, OauthAuthorizationID, OauthRefreshToken, OauthTokenID,
+ PasswordChangeID, SecretKey, SessionID, User, UserID, UserSession, VerifyCode, VerifyHash,
+ },
+};
+
+// we implement a completely custom db type set instead of using rocket_db_pools for two reasons:
+// 1. rocket_db_pools doesn't support sqlx 0.6
+// 2. we want one transaction per *request*, including all guards
+
+#[derive(Clone)]
+pub struct Db {
+ db: Arc<PgPool>,
+}
+
+impl Db {
+ pub async fn connect(db: &str) -> Result<Self> {
+ Ok(Self { db: Arc::new(PgPool::connect(db).await?) })
+ }
+
+ pub async fn begin(&self) -> Result<DbConn> {
+ Ok(DbConn(Mutex::new(ConnState::Capable(self.clone()))))
+ }
+
+ pub async fn migrate(&self) -> Result<()> {
+ sqlx::migrate!().run(&*self.db).await?;
+ Ok(())
+ }
+}
+
+struct ActiveConn {
+ tx: Transaction<'static, Postgres>,
+ always_commit: bool,
+}
+
+#[allow(clippy::large_enum_variant)]
+enum ConnState {
+ None,
+ Capable(Db),
+ Active(ActiveConn),
+ Done,
+}
+
+pub struct DbConn(Mutex<ConnState>);
+
+struct DbWrap(Db);
+struct DbConnWrap(DbConn);
+type DbLock<'c> = MappedMutexGuard<'c, ConnState, ActiveConn>;
+
+#[rocket::async_trait]
+impl Fairing for Db {
+ fn info(&self) -> fairing::Info {
+ fairing::Info {
+ name: "db access",
+ kind: fairing::Kind::Ignite | fairing::Kind::Response | fairing::Kind::Singleton,
+ }
+ }
+
+ async fn on_ignite(&self, rocket: rocket::Rocket<rocket::Build>) -> fairing::Result {
+ Ok(rocket.manage(DbWrap(self.clone())).manage(self.clone()))
+ }
+
+ async fn on_response<'r>(&self, req: &'r Request<'_>, resp: &mut Response<'r>) {
+ let conn = req.local_cache(|| DbConnWrap(DbConn(Mutex::new(ConnState::None))));
+ let s = replace(&mut *conn.0 .0.lock().await, ConnState::Done);
+ if let ConnState::Active(ActiveConn { tx, always_commit }) = s {
+ // don't commit if the request failed, unless explicitly asked to.
+ // this is used by key fetch to invalidate tokens even if the header
+ // signature of the request is incorrect.
+ if !always_commit && resp.status() != Status::Ok {
+ return;
+ }
+ if let Err(e) = tx.commit().await {
+ resp.set_status(Status::InternalServerError);
+ error!("commit failed: {}", e);
+ }
+ }
+ }
+}
+
+impl DbConn {
+ pub async fn commit(self) -> Result<(), sqlx::Error> {
+ match self.0.into_inner() {
+ ConnState::None => Ok(()),
+ ConnState::Capable(_) => Ok(()),
+ ConnState::Active(ac) => ac.tx.commit().await,
+ ConnState::Done => Ok(()),
+ }
+ }
+
+ fn register<'r>(req: &'r Request<'_>) -> &'r DbConn {
+ &req.local_cache(|| {
+ let db = req.rocket().state::<DbWrap>().unwrap().0.clone();
+ DbConnWrap(DbConn(Mutex::new(ConnState::Capable(db))))
+ })
+ .0
+ }
+
+ // defer opening a transaction and thus locking the mutex holding it.
+ // without deferral the order of arguments in route signatures is important,
+ // which may be surprising: placing a DbConn before a guard that also uses
+ // the database causes a concurrent transaction error.
+ // HACK maybe we should return errors instead of panicking, but there's no
+ // way an error here is not a severe bug
+ async fn get(&self) -> sqlx::Result<DbLock<'_>> {
+ let mut m = match self.0.try_lock() {
+ Some(m) => m,
+ None => panic!("attempted to open concurrent transactions"),
+ };
+ match &*m {
+ ConnState::Capable(db) => {
+ *m = ConnState::Active(ActiveConn {
+ tx: db.db.begin().await?,
+ always_commit: false,
+ });
+ },
+ ConnState::None | ConnState::Done => panic!("db connection requested after teardown"),
+ _ => (),
+ }
+ Ok(MutexGuard::map(m, |g| match g {
+ ConnState::Active(ref mut tx) => tx,
+ _ => unreachable!(),
+ }))
+ }
+}
+
+impl<'r> Sentinel for &'r Db {
+ fn abort(rocket: &rocket::Rocket<rocket::Ignite>) -> bool {
+ rocket.state::<DbWrap>().is_none()
+ }
+}
+
+#[rocket::async_trait]
+impl<'r> FromRequest<'r> for &'r Db {
+ type Error = anyhow::Error;
+
+ async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
+ Outcome::Success(&req.rocket().state::<DbWrap>().unwrap().0)
+ }
+}
+
+impl<'r> Sentinel for &'r DbConn {
+ fn abort(rocket: &rocket::Rocket<rocket::Ignite>) -> bool {
+ rocket.state::<DbWrap>().is_none()
+ }
+}
+
+#[rocket::async_trait]
+impl<'r> FromRequest<'r> for &'r DbConn {
+ type Error = anyhow::Error;
+
+ async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
+ Outcome::Success(DbConn::register(req))
+ }
+}
+
+//
+//
+//
+
+//
+//
+//
+
+impl DbConn {
+ pub(crate) async fn always_commit(&self) -> Result<()> {
+ self.get().await?.always_commit = true;
+ Ok(())
+ }
+
+ //
+ //
+ //
+
+ pub(crate) async fn add_session(
+ &self,
+ id: SessionID,
+ user_id: &UserID,
+ key: HawkKey,
+ verified: bool,
+ verify_code: Option<&str>,
+ ) -> sqlx::Result<DateTime<Utc>> {
+ query_scalar!(
+ r#"insert into user_session (session_id, user_id, req_hmac_key, device_id, verified,
+ verify_code)
+ values ($1, $2, $3, null, $4, $5)
+ returning created_at"#,
+ id as _,
+ user_id as _,
+ key as _,
+ verified,
+ verify_code,
+ )
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn use_session(&self, id: &SessionID) -> sqlx::Result<UserSession> {
+ query_as!(
+ UserSession,
+ r#"update user_session
+ set last_active = now()
+ where session_id = $1
+ returning user_id as "uid: UserID", req_hmac_key as "req_hmac_key: HawkKey",
+ device_id as "device_id: DeviceID", created_at, verified, verify_code"#,
+ id as _
+ )
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn use_session_from_refresh(
+ &self,
+ id: &OauthTokenID,
+ ) -> sqlx::Result<(SessionID, UserSession)> {
+ query!(
+ r#"update user_session
+ set last_active = now()
+ where session_id = (
+ select session_id from oauth_token where kind = 'refresh' and id = $1
+ )
+ returning user_id as "uid: UserID", req_hmac_key as "req_hmac_key: HawkKey",
+ device_id as "device_id: DeviceID", session_id as "session_id: SessionID",
+ created_at, verified, verify_code"#,
+ id as _
+ )
+ .map(|r| {
+ (
+ r.session_id.clone(),
+ UserSession {
+ uid: r.uid,
+ req_hmac_key: r.req_hmac_key,
+ device_id: r.device_id,
+ created_at: r.created_at,
+ verified: r.verified,
+ verify_code: r.verify_code,
+ },
+ )
+ })
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn delete_session(&self, user: &UserID, id: &SessionID) -> sqlx::Result<()> {
+ query_scalar!(
+ r#"delete from user_session
+ where user_id = $1 and session_id = $2
+ returning 1"#,
+ user as _,
+ id as _
+ )
+ .fetch_one(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn set_session_device<'d>(
+ &self,
+ id: &SessionID,
+ dev: Option<&'d DeviceID>,
+ ) -> sqlx::Result<()> {
+ query!(
+ r#"update user_session set device_id = $1 where session_id = $2"#,
+ dev as _,
+ id as _
+ )
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn set_session_verified(&self, id: &SessionID) -> sqlx::Result<()> {
+ query_scalar!(
+ r#"update user_session
+ set verified = true, verify_code = null
+ where session_id = $1
+ returning 1"#,
+ id as _
+ )
+ .fetch_one(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ //
+ //
+ //
+
+ pub(crate) async fn enqueue_command(
+ &self,
+ target: &DeviceID,
+ sender: &Option<DeviceID>,
+ command: &str,
+ payload: &Value,
+ ttl: u32,
+ ) -> sqlx::Result<i64> {
+ let expires = Utc::now() + Duration::seconds(ttl as i64);
+ query!(
+ r#"insert into device_commands (device_id, command, payload, expires, sender)
+ values ($1, $2, $3, $4, $5)
+ returning index"#,
+ target as _,
+ command,
+ payload,
+ expires,
+ sender.as_ref().map(ToString::to_string)
+ )
+ .map(|x| x.index)
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn get_commands(
+ &self,
+ user: &UserID,
+ dev: &DeviceID,
+ min_index: i64,
+ limit: i64,
+ ) -> sqlx::Result<(bool, Vec<DeviceCommand>)> {
+ // NOTE while fxa api docs state that command queries return only commands enqueued
+ // *after* index, what pushbox actually does is return commands *starting at* index!
+ let mut results = query_as!(
+ DeviceCommand,
+ r#"select index, command, payload, expires, sender
+ from device_commands join device using (device_id)
+ where index >= $1 and device_id = $2 and user_id = $3
+ order by index
+ limit $4"#,
+ min_index,
+ dev as _,
+ user as _,
+ limit + 1
+ )
+ .fetch_all(&mut self.get().await?.tx)
+ .await?;
+ let more = results.len() > limit as usize;
+ results.truncate(limit as usize);
+ Ok((more, results))
+ }
+
+ //
+ //
+ //
+
+ pub(crate) async fn get_devices(&self, user: &UserID) -> sqlx::Result<Vec<Device>> {
+ query_as!(
+ Device,
+ r#"select d.device_id as "device_id: DeviceID", d.name, d.type as type_,
+ d.push as "push: DevicePush",
+ d.available_commands as "available_commands: DeviceCommands",
+ d.push_expired, d.location, coalesce(us.last_active, to_timestamp(0)) as "last_active!"
+ from device d left join user_session us using (device_id)
+ where d.user_id = $1"#,
+ user as _
+ )
+ .fetch_all(&mut self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn get_device(&self, user: &UserID, dev: &DeviceID) -> sqlx::Result<Device> {
+ query_as!(
+ Device,
+ r#"select d.device_id as "device_id: DeviceID", d.name, d.type as type_,
+ d.push as "push: DevicePush",
+ d.available_commands as "available_commands: DeviceCommands",
+ d.push_expired, d.location, coalesce(us.last_active, to_timestamp(0)) as "last_active!"
+ from device d left join user_session us using (device_id)
+ where d.user_id = $1 and d.device_id = $2"#,
+ user as _,
+ dev as _
+ )
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn change_device<'d>(
+ &self,
+ user: &UserID,
+ id: &DeviceID,
+ dev: DeviceUpdate<'d>,
+ ) -> sqlx::Result<Device> {
+ query_as!(
+ Device,
+ r#"select device_id as "device_id!: DeviceID", name as "name!", type as "type_!",
+ push as "push: DevicePush",
+ available_commands as "available_commands!: DeviceCommands",
+ push_expired as "push_expired!", location as "location!",
+ coalesce(last_active, to_timestamp(0)) as "last_active!"
+ from insert_or_update_device($1, $2, $3, $4, $5, $6, $7) as iud
+ left join user_session using (device_id)"#,
+ id as _,
+ user as _,
+ // these two are not optional but are Option anyway. the db will
+ // refuse insertions that don't have them set.
+ dev.name,
+ dev.type_,
+ dev.push as _,
+ dev.available_commands as _,
+ dev.location
+ )
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn set_push_expired(&self, dev: &DeviceID) -> sqlx::Result<()> {
+ query!(
+ r#"update device
+ set push_expired = true
+ where device_id = $1"#,
+ dev as _
+ )
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn delete_device(&self, user: &UserID, dev: &DeviceID) -> sqlx::Result<()> {
+ query_scalar!(
+ r#"delete from device where user_id = $1 and device_id = $2 returning 1"#,
+ user as _,
+ dev as _
+ )
+ .fetch_one(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ //
+ //
+ //
+
+ pub(crate) async fn add_key_fetch(
+ &self,
+ id: KeyFetchID,
+ hmac_key: &HawkKey,
+ keys: &WrappedKeyBundle,
+ ) -> sqlx::Result<()> {
+ query!(
+ r#"insert into key_fetch (id, hmac_key, keys) values ($1, $2, $3)"#,
+ id as _,
+ hmac_key as _,
+ &keys.to_bytes()[..]
+ )
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn finish_key_fetch(
+ &self,
+ id: &KeyFetchID,
+ ) -> sqlx::Result<(HawkKey, Vec<u8>)> {
+ query!(
+ r#"delete from key_fetch
+ where id = $1 and expires_at > now()
+ returning hmac_key as "hmac_key: HawkKey", keys"#,
+ id as _
+ )
+ .map(|r| (r.hmac_key, r.keys))
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ //
+ //
+ //
+
+ pub(crate) async fn get_refresh_token(
+ &self,
+ id: &OauthTokenID,
+ ) -> sqlx::Result<OauthRefreshToken> {
+ query_as!(
+ OauthRefreshToken,
+ r#"select user_id as "user_id: UserID", client_id, scope as "scope: ScopeSet",
+ session_id as "session_id: SessionID"
+ from oauth_token
+ where id = $1 and kind = 'refresh'"#,
+ id as _
+ )
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn get_access_token(
+ &self,
+ id: &OauthTokenID,
+ ) -> sqlx::Result<OauthAccessToken> {
+ query_as!(
+ OauthAccessToken,
+ r#"select user_id as "user_id: UserID", client_id, scope as "scope: ScopeSet",
+ parent_refresh as "parent_refresh: OauthTokenID",
+ parent_session as "parent_session: SessionID",
+ expires_at as "expires_at!"
+ from oauth_token
+ where id = $1 and kind = 'access' and expires_at > now()"#,
+ id as _
+ )
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn add_refresh_token(
+ &self,
+ id: &OauthTokenID,
+ token: OauthRefreshToken,
+ ) -> sqlx::Result<()> {
+ query!(
+ r#"insert into oauth_token (id, kind, user_id, client_id, scope, session_id)
+ values ($1, 'refresh', $2, $3, $4, $5)"#,
+ id as _,
+ token.user_id as _,
+ token.client_id,
+ token.scope as _,
+ token.session_id as _
+ )
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn add_access_token(
+ &self,
+ id: &OauthTokenID,
+ token: OauthAccessToken,
+ ) -> sqlx::Result<()> {
+ query!(
+ r#"insert into oauth_token (id, kind, user_id, client_id, scope, session_id,
+ parent_refresh, parent_session, expires_at)
+ values ($1, 'access', $2, $3, $4, null, $5, $6, $7)"#,
+ id as _,
+ token.user_id as _,
+ token.client_id,
+ token.scope as _,
+ token.parent_refresh as _,
+ token.parent_session as _,
+ token.expires_at,
+ )
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn delete_oauth_token(&self, id: &OauthTokenID) -> sqlx::Result<()> {
+ query!(r#"delete from oauth_token where id = $1"#, id as _)
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn delete_refresh_token(&self, id: &OauthTokenID) -> sqlx::Result<()> {
+ query!(r#"delete from oauth_token where id = $1 and kind = 'refresh'"#, id as _)
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ //
+ //
+ //
+
+ pub(crate) async fn add_oauth_authorization(
+ &self,
+ id: &OauthAuthorizationID,
+ auth: OauthAuthorization,
+ ) -> sqlx::Result<()> {
+ query!(
+ r#"insert into oauth_authorization (id, user_id, client_id, scope, access_type,
+ code_challenge, keys_jwe, auth_at)
+ values ($1, $2, $3, $4, $5, $6, $7, $8)"#,
+ id as _,
+ auth.user_id as _,
+ auth.client_id,
+ auth.scope as _,
+ auth.access_type as _,
+ auth.code_challenge,
+ auth.keys_jwe,
+ auth.auth_at,
+ )
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn take_oauth_authorization(
+ &self,
+ id: &OauthAuthorizationID,
+ ) -> sqlx::Result<OauthAuthorization> {
+ query_as!(
+ OauthAuthorization,
+ r#"delete from oauth_authorization
+ where id = $1 and expires_at > now()
+ returning user_id as "user_id: UserID", client_id, scope as "scope: ScopeSet",
+ access_type as "access_type: OauthAccessType",
+ code_challenge, keys_jwe, auth_at"#,
+ id as _
+ )
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ //
+ //
+ //
+
+ pub(crate) async fn user_email_exists(&self, email: &str) -> sqlx::Result<bool> {
+ Ok(query_scalar!(r#"select 1 from users where email = lower($1)"#, email)
+ .fetch_optional(&mut self.get().await?.tx)
+ .await?
+ .is_some())
+ }
+
+ pub(crate) async fn add_user(&self, user: User) -> sqlx::Result<UserID> {
+ let id = UserID::random();
+ query_scalar!(
+ r#"insert into users (user_id, auth_salt, email, ka, wrapwrap_kb, verify_hash,
+ display_name)
+ values ($1, $2, $3, $4, $5, $6, $7)"#,
+ id as _,
+ user.auth_salt.as_str(),
+ user.email,
+ user.ka as _,
+ user.wrapwrap_kb as _,
+ user.verify_hash as _,
+ user.display_name,
+ )
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(id)
+ }
+
+ pub(crate) async fn get_user(&self, email: &str) -> sqlx::Result<(UserID, User)> {
+ query!(
+ r#"select user_id as "id: UserID", auth_salt as "auth_salt: String", email,
+ ka as "ka: SecretKey", wrapwrap_kb as "wrapwrap_kb: SecretKey",
+ verify_hash as "verify_hash: VerifyHash", display_name, verified
+ from users
+ where email = lower($1)"#,
+ email
+ )
+ .try_map(|r| {
+ Ok((
+ r.id,
+ User {
+ auth_salt: SaltString::new(&r.auth_salt).map_err(decode_err("auth_salt"))?,
+ email: r.email,
+ ka: r.ka,
+ wrapwrap_kb: r.wrapwrap_kb,
+ verify_hash: r.verify_hash,
+ display_name: r.display_name,
+ verified: r.verified,
+ },
+ ))
+ })
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn get_user_by_id(&self, id: &UserID) -> sqlx::Result<User> {
+ query!(
+ r#"select auth_salt as "auth_salt: String", email,
+ ka as "ka: SecretKey", wrapwrap_kb as "wrapwrap_kb: SecretKey",
+ verify_hash as "verify_hash: VerifyHash", display_name, verified
+ from users
+ where user_id = $1"#,
+ id as _
+ )
+ .try_map(|r| {
+ Ok(User {
+ auth_salt: SaltString::new(&r.auth_salt).map_err(decode_err("auth_salt"))?,
+ email: r.email,
+ ka: r.ka,
+ wrapwrap_kb: r.wrapwrap_kb,
+ verify_hash: r.verify_hash,
+ display_name: r.display_name,
+ verified: r.verified,
+ })
+ })
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn set_user_name(&self, id: &UserID, name: &str) -> sqlx::Result<()> {
+ query!(
+ "update users
+ set display_name = $2
+ where user_id = $1",
+ id as _,
+ name,
+ )
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn set_user_verified(&self, id: &UserID) -> sqlx::Result<()> {
+ query_scalar!("update users set verified = true where user_id = $1 returning 1", id as _)
+ .fetch_one(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn delete_user(&self, email: &str) -> sqlx::Result<()> {
+ query_scalar!(r#"delete from users where email = lower($1)"#, email)
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn change_user_auth(
+ &self,
+ uid: &UserID,
+ salt: SaltString,
+ wwkb: SecretKey,
+ verify_hash: VerifyHash,
+ ) -> sqlx::Result<()> {
+ query!(
+ r#"update users
+ set auth_salt = $2, wrapwrap_kb = $3, verify_hash = $4
+ where user_id = $1"#,
+ uid as _,
+ salt.to_string(),
+ wwkb as _,
+ verify_hash as _,
+ )
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn reset_user_auth(
+ &self,
+ uid: &UserID,
+ salt: SaltString,
+ wwkb: SecretKey,
+ verify_hash: VerifyHash,
+ ) -> sqlx::Result<()> {
+ query!(
+ r#"call reset_user_auth($1, $2, $3, $4)"#,
+ uid as _,
+ salt.to_string(),
+ wwkb as _,
+ verify_hash as _,
+ )
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ //
+ //
+ //
+
+ pub(crate) async fn get_attached_clients(
+ &self,
+ id: &UserID,
+ ) -> sqlx::Result<Vec<AttachedClient>> {
+ query_as!(
+ AttachedClient,
+ r#"select
+ ot.client_id as "client_id?",
+ d.device_id as "device_id?: DeviceID",
+ us.session_id as "session_token_id?: SessionID",
+ ot.id as "refresh_token_id?: OauthTokenID",
+ d.type as "device_type?",
+ d.name as "name?",
+ coalesce(d.created_at, us.created_at, ot.created_at) as "created_time?",
+ us.last_active as "last_access_time?",
+ ot.scope as "scope?"
+ from device d
+ full outer join user_session us on (d.device_id = us.device_id)
+ full outer join oauth_token ot on (us.session_id = ot.session_id)
+ where
+ (ot.kind is null or ot.kind = 'refresh')
+ and $1 in (d.user_id, us.user_id, ot.user_id)
+ order by d.device_id"#,
+ id as _,
+ )
+ .fetch_all(&mut self.get().await?.tx)
+ .await
+ }
+
+ //
+ //
+ //
+
+ pub(crate) async fn get_user_avatar_id(&self, id: &UserID) -> sqlx::Result<Option<AvatarID>> {
+ query!(r#"select id as "id: AvatarID" from user_avatars where user_id = $1"#, id as _,)
+ .map(|r| r.id)
+ .fetch_optional(&mut *self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn get_user_avatar(&self, id: &AvatarID) -> sqlx::Result<Option<Avatar>> {
+ query_as!(
+ Avatar,
+ r#"select id as "id: AvatarID", data, content_type
+ from user_avatars
+ where id = $1"#,
+ id as _,
+ )
+ .fetch_optional(&mut *self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn set_user_avatar(&self, id: &UserID, avatar: Avatar) -> sqlx::Result<()> {
+ query!(
+ r#"insert into user_avatars (user_id, id, data, content_type)
+ values ($1, $2, $3, $4)
+ on conflict (user_id) do update set
+ id = $2, data = $3, content_type = $4"#,
+ id as _,
+ avatar.id as _,
+ avatar.data,
+ avatar.content_type,
+ )
+ .execute(&mut *self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn delete_user_avatar(
+ &self,
+ user: &UserID,
+ id: &AvatarID,
+ ) -> sqlx::Result<()> {
+ query!(r#"delete from user_avatars where user_id = $1 and id = $2"#, user as _, id as _,)
+ .execute(&mut *self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ //
+ //
+ //
+
+ pub(crate) async fn add_verify_code(
+ &self,
+ user: &UserID,
+ session: &SessionID,
+ code: &str,
+ ) -> sqlx::Result<()> {
+ query!(
+ r#"insert into verify_codes (user_id, session_id, code)
+ values ($1, $2, $3)"#,
+ user as _,
+ session as _,
+ code,
+ )
+ .execute(&mut *self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn get_verify_code(
+ &self,
+ user: &UserID,
+ ) -> sqlx::Result<(String, VerifyCode)> {
+ query!(
+ r#"select user_id as "user_id: UserID", session_id as "session_id: SessionID", code,
+ email
+ from verify_codes join users using (user_id)
+ where user_id = $1"#,
+ user as _,
+ )
+ .map(|r| {
+ (r.email, VerifyCode { user_id: r.user_id, session_id: r.session_id, code: r.code })
+ })
+ .fetch_one(&mut *self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn try_use_verify_code(
+ &self,
+ user: &UserID,
+ code: &str,
+ ) -> sqlx::Result<Option<VerifyCode>> {
+ query_as!(
+ VerifyCode,
+ r#"delete from verify_codes
+ where user_id = $1 and code = $2
+ returning user_id as "user_id: UserID", session_id as "session_id: SessionID",
+ code"#,
+ user as _,
+ code,
+ )
+ .fetch_optional(&mut *self.get().await?.tx)
+ .await
+ }
+
+ //
+ //
+ //
+
+ pub(crate) async fn add_password_change(
+ &self,
+ user: &UserID,
+ id: &PasswordChangeID,
+ key: &HawkKey,
+ forgot_code: Option<&str>,
+ ) -> sqlx::Result<()> {
+ query!(
+ r#"insert into password_change_tokens (id, user_id, hmac_key, forgot_code)
+ values ($1, $2, $3, $4)
+ on conflict (user_id) do update set id = $1, hmac_key = $3, forgot_code = $4,
+ expires_at = default"#,
+ id as _,
+ user as _,
+ key as _,
+ forgot_code,
+ )
+ .execute(&mut *self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn finish_password_change(
+ &self,
+ id: &PasswordChangeID,
+ is_forgot: bool,
+ ) -> sqlx::Result<(HawkKey, (UserID, Option<String>))> {
+ query!(
+ r#"delete from password_change_tokens
+ where id = $1 and expires_at > now() and (forgot_code is not null) = $2
+ returning hmac_key as "hmac_key: HawkKey", user_id as "user_id: UserID",
+ forgot_code"#,
+ id as _,
+ is_forgot,
+ )
+ .map(|r| (r.hmac_key, (r.user_id, r.forgot_code)))
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ pub(crate) async fn add_account_reset(
+ &self,
+ user: &UserID,
+ id: &AccountResetID,
+ key: &HawkKey,
+ ) -> sqlx::Result<()> {
+ query!(
+ r#"insert into account_reset_tokens (id, user_id, hmac_key)
+ values ($1, $2, $3)
+ on conflict (user_id) do update set id = $1, hmac_key = $3, expires_at = default"#,
+ id as _,
+ user as _,
+ key as _,
+ )
+ .execute(&mut *self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn finish_account_reset(
+ &self,
+ id: &AccountResetID,
+ ) -> sqlx::Result<(HawkKey, UserID)> {
+ query!(
+ r#"delete from account_reset_tokens
+ where id = $1 and expires_at > now()
+ returning hmac_key as "hmac_key: HawkKey", user_id as "user_id: UserID""#,
+ id as _,
+ )
+ .map(|r| (r.hmac_key, r.user_id))
+ .fetch_one(&mut self.get().await?.tx)
+ .await
+ }
+
+ //
+ //
+ //
+
+ pub async fn add_invite_code(&self, code: &str, expires: DateTime<Utc>) -> sqlx::Result<()> {
+ query!(r#"insert into invite_codes (code, expires_at) values ($1, $2)"#, code, expires,)
+ .execute(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ pub(crate) async fn use_invite_code(&self, code: &str) -> sqlx::Result<()> {
+ query_scalar!(
+ r#"delete from invite_codes where code = $1 and expires_at > now() returning 1"#,
+ code,
+ )
+ .fetch_one(&mut self.get().await?.tx)
+ .await?;
+ Ok(())
+ }
+
+ //
+ //
+ //
+
+ pub(crate) async fn prune_expired_tokens(&self) -> sqlx::Result<()> {
+ query!("call prune_expired_tokens()").execute(&mut self.get().await?.tx).await?;
+ Ok(())
+ }
+
+ pub(crate) async fn prune_expired_verify_codes(&self) -> sqlx::Result<()> {
+ query!("call prune_expired_verify_codes()").execute(&mut self.get().await?.tx).await?;
+ Ok(())
+ }
+}
+
+fn decode_err<E: Error + Send + Sync + 'static>(c: &str) -> impl FnOnce(E) -> sqlx::Error {
+ let index = c.to_string();
+ move |e| sqlx::Error::ColumnDecode { index, source: Box::new(e) }
+}