summaryrefslogtreecommitdiff
path: root/tests/api.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/api.py')
-rw-r--r--tests/api.py252
1 files changed, 252 insertions, 0 deletions
diff --git a/tests/api.py b/tests/api.py
new file mode 100644
index 0000000..9d2f70d
--- /dev/null
+++ b/tests/api.py
@@ -0,0 +1,252 @@
+import asyncio
+import base64
+import binascii
+import http.server
+import http_ece
+import json
+import os
+import queue
+import quopri
+import threading
+from _utils import HawkTokenAuth, APIClient, hexstr
+from aiosmtpd.controller import Controller
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import serialization
+from cryptography.hazmat.primitives.asymmetric import ec
+from fxa.crypto import quick_stretch_password, derive_key, xor
+
+AUTH_URL = "http://localhost:8000/auth"
+PROFILE_URL = "http://localhost:8000/profile"
+OAUTH_URL = "http://localhost:8000/oauth"
+INVITE_URL = "http://localhost:8000/_invite"
+PUSH_PORT = 10264
+SMTP_PORT = 2525
+
+def auth_pw(email, pw):
+ return derive_key(quick_stretch_password(email, pw), "authPW").hex()
+
+class AuthClient:
+ def __init__(self, /, email=None, session=None, bearer=None, props=None):
+ self.password = ""
+ self.client = APIClient(f"{AUTH_URL}/v1")
+ self.email = email
+ assert int(session is not None) + int(bearer is not None) < 2
+ self.session = session
+ self.auth = HawkTokenAuth(session, "sessionToken", self.client) if session else None
+ self.bearer = bearer
+ self.headers = { 'authorization': f'bearer {bearer}' } if bearer else {}
+ self.props = props
+
+ def post(self, url, json=None, **kwds):
+ return self.client.post(url, json, **kwds)
+ def post_a(self, url, json=None, **kwds):
+ kwds.setdefault('headers', {})
+ kwds['headers'] |= self.headers
+ return self.client.post(url, json, auth=self.auth, **kwds)
+
+ def get(self, url, **kwds):
+ return self.client.get(url, **kwds)
+ def get_a(self, url, **kwds):
+ kwds.setdefault('headers', {})
+ kwds['headers'] |= self.headers
+ return self.client.get(url, auth=self.auth, **kwds)
+
+ def delete(self, url, **kwds):
+ return self.client.delete(url, **kwds)
+ def delete_a(self, url, **kwds):
+ kwds.setdefault('headers', {})
+ kwds['headers'] |= self.headers
+ return self.client.delete(url, auth=self.auth, **kwds)
+
+ def create_account(self, email, pw, keys=None, invite=None, **kwds):
+ body = {
+ "email": email,
+ "authPW": hexstr(derive_key(quick_stretch_password(email, pw), "authPW")),
+ "style": invite,
+ }
+ params = { 'keys': str(keys).lower() } if keys is not None else {}
+ resp = self.client.post("/account/create", body, params=params, **kwds)
+ return AuthClient(email=email, session=resp['sessionToken'], props=resp)
+ def destroy_account(self, email, pw, **kwds):
+ body = { "email": email, "authPW": hexstr(derive_key(quick_stretch_password(email, pw), "authPW")) }
+ return self.client.post("/account/destroy", body)
+
+ def fetch_keys(self, key_fetch_token, pw):
+ pw = quick_stretch_password(self.email, pw)
+ auth = HawkTokenAuth(key_fetch_token, "keyFetchToken", self.client)
+ resp = self.client.get("/account/keys", auth=auth)
+ bundle = binascii.unhexlify(resp["bundle"])
+ keys = auth.unbundle("account/keys", bundle)
+ unwrap_key = derive_key(pw, "unwrapBkey")
+ return (keys[:32], xor(keys[32:], unwrap_key))
+
+ def login(self, email, pw, keys=None, **kwds):
+ body = { "email": email, "authPW": hexstr(derive_key(quick_stretch_password(email, pw), "authPW")) }
+ params = { "keys": str(keys).lower() } if keys is not None else {}
+ resp = self.client.post("/account/login", body, params=params)
+ return AuthClient(email=email, session=resp['sessionToken'], props=resp)
+ def destroy_session(self, **kwds):
+ return self.post_a("/session/destroy", kwds)
+
+ def profile(self):
+ token = self.post_a("/oauth/token", {
+ "client_id": "5882386c6d801776",
+ "ttl": 60,
+ "grant_type": "fxa-credentials",
+ "access_type": "online",
+ "scope": "profile:write",
+ })
+ return Profile(token['access_token'])
+
+class Invite:
+ def __init__(self, token):
+ self.client = APIClient(f"{INVITE_URL}/v1")
+ self.auth = HawkTokenAuth(token, "sessionToken", self.client)
+
+ def post(self, url, json=None, **kwds):
+ return self.client.post(url, json, **kwds)
+ def post_a(self, url, json=None, **kwds):
+ return self.client.post(url, json, auth=self.auth, **kwds)
+
+class PasswordChange:
+ def __init__(self, client, token, hkdf='passwordChangeToken'):
+ self.client = client
+ self.auth = HawkTokenAuth(token, hkdf, self.client)
+
+ def post(self, url, json=None, **kwds):
+ return self.client.post(url, json, **kwds)
+ def post_a(self, url, json=None, **kwds):
+ return self.client.post(url, json, auth=self.auth, **kwds)
+
+class AccountReset:
+ def __init__(self, client, token):
+ self.client = client
+ self.auth = HawkTokenAuth(token, 'accountResetToken', self.client)
+
+ def post(self, url, json=None, **kwds):
+ return self.client.post(url, json, **kwds)
+ def post_a(self, url, json=None, **kwds):
+ return self.client.post(url, json, auth=self.auth, **kwds)
+
+class Profile:
+ def __init__(self, token):
+ self.client = APIClient(f"{PROFILE_URL}/v1")
+ self.token = token
+
+ def get(self, url, **kwds):
+ return self.client.get(url, **kwds)
+ def get_a(self, url, **kwds):
+ kwds.setdefault('headers', {})
+ kwds['headers']['authorization'] = f'bearer {self.token}'
+ return self.client.get(url, **kwds)
+
+ def post(self, url, json=None, **kwds):
+ return self.client.post(url, json, **kwds)
+ def post_a(self, url, json=None, **kwds):
+ kwds.setdefault('headers', {})
+ kwds['headers']['authorization'] = f'bearer {self.token}'
+ return self.client.post(url, json, **kwds)
+
+ def delete(self, url, **kwds):
+ return self.client.delete(url, **kwds)
+ def delete_a(self, url, **kwds):
+ kwds.setdefault('headers', {})
+ kwds['headers']['authorization'] = f'bearer {self.token}'
+ return self.client.delete(url, **kwds)
+
+class Oauth:
+ def __init__(self):
+ self.client = APIClient(f"{OAUTH_URL}/v1")
+
+ def post(self, url, json=None, **kwds):
+ return self.client.post(url, json, **kwds)
+
+class Device:
+ def __init__(self, auth, name, type="desktop", commands={}, pcb=None):
+ self.auth = auth
+ dev = auth.post_a("/account/device", {
+ "name": name,
+ "type": type,
+ "availableCommands": commands,
+ } | self._mk_push(pcb))
+ self.id = dev['id']
+ self.props = dev
+
+ def _mk_push(self, pcb):
+ if not pcb:
+ return {}
+
+ self.priv = ec.generate_private_key(ec.SECP256R1, default_backend())
+ self.public = self.priv.public_key().public_bytes(
+ encoding=serialization.Encoding.X962,
+ format=serialization.PublicFormat.UncompressedPoint)
+ self.authkey = os.urandom(16)
+ return {
+ "pushCallback": pcb,
+ "pushPublicKey": base64.urlsafe_b64encode(self.public).decode('utf8'),
+ "pushAuthKey": base64.urlsafe_b64encode(self.authkey).decode('utf8'),
+ }
+
+ def update_pcb(self, pcb):
+ self.props = self.auth.post_a("/account/device", { "id": self.id } | self._mk_push(pcb))
+
+ def decrypt(self, data):
+ raw = http_ece.decrypt(data, private_key=self.priv, auth_secret=self.authkey)
+ return json.loads(raw.decode('utf8'))
+
+class PushServer:
+ def __init__(self):
+ q = self.q = queue.Queue()
+
+ class Handler(http.server.BaseHTTPRequestHandler):
+ def do_POST(self):
+ if self.path.startswith("/err/"):
+ self.send_response(410)
+ self.end_headers()
+ else:
+ self.send_response(200)
+ self.end_headers()
+
+ q.put((self.path, self.headers, self.rfile.read(int(self.headers['content-length']))))
+
+ server = self.server = http.server.ThreadingHTTPServer(("localhost", PUSH_PORT), Handler)
+ threading.Thread(target=server.serve_forever).start()
+
+ def wait(self, timeout=2):
+ return self.q.get(timeout=timeout)
+ def done(self, timeout=2):
+ try:
+ self.q.get(timeout=timeout)
+ return False
+ except queue.Empty:
+ return True
+
+ def good(self, id):
+ return f"http://localhost:{PUSH_PORT}/{id}"
+ def bad(self, id):
+ return f"http://localhost:{PUSH_PORT}/err/{id}"
+
+class MailServer:
+ def __init__(self):
+ q = self.q = queue.Queue()
+
+ class Handler:
+ async def handle_RCPT(self, server, session, envelope, address, rcpt_options):
+ envelope.rcpt_tos.append(address)
+ return '250 OK'
+
+ async def handle_DATA(self, server, session, envelope):
+ headers, body = envelope.content.decode('utf8').split("\r\n\r\n", maxsplit=1)
+ if "Content-Transfer-Encoding: quoted-printable" in headers:
+ body = quopri.decodestring(body).decode('utf8')
+ q.put((envelope.rcpt_tos, body))
+ return '250 Message accepted for delivery'
+
+ self.controller = Controller(Handler(), hostname="localhost", port=SMTP_PORT)
+ self.controller.start()
+
+ def stop(self):
+ self.controller.stop()
+
+ def wait(self, timeout=2):
+ return self.q.get(timeout=timeout)