diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d7d7d50..35476fe 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -20,7 +20,7 @@ To set up a development environment for this repository: 1. Do a development install with pip ```bash - pip install --editable ".[test]" + pip install --editable ".[dev]" ``` 1. Set up pre-commit hooks for automatic code formatting, etc. diff --git a/multiauthenticator/multiauthenticator.py b/multiauthenticator/multiauthenticator.py index 2ea5f39..61b881e 100644 --- a/multiauthenticator/multiauthenticator.py +++ b/multiauthenticator/multiauthenticator.py @@ -32,6 +32,8 @@ from jupyterhub.utils import url_path_join from traitlets import List +PREFIX_SEPARATOR = ":" + class URLScopeMixin: """Mixin class that adds the""" @@ -51,6 +53,14 @@ def get_handlers(self, app): ] +def removeprefix(self: str, prefix: str) -> str: + """PEP-0616 implementation to stay compatible with Python < 3.9""" + if self.startswith(prefix): + return self[len(prefix) :] + else: + return self[:] + + class MultiAuthenticator(Authenticator): """Wrapper class that allows to use more than one authentication provider for JupyterHub""" @@ -69,12 +79,46 @@ def __init__(self, *arg, **kwargs): class WrapperAuthenticator(URLScopeMixin, authenticator_klass): url_scope = url_scope_authenticator + @property + def username_prefix(self): + return f"{getattr(self, 'service_name', self.login_service)}{PREFIX_SEPARATOR}" + + async def authenticate(self, handler, data=None, **kwargs): + response = await super().authenticate(handler, data, **kwargs) + if response is None: + return None + elif type(response) == str: + return self.username_prefix + response + else: + response["name"] = self.username_prefix + response["name"] + return response + + def check_allowed(self, username, authentication=None): + if not username.startswith(self.username_prefix): + return False + + return super().check_allowed( + removeprefix(username, self.username_prefix), authentication + ) + + def check_blocked_users(self, username, authentication=None): + if not username.startswith(self.username_prefix): + return False + + return super().check_blocked_users( + removeprefix(username, self.username_prefix), authentication + ) + service_name = authenticator_configuration.pop("service_name", None) authenticator = WrapperAuthenticator(**authenticator_configuration) - if service_name: + if service_name is not None: + if PREFIX_SEPARATOR in service_name: + raise ValueError(f"Service name cannot contain {PREFIX_SEPARATOR}") authenticator.service_name = service_name + elif PREFIX_SEPARATOR in authenticator.login_service: + raise ValueError(f"Login service cannot contain {PREFIX_SEPARATOR}") self._authenticators.append(authenticator) diff --git a/multiauthenticator/tests/test_multiauthenticator.py b/multiauthenticator/tests/test_multiauthenticator.py index b607f4a..4b55155 100644 --- a/multiauthenticator/tests/test_multiauthenticator.py +++ b/multiauthenticator/tests/test_multiauthenticator.py @@ -2,11 +2,16 @@ # # SPDX-License-Identifier: BSD-3-Clause """Test module for the MultiAuthenticator class""" +import pytest + +from jupyterhub.auth import DummyAuthenticator from jupyterhub.auth import PAMAuthenticator +from oauthenticator import OAuthenticator from oauthenticator.github import GitHubOAuthenticator from oauthenticator.gitlab import GitLabOAuthenticator from oauthenticator.google import GoogleOAuthenticator +from ..multiauthenticator import PREFIX_SEPARATOR from ..multiauthenticator import MultiAuthenticator @@ -82,7 +87,7 @@ def test_same_authenticators(): GoogleOAuthenticator, "/mygoogle", { - "login_service": "My Google", + "service_name": "My Google", "client_id": "yyyyy", "client_secret": "yyyyy", "oauth_callback_url": "http://example.com/hub/mygoogle/oauth_callback", @@ -92,7 +97,7 @@ def test_same_authenticators(): GoogleOAuthenticator, "/othergoogle", { - "login_service": "Other Google", + "service_name": "Other Google", "client_id": "xxxx", "client_secret": "xxxx", "oauth_callback_url": "http://example.com/hub/othergoogle/oauth_callback", @@ -109,9 +114,9 @@ def test_same_authenticators(): for path, handler in handlers: assert isinstance(handler.authenticator, GoogleOAuthenticator) if "mygoogle" in path: - assert handler.authenticator.login_service == "My Google" + assert handler.authenticator.service_name == "My Google" elif "othergoogle" in path: - assert handler.authenticator.login_service == "Other Google" + assert handler.authenticator.service_name == "Other Google" else: raise ValueError(f"Unknown path: {path}") @@ -171,7 +176,6 @@ def test_extra_configuration(): { "service_name": "PAM", "allowed_users": allowed_users, - "not_existing": "boom", }, ), ] @@ -182,5 +186,114 @@ def test_extra_configuration(): for authenticator in multi_authenticator._authenticators: assert authenticator.allowed_users == allowed_users - if isinstance(authenticator, PAMAuthenticator): - assert not hasattr(authenticator, "not_existing") + +def test_username_prefix(): + MultiAuthenticator.authenticators = [ + ( + GitLabOAuthenticator, + "/gitlab", + { + "client_id": "xxxx", + "client_secret": "xxxx", + "oauth_callback_url": "http://example.com/hub/gitlab/oauth_callback", + }, + ), + (PAMAuthenticator, "/pam", {"service_name": "PAM"}), + ] + + multi_authenticator = MultiAuthenticator() + assert len(multi_authenticator._authenticators) == 2 + assert ( + multi_authenticator._authenticators[0].username_prefix + == f"GitLab{PREFIX_SEPARATOR}" + ) + assert ( + multi_authenticator._authenticators[1].username_prefix + == f"PAM{PREFIX_SEPARATOR}" + ) + + +@pytest.mark.asyncio +async def test_authenticated_username_prefix(): + MultiAuthenticator.authenticators = [ + (DummyAuthenticator, "/pam", {"service_name": "Dummy"}), + ] + + multi_authenticator = MultiAuthenticator() + assert len(multi_authenticator._authenticators) == 1 + username = await multi_authenticator._authenticators[0].authenticate( + None, {"username": "test"} + ) + assert username == f"Dummy{PREFIX_SEPARATOR}test" + + +def test_username_prefix_checks(): + MultiAuthenticator.authenticators = [ + (PAMAuthenticator, "/pam", {"service_name": "PAM", "allowed_users": {"test"}}), + ( + PAMAuthenticator, + "/pam", + {"service_name": "PAM2", "blocked_users": {"test2"}}, + ), + ] + + multi_authenticator = MultiAuthenticator() + assert len(multi_authenticator._authenticators) == 2 + authenticator = multi_authenticator._authenticators[0] + + assert authenticator.check_allowed("test") == False + assert authenticator.check_allowed("PAM:test") == True + assert ( + authenticator.check_blocked_users("test") == False + ) # Even if no block list, it does not have the correct prefix + assert authenticator.check_blocked_users("PAM:test") == True + + authenticator = multi_authenticator._authenticators[1] + assert authenticator.check_allowed("test2") == False + assert ( + authenticator.check_allowed("PAM2:test2") == True + ) # Because allowed_users is empty + assert authenticator.check_blocked_users("test2") == False + assert authenticator.check_blocked_users("PAM2:test2") == False + + +@pytest.fixture(params=[f"test me{PREFIX_SEPARATOR}", f"second{PREFIX_SEPARATOR} test"]) +def invalid_name(request): + yield request.param + + +def test_username_prefix_validation_with_service_name(invalid_name): + MultiAuthenticator.authenticators = [ + ( + PAMAuthenticator, + "/pam", + {"service_name": invalid_name, "allowed_users": {"test"}}, + ), + ] + + with pytest.raises(ValueError) as excinfo: + MultiAuthenticator() + + assert f"Service name cannot contain {PREFIX_SEPARATOR}" in str(excinfo.value) + + +def test_username_prefix_validation_with_login_service(invalid_name): + class MyAuthenticator(OAuthenticator): + login_service = invalid_name + + MultiAuthenticator.authenticators = [ + ( + MyAuthenticator, + "/myauth", + { + "client_id": "xxxx", + "client_secret": "xxxx", + "oauth_callback_url": "http://example.com/myauth/oauth_callback", + }, + ), + ] + + with pytest.raises(ValueError) as excinfo: + MultiAuthenticator() + + assert f"Login service cannot contain {PREFIX_SEPARATOR}" in str(excinfo.value) diff --git a/pyproject.toml b/pyproject.toml index 1740ff9..86542b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,8 @@ dependencies = [ ] [project.optional-dependencies] -test = ["pytest", "pytest-cov"] +test = ["pytest", "pytest-cov", "pytest-asyncio"] +dev = ["pre-commit", "jupyterhub-multiauthenticator[test]"] [tool.setuptools] packages = ["multiauthenticator"]