From 2f8dce44d3f2be74b5c6ec0a2e7f4ceced715328 Mon Sep 17 00:00:00 2001 From: pennae Date: Wed, 13 Jul 2022 10:33:30 +0200 Subject: initial import --- src/db/mod.rs | 1026 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1026 insertions(+) create mode 100644 src/db/mod.rs (limited to 'src/db') diff --git a/src/db/mod.rs b/src/db/mod.rs new file mode 100644 index 0000000..040507f --- /dev/null +++ b/src/db/mod.rs @@ -0,0 +1,1026 @@ +use std::{error::Error, mem::replace, sync::Arc}; + +use anyhow::Result; +use chrono::{DateTime, Duration, Utc}; +use password_hash::SaltString; +use rocket::{ + fairing::{self, Fairing}, + futures::lock::{MappedMutexGuard, Mutex, MutexGuard}, + http::Status, + request::{FromRequest, Outcome}, + Request, Response, Sentinel, +}; +use serde_json::Value; +use sqlx::{query, query_as, query_scalar, PgPool, Postgres, Transaction}; + +use crate::{ + crypto::WrappedKeyBundle, + types::{ + oauth::ScopeSet, AccountResetID, AttachedClient, Avatar, AvatarID, Device, DeviceCommand, + DeviceCommands, DeviceID, DevicePush, DeviceUpdate, HawkKey, KeyFetchID, OauthAccessToken, + OauthAccessType, OauthAuthorization, OauthAuthorizationID, OauthRefreshToken, OauthTokenID, + PasswordChangeID, SecretKey, SessionID, User, UserID, UserSession, VerifyCode, VerifyHash, + }, +}; + +// we implement a completely custom db type set instead of using rocket_db_pools for two reasons: +// 1. rocket_db_pools doesn't support sqlx 0.6 +// 2. we want one transaction per *request*, including all guards + +#[derive(Clone)] +pub struct Db { + db: Arc, +} + +impl Db { + pub async fn connect(db: &str) -> Result { + Ok(Self { db: Arc::new(PgPool::connect(db).await?) }) + } + + pub async fn begin(&self) -> Result { + Ok(DbConn(Mutex::new(ConnState::Capable(self.clone())))) + } + + pub async fn migrate(&self) -> Result<()> { + sqlx::migrate!().run(&*self.db).await?; + Ok(()) + } +} + +struct ActiveConn { + tx: Transaction<'static, Postgres>, + always_commit: bool, +} + +#[allow(clippy::large_enum_variant)] +enum ConnState { + None, + Capable(Db), + Active(ActiveConn), + Done, +} + +pub struct DbConn(Mutex); + +struct DbWrap(Db); +struct DbConnWrap(DbConn); +type DbLock<'c> = MappedMutexGuard<'c, ConnState, ActiveConn>; + +#[rocket::async_trait] +impl Fairing for Db { + fn info(&self) -> fairing::Info { + fairing::Info { + name: "db access", + kind: fairing::Kind::Ignite | fairing::Kind::Response | fairing::Kind::Singleton, + } + } + + async fn on_ignite(&self, rocket: rocket::Rocket) -> fairing::Result { + Ok(rocket.manage(DbWrap(self.clone())).manage(self.clone())) + } + + async fn on_response<'r>(&self, req: &'r Request<'_>, resp: &mut Response<'r>) { + let conn = req.local_cache(|| DbConnWrap(DbConn(Mutex::new(ConnState::None)))); + let s = replace(&mut *conn.0 .0.lock().await, ConnState::Done); + if let ConnState::Active(ActiveConn { tx, always_commit }) = s { + // don't commit if the request failed, unless explicitly asked to. + // this is used by key fetch to invalidate tokens even if the header + // signature of the request is incorrect. + if !always_commit && resp.status() != Status::Ok { + return; + } + if let Err(e) = tx.commit().await { + resp.set_status(Status::InternalServerError); + error!("commit failed: {}", e); + } + } + } +} + +impl DbConn { + pub async fn commit(self) -> Result<(), sqlx::Error> { + match self.0.into_inner() { + ConnState::None => Ok(()), + ConnState::Capable(_) => Ok(()), + ConnState::Active(ac) => ac.tx.commit().await, + ConnState::Done => Ok(()), + } + } + + fn register<'r>(req: &'r Request<'_>) -> &'r DbConn { + &req.local_cache(|| { + let db = req.rocket().state::().unwrap().0.clone(); + DbConnWrap(DbConn(Mutex::new(ConnState::Capable(db)))) + }) + .0 + } + + // defer opening a transaction and thus locking the mutex holding it. + // without deferral the order of arguments in route signatures is important, + // which may be surprising: placing a DbConn before a guard that also uses + // the database causes a concurrent transaction error. + // HACK maybe we should return errors instead of panicking, but there's no + // way an error here is not a severe bug + async fn get(&self) -> sqlx::Result> { + let mut m = match self.0.try_lock() { + Some(m) => m, + None => panic!("attempted to open concurrent transactions"), + }; + match &*m { + ConnState::Capable(db) => { + *m = ConnState::Active(ActiveConn { + tx: db.db.begin().await?, + always_commit: false, + }); + }, + ConnState::None | ConnState::Done => panic!("db connection requested after teardown"), + _ => (), + } + Ok(MutexGuard::map(m, |g| match g { + ConnState::Active(ref mut tx) => tx, + _ => unreachable!(), + })) + } +} + +impl<'r> Sentinel for &'r Db { + fn abort(rocket: &rocket::Rocket) -> bool { + rocket.state::().is_none() + } +} + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for &'r Db { + type Error = anyhow::Error; + + async fn from_request(req: &'r Request<'_>) -> Outcome { + Outcome::Success(&req.rocket().state::().unwrap().0) + } +} + +impl<'r> Sentinel for &'r DbConn { + fn abort(rocket: &rocket::Rocket) -> bool { + rocket.state::().is_none() + } +} + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for &'r DbConn { + type Error = anyhow::Error; + + async fn from_request(req: &'r Request<'_>) -> Outcome { + Outcome::Success(DbConn::register(req)) + } +} + +// +// +// + +// +// +// + +impl DbConn { + pub(crate) async fn always_commit(&self) -> Result<()> { + self.get().await?.always_commit = true; + Ok(()) + } + + // + // + // + + pub(crate) async fn add_session( + &self, + id: SessionID, + user_id: &UserID, + key: HawkKey, + verified: bool, + verify_code: Option<&str>, + ) -> sqlx::Result> { + query_scalar!( + r#"insert into user_session (session_id, user_id, req_hmac_key, device_id, verified, + verify_code) + values ($1, $2, $3, null, $4, $5) + returning created_at"#, + id as _, + user_id as _, + key as _, + verified, + verify_code, + ) + .fetch_one(&mut self.get().await?.tx) + .await + } + + pub(crate) async fn use_session(&self, id: &SessionID) -> sqlx::Result { + query_as!( + UserSession, + r#"update user_session + set last_active = now() + where session_id = $1 + returning user_id as "uid: UserID", req_hmac_key as "req_hmac_key: HawkKey", + device_id as "device_id: DeviceID", created_at, verified, verify_code"#, + id as _ + ) + .fetch_one(&mut self.get().await?.tx) + .await + } + + pub(crate) async fn use_session_from_refresh( + &self, + id: &OauthTokenID, + ) -> sqlx::Result<(SessionID, UserSession)> { + query!( + r#"update user_session + set last_active = now() + where session_id = ( + select session_id from oauth_token where kind = 'refresh' and id = $1 + ) + returning user_id as "uid: UserID", req_hmac_key as "req_hmac_key: HawkKey", + device_id as "device_id: DeviceID", session_id as "session_id: SessionID", + created_at, verified, verify_code"#, + id as _ + ) + .map(|r| { + ( + r.session_id.clone(), + UserSession { + uid: r.uid, + req_hmac_key: r.req_hmac_key, + device_id: r.device_id, + created_at: r.created_at, + verified: r.verified, + verify_code: r.verify_code, + }, + ) + }) + .fetch_one(&mut self.get().await?.tx) + .await + } + + pub(crate) async fn delete_session(&self, user: &UserID, id: &SessionID) -> sqlx::Result<()> { + query_scalar!( + r#"delete from user_session + where user_id = $1 and session_id = $2 + returning 1"#, + user as _, + id as _ + ) + .fetch_one(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn set_session_device<'d>( + &self, + id: &SessionID, + dev: Option<&'d DeviceID>, + ) -> sqlx::Result<()> { + query!( + r#"update user_session set device_id = $1 where session_id = $2"#, + dev as _, + id as _ + ) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn set_session_verified(&self, id: &SessionID) -> sqlx::Result<()> { + query_scalar!( + r#"update user_session + set verified = true, verify_code = null + where session_id = $1 + returning 1"#, + id as _ + ) + .fetch_one(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + // + // + // + + pub(crate) async fn enqueue_command( + &self, + target: &DeviceID, + sender: &Option, + command: &str, + payload: &Value, + ttl: u32, + ) -> sqlx::Result { + let expires = Utc::now() + Duration::seconds(ttl as i64); + query!( + r#"insert into device_commands (device_id, command, payload, expires, sender) + values ($1, $2, $3, $4, $5) + returning index"#, + target as _, + command, + payload, + expires, + sender.as_ref().map(ToString::to_string) + ) + .map(|x| x.index) + .fetch_one(&mut self.get().await?.tx) + .await + } + + pub(crate) async fn get_commands( + &self, + user: &UserID, + dev: &DeviceID, + min_index: i64, + limit: i64, + ) -> sqlx::Result<(bool, Vec)> { + // NOTE while fxa api docs state that command queries return only commands enqueued + // *after* index, what pushbox actually does is return commands *starting at* index! + let mut results = query_as!( + DeviceCommand, + r#"select index, command, payload, expires, sender + from device_commands join device using (device_id) + where index >= $1 and device_id = $2 and user_id = $3 + order by index + limit $4"#, + min_index, + dev as _, + user as _, + limit + 1 + ) + .fetch_all(&mut self.get().await?.tx) + .await?; + let more = results.len() > limit as usize; + results.truncate(limit as usize); + Ok((more, results)) + } + + // + // + // + + pub(crate) async fn get_devices(&self, user: &UserID) -> sqlx::Result> { + query_as!( + Device, + r#"select d.device_id as "device_id: DeviceID", d.name, d.type as type_, + d.push as "push: DevicePush", + d.available_commands as "available_commands: DeviceCommands", + d.push_expired, d.location, coalesce(us.last_active, to_timestamp(0)) as "last_active!" + from device d left join user_session us using (device_id) + where d.user_id = $1"#, + user as _ + ) + .fetch_all(&mut self.get().await?.tx) + .await + } + + pub(crate) async fn get_device(&self, user: &UserID, dev: &DeviceID) -> sqlx::Result { + query_as!( + Device, + r#"select d.device_id as "device_id: DeviceID", d.name, d.type as type_, + d.push as "push: DevicePush", + d.available_commands as "available_commands: DeviceCommands", + d.push_expired, d.location, coalesce(us.last_active, to_timestamp(0)) as "last_active!" + from device d left join user_session us using (device_id) + where d.user_id = $1 and d.device_id = $2"#, + user as _, + dev as _ + ) + .fetch_one(&mut self.get().await?.tx) + .await + } + + pub(crate) async fn change_device<'d>( + &self, + user: &UserID, + id: &DeviceID, + dev: DeviceUpdate<'d>, + ) -> sqlx::Result { + query_as!( + Device, + r#"select device_id as "device_id!: DeviceID", name as "name!", type as "type_!", + push as "push: DevicePush", + available_commands as "available_commands!: DeviceCommands", + push_expired as "push_expired!", location as "location!", + coalesce(last_active, to_timestamp(0)) as "last_active!" + from insert_or_update_device($1, $2, $3, $4, $5, $6, $7) as iud + left join user_session using (device_id)"#, + id as _, + user as _, + // these two are not optional but are Option anyway. the db will + // refuse insertions that don't have them set. + dev.name, + dev.type_, + dev.push as _, + dev.available_commands as _, + dev.location + ) + .fetch_one(&mut self.get().await?.tx) + .await + } + + pub(crate) async fn set_push_expired(&self, dev: &DeviceID) -> sqlx::Result<()> { + query!( + r#"update device + set push_expired = true + where device_id = $1"#, + dev as _ + ) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn delete_device(&self, user: &UserID, dev: &DeviceID) -> sqlx::Result<()> { + query_scalar!( + r#"delete from device where user_id = $1 and device_id = $2 returning 1"#, + user as _, + dev as _ + ) + .fetch_one(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + // + // + // + + pub(crate) async fn add_key_fetch( + &self, + id: KeyFetchID, + hmac_key: &HawkKey, + keys: &WrappedKeyBundle, + ) -> sqlx::Result<()> { + query!( + r#"insert into key_fetch (id, hmac_key, keys) values ($1, $2, $3)"#, + id as _, + hmac_key as _, + &keys.to_bytes()[..] + ) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn finish_key_fetch( + &self, + id: &KeyFetchID, + ) -> sqlx::Result<(HawkKey, Vec)> { + query!( + r#"delete from key_fetch + where id = $1 and expires_at > now() + returning hmac_key as "hmac_key: HawkKey", keys"#, + id as _ + ) + .map(|r| (r.hmac_key, r.keys)) + .fetch_one(&mut self.get().await?.tx) + .await + } + + // + // + // + + pub(crate) async fn get_refresh_token( + &self, + id: &OauthTokenID, + ) -> sqlx::Result { + query_as!( + OauthRefreshToken, + r#"select user_id as "user_id: UserID", client_id, scope as "scope: ScopeSet", + session_id as "session_id: SessionID" + from oauth_token + where id = $1 and kind = 'refresh'"#, + id as _ + ) + .fetch_one(&mut self.get().await?.tx) + .await + } + + pub(crate) async fn get_access_token( + &self, + id: &OauthTokenID, + ) -> sqlx::Result { + query_as!( + OauthAccessToken, + r#"select user_id as "user_id: UserID", client_id, scope as "scope: ScopeSet", + parent_refresh as "parent_refresh: OauthTokenID", + parent_session as "parent_session: SessionID", + expires_at as "expires_at!" + from oauth_token + where id = $1 and kind = 'access' and expires_at > now()"#, + id as _ + ) + .fetch_one(&mut self.get().await?.tx) + .await + } + + pub(crate) async fn add_refresh_token( + &self, + id: &OauthTokenID, + token: OauthRefreshToken, + ) -> sqlx::Result<()> { + query!( + r#"insert into oauth_token (id, kind, user_id, client_id, scope, session_id) + values ($1, 'refresh', $2, $3, $4, $5)"#, + id as _, + token.user_id as _, + token.client_id, + token.scope as _, + token.session_id as _ + ) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn add_access_token( + &self, + id: &OauthTokenID, + token: OauthAccessToken, + ) -> sqlx::Result<()> { + query!( + r#"insert into oauth_token (id, kind, user_id, client_id, scope, session_id, + parent_refresh, parent_session, expires_at) + values ($1, 'access', $2, $3, $4, null, $5, $6, $7)"#, + id as _, + token.user_id as _, + token.client_id, + token.scope as _, + token.parent_refresh as _, + token.parent_session as _, + token.expires_at, + ) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn delete_oauth_token(&self, id: &OauthTokenID) -> sqlx::Result<()> { + query!(r#"delete from oauth_token where id = $1"#, id as _) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn delete_refresh_token(&self, id: &OauthTokenID) -> sqlx::Result<()> { + query!(r#"delete from oauth_token where id = $1 and kind = 'refresh'"#, id as _) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + // + // + // + + pub(crate) async fn add_oauth_authorization( + &self, + id: &OauthAuthorizationID, + auth: OauthAuthorization, + ) -> sqlx::Result<()> { + query!( + r#"insert into oauth_authorization (id, user_id, client_id, scope, access_type, + code_challenge, keys_jwe, auth_at) + values ($1, $2, $3, $4, $5, $6, $7, $8)"#, + id as _, + auth.user_id as _, + auth.client_id, + auth.scope as _, + auth.access_type as _, + auth.code_challenge, + auth.keys_jwe, + auth.auth_at, + ) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn take_oauth_authorization( + &self, + id: &OauthAuthorizationID, + ) -> sqlx::Result { + query_as!( + OauthAuthorization, + r#"delete from oauth_authorization + where id = $1 and expires_at > now() + returning user_id as "user_id: UserID", client_id, scope as "scope: ScopeSet", + access_type as "access_type: OauthAccessType", + code_challenge, keys_jwe, auth_at"#, + id as _ + ) + .fetch_one(&mut self.get().await?.tx) + .await + } + + // + // + // + + pub(crate) async fn user_email_exists(&self, email: &str) -> sqlx::Result { + Ok(query_scalar!(r#"select 1 from users where email = lower($1)"#, email) + .fetch_optional(&mut self.get().await?.tx) + .await? + .is_some()) + } + + pub(crate) async fn add_user(&self, user: User) -> sqlx::Result { + let id = UserID::random(); + query_scalar!( + r#"insert into users (user_id, auth_salt, email, ka, wrapwrap_kb, verify_hash, + display_name) + values ($1, $2, $3, $4, $5, $6, $7)"#, + id as _, + user.auth_salt.as_str(), + user.email, + user.ka as _, + user.wrapwrap_kb as _, + user.verify_hash as _, + user.display_name, + ) + .execute(&mut self.get().await?.tx) + .await?; + Ok(id) + } + + pub(crate) async fn get_user(&self, email: &str) -> sqlx::Result<(UserID, User)> { + query!( + r#"select user_id as "id: UserID", auth_salt as "auth_salt: String", email, + ka as "ka: SecretKey", wrapwrap_kb as "wrapwrap_kb: SecretKey", + verify_hash as "verify_hash: VerifyHash", display_name, verified + from users + where email = lower($1)"#, + email + ) + .try_map(|r| { + Ok(( + r.id, + User { + auth_salt: SaltString::new(&r.auth_salt).map_err(decode_err("auth_salt"))?, + email: r.email, + ka: r.ka, + wrapwrap_kb: r.wrapwrap_kb, + verify_hash: r.verify_hash, + display_name: r.display_name, + verified: r.verified, + }, + )) + }) + .fetch_one(&mut self.get().await?.tx) + .await + } + + pub(crate) async fn get_user_by_id(&self, id: &UserID) -> sqlx::Result { + query!( + r#"select auth_salt as "auth_salt: String", email, + ka as "ka: SecretKey", wrapwrap_kb as "wrapwrap_kb: SecretKey", + verify_hash as "verify_hash: VerifyHash", display_name, verified + from users + where user_id = $1"#, + id as _ + ) + .try_map(|r| { + Ok(User { + auth_salt: SaltString::new(&r.auth_salt).map_err(decode_err("auth_salt"))?, + email: r.email, + ka: r.ka, + wrapwrap_kb: r.wrapwrap_kb, + verify_hash: r.verify_hash, + display_name: r.display_name, + verified: r.verified, + }) + }) + .fetch_one(&mut self.get().await?.tx) + .await + } + + pub(crate) async fn set_user_name(&self, id: &UserID, name: &str) -> sqlx::Result<()> { + query!( + "update users + set display_name = $2 + where user_id = $1", + id as _, + name, + ) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn set_user_verified(&self, id: &UserID) -> sqlx::Result<()> { + query_scalar!("update users set verified = true where user_id = $1 returning 1", id as _) + .fetch_one(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn delete_user(&self, email: &str) -> sqlx::Result<()> { + query_scalar!(r#"delete from users where email = lower($1)"#, email) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn change_user_auth( + &self, + uid: &UserID, + salt: SaltString, + wwkb: SecretKey, + verify_hash: VerifyHash, + ) -> sqlx::Result<()> { + query!( + r#"update users + set auth_salt = $2, wrapwrap_kb = $3, verify_hash = $4 + where user_id = $1"#, + uid as _, + salt.to_string(), + wwkb as _, + verify_hash as _, + ) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn reset_user_auth( + &self, + uid: &UserID, + salt: SaltString, + wwkb: SecretKey, + verify_hash: VerifyHash, + ) -> sqlx::Result<()> { + query!( + r#"call reset_user_auth($1, $2, $3, $4)"#, + uid as _, + salt.to_string(), + wwkb as _, + verify_hash as _, + ) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + // + // + // + + pub(crate) async fn get_attached_clients( + &self, + id: &UserID, + ) -> sqlx::Result> { + query_as!( + AttachedClient, + r#"select + ot.client_id as "client_id?", + d.device_id as "device_id?: DeviceID", + us.session_id as "session_token_id?: SessionID", + ot.id as "refresh_token_id?: OauthTokenID", + d.type as "device_type?", + d.name as "name?", + coalesce(d.created_at, us.created_at, ot.created_at) as "created_time?", + us.last_active as "last_access_time?", + ot.scope as "scope?" + from device d + full outer join user_session us on (d.device_id = us.device_id) + full outer join oauth_token ot on (us.session_id = ot.session_id) + where + (ot.kind is null or ot.kind = 'refresh') + and $1 in (d.user_id, us.user_id, ot.user_id) + order by d.device_id"#, + id as _, + ) + .fetch_all(&mut self.get().await?.tx) + .await + } + + // + // + // + + pub(crate) async fn get_user_avatar_id(&self, id: &UserID) -> sqlx::Result> { + query!(r#"select id as "id: AvatarID" from user_avatars where user_id = $1"#, id as _,) + .map(|r| r.id) + .fetch_optional(&mut *self.get().await?.tx) + .await + } + + pub(crate) async fn get_user_avatar(&self, id: &AvatarID) -> sqlx::Result> { + query_as!( + Avatar, + r#"select id as "id: AvatarID", data, content_type + from user_avatars + where id = $1"#, + id as _, + ) + .fetch_optional(&mut *self.get().await?.tx) + .await + } + + pub(crate) async fn set_user_avatar(&self, id: &UserID, avatar: Avatar) -> sqlx::Result<()> { + query!( + r#"insert into user_avatars (user_id, id, data, content_type) + values ($1, $2, $3, $4) + on conflict (user_id) do update set + id = $2, data = $3, content_type = $4"#, + id as _, + avatar.id as _, + avatar.data, + avatar.content_type, + ) + .execute(&mut *self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn delete_user_avatar( + &self, + user: &UserID, + id: &AvatarID, + ) -> sqlx::Result<()> { + query!(r#"delete from user_avatars where user_id = $1 and id = $2"#, user as _, id as _,) + .execute(&mut *self.get().await?.tx) + .await?; + Ok(()) + } + + // + // + // + + pub(crate) async fn add_verify_code( + &self, + user: &UserID, + session: &SessionID, + code: &str, + ) -> sqlx::Result<()> { + query!( + r#"insert into verify_codes (user_id, session_id, code) + values ($1, $2, $3)"#, + user as _, + session as _, + code, + ) + .execute(&mut *self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn get_verify_code( + &self, + user: &UserID, + ) -> sqlx::Result<(String, VerifyCode)> { + query!( + r#"select user_id as "user_id: UserID", session_id as "session_id: SessionID", code, + email + from verify_codes join users using (user_id) + where user_id = $1"#, + user as _, + ) + .map(|r| { + (r.email, VerifyCode { user_id: r.user_id, session_id: r.session_id, code: r.code }) + }) + .fetch_one(&mut *self.get().await?.tx) + .await + } + + pub(crate) async fn try_use_verify_code( + &self, + user: &UserID, + code: &str, + ) -> sqlx::Result> { + query_as!( + VerifyCode, + r#"delete from verify_codes + where user_id = $1 and code = $2 + returning user_id as "user_id: UserID", session_id as "session_id: SessionID", + code"#, + user as _, + code, + ) + .fetch_optional(&mut *self.get().await?.tx) + .await + } + + // + // + // + + pub(crate) async fn add_password_change( + &self, + user: &UserID, + id: &PasswordChangeID, + key: &HawkKey, + forgot_code: Option<&str>, + ) -> sqlx::Result<()> { + query!( + r#"insert into password_change_tokens (id, user_id, hmac_key, forgot_code) + values ($1, $2, $3, $4) + on conflict (user_id) do update set id = $1, hmac_key = $3, forgot_code = $4, + expires_at = default"#, + id as _, + user as _, + key as _, + forgot_code, + ) + .execute(&mut *self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn finish_password_change( + &self, + id: &PasswordChangeID, + is_forgot: bool, + ) -> sqlx::Result<(HawkKey, (UserID, Option))> { + query!( + r#"delete from password_change_tokens + where id = $1 and expires_at > now() and (forgot_code is not null) = $2 + returning hmac_key as "hmac_key: HawkKey", user_id as "user_id: UserID", + forgot_code"#, + id as _, + is_forgot, + ) + .map(|r| (r.hmac_key, (r.user_id, r.forgot_code))) + .fetch_one(&mut self.get().await?.tx) + .await + } + + pub(crate) async fn add_account_reset( + &self, + user: &UserID, + id: &AccountResetID, + key: &HawkKey, + ) -> sqlx::Result<()> { + query!( + r#"insert into account_reset_tokens (id, user_id, hmac_key) + values ($1, $2, $3) + on conflict (user_id) do update set id = $1, hmac_key = $3, expires_at = default"#, + id as _, + user as _, + key as _, + ) + .execute(&mut *self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn finish_account_reset( + &self, + id: &AccountResetID, + ) -> sqlx::Result<(HawkKey, UserID)> { + query!( + r#"delete from account_reset_tokens + where id = $1 and expires_at > now() + returning hmac_key as "hmac_key: HawkKey", user_id as "user_id: UserID""#, + id as _, + ) + .map(|r| (r.hmac_key, r.user_id)) + .fetch_one(&mut self.get().await?.tx) + .await + } + + // + // + // + + pub async fn add_invite_code(&self, code: &str, expires: DateTime) -> sqlx::Result<()> { + query!(r#"insert into invite_codes (code, expires_at) values ($1, $2)"#, code, expires,) + .execute(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + pub(crate) async fn use_invite_code(&self, code: &str) -> sqlx::Result<()> { + query_scalar!( + r#"delete from invite_codes where code = $1 and expires_at > now() returning 1"#, + code, + ) + .fetch_one(&mut self.get().await?.tx) + .await?; + Ok(()) + } + + // + // + // + + pub(crate) async fn prune_expired_tokens(&self) -> sqlx::Result<()> { + query!("call prune_expired_tokens()").execute(&mut self.get().await?.tx).await?; + Ok(()) + } + + pub(crate) async fn prune_expired_verify_codes(&self) -> sqlx::Result<()> { + query!("call prune_expired_verify_codes()").execute(&mut self.get().await?.tx).await?; + Ok(()) + } +} + +fn decode_err(c: &str) -> impl FnOnce(E) -> sqlx::Error { + let index = c.to_string(); + move |e| sqlx::Error::ColumnDecode { index, source: Box::new(e) } +} -- cgit v1.2.3