summaryrefslogtreecommitdiff
path: root/src/auth.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/auth.rs')
-rw-r--r--src/auth.rs241
1 files changed, 241 insertions, 0 deletions
diff --git a/src/auth.rs b/src/auth.rs
new file mode 100644
index 0000000..f56c5e2
--- /dev/null
+++ b/src/auth.rs
@@ -0,0 +1,241 @@
+use std::str::FromStr;
+use std::time::Duration;
+
+use anyhow::Result;
+use hawk::{DigestAlgorithm, Header, Key, PayloadHasher, RequestBuilder};
+use rocket::data::{self, FromData, ToByteUnit};
+use rocket::http::Status;
+use rocket::outcome::{try_outcome, Outcome};
+use rocket::request::{local_cache, FromRequest, Request};
+use rocket::{request, Data, Ignite, Phase, Rocket, Sentinel};
+use serde::Deserialize;
+use serde_json::error::Category;
+
+use crate::crypto::SecretBytes;
+use crate::db::DbConn;
+use crate::types::oauth::ScopeSet;
+use crate::types::{OauthToken, UserID};
+use crate::Config;
+
+#[rocket::async_trait]
+pub(crate) trait AuthSource {
+ type ID: FromStr + Send + Sync + Clone;
+ type Context: Send + Sync;
+ async fn hawk(r: &Request<'_>, id: &Self::ID) -> Result<(SecretBytes<32>, Self::Context)>;
+ async fn bearer_token(r: &Request<'_>, id: &OauthToken) -> Result<(Self::ID, Self::Context)>;
+}
+
+// marker trait and type to communicate that authentication has failed with invalid
+// tokens used. this is needed to properly translate these error for the profile api.
+pub(crate) trait AuthenticatedRequest {
+ fn invalid_token_used(&self) -> bool;
+}
+
+struct InvalidTokenUsed;
+
+impl<'r> AuthenticatedRequest for Request<'r> {
+ fn invalid_token_used(&self) -> bool {
+ self.local_cache(|| None as Option<InvalidTokenUsed>).is_some()
+ }
+}
+
+#[derive(Debug)]
+pub(crate) struct Authenticated<T, Src: AuthSource> {
+ pub body: T,
+ pub session: Src::ID,
+ pub context: Src::Context,
+}
+
+enum AuthKind<'a> {
+ Hawk { header: Header },
+ Token { token: &'a str },
+}
+
+fn drop_auth_prefix<'a>(s: &'a str, prefix: &str) -> Option<&'a str> {
+ if prefix.len() <= s.len() && s[..prefix.len()].eq_ignore_ascii_case(prefix) {
+ Some(&s[prefix.len()..])
+ } else {
+ None
+ }
+}
+
+impl<T, S: AuthSource> Sentinel for Authenticated<T, S> {
+ fn abort(rocket: &Rocket<Ignite>) -> bool {
+ // NOTE data sentinels are broken in rocket 0.5-rc2
+ Self::try_get_state(rocket).is_none() || <&DbConn as Sentinel>::abort(rocket)
+ }
+}
+
+impl<T, Src: AuthSource> Authenticated<T, Src> {
+ fn try_get_state<S: Phase>(r: &Rocket<S>) -> Option<&Config> {
+ r.state::<Config>()
+ }
+
+ fn state<S: Phase>(r: &Rocket<S>) -> &Config {
+ Self::try_get_state(r).unwrap()
+ }
+
+ async fn parse_auth<'a>(
+ request: &'a Request<'_>,
+ ) -> Outcome<AuthKind<'a>, (Status, anyhow::Error), ()> {
+ let auth = match request.headers().get("authorization").take(2).enumerate().last() {
+ Some((0, h)) => h,
+ Some((_, _)) => {
+ return Outcome::Failure((
+ Status::BadRequest,
+ anyhow!("multiple authorization headers present"),
+ ))
+ },
+ None => return Outcome::Forward(()),
+ };
+ if let Some(hawk) = drop_auth_prefix(auth, "hawk ") {
+ match Header::from_str(hawk) {
+ Ok(header) => Outcome::Success(AuthKind::Hawk { header }),
+ Err(e) => Outcome::Failure((
+ Status::Unauthorized,
+ anyhow!(e).context("malformed hawk header"),
+ )),
+ }
+ } else if let Some(token) = drop_auth_prefix(auth, "bearer ") {
+ Outcome::Success(AuthKind::Token { token })
+ } else {
+ Outcome::Forward(())
+ }
+ }
+
+ pub async fn get_conn<'r>(req: &'r Request<'_>) -> Result<&'r DbConn> {
+ match <&'r DbConn as FromRequest<'r>>::from_request(req).await {
+ Outcome::Success(db) => Ok(db),
+ Outcome::Failure((_, e)) => Err(e.context("get db connection")),
+ _ => Err(anyhow!("could not get db connection")),
+ }
+ }
+
+ async fn verify_hawk(
+ request: &Request<'_>,
+ hawk: Header,
+ data: Option<&str>,
+ ) -> Result<(Src::ID, Src::Context)> {
+ let cfg = Self::state(request.rocket());
+ let url = format!("{}{}", cfg.location, request.uri());
+ let url = url::Url::parse(&url).unwrap();
+ let hash = data
+ .map(|d| PayloadHasher::hash("application/json", DigestAlgorithm::Sha256, d))
+ .transpose()?;
+ let hawk_req = RequestBuilder::from_url(request.method().as_str(), &url)?;
+ let hawk_req = match hash.as_ref() {
+ Some(h) => hawk_req.hash(Some(h.as_ref())).request(),
+ _ => hawk_req.request(),
+ };
+ let id: Src::ID =
+ match hawk.id.clone().ok_or_else(|| anyhow!("missing hawk key id"))?.parse() {
+ Ok(id) => id,
+ Err(_) => bail!("malformed hawk key id"),
+ };
+ let (key, context) = Src::hawk(request, &id).await?;
+ let key = Key::new(&key.0, DigestAlgorithm::Sha256)?;
+ // large skew was taken from fxa-auth-server, large clock skews seem to happen
+ if !hawk_req.validate_header(&hawk, &key, Duration::from_secs(20 * 365 * 86400)) {
+ bail!("bad hawk signature");
+ }
+ Ok((id, context))
+ }
+
+ async fn verify_bearer_token(
+ request: &Request<'_>,
+ token: &str,
+ ) -> Result<(Src::ID, Src::Context)> {
+ let token = match token.parse() {
+ Ok(token) => token,
+ Err(_) => bail!("malformed oauth token"),
+ };
+ Src::bearer_token(request, &token).await
+ }
+}
+
+#[rocket::async_trait]
+impl<'r, Src: AuthSource> FromRequest<'r> for Authenticated<(), Src> {
+ type Error = anyhow::Error;
+
+ async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
+ let auth = try_outcome!(Self::parse_auth(request).await);
+ let result = match auth {
+ AuthKind::Hawk { header } => Self::verify_hawk(request, header, None).await,
+ AuthKind::Token { token } => Self::verify_bearer_token(request, token).await,
+ };
+ match result {
+ Ok((session, context)) => {
+ Outcome::Success(Authenticated { body: (), session, context })
+ },
+ Err(e) => {
+ request.local_cache(|| Some(InvalidTokenUsed));
+ Outcome::Failure((Status::Unauthorized, anyhow!(e)))
+ },
+ }
+ }
+}
+
+#[rocket::async_trait]
+impl<'r, T: Deserialize<'r>, Src: AuthSource> FromData<'r> for Authenticated<T, Src> {
+ type Error = anyhow::Error;
+
+ async fn from_data(request: &'r Request<'_>, data: Data<'r>) -> data::Outcome<'r, Self> {
+ let auth = try_outcome_data!(data, Self::parse_auth(request).await);
+ let limit =
+ request.rocket().config().limits.get("json").unwrap_or_else(|| 1u32.mebibytes());
+ let raw_json = match data.open(limit).into_string().await {
+ Ok(r) if r.is_complete() => local_cache!(request, r.into_inner()),
+ Ok(_) => {
+ return data::Outcome::Failure((
+ Status::PayloadTooLarge,
+ anyhow!("request too large"),
+ ))
+ },
+ Err(e) => return data::Outcome::Failure((Status::InternalServerError, e.into())),
+ };
+ let verify_result = match auth {
+ AuthKind::Hawk { header } => Self::verify_hawk(request, header, Some(raw_json)).await,
+ AuthKind::Token { token } => Self::verify_bearer_token(request, token).await,
+ };
+ let result = match verify_result {
+ Ok((session, context)) => {
+ serde_json::from_str(raw_json).map(|body| Authenticated { body, session, context })
+ },
+ Err(e) => {
+ request.local_cache(|| Some(InvalidTokenUsed));
+ return Outcome::Failure((Status::Unauthorized, anyhow!(e)));
+ },
+ };
+ match result {
+ Ok(r) => Outcome::Success(r),
+ Err(e) => {
+ // match Json<T> here to keep catchers generic
+ let status = match e.classify() {
+ Category::Data => Status::UnprocessableEntity,
+ _ => Status::BadRequest,
+ };
+ Outcome::Failure((status, anyhow!(e)))
+ },
+ }
+ }
+}
+
+#[derive(Debug)]
+pub(crate) struct WithBearer;
+
+#[rocket::async_trait]
+impl crate::auth::AuthSource for WithBearer {
+ type ID = UserID;
+ type Context = ScopeSet;
+ async fn hawk(_r: &Request<'_>, _id: &Self::ID) -> Result<(SecretBytes<32>, Self::Context)> {
+ bail!("hawk signatures not allowed here")
+ }
+ async fn bearer_token(
+ r: &Request<'_>,
+ token: &OauthToken,
+ ) -> Result<(Self::ID, Self::Context)> {
+ let db = Authenticated::<(), Self>::get_conn(r).await?;
+ let t = db.get_access_token(&token.hash()).await?;
+ Ok((t.user_id, t.scope))
+ }
+}