From 9df2b5d69002c9ce6784d29f86ad9b7c8b992ae8 Mon Sep 17 00:00:00 2001
From: Leander Karp <karp@cl.uni-heidelberg.de>
Date: Wed, 7 Aug 2024 16:18:00 +0200
Subject: [PATCH] Add missing type declarations

---
 portal/__init__.py          |  4 ++--
 portal/api/agenda.py        |  2 +-
 portal/api/auth.py          | 34 ++++++++++++++++++++++------------
 portal/api/rest.py          |  6 +++---
 portal/api/version.py       |  3 ++-
 portal/main.py              |  7 ++++---
 portal/model/Retrievable.py |  3 ++-
 portal/problem_details.py   |  8 +++++---
 8 files changed, 41 insertions(+), 26 deletions(-)

diff --git a/portal/__init__.py b/portal/__init__.py
index 642054e..04d8900 100644
--- a/portal/__init__.py
+++ b/portal/__init__.py
@@ -1,7 +1,7 @@
 # Taken and adapted from the flaskr tutorial. Database logic adapted from Flask SQLAlchemy Documentation
 # (https://flask-sqlalchemy.palletsprojects.com/en/2.x/contexts/). See LICENSE-3RD-PARTY.md for details.
 import secrets
-from typing import Optional
+from typing import Any, Optional
 from pathlib import Path
 
 from flask import Flask
@@ -9,7 +9,7 @@ from flask_sqlalchemy import SQLAlchemy
 
 from .db_config import get_database_uri, get_uri_for_sqlite
 
-db = SQLAlchemy()
+db: Any = SQLAlchemy()
 
 from .model import *
 
diff --git a/portal/api/agenda.py b/portal/api/agenda.py
index ea5ac56..d7717f6 100644
--- a/portal/api/agenda.py
+++ b/portal/api/agenda.py
@@ -7,7 +7,7 @@ from ..model import Key
 
 @bp.get("/agendas/-/items")
 @require_valid(Key)
-def the_agenda_items():
+def the_agenda_items() -> list[dict[str, str]]:
     tasks = sorted(
         TaigaUserStoryProvider.get_the_one().fetch_user_stories(),
         key=lambda task: task["kanban_order"],
diff --git a/portal/api/auth.py b/portal/api/auth.py
index acf7b0d..eae75b8 100644
--- a/portal/api/auth.py
+++ b/portal/api/auth.py
@@ -1,19 +1,25 @@
 # Login code adapted from the flaskr tutorial (see LICENSE-3RD-PARTY.md for details)
 # See https://flask.palletsprojects.com/en/2.0.x/tutorial/views/
 import functools
+from typing import Any, Callable
 
-from flask import request, g
+from flask import g, request
 
 from portal import db
 
+from ..model import Key, User
+from ..problem_details import ProblemResponse, not_found, unauthorized
 from .blueprint import bp
-from ..model import User, Key
-from ..problem_details import not_found, unauthorized
 
 
 @bp.before_app_request
-def set_user_if_valid():
-    if not request.authorization or request.authorization.type != "basic":
+def set_user_if_valid() -> None:
+    if (
+        not request.authorization
+        or request.authorization.username is None
+        or request.authorization.password is None
+        or request.authorization.type != "basic"
+    ):
         g.user = None
         return
 
@@ -23,15 +29,19 @@ def set_user_if_valid():
 
 
 @bp.before_app_request
-def set_key_if_valid():
-    if not request.authorization or request.authorization.type != "bearer":
+def set_key_if_valid() -> None:
+    if (
+        not request.authorization
+        or request.authorization.token is None
+        or request.authorization.type != "bearer"
+    ):
         g.key = None
         return
 
     g.key = Key.get_by_secret_unless_expired(request.authorization.token)
 
 
-def require_valid(credential_type):
+def require_valid(credential_type: type[User | Key]) -> Callable:
     """
     Requires that valid credentials are provided before allowing access to a route.
     If no valid credentials are provided, abort and send a 401.
@@ -41,9 +51,9 @@ def require_valid(credential_type):
     - `@require_valid(Key)` to require a valid key
     """
 
-    def decorator(route):
+    def decorator(route: Callable) -> Callable:
         @functools.wraps(route)
-        def protected_route(**kwargs):
+        def protected_route(**kwargs: Any) -> ProblemResponse:
             # Determine which credentials to use
             credentials = {User: g.user, Key: g.key}.get(credential_type)
 
@@ -60,14 +70,14 @@ def require_valid(credential_type):
 
 @bp.post("/keys")
 @require_valid(User)
-def post_key():
+def post_key() -> tuple[dict, int]:
     key, secret = Key.generate(g.user)
     return key.to_dict() | {"secret": secret}, 201
 
 
 @bp.delete("/keys/<string:uuid>")
 @require_valid(Key)
-def delete_key(uuid: str):
+def delete_key(uuid: str) -> tuple[str, int] | ProblemResponse:
     key_to_delete = Key.get_only(uuid)
 
     if not key_to_delete:
diff --git a/portal/api/rest.py b/portal/api/rest.py
index dc763fb..6474513 100644
--- a/portal/api/rest.py
+++ b/portal/api/rest.py
@@ -1,11 +1,11 @@
 from flask import render_template
 
+from ..problem_details import ProblemResponse, not_found
 from .blueprint import bp
-from ..problem_details import not_found
 
 
 @bp.get("/<path:_>")
-def fallback_not_found(_):
+def fallback_not_found(_: object) -> ProblemResponse:
     """
     If you can't find a matching API endpoint, fallback to returning a 404.
 
@@ -16,5 +16,5 @@ def fallback_not_found(_):
 
 
 @bp.get("/")
-def index():
+def index() -> str:
     return render_template("redocly.html")
diff --git a/portal/api/version.py b/portal/api/version.py
index 8674f68..b12b527 100644
--- a/portal/api/version.py
+++ b/portal/api/version.py
@@ -1,12 +1,13 @@
 import json
 from datetime import datetime
 from pathlib import Path
+from typing import Any
 
 from .blueprint import bp
 
 
 @bp.get("/version")
-def get_version():
+def get_version() -> dict[str, Any]:
     portal_version_path = Path("./PORTAL_VERSION")
     if portal_version_path.is_file():
         with open("PORTAL_VERSION") as f:
diff --git a/portal/main.py b/portal/main.py
index 327b6d8..bbf1a38 100644
--- a/portal/main.py
+++ b/portal/main.py
@@ -1,11 +1,12 @@
-from flask import Blueprint, render_template, redirect, url_for
+from flask import Blueprint, redirect, render_template, url_for
+from werkzeug import Response
 
 bp = Blueprint("main", __name__)
 
 
 @bp.get("/", defaults={"_": ""})
 @bp.get("/<path:_>")
-def index(_):
+def index(_: object) -> str:
     """
     Serve the index page at /*, i.e. at any path below / and at / itself.
 
@@ -16,5 +17,5 @@ def index(_):
 
 
 @bp.get("/api/")
-def redirect_to_api():
+def redirect_to_api() -> Response:
     return redirect(url_for("api.index"))
diff --git a/portal/model/Retrievable.py b/portal/model/Retrievable.py
index d754852..1b7c649 100644
--- a/portal/model/Retrievable.py
+++ b/portal/model/Retrievable.py
@@ -1,7 +1,8 @@
 from abc import abstractmethod
-from typing import List, Optional, Any, Self
+from typing import Any, List, Optional, Self
 
 from flask import g
+
 from portal import db
 
 
diff --git a/portal/problem_details.py b/portal/problem_details.py
index a397f0d..517ffd6 100644
--- a/portal/problem_details.py
+++ b/portal/problem_details.py
@@ -4,15 +4,17 @@ Provide RFC 7807-compliant problem details responses for the API.
 
 from typing import Optional
 
+ProblemResponse = tuple[dict, int, dict]
 
-def unauthorized():
+
+def unauthorized() -> ProblemResponse:
     """
     Return a 401 message and status code
     """
     return _problem_response(title="Unauthorized", status=401)
 
 
-def not_found():
+def not_found() -> ProblemResponse:
     """
     Return a 404 message and status code
     """
@@ -24,7 +26,7 @@ def _problem_response(
     status: int,
     type_uri: Optional[str] = None,
     extensions: Optional[dict] = None,
-) -> tuple[dict, int, dict]:
+) -> ProblemResponse:
     type_dict = {"type": type_uri} if type_uri is not None else {}
     title_dict = {"title": title}
 
-- 
GitLab