diff --git a/.changes/unreleased/Features-add-gssapi-auth.yaml b/.changes/unreleased/Features-add-gssapi-auth.yaml new file mode 100644 index 00000000..3d0c03bc --- /dev/null +++ b/.changes/unreleased/Features-add-gssapi-auth.yaml @@ -0,0 +1,7 @@ +kind: Features +body: Add `gssapi` authentication method backed by trino-python-client's `GSSAPIAuthentication` class. Supports `kinit`-style credential cache auth in addition to keytab, and exposes mutual_authentication as a case-insensitive string ("REQUIRED" | "OPTIONAL" | "DISABLED"). +time: 2026-05-07T06:50:38.855Z +custom: + Author: hb1915 + Issue: "" + PR: "" diff --git a/dbt/adapters/trino/connections.py b/dbt/adapters/trino/connections.py index 3b89aa20..6c63491b 100644 --- a/dbt/adapters/trino/connections.py +++ b/dbt/adapters/trino/connections.py @@ -40,6 +40,8 @@ def _create_trino_profile(cls, profile): return TrinoCertificateCredentials elif method == "kerberos": return TrinoKerberosCredentials + elif method == "gssapi": + return TrinoGssapiCredentials elif method == "jwt": return TrinoJwtCredentials elif method == "oauth": @@ -223,6 +225,75 @@ def trino_auth(self): ) +# Mapping from human-readable mutual-authentication mode (used in dbt profiles) +# to trino-python-client's integer constants. Kept module-level so it's exposed +# for tests and for any future auth methods that need the same translation. +_GSSAPI_MUTUAL_AUTH_VALUES = { + "REQUIRED": trino.auth.GSSAPIAuthentication.MUTUAL_REQUIRED, + "OPTIONAL": trino.auth.GSSAPIAuthentication.MUTUAL_OPTIONAL, + "DISABLED": trino.auth.GSSAPIAuthentication.MUTUAL_DISABLED, +} + + +@dataclass +class TrinoGssapiCredentials(TrinoCredentials): + host: str + port: Port + user: str + client_tags: Optional[List[str]] = None + roles: Optional[Dict[str, str]] = None + principal: Optional[str] = None + krb5_config: Optional[str] = None + service_name: Optional[str] = None + # One of "REQUIRED", "OPTIONAL", "DISABLED" (case-insensitive). Defaults to + # "DISABLED" to match trino-python-client's GSSAPIAuthentication default. + mutual_authentication: Optional[str] = "DISABLED" + cert: Optional[Union[str, bool]] = None + http_headers: Optional[Dict[str, str]] = None + force_preemptive: Optional[bool] = False + hostname_override: Optional[str] = None + sanitize_mutual_error_response: Optional[bool] = True + delegate: Optional[bool] = False + session_properties: Dict[str, Any] = field(default_factory=dict) + prepared_statements_enabled: bool = PREPARED_STATEMENTS_ENABLED_DEFAULT + retries: Optional[int] = trino.constants.DEFAULT_MAX_ATTEMPTS + timezone: Optional[str] = None + suppress_cert_warning: Optional[bool] = None + + @property + def http_scheme(self): + return HttpScheme.HTTPS + + @property + def method(self): + return "gssapi" + + def trino_auth(self): + return trino.auth.GSSAPIAuthentication( + config=self.krb5_config, + service_name=self.service_name, + mutual_authentication=self._resolve_mutual_authentication(), + force_preemptive=self.force_preemptive, + hostname_override=self.hostname_override, + sanitize_mutual_error_response=self.sanitize_mutual_error_response, + principal=self.principal, + delegate=self.delegate, + ca_bundle=self.cert, + ) + + def _resolve_mutual_authentication(self) -> int: + value = (self.mutual_authentication or "DISABLED").upper() + try: + return _GSSAPI_MUTUAL_AUTH_VALUES[value] + except KeyError: + raise DbtRuntimeError( + "Invalid mutual_authentication value {!r}. " + "Expected one of: REQUIRED, OPTIONAL, DISABLED.".format( + self.mutual_authentication + ) + ) + + @dataclass class TrinoJwtCredentials(TrinoCredentials): host: str diff --git a/dbt/include/trino/sample_profiles.yml b/dbt/include/trino/sample_profiles.yml index 4f0639b4..c863dead 100644 --- a/dbt/include/trino/sample_profiles.yml +++ b/dbt/include/trino/sample_profiles.yml @@ -3,7 +3,7 @@ default: dev: type: trino - method: none # optional, one of {none | ldap | kerberos} + method: none # optional, one of {none | ldap | kerberos | gssapi | jwt | certificate | oauth | oauth_console} user: [dev_user] password: [password] # required if method is ldap or kerberos database: [database name] @@ -14,7 +14,7 @@ default: prod: type: trino - method: none # optional, one of {none | ldap | kerberos} + method: none # optional, one of {none | ldap | kerberos | gssapi | jwt | certificate | oauth | oauth_console} user: [prod_user] password: [prod_password] # required if method is ldap or kerberos database: [database name] diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index c389bc84..70baa994 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -16,6 +16,7 @@ from dbt.adapters.trino.connections import ( HttpScheme, TrinoCertificateCredentials, + TrinoGssapiCredentials, TrinoJwtCredentials, TrinoKerberosCredentials, TrinoLdapCredentials, @@ -320,6 +321,73 @@ def test_kerberos_authentication(self): self.assertEqual(credentials.timezone, "UTC") self.assertEqual(credentials.suppress_cert_warning, False) + + def test_gssapi_authentication(self): + connection = self.acquire_connection_with_profile( + { + "type": "trino", + "catalog": "trinodb", + "host": "database", + "port": 5439, + "method": "gssapi", + "schema": "dbt_test_schema", + "user": "trino_user", + "principal": "trino_user@EXAMPLE.COM", + "krb5_config": "/etc/krb5.conf", + "service_name": "trino", + "hostname_override": "database.example.com", + "mutual_authentication": "OPTIONAL", + "force_preemptive": True, + "delegate": True, + "cert": "/path/to/cert", + "client_tags": ["dev", "gssapi"], + "http_headers": {"X-Trino-Client-Info": "dbt-trino"}, + "session_properties": { + "query_max_run_time": "4h", + "exchange_compression": True, + }, + "timezone": "UTC", + "suppress_cert_warning": False, + } + ) + credentials = connection.credentials + self.assertIsInstance(credentials, TrinoGssapiCredentials) + self.assert_default_connection_credentials(credentials) + self.assertEqual(credentials.http_scheme, HttpScheme.HTTPS) + self.assertEqual(credentials.cert, "/path/to/cert") + self.assertEqual(credentials.client_tags, ["dev", "gssapi"]) + self.assertEqual(credentials.principal, "trino_user@EXAMPLE.COM") + self.assertEqual(credentials.service_name, "trino") + self.assertEqual(credentials.hostname_override, "database.example.com") + self.assertEqual(credentials.mutual_authentication, "OPTIONAL") + self.assertEqual(credentials.force_preemptive, True) + self.assertEqual(credentials.delegate, True) + self.assertEqual(credentials.timezone, "UTC") + self.assertEqual(credentials.suppress_cert_warning, False) + import trino.auth + self.assertEqual( + credentials._resolve_mutual_authentication(), + trino.auth.GSSAPIAuthentication.MUTUAL_OPTIONAL, + ) + + def test_gssapi_authentication_default_mutual_authentication(self): + credentials = TrinoGssapiCredentials( + host="h", port=443, user="u", database="db", schema="s" + ) + import trino.auth + self.assertEqual( + credentials._resolve_mutual_authentication(), + trino.auth.GSSAPIAuthentication.MUTUAL_DISABLED, + ) + + def test_gssapi_authentication_invalid_mutual_authentication(self): + credentials = TrinoGssapiCredentials( + host="h", port=443, user="u", database="db", schema="s", + mutual_authentication="bogus", + ) + with self.assertRaises(DbtRuntimeError): + credentials._resolve_mutual_authentication() + def test_certificate_authentication(self): connection = self.acquire_connection_with_profile( {