diff options
Diffstat (limited to 'src/auth.rs')
-rw-r--r-- | src/auth.rs | 241 |
1 files changed, 241 insertions, 0 deletions
diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 0000000..f56c5e2 --- /dev/null +++ b/src/auth.rs @@ -0,0 +1,241 @@ +use std::str::FromStr; +use std::time::Duration; + +use anyhow::Result; +use hawk::{DigestAlgorithm, Header, Key, PayloadHasher, RequestBuilder}; +use rocket::data::{self, FromData, ToByteUnit}; +use rocket::http::Status; +use rocket::outcome::{try_outcome, Outcome}; +use rocket::request::{local_cache, FromRequest, Request}; +use rocket::{request, Data, Ignite, Phase, Rocket, Sentinel}; +use serde::Deserialize; +use serde_json::error::Category; + +use crate::crypto::SecretBytes; +use crate::db::DbConn; +use crate::types::oauth::ScopeSet; +use crate::types::{OauthToken, UserID}; +use crate::Config; + +#[rocket::async_trait] +pub(crate) trait AuthSource { + type ID: FromStr + Send + Sync + Clone; + type Context: Send + Sync; + async fn hawk(r: &Request<'_>, id: &Self::ID) -> Result<(SecretBytes<32>, Self::Context)>; + async fn bearer_token(r: &Request<'_>, id: &OauthToken) -> Result<(Self::ID, Self::Context)>; +} + +// marker trait and type to communicate that authentication has failed with invalid +// tokens used. this is needed to properly translate these error for the profile api. +pub(crate) trait AuthenticatedRequest { + fn invalid_token_used(&self) -> bool; +} + +struct InvalidTokenUsed; + +impl<'r> AuthenticatedRequest for Request<'r> { + fn invalid_token_used(&self) -> bool { + self.local_cache(|| None as Option<InvalidTokenUsed>).is_some() + } +} + +#[derive(Debug)] +pub(crate) struct Authenticated<T, Src: AuthSource> { + pub body: T, + pub session: Src::ID, + pub context: Src::Context, +} + +enum AuthKind<'a> { + Hawk { header: Header }, + Token { token: &'a str }, +} + +fn drop_auth_prefix<'a>(s: &'a str, prefix: &str) -> Option<&'a str> { + if prefix.len() <= s.len() && s[..prefix.len()].eq_ignore_ascii_case(prefix) { + Some(&s[prefix.len()..]) + } else { + None + } +} + +impl<T, S: AuthSource> Sentinel for Authenticated<T, S> { + fn abort(rocket: &Rocket<Ignite>) -> bool { + // NOTE data sentinels are broken in rocket 0.5-rc2 + Self::try_get_state(rocket).is_none() || <&DbConn as Sentinel>::abort(rocket) + } +} + +impl<T, Src: AuthSource> Authenticated<T, Src> { + fn try_get_state<S: Phase>(r: &Rocket<S>) -> Option<&Config> { + r.state::<Config>() + } + + fn state<S: Phase>(r: &Rocket<S>) -> &Config { + Self::try_get_state(r).unwrap() + } + + async fn parse_auth<'a>( + request: &'a Request<'_>, + ) -> Outcome<AuthKind<'a>, (Status, anyhow::Error), ()> { + let auth = match request.headers().get("authorization").take(2).enumerate().last() { + Some((0, h)) => h, + Some((_, _)) => { + return Outcome::Failure(( + Status::BadRequest, + anyhow!("multiple authorization headers present"), + )) + }, + None => return Outcome::Forward(()), + }; + if let Some(hawk) = drop_auth_prefix(auth, "hawk ") { + match Header::from_str(hawk) { + Ok(header) => Outcome::Success(AuthKind::Hawk { header }), + Err(e) => Outcome::Failure(( + Status::Unauthorized, + anyhow!(e).context("malformed hawk header"), + )), + } + } else if let Some(token) = drop_auth_prefix(auth, "bearer ") { + Outcome::Success(AuthKind::Token { token }) + } else { + Outcome::Forward(()) + } + } + + pub async fn get_conn<'r>(req: &'r Request<'_>) -> Result<&'r DbConn> { + match <&'r DbConn as FromRequest<'r>>::from_request(req).await { + Outcome::Success(db) => Ok(db), + Outcome::Failure((_, e)) => Err(e.context("get db connection")), + _ => Err(anyhow!("could not get db connection")), + } + } + + async fn verify_hawk( + request: &Request<'_>, + hawk: Header, + data: Option<&str>, + ) -> Result<(Src::ID, Src::Context)> { + let cfg = Self::state(request.rocket()); + let url = format!("{}{}", cfg.location, request.uri()); + let url = url::Url::parse(&url).unwrap(); + let hash = data + .map(|d| PayloadHasher::hash("application/json", DigestAlgorithm::Sha256, d)) + .transpose()?; + let hawk_req = RequestBuilder::from_url(request.method().as_str(), &url)?; + let hawk_req = match hash.as_ref() { + Some(h) => hawk_req.hash(Some(h.as_ref())).request(), + _ => hawk_req.request(), + }; + let id: Src::ID = + match hawk.id.clone().ok_or_else(|| anyhow!("missing hawk key id"))?.parse() { + Ok(id) => id, + Err(_) => bail!("malformed hawk key id"), + }; + let (key, context) = Src::hawk(request, &id).await?; + let key = Key::new(&key.0, DigestAlgorithm::Sha256)?; + // large skew was taken from fxa-auth-server, large clock skews seem to happen + if !hawk_req.validate_header(&hawk, &key, Duration::from_secs(20 * 365 * 86400)) { + bail!("bad hawk signature"); + } + Ok((id, context)) + } + + async fn verify_bearer_token( + request: &Request<'_>, + token: &str, + ) -> Result<(Src::ID, Src::Context)> { + let token = match token.parse() { + Ok(token) => token, + Err(_) => bail!("malformed oauth token"), + }; + Src::bearer_token(request, &token).await + } +} + +#[rocket::async_trait] +impl<'r, Src: AuthSource> FromRequest<'r> for Authenticated<(), Src> { + type Error = anyhow::Error; + + async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> { + let auth = try_outcome!(Self::parse_auth(request).await); + let result = match auth { + AuthKind::Hawk { header } => Self::verify_hawk(request, header, None).await, + AuthKind::Token { token } => Self::verify_bearer_token(request, token).await, + }; + match result { + Ok((session, context)) => { + Outcome::Success(Authenticated { body: (), session, context }) + }, + Err(e) => { + request.local_cache(|| Some(InvalidTokenUsed)); + Outcome::Failure((Status::Unauthorized, anyhow!(e))) + }, + } + } +} + +#[rocket::async_trait] +impl<'r, T: Deserialize<'r>, Src: AuthSource> FromData<'r> for Authenticated<T, Src> { + type Error = anyhow::Error; + + async fn from_data(request: &'r Request<'_>, data: Data<'r>) -> data::Outcome<'r, Self> { + let auth = try_outcome_data!(data, Self::parse_auth(request).await); + let limit = + request.rocket().config().limits.get("json").unwrap_or_else(|| 1u32.mebibytes()); + let raw_json = match data.open(limit).into_string().await { + Ok(r) if r.is_complete() => local_cache!(request, r.into_inner()), + Ok(_) => { + return data::Outcome::Failure(( + Status::PayloadTooLarge, + anyhow!("request too large"), + )) + }, + Err(e) => return data::Outcome::Failure((Status::InternalServerError, e.into())), + }; + let verify_result = match auth { + AuthKind::Hawk { header } => Self::verify_hawk(request, header, Some(raw_json)).await, + AuthKind::Token { token } => Self::verify_bearer_token(request, token).await, + }; + let result = match verify_result { + Ok((session, context)) => { + serde_json::from_str(raw_json).map(|body| Authenticated { body, session, context }) + }, + Err(e) => { + request.local_cache(|| Some(InvalidTokenUsed)); + return Outcome::Failure((Status::Unauthorized, anyhow!(e))); + }, + }; + match result { + Ok(r) => Outcome::Success(r), + Err(e) => { + // match Json<T> here to keep catchers generic + let status = match e.classify() { + Category::Data => Status::UnprocessableEntity, + _ => Status::BadRequest, + }; + Outcome::Failure((status, anyhow!(e))) + }, + } + } +} + +#[derive(Debug)] +pub(crate) struct WithBearer; + +#[rocket::async_trait] +impl crate::auth::AuthSource for WithBearer { + type ID = UserID; + type Context = ScopeSet; + async fn hawk(_r: &Request<'_>, _id: &Self::ID) -> Result<(SecretBytes<32>, Self::Context)> { + bail!("hawk signatures not allowed here") + } + async fn bearer_token( + r: &Request<'_>, + token: &OauthToken, + ) -> Result<(Self::ID, Self::Context)> { + let db = Authenticated::<(), Self>::get_conn(r).await?; + let t = db.get_access_token(&token.hash()).await?; + Ok((t.user_id, t.scope)) + } +} |