Verified Commit e3e8e030 authored by Jakob Moser's avatar Jakob Moser
Browse files

Accept CA root cert when creating server

parent 0c053aa3
Loading
Loading
Loading
Loading
+11 −5
Original line number Diff line number Diff line
from pathlib import Path
import ssl

from ldap3 import Tls, Server, Connection
@@ -9,9 +10,6 @@ def is_valid_server_url(url: str) -> bool:
    return url.startswith("ldaps://") or url.startswith("ldap://")


TLS_CONFIGURATION = Tls(validate=ssl.CERT_REQUIRED)


class Directory:
    """
    A (user information) directory (= database), accessible via the Lightweight Directory Access Protocol (LDAP).
@@ -24,12 +22,15 @@ class Directory:
    https://gitlab.cl.uni-heidelberg.de/fachschaft/codex/-/blob/master/ldap/LdapDatabase.py?ref_type=heads
    """

    def __init__(self, url: str, base_dn: str):
    def __init__(
        self, url: str, base_dn: str, ca_certs_file_path: Path | str | None = None
    ) -> None:
        """
        Creates the LdapDatabase instance. Does not establish any connections whatsoever.

        :param url: A ldap server url (must use either ldaps:// or ldap:// scheme)
        :param base_dn: The distinguished name below which to look for users (e.g. "ou=employees,dc=example,dc=com")
        :param ca_certs_file_path: Path to a file containing a root CA, if e.g., a self-created one should be used
        """
        if not is_valid_server_url(url):
            raise ValueError("ldap_server_url must begin with ldaps:// or ldap://")
@@ -43,7 +44,12 @@ class Directory:
        # specified (e.g. if someone ever deletes the scheme-enforcing code above), the connection will fall back to
        # use SSL, and I believe that is a good way for things to be.
        self.__server = Server(
            self.ldap_server_url, use_ssl=True, tls=TLS_CONFIGURATION
            self.ldap_server_url,
            use_ssl=True,
            tls=Tls(
                validate=ssl.CERT_REQUIRED,
                ca_certs_file=str(ca_certs_file_path) if ca_certs_file_path else None,
            ),
        )

    def is_valid(self, username: str, password: str) -> bool: