From e5dda50a83f92ad40b0403239eb54531f5bf60fb Mon Sep 17 00:00:00 2001 From: Konrad Dziedzic Date: Wed, 26 Feb 2025 00:37:31 +0100 Subject: [PATCH] Pass users original role for impersonation Without this the way impersonation work is broken. Considering followin scenarion - there are 2 users: alice and bob - there is a role alice_role that allows for bob impersonation If alice sets alice_role and then calls SET SESSION AUTHORIZATION `bob` access control will make a check `canImpersonateUser`, because alice_role allows for that impersonation it will succeed. However every next query will also call this `canImpersonateUser` check and because alice_role is not assigned anymore it will fail every time. The idea behind the fix is to pass roles of the original user, alice from the example above, between server and clients using new headers X-Trino-Set-Original-Roles. Then roles obtained in this way can be used in the checkCanImpersonate to make sure it passes. --- .../src/main/java/io/trino/cli/Console.java | 8 + .../src/main/java/io/trino/cli/Query.java | 5 + .../java/io/trino/client/ClientSession.java | 17 ++ .../java/io/trino/client/ProtocolHeaders.java | 14 ++ .../java/io/trino/client/StatementClient.java | 2 + .../io/trino/client/StatementClientV1.java | 17 ++ .../java/io/trino/jdbc/TrinoConnection.java | 5 + .../io/trino/jdbc/BaseTrinoDriverTest.java | 2 +- .../trino/jdbc/TestAsyncResultIterator.java | 6 + .../io/trino/jdbc/TestJdbcConnection.java | 6 + .../src/main/java/io/trino/Session.java | 3 +- .../java/io/trino/SessionRepresentation.java | 10 ++ .../trino/dispatcher/FailedDispatchQuery.java | 1 + .../java/io/trino/event/QueryMonitor.java | 1 + .../java/io/trino/execution/QueryInfo.java | 11 ++ .../io/trino/execution/QueryStateMachine.java | 9 + .../SetSessionAuthorizationTask.java | 15 ++ .../HttpRequestSessionContextFactory.java | 26 +-- .../java/io/trino/server/ResultQueryInfo.java | 5 + .../protocol/ExecutingStatementResource.java | 4 + .../java/io/trino/server/protocol/Query.java | 7 + .../server/protocol/QueryResultsResponse.java | 2 + .../testing/TestingAccessControlManager.java | 10 ++ .../execution/MockManagedQueryExecution.java | 1 + .../io/trino/execution/TestQueryInfo.java | 1 + .../trino/metadata/TestMetadataManager.java | 15 +- .../io/trino/server/TestBasicQueryInfo.java | 1 + .../TestHttpRequestSessionContextFactory.java | 18 +- .../io/trino/server/TestQueryStateInfo.java | 1 + .../trino/spi/eventlistener/QueryContext.java | 9 + .../main/sphinx/develop/client-protocol.md | 3 + .../httpquery/TestHttpEventListener.java | 1 + .../TestHttpServerEventListener.java | 1 + .../plugin/eventlistener/kafka/TestUtils.java | 1 + .../mysql/TestMysqlEventListener.java | 2 + .../plugin/openlineage/TrinoEventData.java | 1 + pom.xml | 7 + testing/trino-tests/pom.xml | 13 ++ .../io/trino/security/TestImpersonation.java | 160 ++++++++++++++++++ .../security/TestSystemSecurityMetadata.java | 31 ++-- 40 files changed, 424 insertions(+), 28 deletions(-) create mode 100644 testing/trino-tests/src/test/java/io/trino/security/TestImpersonation.java diff --git a/client/trino-cli/src/main/java/io/trino/cli/Console.java b/client/trino-cli/src/main/java/io/trino/cli/Console.java index 9d1a5782d9cd..e5e82c932ce7 100644 --- a/client/trino-cli/src/main/java/io/trino/cli/Console.java +++ b/client/trino-cli/src/main/java/io/trino/cli/Console.java @@ -45,6 +45,7 @@ import java.nio.file.Paths; import java.util.AbstractMap; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -399,6 +400,13 @@ private static boolean process( builder = builder.roles(ImmutableMap.of()); } + // update session originalRoles + if (!query.getSetOriginalRoles().isEmpty()) { + Set originalRoles = new HashSet<>(session.getOriginalRoles()); + originalRoles.addAll(query.getSetOriginalRoles()); + builder = builder.originalRoles(originalRoles); + } + if (query.isResetAuthorizationUser()) { builder = builder.authorizationUser(Optional.empty()); builder = builder.roles(ImmutableMap.of()); diff --git a/client/trino-cli/src/main/java/io/trino/cli/Query.java b/client/trino-cli/src/main/java/io/trino/cli/Query.java index 593352ed0b7e..b46dab2daaa5 100644 --- a/client/trino-cli/src/main/java/io/trino/cli/Query.java +++ b/client/trino-cli/src/main/java/io/trino/cli/Query.java @@ -100,6 +100,11 @@ public boolean isResetAuthorizationUser() return client.isResetAuthorizationUser(); } + public Set getSetOriginalRoles() + { + return client.getSetOriginalRoles(); + } + public Map getSetSessionProperties() { return client.getSetSessionProperties(); diff --git a/client/trino-client/src/main/java/io/trino/client/ClientSession.java b/client/trino-client/src/main/java/io/trino/client/ClientSession.java index 91a7a539b097..d15236d605c9 100644 --- a/client/trino-client/src/main/java/io/trino/client/ClientSession.java +++ b/client/trino-client/src/main/java/io/trino/client/ClientSession.java @@ -39,6 +39,7 @@ public class ClientSession private final Optional user; private final Optional sessionUser; private final Optional authorizationUser; + private final Set originalRoles; private final String source; private final Optional traceToken; private final Set clientTags; @@ -80,6 +81,7 @@ private ClientSession( Optional user, Optional sessionUser, Optional authorizationUser, + Set originalRoles, String source, Optional traceToken, Set clientTags, @@ -103,6 +105,7 @@ private ClientSession( this.user = requireNonNull(user, "user is null"); this.sessionUser = requireNonNull(sessionUser, "sessionUser is null"); this.authorizationUser = requireNonNull(authorizationUser, "authorizationUser is null"); + this.originalRoles = ImmutableSet.copyOf(requireNonNull(originalRoles, "originalRoles is null")); this.source = requireNonNull(source, "source is null"); this.traceToken = requireNonNull(traceToken, "traceToken is null"); this.clientTags = ImmutableSet.copyOf(requireNonNull(clientTags, "clientTags is null")); @@ -171,6 +174,11 @@ public Optional getAuthorizationUser() return authorizationUser; } + public Set getOriginalRoles() + { + return originalRoles; + } + public String getSource() { return source; @@ -302,6 +310,7 @@ public static final class Builder private Optional user = Optional.empty(); private Optional sessionUser = Optional.empty(); private Optional authorizationUser = Optional.empty(); + private Set originalRoles = ImmutableSet.of(); private String source; private Optional traceToken = Optional.empty(); private Set clientTags = ImmutableSet.of(); @@ -330,6 +339,7 @@ private Builder(ClientSession clientSession) user = clientSession.getUser(); sessionUser = clientSession.getSessionUser(); authorizationUser = clientSession.getAuthorizationUser(); + originalRoles = clientSession.getOriginalRoles(); source = clientSession.getSource(); traceToken = clientSession.getTraceToken(); clientTags = clientSession.getClientTags(); @@ -374,6 +384,12 @@ public Builder authorizationUser(Optional authorizationUser) return this; } + public Builder originalRoles(Set originalRoles) + { + this.originalRoles = originalRoles; + return this; + } + public Builder source(String source) { this.source = source; @@ -489,6 +505,7 @@ public ClientSession build() user, sessionUser, authorizationUser, + originalRoles, source, traceToken, clientTags, diff --git a/client/trino-client/src/main/java/io/trino/client/ProtocolHeaders.java b/client/trino-client/src/main/java/io/trino/client/ProtocolHeaders.java index 3ef75fc91e7e..69e19da4184f 100644 --- a/client/trino-client/src/main/java/io/trino/client/ProtocolHeaders.java +++ b/client/trino-client/src/main/java/io/trino/client/ProtocolHeaders.java @@ -27,6 +27,7 @@ public final class ProtocolHeaders private final String name; private final String requestUser; private final String requestOriginalUser; + private final String requestOriginalRole; private final String requestSource; private final String requestCatalog; private final String requestSchema; @@ -57,6 +58,7 @@ public final class ProtocolHeaders private final String responseClearTransactionId; private final String responseSetAuthorizationUser; private final String responseResetAuthorizationUser; + private final String responseOriginalRole; public static ProtocolHeaders createProtocolHeaders(String name) { @@ -75,6 +77,7 @@ private ProtocolHeaders(String name) String prefix = "X-" + name + "-"; requestUser = prefix + "User"; requestOriginalUser = prefix + "Original-User"; + requestOriginalRole = prefix + "Original-Roles"; requestSource = prefix + "Source"; requestCatalog = prefix + "Catalog"; requestSchema = prefix + "Schema"; @@ -105,6 +108,7 @@ private ProtocolHeaders(String name) responseClearTransactionId = prefix + "Clear-Transaction-Id"; responseSetAuthorizationUser = prefix + "Set-Authorization-User"; responseResetAuthorizationUser = prefix + "Reset-Authorization-User"; + responseOriginalRole = prefix + "Set-Original-Roles"; } public String getProtocolName() @@ -122,6 +126,11 @@ public String requestOriginalUser() return requestOriginalUser; } + public String requestOriginalRole() + { + return requestOriginalRole; + } + public String requestSource() { return requestSource; @@ -272,6 +281,11 @@ public String responseResetAuthorizationUser() return responseResetAuthorizationUser; } + public String responseOriginalRole() + { + return responseOriginalRole; + } + public static ProtocolHeaders detectProtocol(Optional alternateHeaderName, Set headerNames) throws ProtocolDetectionException { diff --git a/client/trino-client/src/main/java/io/trino/client/StatementClient.java b/client/trino-client/src/main/java/io/trino/client/StatementClient.java index 816196f9f0b1..4838e10c95fd 100644 --- a/client/trino-client/src/main/java/io/trino/client/StatementClient.java +++ b/client/trino-client/src/main/java/io/trino/client/StatementClient.java @@ -63,6 +63,8 @@ default Optional getEncoding() boolean isResetAuthorizationUser(); + Set getSetOriginalRoles(); + Map getSetSessionProperties(); Set getResetSessionProperties(); diff --git a/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java b/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java index 2c14330f36db..98d97d2b3bdc 100644 --- a/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java +++ b/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java @@ -88,6 +88,7 @@ class StatementClientV1 private final AtomicReference> setPath = new AtomicReference<>(); private final AtomicReference setAuthorizationUser = new AtomicReference<>(); private final AtomicBoolean resetAuthorizationUser = new AtomicBoolean(); + private final Set setOriginalRoles = Sets.newConcurrentHashSet(); private final Map setSessionProperties = new ConcurrentHashMap<>(); private final Set resetSessionProperties = Sets.newConcurrentHashSet(); private final Map setRoles = new ConcurrentHashMap<>(); @@ -125,6 +126,7 @@ public StatementClientV1(Call.Factory httpCallFactory, Call.Factory segmentHttpC .filter(Optional::isPresent) .map(Optional::get) .findFirst(); + this.setOriginalRoles.addAll(session.getOriginalRoles()); this.clientCapabilities = Joiner.on(",").join(clientCapabilities.orElseGet(() -> stream(ClientCapabilities.values()) .map(Enum::name) .collect(toImmutableSet()))); @@ -181,6 +183,10 @@ private Request buildQueryRequest(ClientSession session, String query, Optional< builder.addHeader(TRINO_HEADERS.requestResourceEstimate(), entry.getKey() + "=" + urlEncode(entry.getValue())); } + for (ClientSelectedRole selectedRole : session.getOriginalRoles()) { + builder.addHeader(TRINO_HEADERS.requestOriginalRole(), selectedRole.toString()); + } + Map roles = session.getRoles(); for (Entry entry : roles.entrySet()) { builder.addHeader(TRINO_HEADERS.requestRole(), entry.getKey() + '=' + urlEncode(entry.getValue().toString())); @@ -316,6 +322,12 @@ public boolean isResetAuthorizationUser() return resetAuthorizationUser.get(); } + @Override + public Set getSetOriginalRoles() + { + return ImmutableSet.copyOf(setOriginalRoles); + } + @Override public Map getSetSessionProperties() { @@ -475,6 +487,11 @@ private void processResponse(Headers headers, QueryResults results) this.resetAuthorizationUser.set(Boolean.parseBoolean(resetAuthorizationUser)); } + setOriginalRoles.addAll(headers.values(TRINO_HEADERS.responseOriginalRole()) + .stream() + .map(role -> ClientSelectedRole.valueOf(urlDecode(role))) + .collect(toImmutableSet())); + for (String setSession : headers.values(TRINO_HEADERS.responseSetSession())) { List keyValue = COLLECTION_HEADER_SPLITTER.splitToList(setSession); if (keyValue.size() != 2) { diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java index 1b0b6730e7c9..2969a4b553d2 100644 --- a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java +++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java @@ -19,6 +19,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; import com.google.common.primitives.Ints; import io.airlift.units.Duration; import io.trino.client.ClientSelectedRole; @@ -107,6 +108,7 @@ public class TrinoConnection private final AtomicReference schema = new AtomicReference<>(); private final AtomicReference> path = new AtomicReference<>(ImmutableList.of()); private final AtomicReference authorizationUser = new AtomicReference<>(); + private final Set originalRoles = Sets.newConcurrentHashSet(); private final AtomicReference timeZoneId = new AtomicReference<>(); private final AtomicReference locale = new AtomicReference<>(); private final AtomicReference networkTimeoutMillis = new AtomicReference<>(Ints.saturatedCast(MINUTES.toMillis(2))); @@ -855,6 +857,7 @@ StatementClient startQuery(String sql, Map sessionPropertiesOver .user(user) .sessionUser(sessionUser.get()) .authorizationUser(Optional.ofNullable(authorizationUser.get())) + .originalRoles(ImmutableSet.copyOf(originalRoles)) .source(source) .traceToken(Optional.ofNullable(clientInfo.get(TRACE_TOKEN))) .clientTags(ImmutableSet.copyOf(clientTags)) @@ -891,10 +894,12 @@ void updateSession(StatementClient client) if (client.getSetAuthorizationUser().isPresent()) { authorizationUser.set(client.getSetAuthorizationUser().get()); + originalRoles.addAll(client.getSetOriginalRoles()); roles.clear(); } if (client.isResetAuthorizationUser()) { authorizationUser.set(null); + originalRoles.clear(); roles.clear(); } diff --git a/client/trino-jdbc/src/test/java/io/trino/jdbc/BaseTrinoDriverTest.java b/client/trino-jdbc/src/test/java/io/trino/jdbc/BaseTrinoDriverTest.java index 86966be9d5d5..871521ee4a71 100644 --- a/client/trino-jdbc/src/test/java/io/trino/jdbc/BaseTrinoDriverTest.java +++ b/client/trino-jdbc/src/test/java/io/trino/jdbc/BaseTrinoDriverTest.java @@ -1241,7 +1241,7 @@ private static Properties toProperties(Map map) return properties; } - private static String getCurrentUser(Connection connection) + public static String getCurrentUser(Connection connection) throws SQLException { try (Statement statement = connection.createStatement(); diff --git a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestAsyncResultIterator.java b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestAsyncResultIterator.java index 3d20a40dc227..9825e221dee1 100644 --- a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestAsyncResultIterator.java +++ b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestAsyncResultIterator.java @@ -215,6 +215,12 @@ public boolean isResetAuthorizationUser() throw new UnsupportedOperationException(); } + @Override + public Set getSetOriginalRoles() + { + throw new UnsupportedOperationException(); + } + @Override public Map getSetSessionProperties() { diff --git a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestJdbcConnection.java b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestJdbcConnection.java index 75ef5a13839a..28f370fcb5a9 100644 --- a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestJdbcConnection.java +++ b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestJdbcConnection.java @@ -1005,6 +1005,12 @@ public boolean isResetAuthorizationUser() return true; } + @Override + public Set getSetOriginalRoles() + { + return Set.of(); + } + @Override public Map getSetSessionProperties() { diff --git a/core/trino-main/src/main/java/io/trino/Session.java b/core/trino-main/src/main/java/io/trino/Session.java index 7841c53cb5b6..860ff2779318 100644 --- a/core/trino-main/src/main/java/io/trino/Session.java +++ b/core/trino-main/src/main/java/io/trino/Session.java @@ -547,7 +547,8 @@ public SessionRepresentation toSessionRepresentation() clientTransactionSupport, identity.getUser(), originalIdentity.getUser(), - identity.getGroups(), + originalIdentity.getEnabledRoles(), + originalIdentity.getGroups(), originalIdentity.getGroups(), identity.getPrincipal().map(Principal::toString), identity.getEnabledRoles(), diff --git a/core/trino-main/src/main/java/io/trino/SessionRepresentation.java b/core/trino-main/src/main/java/io/trino/SessionRepresentation.java index caf538014874..793d76c667e3 100644 --- a/core/trino-main/src/main/java/io/trino/SessionRepresentation.java +++ b/core/trino-main/src/main/java/io/trino/SessionRepresentation.java @@ -48,6 +48,7 @@ public final class SessionRepresentation private final boolean clientTransactionSupport; private final String user; private final String originalUser; + private final Set originalRoles; private final Set groups; private final Set originalUserGroups; private final Optional principal; @@ -81,6 +82,7 @@ public SessionRepresentation( @JsonProperty("clientTransactionSupport") boolean clientTransactionSupport, @JsonProperty("user") String user, @JsonProperty("originalUser") String originalUser, + @JsonProperty("setOriginalRoles") Set originalRoles, @JsonProperty("groups") Set groups, @JsonProperty("originalUserGroups") Set originalUserGroups, @JsonProperty("principal") Optional principal, @@ -112,6 +114,7 @@ public SessionRepresentation( this.clientTransactionSupport = clientTransactionSupport; this.user = requireNonNull(user, "user is null"); this.originalUser = requireNonNull(originalUser, "originalUser is null"); + this.originalRoles = requireNonNull(originalRoles, "setOriginalRoles is null"); this.groups = requireNonNull(groups, "groups is null"); this.originalUserGroups = requireNonNull(originalUserGroups, "originalUserGroups is null"); this.principal = requireNonNull(principal, "principal is null"); @@ -179,6 +182,12 @@ public String getOriginalUser() return originalUser; } + @JsonProperty + public Set getOriginalRoles() + { + return originalRoles; + } + @JsonProperty public Set getGroups() { @@ -350,6 +359,7 @@ public Identity toOriginalIdentity(Map extraCredentials) return Identity.forUser(originalUser) .withGroups(originalUserGroups) .withPrincipal(principal.map(BasicPrincipal::new)) + .withEnabledRoles(originalRoles) .withExtraCredentials(extraCredentials) .build(); } diff --git a/core/trino-main/src/main/java/io/trino/dispatcher/FailedDispatchQuery.java b/core/trino-main/src/main/java/io/trino/dispatcher/FailedDispatchQuery.java index 1e1f00a645cb..255d7771cdc6 100644 --- a/core/trino-main/src/main/java/io/trino/dispatcher/FailedDispatchQuery.java +++ b/core/trino-main/src/main/java/io/trino/dispatcher/FailedDispatchQuery.java @@ -234,6 +234,7 @@ private static QueryInfo immediateFailureQueryInfo( Optional.empty(), Optional.empty(), false, + ImmutableSet.of(), ImmutableMap.of(), ImmutableSet.of(), ImmutableMap.of(), diff --git a/core/trino-main/src/main/java/io/trino/event/QueryMonitor.java b/core/trino-main/src/main/java/io/trino/event/QueryMonitor.java index 9e0bc49c53bc..50c2cfa77c0e 100644 --- a/core/trino-main/src/main/java/io/trino/event/QueryMonitor.java +++ b/core/trino-main/src/main/java/io/trino/event/QueryMonitor.java @@ -365,6 +365,7 @@ private QueryContext createQueryContext(SessionRepresentation session, Optional< return new QueryContext( session.getUser(), session.getOriginalUser(), + session.getOriginalRoles(), session.getPrincipal(), session.getEnabledRoles(), session.getGroups(), diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryInfo.java b/core/trino-main/src/main/java/io/trino/execution/QueryInfo.java index 18a4062fdbd6..9f7f54abde93 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryInfo.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryInfo.java @@ -62,6 +62,7 @@ public class QueryInfo private final Optional setPath; private final Optional setAuthorizationUser; private final boolean resetAuthorizationUser; + private final Set setOriginalRoles; private final Map setSessionProperties; private final Set resetSessionProperties; private final Map setRoles; @@ -101,6 +102,7 @@ public QueryInfo( @JsonProperty("setPath") Optional setPath, @JsonProperty("setAuthorizationUser") Optional setAuthorizationUser, @JsonProperty("resetAuthorizationUser") boolean resetAuthorizationUser, + @JsonProperty("setOriginalRoles") Set setOriginalRoles, @JsonProperty("setSessionProperties") Map setSessionProperties, @JsonProperty("resetSessionProperties") Set resetSessionProperties, @JsonProperty("setRoles") Map setRoles, @@ -134,6 +136,7 @@ public QueryInfo( requireNonNull(setSchema, "setSchema is null"); requireNonNull(setPath, "setPath is null"); requireNonNull(setAuthorizationUser, "setAuthorizationUser is null"); + requireNonNull(setOriginalRoles, "setOriginalRoles is null"); requireNonNull(setSessionProperties, "setSessionProperties is null"); requireNonNull(resetSessionProperties, "resetSessionProperties is null"); requireNonNull(addedPreparedStatements, "addedPreparedStatements is null"); @@ -165,6 +168,7 @@ public QueryInfo( this.setPath = setPath; this.setAuthorizationUser = setAuthorizationUser; this.resetAuthorizationUser = resetAuthorizationUser; + this.setOriginalRoles = setOriginalRoles; this.setSessionProperties = ImmutableMap.copyOf(setSessionProperties); this.resetSessionProperties = ImmutableSet.copyOf(resetSessionProperties); this.setRoles = ImmutableMap.copyOf(setRoles); @@ -287,6 +291,12 @@ public boolean isResetAuthorizationUser() return resetAuthorizationUser; } + @JsonProperty + public Set getSetOriginalRoles() + { + return setOriginalRoles; + } + @JsonProperty public Map getSetSessionProperties() { @@ -455,6 +465,7 @@ public QueryInfo pruneDigests() setPath, setAuthorizationUser, resetAuthorizationUser, + setOriginalRoles, setSessionProperties, resetSessionProperties, setRoles, diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java b/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java index 0e23c80b5ca3..2bea048b787f 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java @@ -156,6 +156,7 @@ public class QueryStateMachine private final AtomicReference setAuthorizationUser = new AtomicReference<>(); private final AtomicBoolean resetAuthorizationUser = new AtomicBoolean(); + private final Set setOriginalRoles = Sets.newConcurrentHashSet(); private final Map setSessionProperties = new ConcurrentHashMap<>(); private final Set resetSessionProperties = Sets.newConcurrentHashSet(); @@ -525,6 +526,7 @@ ResultQueryInfo getResultQueryInfo(Optional stageInfo) Optional.ofNullable(setPath.get()), Optional.ofNullable(setAuthorizationUser.get()), resetAuthorizationUser.get(), + setOriginalRoles, setSessionProperties, resetSessionProperties, setRoles, @@ -618,6 +620,7 @@ QueryInfo getQueryInfo(Optional rootStage) Optional.ofNullable(setPath.get()), Optional.ofNullable(setAuthorizationUser.get()), resetAuthorizationUser.get(), + setOriginalRoles, setSessionProperties, resetSessionProperties, setRoles, @@ -1024,6 +1027,11 @@ public void resetAuthorizationUser() resetAuthorizationUser.set(true); } + public void addSetOriginalRoles(SelectedRole role) + { + setOriginalRoles.add(requireNonNull(role, "role is null")); + } + public void addSetSessionProperties(String key, String value) { setSessionProperties.put(requireNonNull(key, "key is null"), requireNonNull(value, "value is null")); @@ -1417,6 +1425,7 @@ public static QueryInfo pruneQueryInfo(QueryInfo queryInfo, NodeVersion version) queryInfo.getSetPath(), queryInfo.getSetAuthorizationUser(), queryInfo.isResetAuthorizationUser(), + queryInfo.getSetOriginalRoles(), queryInfo.getSetSessionProperties(), queryInfo.getResetSessionProperties(), queryInfo.getSetRoles(), diff --git a/core/trino-main/src/main/java/io/trino/execution/SetSessionAuthorizationTask.java b/core/trino-main/src/main/java/io/trino/execution/SetSessionAuthorizationTask.java index b351549a23ee..20857add7c91 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SetSessionAuthorizationTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/SetSessionAuthorizationTask.java @@ -21,6 +21,7 @@ import io.trino.security.AccessControl; import io.trino.spi.TrinoException; import io.trino.spi.security.Identity; +import io.trino.spi.security.SelectedRole; import io.trino.sql.tree.Expression; import io.trino.sql.tree.Identifier; import io.trino.sql.tree.SetSessionAuthorization; @@ -28,6 +29,8 @@ import io.trino.transaction.TransactionManager; import java.util.List; +import java.util.Optional; +import java.util.Set; import static com.google.common.base.Preconditions.checkState; import static com.google.common.util.concurrent.Futures.immediateFuture; @@ -91,6 +94,18 @@ else if (userExpression instanceof StringLiteral stringLiteral) { accessControl.checkCanImpersonateUser(originalIdentity, user); } stateMachine.setSetAuthorizationUser(user); + SelectedRole selectedRole; + Set enabledRoles = originalIdentity.getEnabledRoles(); + if (enabledRoles.isEmpty()) { + selectedRole = new SelectedRole(SelectedRole.Type.NONE, Optional.empty()); + } + else if (enabledRoles.size() == 1) { + selectedRole = new SelectedRole(SelectedRole.Type.ROLE, Optional.of(enabledRoles.iterator().next())); + } + else { + selectedRole = new SelectedRole(SelectedRole.Type.ALL, Optional.empty()); + } + stateMachine.addSetOriginalRoles(selectedRole); return immediateFuture(null); } } diff --git a/core/trino-main/src/main/java/io/trino/server/HttpRequestSessionContextFactory.java b/core/trino-main/src/main/java/io/trino/server/HttpRequestSessionContextFactory.java index 603f90303598..bbe0855f5d9c 100644 --- a/core/trino-main/src/main/java/io/trino/server/HttpRequestSessionContextFactory.java +++ b/core/trino-main/src/main/java/io/trino/server/HttpRequestSessionContextFactory.java @@ -254,18 +254,14 @@ private Identity buildSessionIdentity(Optional authenticatedIdentity, String user = trinoUser != null ? trinoUser : authenticatedIdentity.map(Identity::getUser).orElse(null); assertRequest(user != null, "User must be set"); SelectedRole systemRole = parseSystemRoleHeaders(protocolHeaders, headers); - ImmutableSet.Builder systemEnabledRoles = ImmutableSet.builder(); - if (systemRole.getType() == Type.ROLE) { - systemEnabledRoles.add(systemRole.getRole().orElseThrow()); - } - return authenticatedIdentity + Identity newIdentity = authenticatedIdentity .map(identity -> Identity.from(identity).withUser(user)) .orElseGet(() -> Identity.forUser(user)) - .withEnabledRoles(systemEnabledRoles.build()) .withAdditionalConnectorRoles(parseConnectorRoleHeaders(protocolHeaders, headers)) .withAdditionalExtraCredentials(parseExtraCredentials(protocolHeaders, headers)) .withAdditionalGroups(groupProvider.getGroups(user)) .build(); + return addEnabledRoles(newIdentity, systemRole, metadata); } private Identity buildSessionOriginalIdentity(Identity identity, ProtocolHeaders protocolHeaders, MultivaluedMap headers) @@ -273,12 +269,18 @@ private Identity buildSessionOriginalIdentity(Identity identity, ProtocolHeaders // We derive original identity using this header, but older clients will not send it, so fall back to identity Optional optionalOriginalUser = Optional .ofNullable(trimEmptyToNull(headers.getFirst(protocolHeaders.requestOriginalUser()))); - Identity originalIdentity = optionalOriginalUser.map(originalUser -> Identity.from(identity) - .withUser(originalUser) - .withExtraCredentials(new HashMap<>()) - .withGroups(groupProvider.getGroups(originalUser)) - .build()) - .orElse(identity); + Optional originalRoles = Optional.ofNullable(trimEmptyToNull(headers.getFirst(protocolHeaders.requestOriginalRole()))); + Identity originalIdentity = optionalOriginalUser.map(originalUser -> { + Identity newIdentity = Identity.from(identity) + .withUser(originalUser) + .withExtraCredentials(new HashMap<>()) + .withGroups(groupProvider.getGroups(originalUser)) + .build(); + if (originalRoles.isPresent()) { + newIdentity = addEnabledRoles(newIdentity, SelectedRole.valueOf(originalRoles.get()), metadata); + } + return newIdentity; + }).orElse(identity); return originalIdentity; } diff --git a/core/trino-main/src/main/java/io/trino/server/ResultQueryInfo.java b/core/trino-main/src/main/java/io/trino/server/ResultQueryInfo.java index 3cdb9dfa6184..059a43a3020c 100644 --- a/core/trino-main/src/main/java/io/trino/server/ResultQueryInfo.java +++ b/core/trino-main/src/main/java/io/trino/server/ResultQueryInfo.java @@ -63,6 +63,8 @@ public record ResultQueryInfo( @JsonProperty boolean resetAuthorizationUser, @JsonProperty + Set setOriginalRoles, + @JsonProperty Map setSessionProperties, @JsonProperty Set resetSessionProperties, @@ -95,6 +97,7 @@ public ResultQueryInfo( @JsonProperty("setPath") Optional setPath, @JsonProperty("setAuthorizationUser") Optional setAuthorizationUser, @JsonProperty("resetAuthorizationUser") boolean resetAuthorizationUser, + @JsonProperty("setOriginalRoles") Set setOriginalRoles, @JsonProperty("setSessionProperties") Map setSessionProperties, @JsonProperty("resetSessionProperties") Set resetSessionProperties, @JsonProperty("setRoles") Map setRoles, @@ -118,6 +121,7 @@ public ResultQueryInfo( this.setPath = requireNonNull(setPath, "setPath is null"); this.setAuthorizationUser = requireNonNull(setAuthorizationUser, "setAuthorizationUser is null"); this.resetAuthorizationUser = resetAuthorizationUser; + this.setOriginalRoles = requireNonNull(setOriginalRoles, "setOriginalRoles is null"); this.setSessionProperties = requireNonNull(setSessionProperties, "setSessionProperties is null"); this.resetSessionProperties = requireNonNull(resetSessionProperties, "resetSessionProperties is null"); this.addedPreparedStatements = requireNonNull(addedPreparedStatements, "addedPreparedStatements is null"); @@ -144,6 +148,7 @@ public ResultQueryInfo(QueryInfo queryInfo) queryInfo.getSetPath(), queryInfo.getSetAuthorizationUser(), queryInfo.isResetAuthorizationUser(), + queryInfo.getSetOriginalRoles(), queryInfo.getSetSessionProperties(), queryInfo.getResetSessionProperties(), queryInfo.getSetRoles(), diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/ExecutingStatementResource.java b/core/trino-main/src/main/java/io/trino/server/protocol/ExecutingStatementResource.java index 936751525f4a..e956038c31d9 100644 --- a/core/trino-main/src/main/java/io/trino/server/protocol/ExecutingStatementResource.java +++ b/core/trino-main/src/main/java/io/trino/server/protocol/ExecutingStatementResource.java @@ -224,6 +224,10 @@ private Response toResponse(QueryResultsResponse resultsResponse, Optional response.header(protocolHeaders.responseOriginalRole(), name)); + // add set session properties resultsResponse.setSessionProperties() .forEach((key, value) -> response.header(protocolHeaders.responseSetSession(), key + '=' + urlEncode(value))); diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/Query.java b/core/trino-main/src/main/java/io/trino/server/protocol/Query.java index c09d52ff98ae..0b0dbcfc8bd1 100644 --- a/core/trino-main/src/main/java/io/trino/server/protocol/Query.java +++ b/core/trino-main/src/main/java/io/trino/server/protocol/Query.java @@ -160,6 +160,9 @@ class Query @GuardedBy("this") private Optional setAuthorizationUser = Optional.empty(); + @GuardedBy("this") + private Set setOriginalRoles = ImmutableSet.of(); + @GuardedBy("this") private boolean resetAuthorizationUser; @@ -505,6 +508,9 @@ private synchronized QueryResultsResponse getNextResult(long token, ExternalUriI setAuthorizationUser = queryInfo.setAuthorizationUser(); resetAuthorizationUser = queryInfo.resetAuthorizationUser(); + // update setOriginalRoles + setOriginalRoles = queryInfo.setOriginalRoles(); + // update setSessionProperties setSessionProperties = queryInfo.setSessionProperties(); resetSessionProperties = queryInfo.resetSessionProperties(); @@ -549,6 +555,7 @@ private synchronized QueryResultsResponse toResultsResponse(QueryResults queryRe setPath, setAuthorizationUser, resetAuthorizationUser, + setOriginalRoles, setSessionProperties, resetSessionProperties, setRoles, diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/QueryResultsResponse.java b/core/trino-main/src/main/java/io/trino/server/protocol/QueryResultsResponse.java index 4387fd9f405b..f1a74b09a4d2 100644 --- a/core/trino-main/src/main/java/io/trino/server/protocol/QueryResultsResponse.java +++ b/core/trino-main/src/main/java/io/trino/server/protocol/QueryResultsResponse.java @@ -30,6 +30,7 @@ record QueryResultsResponse( Optional setPath, Optional setAuthorizationUser, boolean resetAuthorizationUser, + Set setOriginalRoles, Map setSessionProperties, Set resetSessionProperties, Map setRoles, @@ -45,6 +46,7 @@ record QueryResultsResponse( requireNonNull(setSchema, "setSchema is null"); requireNonNull(setPath, "setPath is null"); requireNonNull(setAuthorizationUser, "setAuthorizationUser is null"); + requireNonNull(setOriginalRoles, "setOriginalRoles is null"); requireNonNull(setSessionProperties, "setSessionProperties is null"); requireNonNull(resetSessionProperties, "resetSessionProperties is null"); requireNonNull(setRoles, "setRoles is null"); diff --git a/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java b/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java index 72697b6b8ff8..85775ebf8513 100644 --- a/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java +++ b/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java @@ -144,6 +144,7 @@ public class TestingAccessControlManager private Predicate deniedTables = s -> true; private BiPredicate denyIdentityTable = IDENTITY_TABLE_TRUE; private BiPredicate denyIdentityFunction = IDENTITY_FUNCTION_TRUE; + private BiPredicate denyImpersonationFunction = IDENTITY_FUNCTION_TRUE; @Inject public TestingAccessControlManager( @@ -196,6 +197,7 @@ public void reset() denyIdentityTable = IDENTITY_TABLE_TRUE; rowFilters.clear(); columnMasks.clear(); + denyImpersonationFunction = IDENTITY_FUNCTION_TRUE; } public void denyCatalogs(Predicate deniedCatalogs) @@ -223,6 +225,11 @@ public void denyIdentityFunction(BiPredicate denyIdentityFunct this.denyIdentityFunction = requireNonNull(denyIdentityFunction, "denyIdentityFunction is null"); } + public void denyImpersonation(BiPredicate denyImpersonationFunction) + { + this.denyImpersonationFunction = requireNonNull(denyImpersonationFunction, "denyImpersonationFunction is null"); + } + @Override public Set filterCatalogs(SecurityContext securityContext, Set catalogs) { @@ -258,6 +265,9 @@ public Set filterTables(SecurityContext context, String catalog @Override public void checkCanImpersonateUser(Identity identity, String userName) { + if (!denyImpersonationFunction.test(identity, userName)) { + denyImpersonateUser(identity.getUser(), userName); + } if (shouldDenyPrivilege(userName, userName, IMPERSONATE_USER)) { denyImpersonateUser(identity.getUser(), userName); } diff --git a/core/trino-main/src/test/java/io/trino/execution/MockManagedQueryExecution.java b/core/trino-main/src/test/java/io/trino/execution/MockManagedQueryExecution.java index 4029f63085bc..30f342a0dcd8 100644 --- a/core/trino-main/src/test/java/io/trino/execution/MockManagedQueryExecution.java +++ b/core/trino-main/src/test/java/io/trino/execution/MockManagedQueryExecution.java @@ -268,6 +268,7 @@ public QueryInfo getFullQueryInfo() Optional.empty(), Optional.empty(), false, + ImmutableSet.of(), ImmutableMap.of(), ImmutableSet.of(), ImmutableMap.of(), diff --git a/core/trino-main/src/test/java/io/trino/execution/TestQueryInfo.java b/core/trino-main/src/test/java/io/trino/execution/TestQueryInfo.java index eb53c6e26a72..6e879db46d6d 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestQueryInfo.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestQueryInfo.java @@ -225,6 +225,7 @@ private static QueryInfo createQueryInfo(Optional stageInfo) Optional.of("set_path"), Optional.of("set_authorization_user"), false, + ImmutableSet.of(new SelectedRole(SelectedRole.Type.ROLE, Optional.of("original_role"))), ImmutableMap.of("set_property", "set_value"), ImmutableSet.of("reset_property"), ImmutableMap.of("set_roles", new SelectedRole(SelectedRole.Type.ROLE, Optional.of("role"))), diff --git a/core/trino-main/src/test/java/io/trino/metadata/TestMetadataManager.java b/core/trino-main/src/test/java/io/trino/metadata/TestMetadataManager.java index 82daab9362dd..44eb9a01a5f2 100644 --- a/core/trino-main/src/test/java/io/trino/metadata/TestMetadataManager.java +++ b/core/trino-main/src/test/java/io/trino/metadata/TestMetadataManager.java @@ -18,6 +18,7 @@ import io.trino.FeaturesConfig; import io.trino.security.AllowAllAccessControl; import io.trino.spi.block.BlockEncodingSerde; +import io.trino.spi.security.Identity; import io.trino.spi.type.TypeManager; import io.trino.spi.type.TypeOperators; import io.trino.sql.parser.SqlParser; @@ -25,6 +26,8 @@ import io.trino.transaction.TransactionManager; import io.trino.type.BlockTypeOperators; +import java.util.Set; + import static io.trino.client.NodeVersion.UNKNOWN; import static io.trino.transaction.InMemoryTransactionManager.createTestTransactionManager; import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; @@ -104,7 +107,7 @@ public MetadataManager build() return new MetadataManager( new AllowAllAccessControl(), - new DisabledSystemSecurityMetadata(), + new SecurityMetadata(), transactionManager, globalFunctionCatalog, languageFunctionManager, @@ -113,4 +116,14 @@ public MetadataManager build() new NotImplementedQueryManager()); } } + + private static class SecurityMetadata + extends DisabledSystemSecurityMetadata + { + @Override + public Set listEnabledRoles(Identity identity) + { + return ImmutableSet.of("system-role"); + } + } } diff --git a/core/trino-main/src/test/java/io/trino/server/TestBasicQueryInfo.java b/core/trino-main/src/test/java/io/trino/server/TestBasicQueryInfo.java index f543cdaac64b..063200e96818 100644 --- a/core/trino-main/src/test/java/io/trino/server/TestBasicQueryInfo.java +++ b/core/trino-main/src/test/java/io/trino/server/TestBasicQueryInfo.java @@ -144,6 +144,7 @@ public void testConstructor() Optional.empty(), Optional.empty(), false, + ImmutableSet.of(), ImmutableMap.of(), ImmutableSet.of(), ImmutableMap.of(), diff --git a/core/trino-main/src/test/java/io/trino/server/TestHttpRequestSessionContextFactory.java b/core/trino-main/src/test/java/io/trino/server/TestHttpRequestSessionContextFactory.java index ec7b01e6b02e..fddcaa755db0 100644 --- a/core/trino-main/src/test/java/io/trino/server/TestHttpRequestSessionContextFactory.java +++ b/core/trino-main/src/test/java/io/trino/server/TestHttpRequestSessionContextFactory.java @@ -122,19 +122,31 @@ private static void assertMappedUser(ProtocolHeaders protocolHeaders) userHeaders, Optional.of("testRemote"), Optional.empty()); - assertThat(context.getIdentity()).isEqualTo(Identity.forUser("testUser").withGroups(ImmutableSet.of("testUser")).build()); + assertThat(context.getIdentity()) + .isEqualTo(Identity.forUser("testUser") + .withGroups(ImmutableSet.of("testUser")) + .withEnabledRoles(ImmutableSet.of("system-role")) + .build()); context = sessionContextFactory(protocolHeaders).createSessionContext( emptyHeaders, Optional.of("testRemote"), Optional.of(Identity.forUser("mappedUser").withGroups(ImmutableSet.of("test")).build())); - assertThat(context.getIdentity()).isEqualTo(Identity.forUser("mappedUser").withGroups(ImmutableSet.of("test", "mappedUser")).build()); + assertThat(context.getIdentity()) + .isEqualTo(Identity.forUser("mappedUser") + .withGroups(ImmutableSet.of("test", "mappedUser")) + .withEnabledRoles(ImmutableSet.of("system-role")) + .build()); context = sessionContextFactory(protocolHeaders).createSessionContext( userHeaders, Optional.of("testRemote"), Optional.of(Identity.ofUser("mappedUser"))); - assertThat(context.getIdentity()).isEqualTo(Identity.forUser("testUser").withGroups(ImmutableSet.of("testUser")).build()); + assertThat(context.getIdentity()) + .isEqualTo(Identity.forUser("testUser") + .withGroups(ImmutableSet.of("testUser")) + .withEnabledRoles(ImmutableSet.of("system-role")) + .build()); assertInvalidSession(protocolHeaders, emptyHeaders) .matches(e -> ((WebApplicationException) e).getResponse().getStatus() == 400); diff --git a/core/trino-main/src/test/java/io/trino/server/TestQueryStateInfo.java b/core/trino-main/src/test/java/io/trino/server/TestQueryStateInfo.java index 9e7ec6e246d7..e218e762314d 100644 --- a/core/trino-main/src/test/java/io/trino/server/TestQueryStateInfo.java +++ b/core/trino-main/src/test/java/io/trino/server/TestQueryStateInfo.java @@ -189,6 +189,7 @@ private QueryInfo createQueryInfo(String queryId, QueryState state, String query Optional.empty(), Optional.empty(), false, + ImmutableSet.of(), ImmutableMap.of(), ImmutableSet.of(), ImmutableMap.of(), diff --git a/core/trino-spi/src/main/java/io/trino/spi/eventlistener/QueryContext.java b/core/trino-spi/src/main/java/io/trino/spi/eventlistener/QueryContext.java index a14094f5a9ce..47f917f51c5f 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/eventlistener/QueryContext.java +++ b/core/trino-spi/src/main/java/io/trino/spi/eventlistener/QueryContext.java @@ -33,6 +33,7 @@ public class QueryContext { private final String user; private final String originalUser; + private final Set originalRoles; private final Optional principal; private final Set enabledRoles; private final Set groups; @@ -66,6 +67,7 @@ public class QueryContext public QueryContext( String user, String originalUser, + Set originalRoles, Optional principal, Set enabledRoles, Set groups, @@ -90,6 +92,7 @@ public QueryContext( { this.user = requireNonNull(user, "user is null"); this.originalUser = requireNonNull(originalUser, "originalUser is null"); + this.originalRoles = requireNonNull(originalRoles, "originalRoles is null"); this.principal = requireNonNull(principal, "principal is null"); this.enabledRoles = requireNonNull(enabledRoles, "enabledRoles is null"); this.groups = requireNonNull(groups, "groups is null"); @@ -125,6 +128,12 @@ public String getOriginalUser() return originalUser; } + @JsonProperty + public Set getOriginalRoles() + { + return originalRoles; + } + @JsonProperty public Optional getPrincipal() { diff --git a/docs/src/main/sphinx/develop/client-protocol.md b/docs/src/main/sphinx/develop/client-protocol.md index b7d353ef7d9f..d33fdf8e4b68 100644 --- a/docs/src/main/sphinx/develop/client-protocol.md +++ b/docs/src/main/sphinx/develop/client-protocol.md @@ -228,6 +228,9 @@ subsequent requests to be consistent with the response headers received. - Instructs the client to reset `X-Trino-User` request header to its original value in subsequent client requests and remove `X-Trino-Original-User` to reset the authorization user back to the original user. +* - `X-Trino-Set-Original-Roles` + - Instructs the client to set the roles of the original user in the + `X-Trino-Original-Roles` request header in subsequent client requests. * - `X-Trino-Set-Session` - The value of the `X-Trino-Set-Session` response header is a string of the form *property* = *value*. It instructs the client include session property diff --git a/plugin/trino-http-event-listener/src/test/java/io/trino/plugin/httpquery/TestHttpEventListener.java b/plugin/trino-http-event-listener/src/test/java/io/trino/plugin/httpquery/TestHttpEventListener.java index 8d5fe476e40f..b55952febe27 100644 --- a/plugin/trino-http-event-listener/src/test/java/io/trino/plugin/httpquery/TestHttpEventListener.java +++ b/plugin/trino-http-event-listener/src/test/java/io/trino/plugin/httpquery/TestHttpEventListener.java @@ -106,6 +106,7 @@ final class TestHttpEventListener queryContext = new QueryContext( "user", "originalUser", + Set.of(), Optional.of("principal"), Set.of(), // enabledRoles Set.of(), // groups diff --git a/plugin/trino-http-server-event-listener/src/test/java/io/trino/plugin/httpquery/TestHttpServerEventListener.java b/plugin/trino-http-server-event-listener/src/test/java/io/trino/plugin/httpquery/TestHttpServerEventListener.java index 8f2defd9c343..3a6668bbbdc4 100644 --- a/plugin/trino-http-server-event-listener/src/test/java/io/trino/plugin/httpquery/TestHttpServerEventListener.java +++ b/plugin/trino-http-server-event-listener/src/test/java/io/trino/plugin/httpquery/TestHttpServerEventListener.java @@ -72,6 +72,7 @@ final class TestHttpServerEventListener QueryContext queryContext = new QueryContext( "user", "originalUser", + Set.of(), Optional.of("principal"), Set.of(), // enabledRoles Set.of(), // groups diff --git a/plugin/trino-kafka-event-listener/src/test/java/io/trino/plugin/eventlistener/kafka/TestUtils.java b/plugin/trino-kafka-event-listener/src/test/java/io/trino/plugin/eventlistener/kafka/TestUtils.java index e9bfa85475f4..e5b5ef4d715d 100644 --- a/plugin/trino-kafka-event-listener/src/test/java/io/trino/plugin/eventlistener/kafka/TestUtils.java +++ b/plugin/trino-kafka-event-listener/src/test/java/io/trino/plugin/eventlistener/kafka/TestUtils.java @@ -68,6 +68,7 @@ private TestUtils() {} queryContext = new QueryContext( "user", "originalUser", + Set.of(), Optional.of("principal"), Set.of(), Set.of(), // groups diff --git a/plugin/trino-mysql-event-listener/src/test/java/io/trino/plugin/eventlistener/mysql/TestMysqlEventListener.java b/plugin/trino-mysql-event-listener/src/test/java/io/trino/plugin/eventlistener/mysql/TestMysqlEventListener.java index e5a06e514d25..54a1e5c53319 100644 --- a/plugin/trino-mysql-event-listener/src/test/java/io/trino/plugin/eventlistener/mysql/TestMysqlEventListener.java +++ b/plugin/trino-mysql-event-listener/src/test/java/io/trino/plugin/eventlistener/mysql/TestMysqlEventListener.java @@ -144,6 +144,7 @@ final class TestMysqlEventListener private static final QueryContext FULL_QUERY_CONTEXT = new QueryContext( "user", "originalUser", + Set.of("role1"), Optional.of("principal"), Set.of("role1", "role2"), Set.of("group1", "group2"), @@ -305,6 +306,7 @@ final class TestMysqlEventListener private static final QueryContext MINIMAL_QUERY_CONTEXT = new QueryContext( "user", "originalUser", + Set.of(), Optional.empty(), Set.of(), Set.of(), diff --git a/plugin/trino-openlineage/src/test/java/io/trino/plugin/openlineage/TrinoEventData.java b/plugin/trino-openlineage/src/test/java/io/trino/plugin/openlineage/TrinoEventData.java index 5a6adad0f390..bfcf0bfa432e 100644 --- a/plugin/trino-openlineage/src/test/java/io/trino/plugin/openlineage/TrinoEventData.java +++ b/plugin/trino-openlineage/src/test/java/io/trino/plugin/openlineage/TrinoEventData.java @@ -58,6 +58,7 @@ private TrinoEventData() queryContext = new QueryContext( "user", "originalUser", + Set.of(), Optional.of("principal"), Set.of(), // enabledRoles Set.of(), // groups diff --git a/pom.xml b/pom.xml index 1b1151d37bd2..170d8c8ba286 100644 --- a/pom.xml +++ b/pom.xml @@ -1171,6 +1171,13 @@ ${project.version} + + io.trino + trino-jdbc + ${project.version} + test-jar + + io.trino trino-jmx diff --git a/testing/trino-tests/pom.xml b/testing/trino-tests/pom.xml index 52c3510ba84a..53c097a6d35f 100644 --- a/testing/trino-tests/pom.xml +++ b/testing/trino-tests/pom.xml @@ -201,6 +201,19 @@ test + + io.trino + trino-jdbc + test + + + + io.trino + trino-jdbc + test-jar + test + + io.trino trino-main diff --git a/testing/trino-tests/src/test/java/io/trino/security/TestImpersonation.java b/testing/trino-tests/src/test/java/io/trino/security/TestImpersonation.java new file mode 100644 index 000000000000..3366eed50ce8 --- /dev/null +++ b/testing/trino-tests/src/test/java/io/trino/security/TestImpersonation.java @@ -0,0 +1,160 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.security; + +import com.google.common.collect.ImmutableSet; +import io.airlift.log.Logging; +import io.trino.jdbc.TrinoConnection; +import io.trino.metadata.SystemSecurityMetadata; +import io.trino.plugin.memory.MemoryPlugin; +import io.trino.server.testing.TestingTrinoServer; +import io.trino.spi.security.TrinoPrincipal; +import io.trino.testing.TestingAccessControlManager; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.parallel.Execution; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Optional; +import java.util.stream.Stream; + +import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; +import static io.trino.jdbc.BaseTrinoDriverTest.getCurrentUser; +import static io.trino.spi.security.PrincipalType.USER; +import static java.lang.String.format; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD; + +@TestInstance(PER_CLASS) +@Execution(SAME_THREAD) +public class TestImpersonation +{ + private TestingTrinoServer server; + private final TestingSystemSecurityMetadata securityMetadata = new TestingSystemSecurityMetadata(); + private TestingAccessControlManager accessControl; + + @BeforeAll + public void setup() + throws Exception + { + Logging.initialize(); + server = TestingTrinoServer.builder() + .setAdditionalModule(binder -> { + newOptionalBinder(binder, SystemSecurityMetadata.class) + .setBinding() + .toInstance(securityMetadata); + }).build(); + server.installPlugin(new MemoryPlugin()); + server.createCatalog("memory", "memory"); + accessControl = server.getAccessControl(); + } + + @ParameterizedTest + @MethodSource("roles") + @Timeout(10) + public void testImpersonationAllowedByRole(String roleName) + throws Exception + { + securityMetadata.reset(); + accessControl.reset(); + try (TrinoConnection connection = createConnection("memory", "default", "alice").unwrap(TrinoConnection.class); + Statement statement = connection.createStatement()) { + assertThat(getCurrentUser(connection)).isEqualTo("alice"); + securityMetadata.createRole(null, "invalid_role", Optional.empty()); + securityMetadata.grantRoles( + null, + ImmutableSet.of("invalid_role"), + ImmutableSet.of(new TrinoPrincipal(USER, "alice")), + false, + Optional.empty()); + denyImpersonation(); + statement.execute("SET ROLE invalid_role"); + assertThatThrownBy(() -> statement.execute("SET SESSION AUTHORIZATION john")) + .hasMessageContaining("User alice cannot impersonate user john"); + + securityMetadata.createRole(null, "alice_role", Optional.empty()); + securityMetadata.grantRoles( + null, + ImmutableSet.of("alice_role"), + ImmutableSet.of(new TrinoPrincipal(USER, "alice")), + false, + Optional.empty()); + assertThatThrownBy(() -> statement.execute("SET SESSION AUTHORIZATION john")) + .hasMessageContaining("User alice cannot impersonate user john"); + + statement.execute("SET ROLE " + roleName); + statement.execute("SET SESSION AUTHORIZATION john"); + + // This would fail if roles were not correctly propagated + statement.execute("SHOW SCHEMAS IN memory"); + // Call more than once to make sure everything is propagated correctly + // to subsequent calls + statement.execute("SHOW SCHEMAS IN memory"); + } + } + + @Test + @Timeout(10) + public void testImpersonationDisallowedWhenRoleIsNone() + throws Exception + { + securityMetadata.reset(); + accessControl.reset(); + try (TrinoConnection connection = createConnection("memory", "default", "alice").unwrap(TrinoConnection.class); + Statement statement = connection.createStatement()) { + assertThat(getCurrentUser(connection)).isEqualTo("alice"); + securityMetadata.createRole(null, "alice_role", Optional.empty()); + denyImpersonation(); + securityMetadata.grantRoles( + null, + ImmutableSet.of("alice_role"), + ImmutableSet.of(new TrinoPrincipal(USER, "alice")), + false, + Optional.empty()); + statement.execute("SET ROLE NONE"); + + assertThatThrownBy(() -> statement.execute("SET SESSION AUTHORIZATION john")) + .hasMessageContaining("User alice cannot impersonate user john"); + } + } + + private Connection createConnection(String catalog, String schema, String user) + throws SQLException + { + String url = format("jdbc:trino://%s/%s/%s", server.getAddress(), catalog, schema); + return DriverManager.getConnection(url, user, null); + } + + private Stream roles() + { + return Stream.of("alice_role", "ALL"); + } + + private void denyImpersonation() + { + accessControl.denyImpersonation((identity, _) -> + identity.getEnabledRoles() + .stream() + .anyMatch(role -> role.equalsIgnoreCase("alice_role"))); + } +} diff --git a/testing/trino-tests/src/test/java/io/trino/security/TestSystemSecurityMetadata.java b/testing/trino-tests/src/test/java/io/trino/security/TestSystemSecurityMetadata.java index 4dadcbf64af0..249ee4ffccf3 100644 --- a/testing/trino-tests/src/test/java/io/trino/security/TestSystemSecurityMetadata.java +++ b/testing/trino-tests/src/test/java/io/trino/security/TestSystemSecurityMetadata.java @@ -25,6 +25,7 @@ import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; import static io.trino.testing.TestingSession.testSessionBuilder; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD; @Execution(SAME_THREAD) // TestingSystemSecurityMetadata is shared mutable state @@ -96,10 +97,14 @@ public void testRoleGrant() assertQueryReturnsEmptyResult(alice, "SHOW CURRENT ROLES"); assertQueryReturnsEmptyResult(alice, "SHOW ROLE GRANTS"); assertQueryReturnsEmptyResult(alice, "SELECT * FROM system.information_schema.applicable_roles"); - assertQueryFails(aliceWithRole, "SHOW ROLES", "Access Denied: Cannot set role role1"); - assertQueryFails(aliceWithRole, "SHOW CURRENT ROLES", "Access Denied: Cannot set role role1"); - assertQueryFails(aliceWithRole, "SHOW ROLE GRANTS", "Access Denied: Cannot set role role1"); - assertQueryFails(aliceWithRole, "SELECT * FROM system.information_schema.applicable_roles", "Access Denied: Cannot set role role1"); + assertThatThrownBy(() -> getQueryRunner().execute(aliceWithRole, "SHOW ROLES")) + .hasMessageContaining("Access Denied: Cannot set role role1"); + assertThatThrownBy(() -> getQueryRunner().execute(aliceWithRole, "SHOW CURRENT ROLES")) + .hasMessageContaining("Access Denied: Cannot set role role1"); + assertThatThrownBy(() -> getQueryRunner().execute(aliceWithRole, "SHOW ROLE GRANTS")) + .hasMessageContaining("Access Denied: Cannot set role role1"); + assertThatThrownBy(() -> getQueryRunner().execute(aliceWithRole, "SELECT * FROM system.information_schema.applicable_roles")) + .hasMessageContaining("Access Denied: Cannot set role role1"); assertQuerySucceeds("GRANT role1 TO USER alice"); assertQuerySucceeds(alice, "SET ROLE role1"); @@ -117,10 +122,14 @@ public void testRoleGrant() assertQueryReturnsEmptyResult(alice, "SHOW CURRENT ROLES"); assertQueryReturnsEmptyResult(alice, "SHOW ROLE GRANTS"); assertQueryReturnsEmptyResult(alice, "SELECT * FROM system.information_schema.applicable_roles"); - assertQueryFails(aliceWithRole, "SHOW ROLES", "Access Denied: Cannot set role role1"); - assertQueryFails(aliceWithRole, "SHOW CURRENT ROLES", "Access Denied: Cannot set role role1"); - assertQueryFails(aliceWithRole, "SHOW ROLE GRANTS", "Access Denied: Cannot set role role1"); - assertQueryFails(aliceWithRole, "SELECT * FROM system.information_schema.applicable_roles", "Access Denied: Cannot set role role1"); + assertThatThrownBy(() -> getQueryRunner().execute(aliceWithRole, "SHOW ROLES")) + .hasMessageContaining("Access Denied: Cannot set role role1"); + assertThatThrownBy(() -> getQueryRunner().execute(aliceWithRole, "SHOW CURRENT ROLES")) + .hasMessageContaining("Access Denied: Cannot set role role1"); + assertThatThrownBy(() -> getQueryRunner().execute(aliceWithRole, "SHOW ROLE GRANTS")) + .hasMessageContaining("Access Denied: Cannot set role role1"); + assertThatThrownBy(() -> getQueryRunner().execute(aliceWithRole, "SELECT * FROM system.information_schema.applicable_roles")) + .hasMessageContaining("Access Denied: Cannot set role role1"); assertQuerySucceeds("DROP ROLE role1"); } @@ -138,7 +147,8 @@ public void testTransitiveRoleGrant() assertQuerySucceeds("GRANT role1 TO USER alice"); String roleNotApplicableErrorMessage = "Access Denied: Cannot set role role2"; - assertQueryFails(aliceWithRole, "SHOW ROLES", roleNotApplicableErrorMessage); + assertThatThrownBy(() -> getQueryRunner().execute(aliceWithRole, "SHOW ROLES")) + .hasMessageContaining(roleNotApplicableErrorMessage); assertQuerySucceeds("GRANT role2 TO ROLE role1"); assertQuery(alice, "SHOW ROLES", "VALUES 'role1', 'role2'"); @@ -152,7 +162,8 @@ public void testTransitiveRoleGrant() + "('role1', 'ROLE', 'role2', 'NO')"); assertQuerySucceeds("REVOKE role2 FROM ROLE role1"); - assertQueryFails(aliceWithRole, "SHOW ROLES", roleNotApplicableErrorMessage); + assertThatThrownBy(() -> getQueryRunner().execute(aliceWithRole, "SHOW ROLES")) + .hasMessageContaining(roleNotApplicableErrorMessage); assertQuerySucceeds("REVOKE role1 FROM USER alice"); assertQuerySucceeds("DROP ROLE role1");