summaryrefslogtreecommitdiff
path: root/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib.rs')
-rw-r--r--src/lib.rs319
1 files changed, 319 insertions, 0 deletions
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 0000000..1e6fa31
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,319 @@
+use std::{
+ path::PathBuf,
+ sync::Arc,
+ time::{Duration as StdDuration, SystemTime, UNIX_EPOCH},
+};
+
+use anyhow::Context;
+use chrono::Duration;
+use db::Db;
+use futures::Future;
+use lettre::message::Mailbox;
+use mailer::Mailer;
+use push::PushClient;
+use rocket::{
+ fairing::AdHoc,
+ http::{uri::Absolute, ContentType, Header},
+ request::{self, FromRequest},
+ response::Redirect,
+ tokio::{
+ spawn,
+ time::{interval_at, Instant, MissedTickBehavior},
+ },
+ Request, State,
+};
+use serde_json::{json, Value};
+use utils::DeferredActions;
+
+use crate::api::auth::invite::generate_invite_link;
+
+#[macro_use]
+extern crate rocket;
+#[macro_use]
+extern crate anyhow;
+#[macro_use]
+extern crate lazy_static;
+
+#[macro_use]
+pub(crate) mod utils;
+pub(crate) mod api;
+mod auth;
+mod cache;
+mod crypto;
+pub mod db;
+mod js;
+mod mailer;
+mod push;
+mod types;
+
+fn default_push_ttl() -> std::time::Duration {
+ std::time::Duration::from_secs(2 * 86400)
+}
+
+fn default_task_interval() -> std::time::Duration {
+ std::time::Duration::from_secs(5 * 60)
+}
+
+#[derive(serde::Deserialize)]
+struct Config {
+ database_url: String,
+ location: Absolute<'static>,
+ token_server_location: Absolute<'static>,
+ vapid_key: PathBuf,
+ vapid_subject: String,
+ #[serde(default = "default_push_ttl", with = "humantime_serde")]
+ default_push_ttl: std::time::Duration,
+ #[serde(default = "default_task_interval", with = "humantime_serde")]
+ prune_expired_interval: std::time::Duration,
+
+ mail_from: Mailbox,
+ mail_host: Option<String>,
+ mail_port: Option<u16>,
+
+ #[serde(default)]
+ invite_only: bool,
+ #[serde(default)]
+ invite_admin_address: String,
+}
+
+impl Config {
+ pub fn avatars_prefix(&self) -> Absolute<'static> {
+ Absolute::parse_owned(format!("{}/avatars", self.location)).unwrap()
+ }
+}
+
+#[get("/")]
+async fn root() -> (ContentType, &'static str) {
+ (ContentType::HTML, include_str!("../web/index.html"))
+}
+
+#[get("/settings/<_..>")]
+async fn settings() -> Redirect {
+ Redirect::to(uri!("/#/settings"))
+}
+
+#[get("/auth/v1/authorization")]
+async fn auth_auth() -> (ContentType, &'static str) {
+ root().await
+}
+
+#[get("/force_auth")]
+async fn force_auth() -> Redirect {
+ Redirect::to(uri!("/#/force_auth"))
+}
+
+#[derive(Debug)]
+struct IsFenix(bool);
+
+#[rocket::async_trait]
+impl<'r> FromRequest<'r> for IsFenix {
+ type Error = std::convert::Infallible;
+
+ async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
+ let ua = request.headers().get_one("user-agent");
+ request::Outcome::Success(IsFenix(
+ ua.map(|ua| ua.contains("Firefox") && ua.contains("Android")).unwrap_or(false),
+ ))
+ }
+}
+
+#[get("/.well-known/fxa-client-configuration")]
+async fn fxa_client_configuration(cfg: &State<Config>, is_fenix: IsFenix) -> Value {
+ let base = &cfg.location;
+ json!({
+ "auth_server_base_url": format!("{base}/auth"),
+ "oauth_server_base_url": format!("{base}/oauth"),
+ "pairing_server_base_uri": format!("{base}/pairing"),
+ "profile_server_base_url": format!("{base}/profile"),
+ // NOTE trailing slash is *essential*, otherwise fenix will refuse to sync.
+ // likewise firefox desktop seems to misbehave if there *is* a trailing slash.
+ "sync_tokenserver_base_url": format!("{}{}", cfg.token_server_location, if is_fenix.0 { "/" } else { "" })
+ })
+}
+
+// NOTE it looks like firefox does not implement refresh token rotation.
+// since it also looks like it doesn't implement MTLS we can't secure
+// refresh tokens against being stolen as advised by
+// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-security-topics
+// section 2.2.2
+
+// NOTE firefox "oldsync" scope is the current version?
+// https://github.com/mozilla/fxa/blob/main/packages/fxa-auth-server/docs/oauth/scopes.md
+// https://mozilla.github.io/ecosystem-platform/explanation/onepw-protocol
+// https://mozilla.github.io/ecosystem-platform/api
+// https://github.com/mozilla/fxa/blob/main/packages/fxa-auth-server/docs/device_registration.md
+// -> push for everything
+// https://mozilla.github.io/ecosystem-platform/explanation/scoped-keys
+
+#[get("/.well-known/openid-configuration")]
+fn oid(cfg: &State<Config>) -> Value {
+ let base = &cfg.location;
+ json!({
+ "authorization_endpoint": format!("{base}/auth/v1/authorization"),
+ "introspection_endpoint": format!("{base}/oauth/v1/introspect"),
+ "issuer": base.to_string(),
+ "jwks_uri": format!("{base}/oauth/v1/jwks"),
+ "revocation_endpoint": format!("{base}/oauth/v1/destroy"),
+ "token_endpoint": format!("{base}/auth/v1/oauth/token"),
+ "userinfo_endpoint": format!("{base}/profile/v1/profile"),
+ "claims_supported": ["aud","exp","iat","iss","sub"],
+ "id_token_signing_alg_values_supported": ["RS256"],
+ "response_types_supported": ["code","token"],
+ "scopes_supported": ["openid","profile","email"],
+ "subject_types_supported": ["public"],
+ "token_endpoint_auth_methods_supported": ["client_secret_post"],
+ })
+}
+
+fn spawn_periodic<A, P, F>(context: &'static str, t: StdDuration, p: P, f: A)
+where
+ A: Fn(P) -> F + Send + Sync + Sized + 'static,
+ P: Clone + Send + Sync + 'static,
+ F: Future<Output = anyhow::Result<()>> + Send + Sized,
+{
+ let mut interval = interval_at(Instant::now() + t, t);
+ interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
+
+ spawn(async move {
+ loop {
+ interval.tick().await;
+ info!("starting periodic {context}");
+ if let Err(e) = f(p.clone()).await {
+ error!("periodic {context} failed: {e}");
+ }
+ }
+ });
+}
+
+async fn ensure_invite_admin(db: &Db, cfg: &Config) -> anyhow::Result<()> {
+ if !cfg.invite_only {
+ return Ok(());
+ }
+
+ let tx = db.begin().await?;
+ match tx.get_user(&cfg.invite_admin_address).await {
+ Err(sqlx::Error::RowNotFound) => {
+ let url = generate_invite_link(&tx, cfg, Duration::hours(1)).await?;
+ tx.commit().await?;
+ warn!("admin user {} does not exist, register at {url}", cfg.invite_admin_address);
+ Ok(())
+ },
+ Err(e) => Err(anyhow!(e)),
+ Ok(_) => Ok(()),
+ }
+}
+
+pub async fn build() -> anyhow::Result<rocket::Rocket<rocket::Build>> {
+ let rocket = rocket::build();
+ let config = rocket.figment().extract::<Config>().context("reading config")?;
+ let db = Arc::new(Db::connect(&config.database_url).await.unwrap());
+
+ db.migrate().await.context("running db migrations")?;
+
+ ensure_invite_admin(&db, &config).await?;
+ let push = Arc::new(
+ PushClient::new(
+ &config.vapid_key,
+ &config.vapid_subject,
+ config.location.clone(),
+ config.default_push_ttl,
+ )
+ .context("setting up push notifications")?,
+ );
+ let mailer = Arc::new(
+ Mailer::new(
+ config.mail_from.clone(),
+ config.mail_host.as_deref().unwrap_or("localhost"),
+ config.mail_port.unwrap_or(25),
+ config.location.clone(),
+ )
+ .context("setting up mail notifications")?,
+ );
+ spawn_periodic("verify code prune", StdDuration::from_secs(5 * 60), Arc::clone(&db), {
+ |db| async move {
+ let tx = db.begin().await?;
+ tx.prune_expired_verify_codes().await?;
+ tx.commit().await?;
+ Ok(())
+ }
+ });
+ spawn_periodic("expired token prune", config.prune_expired_interval, Arc::clone(&db), {
+ |db| async move {
+ let tx = db.begin().await?;
+ tx.prune_expired_tokens().await?;
+ tx.commit().await?;
+ Ok(())
+ }
+ });
+ let rocket = rocket
+ .manage(config)
+ .manage(push)
+ .manage(mailer)
+ .attach(db)
+ .attach(DeferredActions)
+ .mount("/", routes![root, settings, oid, auth_auth, force_auth, fxa_client_configuration,])
+ .register("/auth/v1", catchers![api::auth::catch_all,])
+ .mount(
+ "/auth/v1",
+ routes![
+ api::auth::account::create,
+ api::auth::account::login,
+ api::auth::account::destroy,
+ api::auth::account::keys,
+ api::auth::account::reset,
+ api::auth::oauth::token_authenticated,
+ api::auth::oauth::token_unauthenticated,
+ api::auth::oauth::destroy,
+ api::auth::oauth::scoped_key_data,
+ api::auth::device::devices,
+ api::auth::device::device,
+ api::auth::device::invoke,
+ api::auth::device::commands,
+ api::auth::session::status,
+ api::auth::session::resend_code,
+ api::auth::session::verify_code,
+ api::auth::session::destroy,
+ api::auth::oauth::authorization,
+ api::auth::device::destroy,
+ api::auth::device::notify,
+ api::auth::device::attached_clients,
+ api::auth::device::destroy_attached_client,
+ api::auth::email::status,
+ api::auth::email::verify_code,
+ api::auth::email::resend_code,
+ api::auth::password::change_start,
+ api::auth::password::change_finish,
+ api::auth::password::forgot_start,
+ api::auth::password::forgot_finish,
+ ],
+ )
+ // slight hack to allow the js auth client to "just work"
+ .register("/_invite/v1", catchers![api::auth::catch_all,])
+ .mount("/_invite/v1", routes![api::auth::invite::generate,])
+ .attach(AdHoc::on_response("/auth Timestamp", |req, resp| {
+ Box::pin(async move {
+ if req.uri().path().as_str().starts_with("/auth/v1/") {
+ if let Ok(ts) = SystemTime::now().duration_since(UNIX_EPOCH) {
+ resp.set_header(Header::new("timestamp", ts.as_secs().to_string()));
+ }
+ }
+ })
+ }))
+ .register("/profile", catchers![api::profile::catch_all,])
+ .mount(
+ "/profile/v1",
+ routes![
+ api::profile::profile,
+ api::profile::display_name_post,
+ api::profile::avatar_get,
+ api::profile::avatar_upload,
+ api::profile::avatar_delete,
+ ],
+ )
+ .register("/avatars", catchers![api::profile::catch_all,])
+ .mount("/avatars", routes![api::profile::avatar_get_img])
+ .register("/oauth/v1", catchers![api::oauth::catch_all,])
+ .mount("/oauth/v1", routes![api::oauth::destroy, api::oauth::jwks, api::oauth::verify,])
+ .mount("/js", routes![js::static_js]);
+ Ok(rocket)
+}