summaryrefslogtreecommitdiff
path: root/src/types/oauth.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/types/oauth.rs')
-rw-r--r--src/types/oauth.rs267
1 files changed, 267 insertions, 0 deletions
diff --git a/src/types/oauth.rs b/src/types/oauth.rs
new file mode 100644
index 0000000..222c567
--- /dev/null
+++ b/src/types/oauth.rs
@@ -0,0 +1,267 @@
+use std::{borrow::Cow, fmt::Display};
+
+use serde::{Deserialize, Serialize};
+
+#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
+#[serde(transparent)]
+pub(crate) struct Scope<'a>(pub Cow<'a, str>);
+
+impl<'a> Scope<'a> {
+ pub const fn borrowed(s: &'a str) -> Self {
+ Self(Cow::Borrowed(s))
+ }
+
+ pub fn into_owned(self) -> Scope<'static> {
+ Scope(Cow::Owned(self.0.into_owned()))
+ }
+
+ pub fn implies(&self, other: &Scope) -> bool {
+ let (a, b) = (&*self.0, &*other.0);
+ match (a.strip_prefix("https://"), b.strip_prefix("https://")) {
+ (Some(a), Some(b)) => {
+ let (a_origin, a_path) = a.split_once('/').unwrap_or((a, ""));
+ let (b_origin, b_path) = b.split_once('/').unwrap_or((b, ""));
+ if a_origin != b_origin {
+ false
+ } else {
+ let (a_path, a_frag) = match a_path.split_once('#') {
+ Some((p, f)) => (p, Some(f)),
+ None => (a_path, None),
+ };
+ let (b_path, b_frag) = match b_path.split_once('#') {
+ Some((p, f)) => (p, Some(f)),
+ None => (b_path, None),
+ };
+ if b_path
+ .strip_prefix(a_path)
+ .map_or(false, |br| br.is_empty() || br.starts_with('/'))
+ {
+ match (a_frag, b_frag) {
+ (Some(af), Some(bf)) => af == bf,
+ (Some(_), None) => false,
+ _ => true,
+ }
+ } else {
+ false
+ }
+ }
+ },
+ (None, None) => {
+ let (a, a_write) =
+ a.strip_suffix(":write").map(|s| (s, true)).unwrap_or((a, false));
+ let (b, b_write) =
+ b.strip_suffix(":write").map(|s| (s, true)).unwrap_or((b, false));
+ if b_write && !a_write {
+ false
+ } else {
+ b.strip_prefix(a).map_or(false, |br| br.is_empty() || br.starts_with(':'))
+ }
+ },
+ _ => false,
+ }
+ }
+}
+
+impl<'a> Display for Scope<'a> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ self.0.fmt(f)
+ }
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize, sqlx::Type)]
+#[serde(transparent)]
+#[sqlx(transparent)]
+pub(crate) struct ScopeSet(String);
+
+impl ScopeSet {
+ pub fn split(&self) -> impl Iterator<Item = Scope> {
+ // not using split_whitespace because the oauth spec explicitly says to split on SP
+ self.0.split(' ').filter(|s| !s.is_empty()).map(Scope::borrowed)
+ }
+
+ pub fn implies(&self, scope: &Scope) -> bool {
+ self.split().any(|a| a.implies(scope))
+ }
+
+ pub fn implies_all(&self, scopes: &ScopeSet) -> bool {
+ scopes.split().all(|b| self.implies(&b))
+ }
+
+ pub fn is_allowed_by(&self, allowed: &[Scope]) -> bool {
+ self.split().all(|scope| allowed.iter().any(|perm| perm.implies(&scope)))
+ }
+
+ pub fn remove(&self, remove: &Scope) -> ScopeSet {
+ let remaining = self.split().filter(|s| !remove.implies(s));
+ ScopeSet(remaining.map(|s| s.0).collect::<Vec<_>>().join(" "))
+ }
+}
+
+impl PartialEq for ScopeSet {
+ fn eq(&self, other: &Self) -> bool {
+ let (mut a, mut b) = (self.split().collect::<Vec<_>>(), other.split().collect::<Vec<_>>());
+ a.sort_by(|a, b| a.0.cmp(&b.0));
+ b.sort_by(|a, b| a.0.cmp(&b.0));
+ a.eq(&b)
+ }
+}
+
+impl Eq for ScopeSet {}
+
+impl Display for ScopeSet {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ self.0.fmt(f)
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::{Scope, ScopeSet};
+
+ #[test]
+ fn test_scope_implies() {
+ assert!(ScopeSet("profile:write".to_string()).implies(&Scope::borrowed("profile")));
+ assert!(ScopeSet("profile".to_string()).implies(&Scope::borrowed("profile:email")));
+ assert!(ScopeSet("profile:write".to_string()).implies(&Scope::borrowed("profile:email")));
+ assert!(
+ ScopeSet("profile:write".to_string()).implies(&Scope::borrowed("profile:email:write"))
+ );
+ assert!(
+ ScopeSet("profile:email:write".to_string()).implies(&Scope::borrowed("profile:email"))
+ );
+ assert!(ScopeSet("profile profile:email:write".to_string())
+ .implies(&Scope::borrowed("profile:email")));
+ assert!(ScopeSet("profile profile:email:write".to_string())
+ .implies(&Scope::borrowed("profile:display_name")));
+ assert!(ScopeSet("profile https://identity.mozilla.com/apps/oldsync".to_string())
+ .implies(&Scope::borrowed("profile")));
+ assert!(ScopeSet("profile https://identity.mozilla.com/apps/oldsync".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync")));
+ assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync#read")));
+ assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/bookmarks")));
+ assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/bookmarks#read")));
+ assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync#read".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/bookmarks#read")));
+ assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync#read profile".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/bookmarks#read")));
+
+ assert!(!ScopeSet("profile:email:write".to_string()).implies(&Scope::borrowed("profile")));
+ assert!(
+ !ScopeSet("profile:email:write".to_string()).implies(&Scope::borrowed("profile:write"))
+ );
+ assert!(!ScopeSet("profile:email".to_string())
+ .implies(&Scope::borrowed("profile:display_name")));
+ assert!(!ScopeSet("profilebogey".to_string()).implies(&Scope::borrowed("profile")));
+ assert!(!ScopeSet("profile:write".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync")));
+ assert!(!ScopeSet("profile profile:email:write".to_string())
+ .implies(&Scope::borrowed("profile:write")));
+ assert!(!ScopeSet("https".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync")));
+ assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsync".to_string())
+ .implies(&Scope::borrowed("profile")));
+ assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsync#read".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/bookmarks")));
+ assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsync#write".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/bookmarks#read")));
+ assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsync/bookmarks".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync")));
+ assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsync/bookmarks".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync/passwords")));
+ assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsyncer".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync")));
+ assert!(!ScopeSet("https://identity.mozilla.com/apps/oldsync".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsyncer")));
+ assert!(!ScopeSet("https://identity.mozilla.org/apps/oldsync".to_string())
+ .implies(&Scope::borrowed("https://identity.mozilla.com/apps/oldsync")));
+ }
+
+ #[test]
+ fn test_scopes_allowed_by() {
+ const ALLOWED: [Scope; 2] = [
+ Scope::borrowed("profile:write"),
+ Scope::borrowed("https://identity.mozilla.com/apps/oldsync"),
+ ];
+
+ assert!(ScopeSet("profile".to_string()).is_allowed_by(&ALLOWED));
+ assert!(ScopeSet("profile:write".to_string()).is_allowed_by(&ALLOWED));
+ assert!(ScopeSet("profile:email".to_string()).is_allowed_by(&ALLOWED));
+ assert!(ScopeSet("profile:email:write".to_string()).is_allowed_by(&ALLOWED));
+ assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync".to_string())
+ .is_allowed_by(&ALLOWED));
+ assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync#read".to_string())
+ .is_allowed_by(&ALLOWED));
+ assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync/bookmarks".to_string())
+ .is_allowed_by(&ALLOWED));
+ assert!(ScopeSet("https://identity.mozilla.com/apps/oldsync/bookmarks#read".to_string())
+ .is_allowed_by(&ALLOWED));
+ assert!(ScopeSet("profile https://identity.mozilla.com/apps/oldsync".to_string())
+ .is_allowed_by(&ALLOWED));
+
+ assert!(!ScopeSet("storage".to_string()).is_allowed_by(&ALLOWED));
+ assert!(!ScopeSet("storage:write".to_string()).is_allowed_by(&ALLOWED));
+ assert!(!ScopeSet("storage:email".to_string()).is_allowed_by(&ALLOWED));
+ assert!(!ScopeSet("storage:email:write".to_string()).is_allowed_by(&ALLOWED));
+ assert!(!ScopeSet("https://identity.mozilla.com/apps/newsync".to_string())
+ .is_allowed_by(&ALLOWED));
+ assert!(!ScopeSet("https://identity.mozilla.com/apps/newsync#read".to_string())
+ .is_allowed_by(&ALLOWED));
+ assert!(!ScopeSet("https://identity.mozilla.com/apps/newsync/bookmarks".to_string())
+ .is_allowed_by(&ALLOWED));
+ assert!(!ScopeSet("https://identity.mozilla.com/apps/newsync/bookmarks#read".to_string())
+ .is_allowed_by(&ALLOWED));
+ assert!(!ScopeSet("storage https://identity.mozilla.com/apps/newsync".to_string())
+ .is_allowed_by(&ALLOWED));
+ }
+
+ #[test]
+ fn test_scopes_remove() {
+ assert_eq!(
+ ScopeSet("profile foo".to_string()).remove(&Scope::borrowed("profile")),
+ ScopeSet("foo".to_string())
+ );
+ assert_ne!(
+ ScopeSet("profile:write foo".to_string()).remove(&Scope::borrowed("profile")),
+ ScopeSet("foo".to_string())
+ );
+ assert_eq!(
+ ScopeSet("profile:write foo".to_string()).remove(&Scope::borrowed("profile:write")),
+ ScopeSet("foo".to_string())
+ );
+ assert_eq!(
+ ScopeSet("profile:x foo".to_string()).remove(&Scope::borrowed("profile")),
+ ScopeSet("foo".to_string())
+ );
+ assert_ne!(
+ ScopeSet("profile:x:write foo".to_string()).remove(&Scope::borrowed("profile")),
+ ScopeSet("foo".to_string())
+ );
+ assert_eq!(
+ ScopeSet("profile:x:write foo".to_string()).remove(&Scope::borrowed("profile:write")),
+ ScopeSet("foo".to_string())
+ );
+
+ assert_eq!(
+ ScopeSet("https://foo/bar foo".to_string()).remove(&Scope::borrowed("https://foo/bar")),
+ ScopeSet("foo".to_string())
+ );
+ assert_eq!(
+ ScopeSet("https://foo/bar/baz foo".to_string())
+ .remove(&Scope::borrowed("https://foo/bar")),
+ ScopeSet("foo".to_string())
+ );
+ assert_eq!(
+ ScopeSet("https://foo/bar#read foo".to_string())
+ .remove(&Scope::borrowed("https://foo/bar")),
+ ScopeSet("foo".to_string())
+ );
+ assert_eq!(
+ ScopeSet("https://foo/bar/baz#read foo".to_string())
+ .remove(&Scope::borrowed("https://foo/bar")),
+ ScopeSet("foo".to_string())
+ );
+ }
+}