use std::str::FromStr; use std::time::Duration; use anyhow::Result; use hawk::{DigestAlgorithm, Header, Key, PayloadHasher, RequestBuilder}; use rocket::data::{self, FromData, ToByteUnit}; use rocket::http::Status; use rocket::outcome::{try_outcome, Outcome}; use rocket::request::{local_cache, FromRequest, Request}; use rocket::{request, Data, Ignite, Phase, Rocket, Sentinel}; use serde::Deserialize; use serde_json::error::Category; use crate::db::DbConn; use crate::types::oauth::ScopeSet; use crate::types::{HawkKey, OauthToken, UserID}; use crate::Config; #[rocket::async_trait] pub(crate) trait AuthSource { type ID: FromStr + Send + Sync + Clone; type Context: Send + Sync; async fn hawk(r: &Request<'_>, id: &Self::ID) -> Result<(HawkKey, Self::Context)>; async fn bearer_token(r: &Request<'_>, id: &OauthToken) -> Result<(Self::ID, Self::Context)>; } // marker trait and type to communicate that authentication has failed with invalid // tokens used. this is needed to properly translate these error for the profile api. pub(crate) trait AuthenticatedRequest { fn invalid_token_used(&self) -> bool; } struct InvalidTokenUsed; impl<'r> AuthenticatedRequest for Request<'r> { fn invalid_token_used(&self) -> bool { self.local_cache(|| None as Option).is_some() } } #[derive(Debug)] pub(crate) struct Authenticated { pub body: T, pub session: Src::ID, pub context: Src::Context, } enum AuthKind<'a> { Hawk { header: Header }, Token { token: &'a str }, } fn drop_auth_prefix<'a>(s: &'a str, prefix: &str) -> Option<&'a str> { if prefix.len() <= s.len() && s[..prefix.len()].eq_ignore_ascii_case(prefix) { Some(&s[prefix.len()..]) } else { None } } impl Sentinel for Authenticated { fn abort(rocket: &Rocket) -> bool { // NOTE data sentinels are broken in rocket 0.5-rc2 Self::try_get_state(rocket).is_none() || <&DbConn as Sentinel>::abort(rocket) } } impl Authenticated { fn try_get_state(r: &Rocket) -> Option<&Config> { r.state::() } fn state(r: &Rocket) -> &Config { Self::try_get_state(r).unwrap() } async fn parse_auth<'a>( request: &'a Request<'_>, ) -> Outcome, (Status, anyhow::Error), Status> { let auth = match request.headers().get("authorization").take(2).enumerate().last() { Some((0, h)) => h, Some((_, _)) => { return Outcome::Error(( Status::BadRequest, anyhow!("multiple authorization headers present"), )) }, None => return Outcome::Forward(Status::Unauthorized), }; if let Some(hawk) = drop_auth_prefix(auth, "hawk ") { match Header::from_str(hawk) { Ok(header) => Outcome::Success(AuthKind::Hawk { header }), Err(e) => Outcome::Error(( Status::Unauthorized, anyhow!(e).context("malformed hawk header"), )), } } else if let Some(token) = drop_auth_prefix(auth, "bearer ") { Outcome::Success(AuthKind::Token { token }) } else { Outcome::Forward(Status::Unauthorized) } } pub async fn get_conn<'r>(req: &'r Request<'_>) -> Result<&'r DbConn> { match <&'r DbConn as FromRequest<'r>>::from_request(req).await { Outcome::Success(db) => Ok(db), Outcome::Error((_, e)) => Err(e.context("get db connection")), _ => Err(anyhow!("could not get db connection")), } } async fn verify_hawk( request: &Request<'_>, hawk: Header, data: Option<&str>, ) -> Result<(Src::ID, Src::Context)> { let cfg = Self::state(request.rocket()); let url = format!("{}{}", cfg.location, request.uri()); let url = url::Url::parse(&url).unwrap(); let hash = data .map(|d| PayloadHasher::hash("application/json", DigestAlgorithm::Sha256, d)) .transpose()?; let hawk_req = RequestBuilder::from_url(request.method().as_str(), &url)?; let hawk_req = match hash.as_ref() { Some(h) => hawk_req.hash(Some(h.as_ref())).request(), _ => hawk_req.request(), }; let id: Src::ID = match hawk.id.clone().ok_or_else(|| anyhow!("missing hawk key id"))?.parse() { Ok(id) => id, Err(_) => bail!("malformed hawk key id"), }; let (key, context) = Src::hawk(request, &id).await?; let key = Key::new(&key.0, DigestAlgorithm::Sha256)?; // large skew was taken from fxa-auth-server, large clock skews seem to happen if !hawk_req.validate_header(&hawk, &key, Duration::from_secs(20 * 365 * 86400)) { bail!("bad hawk signature"); } Ok((id, context)) } async fn verify_bearer_token( request: &Request<'_>, token: &str, ) -> Result<(Src::ID, Src::Context)> { let token = match token.parse() { Ok(token) => token, Err(_) => bail!("malformed oauth token"), }; Src::bearer_token(request, &token).await } } #[rocket::async_trait] impl<'r, Src: AuthSource> FromRequest<'r> for Authenticated<(), Src> { type Error = anyhow::Error; async fn from_request(request: &'r Request<'_>) -> request::Outcome { let auth = try_outcome!(Self::parse_auth(request).await); let result = match auth { AuthKind::Hawk { header } => Self::verify_hawk(request, header, None).await, AuthKind::Token { token } => Self::verify_bearer_token(request, token).await, }; match result { Ok((session, context)) => { Outcome::Success(Authenticated { body: (), session, context }) }, Err(e) => { request.local_cache(|| Some(InvalidTokenUsed)); Outcome::Error((Status::Unauthorized, anyhow!(e))) }, } } } #[rocket::async_trait] impl<'r, T: Deserialize<'r>, Src: AuthSource> FromData<'r> for Authenticated { type Error = anyhow::Error; async fn from_data(request: &'r Request<'_>, data: Data<'r>) -> data::Outcome<'r, Self> { let auth = try_outcome_data!(data, Self::parse_auth(request).await); let limit = request.rocket().config().limits.get("json").unwrap_or_else(|| 1u32.mebibytes()); let raw_json = match data.open(limit).into_string().await { Ok(r) if r.is_complete() => local_cache!(request, r.into_inner()), Ok(_) => { return data::Outcome::Error(( Status::PayloadTooLarge, anyhow!("request too large"), )) }, Err(e) => return data::Outcome::Error((Status::InternalServerError, e.into())), }; let verify_result = match auth { AuthKind::Hawk { header } => Self::verify_hawk(request, header, Some(raw_json)).await, AuthKind::Token { token } => Self::verify_bearer_token(request, token).await, }; let result = match verify_result { Ok((session, context)) => { serde_json::from_str(raw_json).map(|body| Authenticated { body, session, context }) }, Err(e) => { request.local_cache(|| Some(InvalidTokenUsed)); return Outcome::Error((Status::Unauthorized, anyhow!(e))); }, }; match result { Ok(r) => Outcome::Success(r), Err(e) => { // match Json here to keep catchers generic let status = match e.classify() { Category::Data => Status::UnprocessableEntity, _ => Status::BadRequest, }; Outcome::Error((status, anyhow!(e))) }, } } } #[derive(Debug)] pub(crate) struct WithBearer; #[rocket::async_trait] impl crate::auth::AuthSource for WithBearer { type ID = UserID; type Context = ScopeSet; async fn hawk(_r: &Request<'_>, _id: &Self::ID) -> Result<(HawkKey, Self::Context)> { bail!("hawk signatures not allowed here") } async fn bearer_token( r: &Request<'_>, token: &OauthToken, ) -> Result<(Self::ID, Self::Context)> { let db = Authenticated::<(), Self>::get_conn(r).await?; let t = db.get_access_token(&token.hash()).await?; Ok((t.user_id, t.scope)) } }