From 9df2b5d69002c9ce6784d29f86ad9b7c8b992ae8 Mon Sep 17 00:00:00 2001 From: Leander Karp <karp@cl.uni-heidelberg.de> Date: Wed, 7 Aug 2024 16:18:00 +0200 Subject: [PATCH] Add missing type declarations --- portal/__init__.py | 4 ++-- portal/api/agenda.py | 2 +- portal/api/auth.py | 34 ++++++++++++++++++++++------------ portal/api/rest.py | 6 +++--- portal/api/version.py | 3 ++- portal/main.py | 7 ++++--- portal/model/Retrievable.py | 3 ++- portal/problem_details.py | 8 +++++--- 8 files changed, 41 insertions(+), 26 deletions(-) diff --git a/portal/__init__.py b/portal/__init__.py index 642054e..04d8900 100644 --- a/portal/__init__.py +++ b/portal/__init__.py @@ -1,7 +1,7 @@ # Taken and adapted from the flaskr tutorial. Database logic adapted from Flask SQLAlchemy Documentation # (https://flask-sqlalchemy.palletsprojects.com/en/2.x/contexts/). See LICENSE-3RD-PARTY.md for details. import secrets -from typing import Optional +from typing import Any, Optional from pathlib import Path from flask import Flask @@ -9,7 +9,7 @@ from flask_sqlalchemy import SQLAlchemy from .db_config import get_database_uri, get_uri_for_sqlite -db = SQLAlchemy() +db: Any = SQLAlchemy() from .model import * diff --git a/portal/api/agenda.py b/portal/api/agenda.py index ea5ac56..d7717f6 100644 --- a/portal/api/agenda.py +++ b/portal/api/agenda.py @@ -7,7 +7,7 @@ from ..model import Key @bp.get("/agendas/-/items") @require_valid(Key) -def the_agenda_items(): +def the_agenda_items() -> list[dict[str, str]]: tasks = sorted( TaigaUserStoryProvider.get_the_one().fetch_user_stories(), key=lambda task: task["kanban_order"], diff --git a/portal/api/auth.py b/portal/api/auth.py index acf7b0d..eae75b8 100644 --- a/portal/api/auth.py +++ b/portal/api/auth.py @@ -1,19 +1,25 @@ # Login code adapted from the flaskr tutorial (see LICENSE-3RD-PARTY.md for details) # See https://flask.palletsprojects.com/en/2.0.x/tutorial/views/ import functools +from typing import Any, Callable -from flask import request, g +from flask import g, request from portal import db +from ..model import Key, User +from ..problem_details import ProblemResponse, not_found, unauthorized from .blueprint import bp -from ..model import User, Key -from ..problem_details import not_found, unauthorized @bp.before_app_request -def set_user_if_valid(): - if not request.authorization or request.authorization.type != "basic": +def set_user_if_valid() -> None: + if ( + not request.authorization + or request.authorization.username is None + or request.authorization.password is None + or request.authorization.type != "basic" + ): g.user = None return @@ -23,15 +29,19 @@ def set_user_if_valid(): @bp.before_app_request -def set_key_if_valid(): - if not request.authorization or request.authorization.type != "bearer": +def set_key_if_valid() -> None: + if ( + not request.authorization + or request.authorization.token is None + or request.authorization.type != "bearer" + ): g.key = None return g.key = Key.get_by_secret_unless_expired(request.authorization.token) -def require_valid(credential_type): +def require_valid(credential_type: type[User | Key]) -> Callable: """ Requires that valid credentials are provided before allowing access to a route. If no valid credentials are provided, abort and send a 401. @@ -41,9 +51,9 @@ def require_valid(credential_type): - `@require_valid(Key)` to require a valid key """ - def decorator(route): + def decorator(route: Callable) -> Callable: @functools.wraps(route) - def protected_route(**kwargs): + def protected_route(**kwargs: Any) -> ProblemResponse: # Determine which credentials to use credentials = {User: g.user, Key: g.key}.get(credential_type) @@ -60,14 +70,14 @@ def require_valid(credential_type): @bp.post("/keys") @require_valid(User) -def post_key(): +def post_key() -> tuple[dict, int]: key, secret = Key.generate(g.user) return key.to_dict() | {"secret": secret}, 201 @bp.delete("/keys/<string:uuid>") @require_valid(Key) -def delete_key(uuid: str): +def delete_key(uuid: str) -> tuple[str, int] | ProblemResponse: key_to_delete = Key.get_only(uuid) if not key_to_delete: diff --git a/portal/api/rest.py b/portal/api/rest.py index dc763fb..6474513 100644 --- a/portal/api/rest.py +++ b/portal/api/rest.py @@ -1,11 +1,11 @@ from flask import render_template +from ..problem_details import ProblemResponse, not_found from .blueprint import bp -from ..problem_details import not_found @bp.get("/<path:_>") -def fallback_not_found(_): +def fallback_not_found(_: object) -> ProblemResponse: """ If you can't find a matching API endpoint, fallback to returning a 404. @@ -16,5 +16,5 @@ def fallback_not_found(_): @bp.get("/") -def index(): +def index() -> str: return render_template("redocly.html") diff --git a/portal/api/version.py b/portal/api/version.py index 8674f68..b12b527 100644 --- a/portal/api/version.py +++ b/portal/api/version.py @@ -1,12 +1,13 @@ import json from datetime import datetime from pathlib import Path +from typing import Any from .blueprint import bp @bp.get("/version") -def get_version(): +def get_version() -> dict[str, Any]: portal_version_path = Path("./PORTAL_VERSION") if portal_version_path.is_file(): with open("PORTAL_VERSION") as f: diff --git a/portal/main.py b/portal/main.py index 327b6d8..bbf1a38 100644 --- a/portal/main.py +++ b/portal/main.py @@ -1,11 +1,12 @@ -from flask import Blueprint, render_template, redirect, url_for +from flask import Blueprint, redirect, render_template, url_for +from werkzeug import Response bp = Blueprint("main", __name__) @bp.get("/", defaults={"_": ""}) @bp.get("/<path:_>") -def index(_): +def index(_: object) -> str: """ Serve the index page at /*, i.e. at any path below / and at / itself. @@ -16,5 +17,5 @@ def index(_): @bp.get("/api/") -def redirect_to_api(): +def redirect_to_api() -> Response: return redirect(url_for("api.index")) diff --git a/portal/model/Retrievable.py b/portal/model/Retrievable.py index d754852..1b7c649 100644 --- a/portal/model/Retrievable.py +++ b/portal/model/Retrievable.py @@ -1,7 +1,8 @@ from abc import abstractmethod -from typing import List, Optional, Any, Self +from typing import Any, List, Optional, Self from flask import g + from portal import db diff --git a/portal/problem_details.py b/portal/problem_details.py index a397f0d..517ffd6 100644 --- a/portal/problem_details.py +++ b/portal/problem_details.py @@ -4,15 +4,17 @@ Provide RFC 7807-compliant problem details responses for the API. from typing import Optional +ProblemResponse = tuple[dict, int, dict] -def unauthorized(): + +def unauthorized() -> ProblemResponse: """ Return a 401 message and status code """ return _problem_response(title="Unauthorized", status=401) -def not_found(): +def not_found() -> ProblemResponse: """ Return a 404 message and status code """ @@ -24,7 +26,7 @@ def _problem_response( status: int, type_uri: Optional[str] = None, extensions: Optional[dict] = None, -) -> tuple[dict, int, dict]: +) -> ProblemResponse: type_dict = {"type": type_uri} if type_uri is not None else {} title_dict = {"title": title} -- GitLab