diff options
Diffstat (limited to 'src/types')
-rw-r--r-- | src/types/oauth.rs | 267 |
1 files changed, 267 insertions, 0 deletions
diff --git a/src/types/oauth.rs b/src/types/oauth.rs new file mode 100644 index 0000000..222c567 --- /dev/null +++ b/src/types/oauth.rs @@ -0,0 +1,267 @@ +use std::{borrow::Cow, fmt::Display}; + +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(transparent)] +pub(crate) struct Scope<'a>(pub Cow<'a, str>); + +impl<'a> Scope<'a> { + pub const fn borrowed(s: &'a str) -> Self { + Self(Cow::Borrowed(s)) + } + + pub fn into_owned(self) -> Scope<'static> { + Scope(Cow::Owned(self.0.into_owned())) + } + + pub fn implies(&self, other: &Scope) -> bool { + let (a, b) = (&*self.0, &*other.0); + match (a.strip_prefix("https://"), b.strip_prefix("https://")) { + (Some(a), Some(b)) => { + let (a_origin, a_path) = a.split_once('/').unwrap_or((a, "")); + let (b_origin, b_path) = b.split_once('/').unwrap_or((b, "")); + if a_origin != b_origin { + false + } else { + let (a_path, a_frag) = match a_path.split_once('#') { + Some((p, f)) => (p, Some(f)), + None => (a_path, None), + }; + let (b_path, b_frag) = match b_path.split_once('#') { + Some((p, f)) => (p, Some(f)), + None => (b_path, None), + }; + if b_path + .strip_prefix(a_path) + .map_or(false, |br| br.is_empty() || br.starts_with('/')) + { + match (a_frag, b_frag) { + (Some(af), Some(bf)) => af == bf, + (Some(_), None) => false, + _ => true, + } + } else { + false + } + } + }, + (None, None) => { + let (a, a_write) = + a.strip_suffix(":write").map(|s| (s, true)).unwrap_or((a, false)); + let (b, b_write) = + b.strip_suffix(":write").map(|s| (s, true)).unwrap_or((b, false)); + if b_write && !a_write { + false + } else { + b.strip_prefix(a).map_or(false, |br| br.is_empty() || br.starts_with(':')) + } + }, + _ => false, + } + } +} + +impl<'a> Display for Scope<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +#[derive(Clone, Debug, Serialize, Deserialize, sqlx::Type)] +#[serde(transparent)] +#[sqlx(transparent)] +pub(crate) struct ScopeSet(String); + +impl ScopeSet { + pub fn split(&self) -> impl Iterator<Item = Scope> { + // not using split_whitespace because the oauth spec explicitly says to split on SP + self.0.split(' ').filter(|s| !s.is_empty()).map(Scope::borrowed) + } + + pub fn implies(&self, scope: &Scope) -> bool { + self.split().any(|a| a.implies(scope)) + } + + pub fn implies_all(&self, scopes: &ScopeSet) -> bool { + scopes.split().all(|b| self.implies(&b)) + } + + pub fn is_allowed_by(&self, allowed: &[Scope]) -> bool { + self.split().all(|scope| allowed.iter().any(|perm| perm.implies(&scope))) + } + + pub fn remove(&self, remove: &Scope) -> ScopeSet { + let remaining = self.split().filter(|s| !remove.implies(s)); + ScopeSet(remaining.map(|s| s.0).collect::<Vec<_>>().join(" ")) + } +} + +impl PartialEq for ScopeSet { + fn eq(&self, other: &Self) -> bool { + let (mut a, mut b) = (self.split().collect::<Vec<_>>(), other.split().collect::<Vec<_>>()); + a.sort_by(|a, b| a.0.cmp(&b.0)); + b.sort_by(|a, b| a.0.cmp(&b.0)); + a.eq(&b) + } +} + +impl Eq for ScopeSet {} + +impl Display for ScopeSet { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +#[cfg(test)] +mod test { + use super::{Scope, ScopeSet}; + + #[test] + fn test_scope_implies() { + assert!(ScopeSet("profile:write".to_string()).implies(&Scope::borrowed("profile"))); + assert!(ScopeSet("profile".to_string()).implies(&Scope::borrowed("profile:email"))); + assert!(ScopeSet("profile:write".to_string()).implies(&Scope::borrowed("profile:email"))); + assert!( + ScopeSet("profile:write".to_string()).implies(&Scope::borrowed("profile:email:write")) + ); + assert!( + ScopeSet("profile:email:write".to_string()).implies(&Scope::borrowed("profile:email")) + ); + assert!(ScopeSet("profile profile:email:write".to_string()) + .implies(&Scope::borrowed("profile:email"))); + assert!(ScopeSet("profile profile:email:write".to_string()) + .implies(&Scope::borrowed("profile:display_name"))); + assert!(ScopeSet("profile https://identity.mozilla.com/apps/oldsync".to_string()) + .implies(&Scope::borrowed("profile"))); + assert!(ScopeSet("profile https://identity.mozilla.com/apps/oldsync".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync"))); + assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync#read"))); + assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/bookmarks"))); + assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/bookmarks#read"))); + assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync#read".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/bookmarks#read"))); + assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync#read profile".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/bookmarks#read"))); + + assert!(!ScopeSet("profile:email:write".to_string()).implies(&Scope::borrowed("profile"))); + assert!( + !ScopeSet("profile:email:write".to_string()).implies(&Scope::borrowed("profile:write")) + ); + assert!(!ScopeSet("profile:email".to_string()) + .implies(&Scope::borrowed("profile:display_name"))); + assert!(!ScopeSet("profilebogey".to_string()).implies(&Scope::borrowed("profile"))); + assert!(!ScopeSet("profile:write".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync"))); + assert!(!ScopeSet("profile profile:email:write".to_string()) + .implies(&Scope::borrowed("profile:write"))); + assert!(!ScopeSet("https".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync"))); + assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsync".to_string()) + .implies(&Scope::borrowed("profile"))); + assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsync#read".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/bookmarks"))); + assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsync#write".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/bookmarks#read"))); + assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsync/bookmarks".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync"))); + assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsync/bookmarks".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/passwords"))); + assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsyncer".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync"))); + assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsync".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsyncer"))); + assert!(!ScopeSet("https://identity.mozilla.org/apps/oldsync".to_string()) + .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync"))); + } + + #[test] + fn test_scopes_allowed_by() { + const ALLOWED: [Scope; 2] = [ + Scope::borrowed("profile:write"), + Scope::borrowed("https://identity.mozilla.com/apps/oldsync"), + ]; + + assert!(ScopeSet("profile".to_string()).is_allowed_by(&ALLOWED)); + assert!(ScopeSet("profile:write".to_string()).is_allowed_by(&ALLOWED)); + assert!(ScopeSet("profile:email".to_string()).is_allowed_by(&ALLOWED)); + assert!(ScopeSet("profile:email:write".to_string()).is_allowed_by(&ALLOWED)); + assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync".to_string()) + .is_allowed_by(&ALLOWED)); + assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync#read".to_string()) + .is_allowed_by(&ALLOWED)); + assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync/bookmarks".to_string()) + .is_allowed_by(&ALLOWED)); + assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync/bookmarks#read".to_string()) + .is_allowed_by(&ALLOWED)); + assert!(ScopeSet("profile https://identity.mozilla.com/apps/oldsync".to_string()) + .is_allowed_by(&ALLOWED)); + + assert!(!ScopeSet("storage".to_string()).is_allowed_by(&ALLOWED)); + assert!(!ScopeSet("storage:write".to_string()).is_allowed_by(&ALLOWED)); + assert!(!ScopeSet("storage:email".to_string()).is_allowed_by(&ALLOWED)); + assert!(!ScopeSet("storage:email:write".to_string()).is_allowed_by(&ALLOWED)); + assert!(!ScopeSet("https://identity.mozilla.com/apps/newsync".to_string()) + .is_allowed_by(&ALLOWED)); + assert!(!ScopeSet("https://identity.mozilla.com/apps/newsync#read".to_string()) + .is_allowed_by(&ALLOWED)); + assert!(!ScopeSet("https://identity.mozilla.com/apps/newsync/bookmarks".to_string()) + .is_allowed_by(&ALLOWED)); + assert!(!ScopeSet("https://identity.mozilla.com/apps/newsync/bookmarks#read".to_string()) + .is_allowed_by(&ALLOWED)); + assert!(!ScopeSet("storage https://identity.mozilla.com/apps/newsync".to_string()) + .is_allowed_by(&ALLOWED)); + } + + #[test] + fn test_scopes_remove() { + assert_eq!( + ScopeSet("profile foo".to_string()).remove(&Scope::borrowed("profile")), + ScopeSet("foo".to_string()) + ); + assert_ne!( + ScopeSet("profile:write foo".to_string()).remove(&Scope::borrowed("profile")), + ScopeSet("foo".to_string()) + ); + assert_eq!( + ScopeSet("profile:write foo".to_string()).remove(&Scope::borrowed("profile:write")), + ScopeSet("foo".to_string()) + ); + assert_eq!( + ScopeSet("profile:x foo".to_string()).remove(&Scope::borrowed("profile")), + ScopeSet("foo".to_string()) + ); + assert_ne!( + ScopeSet("profile:x:write foo".to_string()).remove(&Scope::borrowed("profile")), + ScopeSet("foo".to_string()) + ); + assert_eq!( + ScopeSet("profile:x:write foo".to_string()).remove(&Scope::borrowed("profile:write")), + ScopeSet("foo".to_string()) + ); + + assert_eq!( + ScopeSet("https://foo/bar foo".to_string()).remove(&Scope::borrowed("https://foo/bar")), + ScopeSet("foo".to_string()) + ); + assert_eq!( + ScopeSet("https://foo/bar/baz foo".to_string()) + .remove(&Scope::borrowed("https://foo/bar")), + ScopeSet("foo".to_string()) + ); + assert_eq!( + ScopeSet("https://foo/bar#read foo".to_string()) + .remove(&Scope::borrowed("https://foo/bar")), + ScopeSet("foo".to_string()) + ); + assert_eq!( + ScopeSet("https://foo/bar/baz#read foo".to_string()) + .remove(&Scope::borrowed("https://foo/bar")), + ScopeSet("foo".to_string()) + ); + } +} |