diff options
Diffstat (limited to 'tests/api.py')
-rw-r--r-- | tests/api.py | 252 |
1 files changed, 252 insertions, 0 deletions
diff --git a/tests/api.py b/tests/api.py new file mode 100644 index 0000000..9d2f70d --- /dev/null +++ b/tests/api.py @@ -0,0 +1,252 @@ +import asyncio +import base64 +import binascii +import http.server +import http_ece +import json +import os +import queue +import quopri +import threading +from _utils import HawkTokenAuth, APIClient, hexstr +from aiosmtpd.controller import Controller +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ec +from fxa.crypto import quick_stretch_password, derive_key, xor + +AUTH_URL = "http://localhost:8000/auth" +PROFILE_URL = "http://localhost:8000/profile" +OAUTH_URL = "http://localhost:8000/oauth" +INVITE_URL = "http://localhost:8000/_invite" +PUSH_PORT = 10264 +SMTP_PORT = 2525 + +def auth_pw(email, pw): + return derive_key(quick_stretch_password(email, pw), "authPW").hex() + +class AuthClient: + def __init__(self, /, email=None, session=None, bearer=None, props=None): + self.password = "" + self.client = APIClient(f"{AUTH_URL}/v1") + self.email = email + assert int(session is not None) + int(bearer is not None) < 2 + self.session = session + self.auth = HawkTokenAuth(session, "sessionToken", self.client) if session else None + self.bearer = bearer + self.headers = { 'authorization': f'bearer {bearer}' } if bearer else {} + self.props = props + + def post(self, url, json=None, **kwds): + return self.client.post(url, json, **kwds) + def post_a(self, url, json=None, **kwds): + kwds.setdefault('headers', {}) + kwds['headers'] |= self.headers + return self.client.post(url, json, auth=self.auth, **kwds) + + def get(self, url, **kwds): + return self.client.get(url, **kwds) + def get_a(self, url, **kwds): + kwds.setdefault('headers', {}) + kwds['headers'] |= self.headers + return self.client.get(url, auth=self.auth, **kwds) + + def delete(self, url, **kwds): + return self.client.delete(url, **kwds) + def delete_a(self, url, **kwds): + kwds.setdefault('headers', {}) + kwds['headers'] |= self.headers + return self.client.delete(url, auth=self.auth, **kwds) + + def create_account(self, email, pw, keys=None, invite=None, **kwds): + body = { + "email": email, + "authPW": hexstr(derive_key(quick_stretch_password(email, pw), "authPW")), + "style": invite, + } + params = { 'keys': str(keys).lower() } if keys is not None else {} + resp = self.client.post("/account/create", body, params=params, **kwds) + return AuthClient(email=email, session=resp['sessionToken'], props=resp) + def destroy_account(self, email, pw, **kwds): + body = { "email": email, "authPW": hexstr(derive_key(quick_stretch_password(email, pw), "authPW")) } + return self.client.post("/account/destroy", body) + + def fetch_keys(self, key_fetch_token, pw): + pw = quick_stretch_password(self.email, pw) + auth = HawkTokenAuth(key_fetch_token, "keyFetchToken", self.client) + resp = self.client.get("/account/keys", auth=auth) + bundle = binascii.unhexlify(resp["bundle"]) + keys = auth.unbundle("account/keys", bundle) + unwrap_key = derive_key(pw, "unwrapBkey") + return (keys[:32], xor(keys[32:], unwrap_key)) + + def login(self, email, pw, keys=None, **kwds): + body = { "email": email, "authPW": hexstr(derive_key(quick_stretch_password(email, pw), "authPW")) } + params = { "keys": str(keys).lower() } if keys is not None else {} + resp = self.client.post("/account/login", body, params=params) + return AuthClient(email=email, session=resp['sessionToken'], props=resp) + def destroy_session(self, **kwds): + return self.post_a("/session/destroy", kwds) + + def profile(self): + token = self.post_a("/oauth/token", { + "client_id": "5882386c6d801776", + "ttl": 60, + "grant_type": "fxa-credentials", + "access_type": "online", + "scope": "profile:write", + }) + return Profile(token['access_token']) + +class Invite: + def __init__(self, token): + self.client = APIClient(f"{INVITE_URL}/v1") + self.auth = HawkTokenAuth(token, "sessionToken", self.client) + + def post(self, url, json=None, **kwds): + return self.client.post(url, json, **kwds) + def post_a(self, url, json=None, **kwds): + return self.client.post(url, json, auth=self.auth, **kwds) + +class PasswordChange: + def __init__(self, client, token, hkdf='passwordChangeToken'): + self.client = client + self.auth = HawkTokenAuth(token, hkdf, self.client) + + def post(self, url, json=None, **kwds): + return self.client.post(url, json, **kwds) + def post_a(self, url, json=None, **kwds): + return self.client.post(url, json, auth=self.auth, **kwds) + +class AccountReset: + def __init__(self, client, token): + self.client = client + self.auth = HawkTokenAuth(token, 'accountResetToken', self.client) + + def post(self, url, json=None, **kwds): + return self.client.post(url, json, **kwds) + def post_a(self, url, json=None, **kwds): + return self.client.post(url, json, auth=self.auth, **kwds) + +class Profile: + def __init__(self, token): + self.client = APIClient(f"{PROFILE_URL}/v1") + self.token = token + + def get(self, url, **kwds): + return self.client.get(url, **kwds) + def get_a(self, url, **kwds): + kwds.setdefault('headers', {}) + kwds['headers']['authorization'] = f'bearer {self.token}' + return self.client.get(url, **kwds) + + def post(self, url, json=None, **kwds): + return self.client.post(url, json, **kwds) + def post_a(self, url, json=None, **kwds): + kwds.setdefault('headers', {}) + kwds['headers']['authorization'] = f'bearer {self.token}' + return self.client.post(url, json, **kwds) + + def delete(self, url, **kwds): + return self.client.delete(url, **kwds) + def delete_a(self, url, **kwds): + kwds.setdefault('headers', {}) + kwds['headers']['authorization'] = f'bearer {self.token}' + return self.client.delete(url, **kwds) + +class Oauth: + def __init__(self): + self.client = APIClient(f"{OAUTH_URL}/v1") + + def post(self, url, json=None, **kwds): + return self.client.post(url, json, **kwds) + +class Device: + def __init__(self, auth, name, type="desktop", commands={}, pcb=None): + self.auth = auth + dev = auth.post_a("/account/device", { + "name": name, + "type": type, + "availableCommands": commands, + } | self._mk_push(pcb)) + self.id = dev['id'] + self.props = dev + + def _mk_push(self, pcb): + if not pcb: + return {} + + self.priv = ec.generate_private_key(ec.SECP256R1, default_backend()) + self.public = self.priv.public_key().public_bytes( + encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.UncompressedPoint) + self.authkey = os.urandom(16) + return { + "pushCallback": pcb, + "pushPublicKey": base64.urlsafe_b64encode(self.public).decode('utf8'), + "pushAuthKey": base64.urlsafe_b64encode(self.authkey).decode('utf8'), + } + + def update_pcb(self, pcb): + self.props = self.auth.post_a("/account/device", { "id": self.id } | self._mk_push(pcb)) + + def decrypt(self, data): + raw = http_ece.decrypt(data, private_key=self.priv, auth_secret=self.authkey) + return json.loads(raw.decode('utf8')) + +class PushServer: + def __init__(self): + q = self.q = queue.Queue() + + class Handler(http.server.BaseHTTPRequestHandler): + def do_POST(self): + if self.path.startswith("/err/"): + self.send_response(410) + self.end_headers() + else: + self.send_response(200) + self.end_headers() + + q.put((self.path, self.headers, self.rfile.read(int(self.headers['content-length'])))) + + server = self.server = http.server.ThreadingHTTPServer(("localhost", PUSH_PORT), Handler) + threading.Thread(target=server.serve_forever).start() + + def wait(self, timeout=2): + return self.q.get(timeout=timeout) + def done(self, timeout=2): + try: + self.q.get(timeout=timeout) + return False + except queue.Empty: + return True + + def good(self, id): + return f"http://localhost:{PUSH_PORT}/{id}" + def bad(self, id): + return f"http://localhost:{PUSH_PORT}/err/{id}" + +class MailServer: + def __init__(self): + q = self.q = queue.Queue() + + class Handler: + async def handle_RCPT(self, server, session, envelope, address, rcpt_options): + envelope.rcpt_tos.append(address) + return '250 OK' + + async def handle_DATA(self, server, session, envelope): + headers, body = envelope.content.decode('utf8').split("\r\n\r\n", maxsplit=1) + if "Content-Transfer-Encoding: quoted-printable" in headers: + body = quopri.decodestring(body).decode('utf8') + q.put((envelope.rcpt_tos, body)) + return '250 Message accepted for delivery' + + self.controller = Controller(Handler(), hostname="localhost", port=SMTP_PORT) + self.controller.start() + + def stop(self): + self.controller.stop() + + def wait(self, timeout=2): + return self.q.get(timeout=timeout) |