diff --git a/client/trino-cli/src/main/java/io/trino/cli/Console.java b/client/trino-cli/src/main/java/io/trino/cli/Console.java index 9d1a5782d9cd..e5e82c932ce7 100644 --- a/client/trino-cli/src/main/java/io/trino/cli/Console.java +++ b/client/trino-cli/src/main/java/io/trino/cli/Console.java @@ -45,6 +45,7 @@ import java.nio.file.Paths; import java.util.AbstractMap; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -399,6 +400,13 @@ private static boolean process( builder = builder.roles(ImmutableMap.of()); } + // update session originalRoles + if (!query.getSetOriginalRoles().isEmpty()) { + Set originalRoles = new HashSet<>(session.getOriginalRoles()); + originalRoles.addAll(query.getSetOriginalRoles()); + builder = builder.originalRoles(originalRoles); + } + if (query.isResetAuthorizationUser()) { builder = builder.authorizationUser(Optional.empty()); builder = builder.roles(ImmutableMap.of()); diff --git a/client/trino-cli/src/main/java/io/trino/cli/Query.java b/client/trino-cli/src/main/java/io/trino/cli/Query.java index 593352ed0b7e..b46dab2daaa5 100644 --- a/client/trino-cli/src/main/java/io/trino/cli/Query.java +++ b/client/trino-cli/src/main/java/io/trino/cli/Query.java @@ -100,6 +100,11 @@ public boolean isResetAuthorizationUser() return client.isResetAuthorizationUser(); } + public Set getSetOriginalRoles() + { + return client.getSetOriginalRoles(); + } + public Map getSetSessionProperties() { return client.getSetSessionProperties(); diff --git a/client/trino-client/src/main/java/io/trino/client/ClientSession.java b/client/trino-client/src/main/java/io/trino/client/ClientSession.java index 91a7a539b097..d15236d605c9 100644 --- a/client/trino-client/src/main/java/io/trino/client/ClientSession.java +++ b/client/trino-client/src/main/java/io/trino/client/ClientSession.java @@ -39,6 +39,7 @@ public class ClientSession private final Optional user; private final Optional sessionUser; private final Optional authorizationUser; + private final Set originalRoles; private final String source; private final Optional traceToken; private final Set clientTags; @@ -80,6 +81,7 @@ private ClientSession( Optional user, Optional sessionUser, Optional authorizationUser, + Set originalRoles, String source, Optional traceToken, Set clientTags, @@ -103,6 +105,7 @@ private ClientSession( this.user = requireNonNull(user, "user is null"); this.sessionUser = requireNonNull(sessionUser, "sessionUser is null"); this.authorizationUser = requireNonNull(authorizationUser, "authorizationUser is null"); + this.originalRoles = ImmutableSet.copyOf(requireNonNull(originalRoles, "originalRoles is null")); this.source = requireNonNull(source, "source is null"); this.traceToken = requireNonNull(traceToken, "traceToken is null"); this.clientTags = ImmutableSet.copyOf(requireNonNull(clientTags, "clientTags is null")); @@ -171,6 +174,11 @@ public Optional getAuthorizationUser() return authorizationUser; } + public Set getOriginalRoles() + { + return originalRoles; + } + public String getSource() { return source; @@ -302,6 +310,7 @@ public static final class Builder private Optional user = Optional.empty(); private Optional sessionUser = Optional.empty(); private Optional authorizationUser = Optional.empty(); + private Set originalRoles = ImmutableSet.of(); private String source; private Optional traceToken = Optional.empty(); private Set clientTags = ImmutableSet.of(); @@ -330,6 +339,7 @@ private Builder(ClientSession clientSession) user = clientSession.getUser(); sessionUser = clientSession.getSessionUser(); authorizationUser = clientSession.getAuthorizationUser(); + originalRoles = clientSession.getOriginalRoles(); source = clientSession.getSource(); traceToken = clientSession.getTraceToken(); clientTags = clientSession.getClientTags(); @@ -374,6 +384,12 @@ public Builder authorizationUser(Optional authorizationUser) return this; } + public Builder originalRoles(Set originalRoles) + { + this.originalRoles = originalRoles; + return this; + } + public Builder source(String source) { this.source = source; @@ -489,6 +505,7 @@ public ClientSession build() user, sessionUser, authorizationUser, + originalRoles, source, traceToken, clientTags, diff --git a/client/trino-client/src/main/java/io/trino/client/ProtocolHeaders.java b/client/trino-client/src/main/java/io/trino/client/ProtocolHeaders.java index 3ef75fc91e7e..69e19da4184f 100644 --- a/client/trino-client/src/main/java/io/trino/client/ProtocolHeaders.java +++ b/client/trino-client/src/main/java/io/trino/client/ProtocolHeaders.java @@ -27,6 +27,7 @@ public final class ProtocolHeaders private final String name; private final String requestUser; private final String requestOriginalUser; + private final String requestOriginalRole; private final String requestSource; private final String requestCatalog; private final String requestSchema; @@ -57,6 +58,7 @@ public final class ProtocolHeaders private final String responseClearTransactionId; private final String responseSetAuthorizationUser; private final String responseResetAuthorizationUser; + private final String responseOriginalRole; public static ProtocolHeaders createProtocolHeaders(String name) { @@ -75,6 +77,7 @@ private ProtocolHeaders(String name) String prefix = "X-" + name + "-"; requestUser = prefix + "User"; requestOriginalUser = prefix + "Original-User"; + requestOriginalRole = prefix + "Original-Roles"; requestSource = prefix + "Source"; requestCatalog = prefix + "Catalog"; requestSchema = prefix + "Schema"; @@ -105,6 +108,7 @@ private ProtocolHeaders(String name) responseClearTransactionId = prefix + "Clear-Transaction-Id"; responseSetAuthorizationUser = prefix + "Set-Authorization-User"; responseResetAuthorizationUser = prefix + "Reset-Authorization-User"; + responseOriginalRole = prefix + "Set-Original-Roles"; } public String getProtocolName() @@ -122,6 +126,11 @@ public String requestOriginalUser() return requestOriginalUser; } + public String requestOriginalRole() + { + return requestOriginalRole; + } + public String requestSource() { return requestSource; @@ -272,6 +281,11 @@ public String responseResetAuthorizationUser() return responseResetAuthorizationUser; } + public String responseOriginalRole() + { + return responseOriginalRole; + } + public static ProtocolHeaders detectProtocol(Optional alternateHeaderName, Set headerNames) throws ProtocolDetectionException { diff --git a/client/trino-client/src/main/java/io/trino/client/StatementClient.java b/client/trino-client/src/main/java/io/trino/client/StatementClient.java index 816196f9f0b1..4838e10c95fd 100644 --- a/client/trino-client/src/main/java/io/trino/client/StatementClient.java +++ b/client/trino-client/src/main/java/io/trino/client/StatementClient.java @@ -63,6 +63,8 @@ default Optional getEncoding() boolean isResetAuthorizationUser(); + Set getSetOriginalRoles(); + Map getSetSessionProperties(); Set getResetSessionProperties(); diff --git a/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java b/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java index 2c14330f36db..98d97d2b3bdc 100644 --- a/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java +++ b/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java @@ -88,6 +88,7 @@ class StatementClientV1 private final AtomicReference> setPath = new AtomicReference<>(); private final AtomicReference setAuthorizationUser = new AtomicReference<>(); private final AtomicBoolean resetAuthorizationUser = new AtomicBoolean(); + private final Set setOriginalRoles = Sets.newConcurrentHashSet(); private final Map setSessionProperties = new ConcurrentHashMap<>(); private final Set resetSessionProperties = Sets.newConcurrentHashSet(); private final Map setRoles = new ConcurrentHashMap<>(); @@ -125,6 +126,7 @@ public StatementClientV1(Call.Factory httpCallFactory, Call.Factory segmentHttpC .filter(Optional::isPresent) .map(Optional::get) .findFirst(); + this.setOriginalRoles.addAll(session.getOriginalRoles()); this.clientCapabilities = Joiner.on(",").join(clientCapabilities.orElseGet(() -> stream(ClientCapabilities.values()) .map(Enum::name) .collect(toImmutableSet()))); @@ -181,6 +183,10 @@ private Request buildQueryRequest(ClientSession session, String query, Optional< builder.addHeader(TRINO_HEADERS.requestResourceEstimate(), entry.getKey() + "=" + urlEncode(entry.getValue())); } + for (ClientSelectedRole selectedRole : session.getOriginalRoles()) { + builder.addHeader(TRINO_HEADERS.requestOriginalRole(), selectedRole.toString()); + } + Map roles = session.getRoles(); for (Entry entry : roles.entrySet()) { builder.addHeader(TRINO_HEADERS.requestRole(), entry.getKey() + '=' + urlEncode(entry.getValue().toString())); @@ -316,6 +322,12 @@ public boolean isResetAuthorizationUser() return resetAuthorizationUser.get(); } + @Override + public Set getSetOriginalRoles() + { + return ImmutableSet.copyOf(setOriginalRoles); + } + @Override public Map getSetSessionProperties() { @@ -475,6 +487,11 @@ private void processResponse(Headers headers, QueryResults results) this.resetAuthorizationUser.set(Boolean.parseBoolean(resetAuthorizationUser)); } + setOriginalRoles.addAll(headers.values(TRINO_HEADERS.responseOriginalRole()) + .stream() + .map(role -> ClientSelectedRole.valueOf(urlDecode(role))) + .collect(toImmutableSet())); + for (String setSession : headers.values(TRINO_HEADERS.responseSetSession())) { List keyValue = COLLECTION_HEADER_SPLITTER.splitToList(setSession); if (keyValue.size() != 2) { diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java index 1b0b6730e7c9..2969a4b553d2 100644 --- a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java +++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java @@ -19,6 +19,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; import com.google.common.primitives.Ints; import io.airlift.units.Duration; import io.trino.client.ClientSelectedRole; @@ -107,6 +108,7 @@ public class TrinoConnection private final AtomicReference schema = new AtomicReference<>(); private final AtomicReference> path = new AtomicReference<>(ImmutableList.of()); private final AtomicReference authorizationUser = new AtomicReference<>(); + private final Set originalRoles = Sets.newConcurrentHashSet(); private final AtomicReference timeZoneId = new AtomicReference<>(); private final AtomicReference locale = new AtomicReference<>(); private final AtomicReference networkTimeoutMillis = new AtomicReference<>(Ints.saturatedCast(MINUTES.toMillis(2))); @@ -855,6 +857,7 @@ StatementClient startQuery(String sql, Map sessionPropertiesOver .user(user) .sessionUser(sessionUser.get()) .authorizationUser(Optional.ofNullable(authorizationUser.get())) + .originalRoles(ImmutableSet.copyOf(originalRoles)) .source(source) .traceToken(Optional.ofNullable(clientInfo.get(TRACE_TOKEN))) .clientTags(ImmutableSet.copyOf(clientTags)) @@ -891,10 +894,12 @@ void updateSession(StatementClient client) if (client.getSetAuthorizationUser().isPresent()) { authorizationUser.set(client.getSetAuthorizationUser().get()); + originalRoles.addAll(client.getSetOriginalRoles()); roles.clear(); } if (client.isResetAuthorizationUser()) { authorizationUser.set(null); + originalRoles.clear(); roles.clear(); } diff --git a/client/trino-jdbc/src/test/java/io/trino/jdbc/BaseTrinoDriverTest.java b/client/trino-jdbc/src/test/java/io/trino/jdbc/BaseTrinoDriverTest.java index 86966be9d5d5..871521ee4a71 100644 --- a/client/trino-jdbc/src/test/java/io/trino/jdbc/BaseTrinoDriverTest.java +++ b/client/trino-jdbc/src/test/java/io/trino/jdbc/BaseTrinoDriverTest.java @@ -1241,7 +1241,7 @@ private static Properties toProperties(Map map) return properties; } - private static String getCurrentUser(Connection connection) + public static String getCurrentUser(Connection connection) throws SQLException { try (Statement statement = connection.createStatement(); diff --git a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestAsyncResultIterator.java b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestAsyncResultIterator.java index 3d20a40dc227..9825e221dee1 100644 --- a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestAsyncResultIterator.java +++ b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestAsyncResultIterator.java @@ -215,6 +215,12 @@ public boolean isResetAuthorizationUser() throw new UnsupportedOperationException(); } + @Override + public Set getSetOriginalRoles() + { + throw new UnsupportedOperationException(); + } + @Override public Map getSetSessionProperties() { diff --git a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestJdbcConnection.java b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestJdbcConnection.java index 75ef5a13839a..28f370fcb5a9 100644 --- a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestJdbcConnection.java +++ b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestJdbcConnection.java @@ -1005,6 +1005,12 @@ public boolean isResetAuthorizationUser() return true; } + @Override + public Set getSetOriginalRoles() + { + return Set.of(); + } + @Override public Map getSetSessionProperties() { diff --git a/core/trino-main/src/main/java/io/trino/Session.java b/core/trino-main/src/main/java/io/trino/Session.java index 7841c53cb5b6..860ff2779318 100644 --- a/core/trino-main/src/main/java/io/trino/Session.java +++ b/core/trino-main/src/main/java/io/trino/Session.java @@ -547,7 +547,8 @@ public SessionRepresentation toSessionRepresentation() clientTransactionSupport, identity.getUser(), originalIdentity.getUser(), - identity.getGroups(), + originalIdentity.getEnabledRoles(), + originalIdentity.getGroups(), originalIdentity.getGroups(), identity.getPrincipal().map(Principal::toString), identity.getEnabledRoles(), diff --git a/core/trino-main/src/main/java/io/trino/SessionRepresentation.java b/core/trino-main/src/main/java/io/trino/SessionRepresentation.java index caf538014874..793d76c667e3 100644 --- a/core/trino-main/src/main/java/io/trino/SessionRepresentation.java +++ b/core/trino-main/src/main/java/io/trino/SessionRepresentation.java @@ -48,6 +48,7 @@ public final class SessionRepresentation private final boolean clientTransactionSupport; private final String user; private final String originalUser; + private final Set originalRoles; private final Set groups; private final Set originalUserGroups; private final Optional principal; @@ -81,6 +82,7 @@ public SessionRepresentation( @JsonProperty("clientTransactionSupport") boolean clientTransactionSupport, @JsonProperty("user") String user, @JsonProperty("originalUser") String originalUser, + @JsonProperty("setOriginalRoles") Set originalRoles, @JsonProperty("groups") Set groups, @JsonProperty("originalUserGroups") Set originalUserGroups, @JsonProperty("principal") Optional principal, @@ -112,6 +114,7 @@ public SessionRepresentation( this.clientTransactionSupport = clientTransactionSupport; this.user = requireNonNull(user, "user is null"); this.originalUser = requireNonNull(originalUser, "originalUser is null"); + this.originalRoles = requireNonNull(originalRoles, "setOriginalRoles is null"); this.groups = requireNonNull(groups, "groups is null"); this.originalUserGroups = requireNonNull(originalUserGroups, "originalUserGroups is null"); this.principal = requireNonNull(principal, "principal is null"); @@ -179,6 +182,12 @@ public String getOriginalUser() return originalUser; } + @JsonProperty + public Set getOriginalRoles() + { + return originalRoles; + } + @JsonProperty public Set getGroups() { @@ -350,6 +359,7 @@ public Identity toOriginalIdentity(Map extraCredentials) return Identity.forUser(originalUser) .withGroups(originalUserGroups) .withPrincipal(principal.map(BasicPrincipal::new)) + .withEnabledRoles(originalRoles) .withExtraCredentials(extraCredentials) .build(); } diff --git a/core/trino-main/src/main/java/io/trino/dispatcher/FailedDispatchQuery.java b/core/trino-main/src/main/java/io/trino/dispatcher/FailedDispatchQuery.java index 1e1f00a645cb..255d7771cdc6 100644 --- a/core/trino-main/src/main/java/io/trino/dispatcher/FailedDispatchQuery.java +++ b/core/trino-main/src/main/java/io/trino/dispatcher/FailedDispatchQuery.java @@ -234,6 +234,7 @@ private static QueryInfo immediateFailureQueryInfo( Optional.empty(), Optional.empty(), false, + ImmutableSet.of(), ImmutableMap.of(), ImmutableSet.of(), ImmutableMap.of(), diff --git a/core/trino-main/src/main/java/io/trino/event/QueryMonitor.java b/core/trino-main/src/main/java/io/trino/event/QueryMonitor.java index 9e0bc49c53bc..50c2cfa77c0e 100644 --- a/core/trino-main/src/main/java/io/trino/event/QueryMonitor.java +++ b/core/trino-main/src/main/java/io/trino/event/QueryMonitor.java @@ -365,6 +365,7 @@ private QueryContext createQueryContext(SessionRepresentation session, Optional< return new QueryContext( session.getUser(), session.getOriginalUser(), + session.getOriginalRoles(), session.getPrincipal(), session.getEnabledRoles(), session.getGroups(), diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryInfo.java b/core/trino-main/src/main/java/io/trino/execution/QueryInfo.java index 18a4062fdbd6..9f7f54abde93 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryInfo.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryInfo.java @@ -62,6 +62,7 @@ public class QueryInfo private final Optional setPath; private final Optional setAuthorizationUser; private final boolean resetAuthorizationUser; + private final Set setOriginalRoles; private final Map setSessionProperties; private final Set resetSessionProperties; private final Map setRoles; @@ -101,6 +102,7 @@ public QueryInfo( @JsonProperty("setPath") Optional setPath, @JsonProperty("setAuthorizationUser") Optional setAuthorizationUser, @JsonProperty("resetAuthorizationUser") boolean resetAuthorizationUser, + @JsonProperty("setOriginalRoles") Set setOriginalRoles, @JsonProperty("setSessionProperties") Map setSessionProperties, @JsonProperty("resetSessionProperties") Set resetSessionProperties, @JsonProperty("setRoles") Map setRoles, @@ -134,6 +136,7 @@ public QueryInfo( requireNonNull(setSchema, "setSchema is null"); requireNonNull(setPath, "setPath is null"); requireNonNull(setAuthorizationUser, "setAuthorizationUser is null"); + requireNonNull(setOriginalRoles, "setOriginalRoles is null"); requireNonNull(setSessionProperties, "setSessionProperties is null"); requireNonNull(resetSessionProperties, "resetSessionProperties is null"); requireNonNull(addedPreparedStatements, "addedPreparedStatements is null"); @@ -165,6 +168,7 @@ public QueryInfo( this.setPath = setPath; this.setAuthorizationUser = setAuthorizationUser; this.resetAuthorizationUser = resetAuthorizationUser; + this.setOriginalRoles = setOriginalRoles; this.setSessionProperties = ImmutableMap.copyOf(setSessionProperties); this.resetSessionProperties = ImmutableSet.copyOf(resetSessionProperties); this.setRoles = ImmutableMap.copyOf(setRoles); @@ -287,6 +291,12 @@ public boolean isResetAuthorizationUser() return resetAuthorizationUser; } + @JsonProperty + public Set getSetOriginalRoles() + { + return setOriginalRoles; + } + @JsonProperty public Map getSetSessionProperties() { @@ -455,6 +465,7 @@ public QueryInfo pruneDigests() setPath, setAuthorizationUser, resetAuthorizationUser, + setOriginalRoles, setSessionProperties, resetSessionProperties, setRoles, diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java b/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java index 0e23c80b5ca3..2bea048b787f 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java @@ -156,6 +156,7 @@ public class QueryStateMachine private final AtomicReference setAuthorizationUser = new AtomicReference<>(); private final AtomicBoolean resetAuthorizationUser = new AtomicBoolean(); + private final Set setOriginalRoles = Sets.newConcurrentHashSet(); private final Map setSessionProperties = new ConcurrentHashMap<>(); private final Set resetSessionProperties = Sets.newConcurrentHashSet(); @@ -525,6 +526,7 @@ ResultQueryInfo getResultQueryInfo(Optional stageInfo) Optional.ofNullable(setPath.get()), Optional.ofNullable(setAuthorizationUser.get()), resetAuthorizationUser.get(), + setOriginalRoles, setSessionProperties, resetSessionProperties, setRoles, @@ -618,6 +620,7 @@ QueryInfo getQueryInfo(Optional rootStage) Optional.ofNullable(setPath.get()), Optional.ofNullable(setAuthorizationUser.get()), resetAuthorizationUser.get(), + setOriginalRoles, setSessionProperties, resetSessionProperties, setRoles, @@ -1024,6 +1027,11 @@ public void resetAuthorizationUser() resetAuthorizationUser.set(true); } + public void addSetOriginalRoles(SelectedRole role) + { + setOriginalRoles.add(requireNonNull(role, "role is null")); + } + public void addSetSessionProperties(String key, String value) { setSessionProperties.put(requireNonNull(key, "key is null"), requireNonNull(value, "value is null")); @@ -1417,6 +1425,7 @@ public static QueryInfo pruneQueryInfo(QueryInfo queryInfo, NodeVersion version) queryInfo.getSetPath(), queryInfo.getSetAuthorizationUser(), queryInfo.isResetAuthorizationUser(), + queryInfo.getSetOriginalRoles(), queryInfo.getSetSessionProperties(), queryInfo.getResetSessionProperties(), queryInfo.getSetRoles(), diff --git a/core/trino-main/src/main/java/io/trino/execution/SetSessionAuthorizationTask.java b/core/trino-main/src/main/java/io/trino/execution/SetSessionAuthorizationTask.java index b351549a23ee..20857add7c91 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SetSessionAuthorizationTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/SetSessionAuthorizationTask.java @@ -21,6 +21,7 @@ import io.trino.security.AccessControl; import io.trino.spi.TrinoException; import io.trino.spi.security.Identity; +import io.trino.spi.security.SelectedRole; import io.trino.sql.tree.Expression; import io.trino.sql.tree.Identifier; import io.trino.sql.tree.SetSessionAuthorization; @@ -28,6 +29,8 @@ import io.trino.transaction.TransactionManager; import java.util.List; +import java.util.Optional; +import java.util.Set; import static com.google.common.base.Preconditions.checkState; import static com.google.common.util.concurrent.Futures.immediateFuture; @@ -91,6 +94,18 @@ else if (userExpression instanceof StringLiteral stringLiteral) { accessControl.checkCanImpersonateUser(originalIdentity, user); } stateMachine.setSetAuthorizationUser(user); + SelectedRole selectedRole; + Set enabledRoles = originalIdentity.getEnabledRoles(); + if (enabledRoles.isEmpty()) { + selectedRole = new SelectedRole(SelectedRole.Type.NONE, Optional.empty()); + } + else if (enabledRoles.size() == 1) { + selectedRole = new SelectedRole(SelectedRole.Type.ROLE, Optional.of(enabledRoles.iterator().next())); + } + else { + selectedRole = new SelectedRole(SelectedRole.Type.ALL, Optional.empty()); + } + stateMachine.addSetOriginalRoles(selectedRole); return immediateFuture(null); } } diff --git a/core/trino-main/src/main/java/io/trino/server/HttpRequestSessionContextFactory.java b/core/trino-main/src/main/java/io/trino/server/HttpRequestSessionContextFactory.java index 603f90303598..bbe0855f5d9c 100644 --- a/core/trino-main/src/main/java/io/trino/server/HttpRequestSessionContextFactory.java +++ b/core/trino-main/src/main/java/io/trino/server/HttpRequestSessionContextFactory.java @@ -254,18 +254,14 @@ private Identity buildSessionIdentity(Optional authenticatedIdentity, String user = trinoUser != null ? trinoUser : authenticatedIdentity.map(Identity::getUser).orElse(null); assertRequest(user != null, "User must be set"); SelectedRole systemRole = parseSystemRoleHeaders(protocolHeaders, headers); - ImmutableSet.Builder systemEnabledRoles = ImmutableSet.builder(); - if (systemRole.getType() == Type.ROLE) { - systemEnabledRoles.add(systemRole.getRole().orElseThrow()); - } - return authenticatedIdentity + Identity newIdentity = authenticatedIdentity .map(identity -> Identity.from(identity).withUser(user)) .orElseGet(() -> Identity.forUser(user)) - .withEnabledRoles(systemEnabledRoles.build()) .withAdditionalConnectorRoles(parseConnectorRoleHeaders(protocolHeaders, headers)) .withAdditionalExtraCredentials(parseExtraCredentials(protocolHeaders, headers)) .withAdditionalGroups(groupProvider.getGroups(user)) .build(); + return addEnabledRoles(newIdentity, systemRole, metadata); } private Identity buildSessionOriginalIdentity(Identity identity, ProtocolHeaders protocolHeaders, MultivaluedMap headers) @@ -273,12 +269,18 @@ private Identity buildSessionOriginalIdentity(Identity identity, ProtocolHeaders // We derive original identity using this header, but older clients will not send it, so fall back to identity Optional optionalOriginalUser = Optional .ofNullable(trimEmptyToNull(headers.getFirst(protocolHeaders.requestOriginalUser()))); - Identity originalIdentity = optionalOriginalUser.map(originalUser -> Identity.from(identity) - .withUser(originalUser) - .withExtraCredentials(new HashMap<>()) - .withGroups(groupProvider.getGroups(originalUser)) - .build()) - .orElse(identity); + Optional originalRoles = Optional.ofNullable(trimEmptyToNull(headers.getFirst(protocolHeaders.requestOriginalRole()))); + Identity originalIdentity = optionalOriginalUser.map(originalUser -> { + Identity newIdentity = Identity.from(identity) + .withUser(originalUser) + .withExtraCredentials(new HashMap<>()) + .withGroups(groupProvider.getGroups(originalUser)) + .build(); + if (originalRoles.isPresent()) { + newIdentity = addEnabledRoles(newIdentity, SelectedRole.valueOf(originalRoles.get()), metadata); + } + return newIdentity; + }).orElse(identity); return originalIdentity; } diff --git a/core/trino-main/src/main/java/io/trino/server/ResultQueryInfo.java b/core/trino-main/src/main/java/io/trino/server/ResultQueryInfo.java index 3cdb9dfa6184..059a43a3020c 100644 --- a/core/trino-main/src/main/java/io/trino/server/ResultQueryInfo.java +++ b/core/trino-main/src/main/java/io/trino/server/ResultQueryInfo.java @@ -63,6 +63,8 @@ public record ResultQueryInfo( @JsonProperty boolean resetAuthorizationUser, @JsonProperty + Set setOriginalRoles, + @JsonProperty Map setSessionProperties, @JsonProperty Set resetSessionProperties, @@ -95,6 +97,7 @@ public ResultQueryInfo( @JsonProperty("setPath") Optional setPath, @JsonProperty("setAuthorizationUser") Optional setAuthorizationUser, @JsonProperty("resetAuthorizationUser") boolean resetAuthorizationUser, + @JsonProperty("setOriginalRoles") Set setOriginalRoles, @JsonProperty("setSessionProperties") Map setSessionProperties, @JsonProperty("resetSessionProperties") Set resetSessionProperties, @JsonProperty("setRoles") Map setRoles, @@ -118,6 +121,7 @@ public ResultQueryInfo( this.setPath = requireNonNull(setPath, "setPath is null"); this.setAuthorizationUser = requireNonNull(setAuthorizationUser, "setAuthorizationUser is null"); this.resetAuthorizationUser = resetAuthorizationUser; + this.setOriginalRoles = requireNonNull(setOriginalRoles, "setOriginalRoles is null"); this.setSessionProperties = requireNonNull(setSessionProperties, "setSessionProperties is null"); this.resetSessionProperties = requireNonNull(resetSessionProperties, "resetSessionProperties is null"); this.addedPreparedStatements = requireNonNull(addedPreparedStatements, "addedPreparedStatements is null"); @@ -144,6 +148,7 @@ public ResultQueryInfo(QueryInfo queryInfo) queryInfo.getSetPath(), queryInfo.getSetAuthorizationUser(), queryInfo.isResetAuthorizationUser(), + queryInfo.getSetOriginalRoles(), queryInfo.getSetSessionProperties(), queryInfo.getResetSessionProperties(), queryInfo.getSetRoles(), diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/ExecutingStatementResource.java b/core/trino-main/src/main/java/io/trino/server/protocol/ExecutingStatementResource.java index 936751525f4a..e956038c31d9 100644 --- a/core/trino-main/src/main/java/io/trino/server/protocol/ExecutingStatementResource.java +++ b/core/trino-main/src/main/java/io/trino/server/protocol/ExecutingStatementResource.java @@ -224,6 +224,10 @@ private Response toResponse(QueryResultsResponse resultsResponse, Optional response.header(protocolHeaders.responseOriginalRole(), name)); + // add set session properties resultsResponse.setSessionProperties() .forEach((key, value) -> response.header(protocolHeaders.responseSetSession(), key + '=' + urlEncode(value))); diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/Query.java b/core/trino-main/src/main/java/io/trino/server/protocol/Query.java index c09d52ff98ae..0b0dbcfc8bd1 100644 --- a/core/trino-main/src/main/java/io/trino/server/protocol/Query.java +++ b/core/trino-main/src/main/java/io/trino/server/protocol/Query.java @@ -160,6 +160,9 @@ class Query @GuardedBy("this") private Optional setAuthorizationUser = Optional.empty(); + @GuardedBy("this") + private Set setOriginalRoles = ImmutableSet.of(); + @GuardedBy("this") private boolean resetAuthorizationUser; @@ -505,6 +508,9 @@ private synchronized QueryResultsResponse getNextResult(long token, ExternalUriI setAuthorizationUser = queryInfo.setAuthorizationUser(); resetAuthorizationUser = queryInfo.resetAuthorizationUser(); + // update setOriginalRoles + setOriginalRoles = queryInfo.setOriginalRoles(); + // update setSessionProperties setSessionProperties = queryInfo.setSessionProperties(); resetSessionProperties = queryInfo.resetSessionProperties(); @@ -549,6 +555,7 @@ private synchronized QueryResultsResponse toResultsResponse(QueryResults queryRe setPath, setAuthorizationUser, resetAuthorizationUser, + setOriginalRoles, setSessionProperties, resetSessionProperties, setRoles, diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/QueryResultsResponse.java b/core/trino-main/src/main/java/io/trino/server/protocol/QueryResultsResponse.java index 4387fd9f405b..f1a74b09a4d2 100644 --- a/core/trino-main/src/main/java/io/trino/server/protocol/QueryResultsResponse.java +++ b/core/trino-main/src/main/java/io/trino/server/protocol/QueryResultsResponse.java @@ -30,6 +30,7 @@ record QueryResultsResponse( Optional setPath, Optional setAuthorizationUser, boolean resetAuthorizationUser, + Set setOriginalRoles, Map setSessionProperties, Set resetSessionProperties, Map setRoles, @@ -45,6 +46,7 @@ record QueryResultsResponse( requireNonNull(setSchema, "setSchema is null"); requireNonNull(setPath, "setPath is null"); requireNonNull(setAuthorizationUser, "setAuthorizationUser is null"); + requireNonNull(setOriginalRoles, "setOriginalRoles is null"); requireNonNull(setSessionProperties, "setSessionProperties is null"); requireNonNull(resetSessionProperties, "resetSessionProperties is null"); requireNonNull(setRoles, "setRoles is null"); diff --git a/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java b/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java index 72697b6b8ff8..85775ebf8513 100644 --- a/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java +++ b/core/trino-main/src/main/java/io/trino/testing/TestingAccessControlManager.java @@ -144,6 +144,7 @@ public class TestingAccessControlManager private Predicate deniedTables = s -> true; private BiPredicate denyIdentityTable = IDENTITY_TABLE_TRUE; private BiPredicate denyIdentityFunction = IDENTITY_FUNCTION_TRUE; + private BiPredicate denyImpersonationFunction = IDENTITY_FUNCTION_TRUE; @Inject public TestingAccessControlManager( @@ -196,6 +197,7 @@ public void reset() denyIdentityTable = IDENTITY_TABLE_TRUE; rowFilters.clear(); columnMasks.clear(); + denyImpersonationFunction = IDENTITY_FUNCTION_TRUE; } public void denyCatalogs(Predicate deniedCatalogs) @@ -223,6 +225,11 @@ public void denyIdentityFunction(BiPredicate denyIdentityFunct this.denyIdentityFunction = requireNonNull(denyIdentityFunction, "denyIdentityFunction is null"); } + public void denyImpersonation(BiPredicate denyImpersonationFunction) + { + this.denyImpersonationFunction = requireNonNull(denyImpersonationFunction, "denyImpersonationFunction is null"); + } + @Override public Set filterCatalogs(SecurityContext securityContext, Set catalogs) { @@ -258,6 +265,9 @@ public Set filterTables(SecurityContext context, String catalog @Override public void checkCanImpersonateUser(Identity identity, String userName) { + if (!denyImpersonationFunction.test(identity, userName)) { + denyImpersonateUser(identity.getUser(), userName); + } if (shouldDenyPrivilege(userName, userName, IMPERSONATE_USER)) { denyImpersonateUser(identity.getUser(), userName); } diff --git a/core/trino-main/src/test/java/io/trino/execution/MockManagedQueryExecution.java b/core/trino-main/src/test/java/io/trino/execution/MockManagedQueryExecution.java index 4029f63085bc..30f342a0dcd8 100644 --- a/core/trino-main/src/test/java/io/trino/execution/MockManagedQueryExecution.java +++ b/core/trino-main/src/test/java/io/trino/execution/MockManagedQueryExecution.java @@ -268,6 +268,7 @@ public QueryInfo getFullQueryInfo() Optional.empty(), Optional.empty(), false, + ImmutableSet.of(), ImmutableMap.of(), ImmutableSet.of(), ImmutableMap.of(), diff --git a/core/trino-main/src/test/java/io/trino/execution/TestQueryInfo.java b/core/trino-main/src/test/java/io/trino/execution/TestQueryInfo.java index eb53c6e26a72..6e879db46d6d 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestQueryInfo.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestQueryInfo.java @@ -225,6 +225,7 @@ private static QueryInfo createQueryInfo(Optional stageInfo) Optional.of("set_path"), Optional.of("set_authorization_user"), false, + ImmutableSet.of(new SelectedRole(SelectedRole.Type.ROLE, Optional.of("original_role"))), ImmutableMap.of("set_property", "set_value"), ImmutableSet.of("reset_property"), ImmutableMap.of("set_roles", new SelectedRole(SelectedRole.Type.ROLE, Optional.of("role"))), diff --git a/core/trino-main/src/test/java/io/trino/metadata/TestMetadataManager.java b/core/trino-main/src/test/java/io/trino/metadata/TestMetadataManager.java index 82daab9362dd..44eb9a01a5f2 100644 --- a/core/trino-main/src/test/java/io/trino/metadata/TestMetadataManager.java +++ b/core/trino-main/src/test/java/io/trino/metadata/TestMetadataManager.java @@ -18,6 +18,7 @@ import io.trino.FeaturesConfig; import io.trino.security.AllowAllAccessControl; import io.trino.spi.block.BlockEncodingSerde; +import io.trino.spi.security.Identity; import io.trino.spi.type.TypeManager; import io.trino.spi.type.TypeOperators; import io.trino.sql.parser.SqlParser; @@ -25,6 +26,8 @@ import io.trino.transaction.TransactionManager; import io.trino.type.BlockTypeOperators; +import java.util.Set; + import static io.trino.client.NodeVersion.UNKNOWN; import static io.trino.transaction.InMemoryTransactionManager.createTestTransactionManager; import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; @@ -104,7 +107,7 @@ public MetadataManager build() return new MetadataManager( new AllowAllAccessControl(), - new DisabledSystemSecurityMetadata(), + new SecurityMetadata(), transactionManager, globalFunctionCatalog, languageFunctionManager, @@ -113,4 +116,14 @@ public MetadataManager build() new NotImplementedQueryManager()); } } + + private static class SecurityMetadata + extends DisabledSystemSecurityMetadata + { + @Override + public Set listEnabledRoles(Identity identity) + { + return ImmutableSet.of("system-role"); + } + } } diff --git a/core/trino-main/src/test/java/io/trino/server/TestBasicQueryInfo.java b/core/trino-main/src/test/java/io/trino/server/TestBasicQueryInfo.java index f543cdaac64b..063200e96818 100644 --- a/core/trino-main/src/test/java/io/trino/server/TestBasicQueryInfo.java +++ b/core/trino-main/src/test/java/io/trino/server/TestBasicQueryInfo.java @@ -144,6 +144,7 @@ public void testConstructor() Optional.empty(), Optional.empty(), false, + ImmutableSet.of(), ImmutableMap.of(), ImmutableSet.of(), ImmutableMap.of(), diff --git a/core/trino-main/src/test/java/io/trino/server/TestHttpRequestSessionContextFactory.java b/core/trino-main/src/test/java/io/trino/server/TestHttpRequestSessionContextFactory.java index ec7b01e6b02e..fddcaa755db0 100644 --- a/core/trino-main/src/test/java/io/trino/server/TestHttpRequestSessionContextFactory.java +++ b/core/trino-main/src/test/java/io/trino/server/TestHttpRequestSessionContextFactory.java @@ -122,19 +122,31 @@ private static void assertMappedUser(ProtocolHeaders protocolHeaders) userHeaders, Optional.of("testRemote"), Optional.empty()); - assertThat(context.getIdentity()).isEqualTo(Identity.forUser("testUser").withGroups(ImmutableSet.of("testUser")).build()); + assertThat(context.getIdentity()) + .isEqualTo(Identity.forUser("testUser") + .withGroups(ImmutableSet.of("testUser")) + .withEnabledRoles(ImmutableSet.of("system-role")) + .build()); context = sessionContextFactory(protocolHeaders).createSessionContext( emptyHeaders, Optional.of("testRemote"), Optional.of(Identity.forUser("mappedUser").withGroups(ImmutableSet.of("test")).build())); - assertThat(context.getIdentity()).isEqualTo(Identity.forUser("mappedUser").withGroups(ImmutableSet.of("test", "mappedUser")).build()); + assertThat(context.getIdentity()) + .isEqualTo(Identity.forUser("mappedUser") + .withGroups(ImmutableSet.of("test", "mappedUser")) + .withEnabledRoles(ImmutableSet.of("system-role")) + .build()); context = sessionContextFactory(protocolHeaders).createSessionContext( userHeaders, Optional.of("testRemote"), Optional.of(Identity.ofUser("mappedUser"))); - assertThat(context.getIdentity()).isEqualTo(Identity.forUser("testUser").withGroups(ImmutableSet.of("testUser")).build()); + assertThat(context.getIdentity()) + .isEqualTo(Identity.forUser("testUser") + .withGroups(ImmutableSet.of("testUser")) + .withEnabledRoles(ImmutableSet.of("system-role")) + .build()); assertInvalidSession(protocolHeaders, emptyHeaders) .matches(e -> ((WebApplicationException) e).getResponse().getStatus() == 400); diff --git a/core/trino-main/src/test/java/io/trino/server/TestQueryStateInfo.java b/core/trino-main/src/test/java/io/trino/server/TestQueryStateInfo.java index 9e7ec6e246d7..e218e762314d 100644 --- a/core/trino-main/src/test/java/io/trino/server/TestQueryStateInfo.java +++ b/core/trino-main/src/test/java/io/trino/server/TestQueryStateInfo.java @@ -189,6 +189,7 @@ private QueryInfo createQueryInfo(String queryId, QueryState state, String query Optional.empty(), Optional.empty(), false, + ImmutableSet.of(), ImmutableMap.of(), ImmutableSet.of(), ImmutableMap.of(), diff --git a/core/trino-spi/src/main/java/io/trino/spi/eventlistener/QueryContext.java b/core/trino-spi/src/main/java/io/trino/spi/eventlistener/QueryContext.java index a14094f5a9ce..47f917f51c5f 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/eventlistener/QueryContext.java +++ b/core/trino-spi/src/main/java/io/trino/spi/eventlistener/QueryContext.java @@ -33,6 +33,7 @@ public class QueryContext { private final String user; private final String originalUser; + private final Set originalRoles; private final Optional principal; private final Set enabledRoles; private final Set groups; @@ -66,6 +67,7 @@ public class QueryContext public QueryContext( String user, String originalUser, + Set originalRoles, Optional principal, Set enabledRoles, Set groups, @@ -90,6 +92,7 @@ public QueryContext( { this.user = requireNonNull(user, "user is null"); this.originalUser = requireNonNull(originalUser, "originalUser is null"); + this.originalRoles = requireNonNull(originalRoles, "originalRoles is null"); this.principal = requireNonNull(principal, "principal is null"); this.enabledRoles = requireNonNull(enabledRoles, "enabledRoles is null"); this.groups = requireNonNull(groups, "groups is null"); @@ -125,6 +128,12 @@ public String getOriginalUser() return originalUser; } + @JsonProperty + public Set getOriginalRoles() + { + return originalRoles; + } + @JsonProperty public Optional getPrincipal() { diff --git a/docs/src/main/sphinx/develop/client-protocol.md b/docs/src/main/sphinx/develop/client-protocol.md index b7d353ef7d9f..d33fdf8e4b68 100644 --- a/docs/src/main/sphinx/develop/client-protocol.md +++ b/docs/src/main/sphinx/develop/client-protocol.md @@ -228,6 +228,9 @@ subsequent requests to be consistent with the response headers received. - Instructs the client to reset `X-Trino-User` request header to its original value in subsequent client requests and remove `X-Trino-Original-User` to reset the authorization user back to the original user. +* - `X-Trino-Set-Original-Roles` + - Instructs the client to set the roles of the original user in the + `X-Trino-Original-Roles` request header in subsequent client requests. * - `X-Trino-Set-Session` - The value of the `X-Trino-Set-Session` response header is a string of the form *property* = *value*. It instructs the client include session property diff --git a/plugin/trino-http-event-listener/src/test/java/io/trino/plugin/httpquery/TestHttpEventListener.java b/plugin/trino-http-event-listener/src/test/java/io/trino/plugin/httpquery/TestHttpEventListener.java index 8d5fe476e40f..b55952febe27 100644 --- a/plugin/trino-http-event-listener/src/test/java/io/trino/plugin/httpquery/TestHttpEventListener.java +++ b/plugin/trino-http-event-listener/src/test/java/io/trino/plugin/httpquery/TestHttpEventListener.java @@ -106,6 +106,7 @@ final class TestHttpEventListener queryContext = new QueryContext( "user", "originalUser", + Set.of(), Optional.of("principal"), Set.of(), // enabledRoles Set.of(), // groups diff --git a/plugin/trino-http-server-event-listener/src/test/java/io/trino/plugin/httpquery/TestHttpServerEventListener.java b/plugin/trino-http-server-event-listener/src/test/java/io/trino/plugin/httpquery/TestHttpServerEventListener.java index 8f2defd9c343..3a6668bbbdc4 100644 --- a/plugin/trino-http-server-event-listener/src/test/java/io/trino/plugin/httpquery/TestHttpServerEventListener.java +++ b/plugin/trino-http-server-event-listener/src/test/java/io/trino/plugin/httpquery/TestHttpServerEventListener.java @@ -72,6 +72,7 @@ final class TestHttpServerEventListener QueryContext queryContext = new QueryContext( "user", "originalUser", + Set.of(), Optional.of("principal"), Set.of(), // enabledRoles Set.of(), // groups diff --git a/plugin/trino-kafka-event-listener/src/test/java/io/trino/plugin/eventlistener/kafka/TestUtils.java b/plugin/trino-kafka-event-listener/src/test/java/io/trino/plugin/eventlistener/kafka/TestUtils.java index e9bfa85475f4..e5b5ef4d715d 100644 --- a/plugin/trino-kafka-event-listener/src/test/java/io/trino/plugin/eventlistener/kafka/TestUtils.java +++ b/plugin/trino-kafka-event-listener/src/test/java/io/trino/plugin/eventlistener/kafka/TestUtils.java @@ -68,6 +68,7 @@ private TestUtils() {} queryContext = new QueryContext( "user", "originalUser", + Set.of(), Optional.of("principal"), Set.of(), Set.of(), // groups diff --git a/plugin/trino-mysql-event-listener/src/test/java/io/trino/plugin/eventlistener/mysql/TestMysqlEventListener.java b/plugin/trino-mysql-event-listener/src/test/java/io/trino/plugin/eventlistener/mysql/TestMysqlEventListener.java index e5a06e514d25..54a1e5c53319 100644 --- a/plugin/trino-mysql-event-listener/src/test/java/io/trino/plugin/eventlistener/mysql/TestMysqlEventListener.java +++ b/plugin/trino-mysql-event-listener/src/test/java/io/trino/plugin/eventlistener/mysql/TestMysqlEventListener.java @@ -144,6 +144,7 @@ final class TestMysqlEventListener private static final QueryContext FULL_QUERY_CONTEXT = new QueryContext( "user", "originalUser", + Set.of("role1"), Optional.of("principal"), Set.of("role1", "role2"), Set.of("group1", "group2"), @@ -305,6 +306,7 @@ final class TestMysqlEventListener private static final QueryContext MINIMAL_QUERY_CONTEXT = new QueryContext( "user", "originalUser", + Set.of(), Optional.empty(), Set.of(), Set.of(), diff --git a/plugin/trino-openlineage/src/test/java/io/trino/plugin/openlineage/TrinoEventData.java b/plugin/trino-openlineage/src/test/java/io/trino/plugin/openlineage/TrinoEventData.java index 5a6adad0f390..bfcf0bfa432e 100644 --- a/plugin/trino-openlineage/src/test/java/io/trino/plugin/openlineage/TrinoEventData.java +++ b/plugin/trino-openlineage/src/test/java/io/trino/plugin/openlineage/TrinoEventData.java @@ -58,6 +58,7 @@ private TrinoEventData() queryContext = new QueryContext( "user", "originalUser", + Set.of(), Optional.of("principal"), Set.of(), // enabledRoles Set.of(), // groups diff --git a/pom.xml b/pom.xml index 1b1151d37bd2..170d8c8ba286 100644 --- a/pom.xml +++ b/pom.xml @@ -1171,6 +1171,13 @@ ${project.version} + + io.trino + trino-jdbc + ${project.version} + test-jar + + io.trino trino-jmx diff --git a/testing/trino-tests/pom.xml b/testing/trino-tests/pom.xml index 52c3510ba84a..53c097a6d35f 100644 --- a/testing/trino-tests/pom.xml +++ b/testing/trino-tests/pom.xml @@ -201,6 +201,19 @@ test + + io.trino + trino-jdbc + test + + + + io.trino + trino-jdbc + test-jar + test + + io.trino trino-main diff --git a/testing/trino-tests/src/test/java/io/trino/security/TestImpersonation.java b/testing/trino-tests/src/test/java/io/trino/security/TestImpersonation.java new file mode 100644 index 000000000000..3366eed50ce8 --- /dev/null +++ b/testing/trino-tests/src/test/java/io/trino/security/TestImpersonation.java @@ -0,0 +1,160 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.security; + +import com.google.common.collect.ImmutableSet; +import io.airlift.log.Logging; +import io.trino.jdbc.TrinoConnection; +import io.trino.metadata.SystemSecurityMetadata; +import io.trino.plugin.memory.MemoryPlugin; +import io.trino.server.testing.TestingTrinoServer; +import io.trino.spi.security.TrinoPrincipal; +import io.trino.testing.TestingAccessControlManager; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.parallel.Execution; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Optional; +import java.util.stream.Stream; + +import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; +import static io.trino.jdbc.BaseTrinoDriverTest.getCurrentUser; +import static io.trino.spi.security.PrincipalType.USER; +import static java.lang.String.format; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD; + +@TestInstance(PER_CLASS) +@Execution(SAME_THREAD) +public class TestImpersonation +{ + private TestingTrinoServer server; + private final TestingSystemSecurityMetadata securityMetadata = new TestingSystemSecurityMetadata(); + private TestingAccessControlManager accessControl; + + @BeforeAll + public void setup() + throws Exception + { + Logging.initialize(); + server = TestingTrinoServer.builder() + .setAdditionalModule(binder -> { + newOptionalBinder(binder, SystemSecurityMetadata.class) + .setBinding() + .toInstance(securityMetadata); + }).build(); + server.installPlugin(new MemoryPlugin()); + server.createCatalog("memory", "memory"); + accessControl = server.getAccessControl(); + } + + @ParameterizedTest + @MethodSource("roles") + @Timeout(10) + public void testImpersonationAllowedByRole(String roleName) + throws Exception + { + securityMetadata.reset(); + accessControl.reset(); + try (TrinoConnection connection = createConnection("memory", "default", "alice").unwrap(TrinoConnection.class); + Statement statement = connection.createStatement()) { + assertThat(getCurrentUser(connection)).isEqualTo("alice"); + securityMetadata.createRole(null, "invalid_role", Optional.empty()); + securityMetadata.grantRoles( + null, + ImmutableSet.of("invalid_role"), + ImmutableSet.of(new TrinoPrincipal(USER, "alice")), + false, + Optional.empty()); + denyImpersonation(); + statement.execute("SET ROLE invalid_role"); + assertThatThrownBy(() -> statement.execute("SET SESSION AUTHORIZATION john")) + .hasMessageContaining("User alice cannot impersonate user john"); + + securityMetadata.createRole(null, "alice_role", Optional.empty()); + securityMetadata.grantRoles( + null, + ImmutableSet.of("alice_role"), + ImmutableSet.of(new TrinoPrincipal(USER, "alice")), + false, + Optional.empty()); + assertThatThrownBy(() -> statement.execute("SET SESSION AUTHORIZATION john")) + .hasMessageContaining("User alice cannot impersonate user john"); + + statement.execute("SET ROLE " + roleName); + statement.execute("SET SESSION AUTHORIZATION john"); + + // This would fail if roles were not correctly propagated + statement.execute("SHOW SCHEMAS IN memory"); + // Call more than once to make sure everything is propagated correctly + // to subsequent calls + statement.execute("SHOW SCHEMAS IN memory"); + } + } + + @Test + @Timeout(10) + public void testImpersonationDisallowedWhenRoleIsNone() + throws Exception + { + securityMetadata.reset(); + accessControl.reset(); + try (TrinoConnection connection = createConnection("memory", "default", "alice").unwrap(TrinoConnection.class); + Statement statement = connection.createStatement()) { + assertThat(getCurrentUser(connection)).isEqualTo("alice"); + securityMetadata.createRole(null, "alice_role", Optional.empty()); + denyImpersonation(); + securityMetadata.grantRoles( + null, + ImmutableSet.of("alice_role"), + ImmutableSet.of(new TrinoPrincipal(USER, "alice")), + false, + Optional.empty()); + statement.execute("SET ROLE NONE"); + + assertThatThrownBy(() -> statement.execute("SET SESSION AUTHORIZATION john")) + .hasMessageContaining("User alice cannot impersonate user john"); + } + } + + private Connection createConnection(String catalog, String schema, String user) + throws SQLException + { + String url = format("jdbc:trino://%s/%s/%s", server.getAddress(), catalog, schema); + return DriverManager.getConnection(url, user, null); + } + + private Stream roles() + { + return Stream.of("alice_role", "ALL"); + } + + private void denyImpersonation() + { + accessControl.denyImpersonation((identity, _) -> + identity.getEnabledRoles() + .stream() + .anyMatch(role -> role.equalsIgnoreCase("alice_role"))); + } +} diff --git a/testing/trino-tests/src/test/java/io/trino/security/TestSystemSecurityMetadata.java b/testing/trino-tests/src/test/java/io/trino/security/TestSystemSecurityMetadata.java index 4dadcbf64af0..249ee4ffccf3 100644 --- a/testing/trino-tests/src/test/java/io/trino/security/TestSystemSecurityMetadata.java +++ b/testing/trino-tests/src/test/java/io/trino/security/TestSystemSecurityMetadata.java @@ -25,6 +25,7 @@ import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; import static io.trino.testing.TestingSession.testSessionBuilder; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD; @Execution(SAME_THREAD) // TestingSystemSecurityMetadata is shared mutable state @@ -96,10 +97,14 @@ public void testRoleGrant() assertQueryReturnsEmptyResult(alice, "SHOW CURRENT ROLES"); assertQueryReturnsEmptyResult(alice, "SHOW ROLE GRANTS"); assertQueryReturnsEmptyResult(alice, "SELECT * FROM system.information_schema.applicable_roles"); - assertQueryFails(aliceWithRole, "SHOW ROLES", "Access Denied: Cannot set role role1"); - assertQueryFails(aliceWithRole, "SHOW CURRENT ROLES", "Access Denied: Cannot set role role1"); - assertQueryFails(aliceWithRole, "SHOW ROLE GRANTS", "Access Denied: Cannot set role role1"); - assertQueryFails(aliceWithRole, "SELECT * FROM system.information_schema.applicable_roles", "Access Denied: Cannot set role role1"); + assertThatThrownBy(() -> getQueryRunner().execute(aliceWithRole, "SHOW ROLES")) + .hasMessageContaining("Access Denied: Cannot set role role1"); + assertThatThrownBy(() -> getQueryRunner().execute(aliceWithRole, "SHOW CURRENT ROLES")) + .hasMessageContaining("Access Denied: Cannot set role role1"); + assertThatThrownBy(() -> getQueryRunner().execute(aliceWithRole, "SHOW ROLE GRANTS")) + .hasMessageContaining("Access Denied: Cannot set role role1"); + assertThatThrownBy(() -> getQueryRunner().execute(aliceWithRole, "SELECT * FROM system.information_schema.applicable_roles")) + .hasMessageContaining("Access Denied: Cannot set role role1"); assertQuerySucceeds("GRANT role1 TO USER alice"); assertQuerySucceeds(alice, "SET ROLE role1"); @@ -117,10 +122,14 @@ public void testRoleGrant() assertQueryReturnsEmptyResult(alice, "SHOW CURRENT ROLES"); assertQueryReturnsEmptyResult(alice, "SHOW ROLE GRANTS"); assertQueryReturnsEmptyResult(alice, "SELECT * FROM system.information_schema.applicable_roles"); - assertQueryFails(aliceWithRole, "SHOW ROLES", "Access Denied: Cannot set role role1"); - assertQueryFails(aliceWithRole, "SHOW CURRENT ROLES", "Access Denied: Cannot set role role1"); - assertQueryFails(aliceWithRole, "SHOW ROLE GRANTS", "Access Denied: Cannot set role role1"); - assertQueryFails(aliceWithRole, "SELECT * FROM system.information_schema.applicable_roles", "Access Denied: Cannot set role role1"); + assertThatThrownBy(() -> getQueryRunner().execute(aliceWithRole, "SHOW ROLES")) + .hasMessageContaining("Access Denied: Cannot set role role1"); + assertThatThrownBy(() -> getQueryRunner().execute(aliceWithRole, "SHOW CURRENT ROLES")) + .hasMessageContaining("Access Denied: Cannot set role role1"); + assertThatThrownBy(() -> getQueryRunner().execute(aliceWithRole, "SHOW ROLE GRANTS")) + .hasMessageContaining("Access Denied: Cannot set role role1"); + assertThatThrownBy(() -> getQueryRunner().execute(aliceWithRole, "SELECT * FROM system.information_schema.applicable_roles")) + .hasMessageContaining("Access Denied: Cannot set role role1"); assertQuerySucceeds("DROP ROLE role1"); } @@ -138,7 +147,8 @@ public void testTransitiveRoleGrant() assertQuerySucceeds("GRANT role1 TO USER alice"); String roleNotApplicableErrorMessage = "Access Denied: Cannot set role role2"; - assertQueryFails(aliceWithRole, "SHOW ROLES", roleNotApplicableErrorMessage); + assertThatThrownBy(() -> getQueryRunner().execute(aliceWithRole, "SHOW ROLES")) + .hasMessageContaining(roleNotApplicableErrorMessage); assertQuerySucceeds("GRANT role2 TO ROLE role1"); assertQuery(alice, "SHOW ROLES", "VALUES 'role1', 'role2'"); @@ -152,7 +162,8 @@ public void testTransitiveRoleGrant() + "('role1', 'ROLE', 'role2', 'NO')"); assertQuerySucceeds("REVOKE role2 FROM ROLE role1"); - assertQueryFails(aliceWithRole, "SHOW ROLES", roleNotApplicableErrorMessage); + assertThatThrownBy(() -> getQueryRunner().execute(aliceWithRole, "SHOW ROLES")) + .hasMessageContaining(roleNotApplicableErrorMessage); assertQuerySucceeds("REVOKE role1 FROM USER alice"); assertQuerySucceeds("DROP ROLE role1");