Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions client/trino-cli/src/main/java/io/trino/cli/Console.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import java.nio.file.Paths;
import java.util.AbstractMap;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -399,6 +400,13 @@ private static boolean process(
builder = builder.roles(ImmutableMap.of());
}

// update session originalRoles
if (!query.getSetOriginalRoles().isEmpty()) {
Set<ClientSelectedRole> originalRoles = new HashSet<>(session.getOriginalRoles());
originalRoles.addAll(query.getSetOriginalRoles());
builder = builder.originalRoles(originalRoles);
}

if (query.isResetAuthorizationUser()) {
builder = builder.authorizationUser(Optional.empty());
builder = builder.roles(ImmutableMap.of());
Expand Down
5 changes: 5 additions & 0 deletions client/trino-cli/src/main/java/io/trino/cli/Query.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ public boolean isResetAuthorizationUser()
return client.isResetAuthorizationUser();
}

public Set<ClientSelectedRole> getSetOriginalRoles()
{
return client.getSetOriginalRoles();
}

public Map<String, String> getSetSessionProperties()
{
return client.getSetSessionProperties();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public class ClientSession
private final Optional<String> user;
private final Optional<String> sessionUser;
private final Optional<String> authorizationUser;
private final Set<ClientSelectedRole> originalRoles;
private final String source;
private final Optional<String> traceToken;
private final Set<String> clientTags;
Expand Down Expand Up @@ -80,6 +81,7 @@ private ClientSession(
Optional<String> user,
Optional<String> sessionUser,
Optional<String> authorizationUser,
Set<ClientSelectedRole> originalRoles,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we pass original user info?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Via sessionUser AFAIK, see

this.originalUser = Stream.of(session.getSessionUser(), session.getUser())

String source,
Optional<String> traceToken,
Set<String> clientTags,
Expand All @@ -103,6 +105,7 @@ private ClientSession(
this.user = requireNonNull(user, "user is null");
this.sessionUser = requireNonNull(sessionUser, "sessionUser is null");
this.authorizationUser = requireNonNull(authorizationUser, "authorizationUser is null");
this.originalRoles = ImmutableSet.copyOf(requireNonNull(originalRoles, "originalRoles is null"));
this.source = requireNonNull(source, "source is null");
this.traceToken = requireNonNull(traceToken, "traceToken is null");
this.clientTags = ImmutableSet.copyOf(requireNonNull(clientTags, "clientTags is null"));
Expand Down Expand Up @@ -171,6 +174,11 @@ public Optional<String> getAuthorizationUser()
return authorizationUser;
}

public Set<ClientSelectedRole> getOriginalRoles()
{
return originalRoles;
}

public String getSource()
{
return source;
Expand Down Expand Up @@ -302,6 +310,7 @@ public static final class Builder
private Optional<String> user = Optional.empty();
private Optional<String> sessionUser = Optional.empty();
private Optional<String> authorizationUser = Optional.empty();
private Set<ClientSelectedRole> originalRoles = ImmutableSet.of();
private String source;
private Optional<String> traceToken = Optional.empty();
private Set<String> clientTags = ImmutableSet.of();
Expand Down Expand Up @@ -330,6 +339,7 @@ private Builder(ClientSession clientSession)
user = clientSession.getUser();
sessionUser = clientSession.getSessionUser();
authorizationUser = clientSession.getAuthorizationUser();
originalRoles = clientSession.getOriginalRoles();
source = clientSession.getSource();
traceToken = clientSession.getTraceToken();
clientTags = clientSession.getClientTags();
Expand Down Expand Up @@ -374,6 +384,12 @@ public Builder authorizationUser(Optional<String> authorizationUser)
return this;
}

public Builder originalRoles(Set<ClientSelectedRole> originalRoles)
{
this.originalRoles = originalRoles;
return this;
}

public Builder source(String source)
{
this.source = source;
Expand Down Expand Up @@ -489,6 +505,7 @@ public ClientSession build()
user,
sessionUser,
authorizationUser,
originalRoles,
source,
traceToken,
clientTags,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public final class ProtocolHeaders
private final String name;
private final String requestUser;
private final String requestOriginalUser;
private final String requestOriginalRole;
private final String requestSource;
private final String requestCatalog;
private final String requestSchema;
Expand Down Expand Up @@ -57,6 +58,7 @@ public final class ProtocolHeaders
private final String responseClearTransactionId;
private final String responseSetAuthorizationUser;
private final String responseResetAuthorizationUser;
private final String responseOriginalRole;

public static ProtocolHeaders createProtocolHeaders(String name)
{
Expand All @@ -75,6 +77,7 @@ private ProtocolHeaders(String name)
String prefix = "X-" + name + "-";
requestUser = prefix + "User";
requestOriginalUser = prefix + "Original-User";
requestOriginalRole = prefix + "Original-Roles";
requestSource = prefix + "Source";
requestCatalog = prefix + "Catalog";
requestSchema = prefix + "Schema";
Expand Down Expand Up @@ -105,6 +108,7 @@ private ProtocolHeaders(String name)
responseClearTransactionId = prefix + "Clear-Transaction-Id";
responseSetAuthorizationUser = prefix + "Set-Authorization-User";
responseResetAuthorizationUser = prefix + "Reset-Authorization-User";
responseOriginalRole = prefix + "Set-Original-Roles";
}

public String getProtocolName()
Expand All @@ -122,6 +126,11 @@ public String requestOriginalUser()
return requestOriginalUser;
}

public String requestOriginalRole()
{
return requestOriginalRole;
}

public String requestSource()
{
return requestSource;
Expand Down Expand Up @@ -272,6 +281,11 @@ public String responseResetAuthorizationUser()
return responseResetAuthorizationUser;
}

public String responseOriginalRole()
{
return responseOriginalRole;
}

public static ProtocolHeaders detectProtocol(Optional<String> alternateHeaderName, Set<String> headerNames)
throws ProtocolDetectionException
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ default Optional<String> getEncoding()

boolean isResetAuthorizationUser();

Set<ClientSelectedRole> getSetOriginalRoles();

Map<String, String> getSetSessionProperties();

Set<String> getResetSessionProperties();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class StatementClientV1
private final AtomicReference<List<String>> setPath = new AtomicReference<>();
private final AtomicReference<String> setAuthorizationUser = new AtomicReference<>();
private final AtomicBoolean resetAuthorizationUser = new AtomicBoolean();
private final Set<ClientSelectedRole> setOriginalRoles = Sets.newConcurrentHashSet();
private final Map<String, String> setSessionProperties = new ConcurrentHashMap<>();
private final Set<String> resetSessionProperties = Sets.newConcurrentHashSet();
private final Map<String, ClientSelectedRole> setRoles = new ConcurrentHashMap<>();
Expand Down Expand Up @@ -125,6 +126,7 @@ public StatementClientV1(Call.Factory httpCallFactory, Call.Factory segmentHttpC
.filter(Optional::isPresent)
.map(Optional::get)
.findFirst();
this.setOriginalRoles.addAll(session.getOriginalRoles());
this.clientCapabilities = Joiner.on(",").join(clientCapabilities.orElseGet(() -> stream(ClientCapabilities.values())
.map(Enum::name)
.collect(toImmutableSet())));
Expand Down Expand Up @@ -181,6 +183,10 @@ private Request buildQueryRequest(ClientSession session, String query, Optional<
builder.addHeader(TRINO_HEADERS.requestResourceEstimate(), entry.getKey() + "=" + urlEncode(entry.getValue()));
}

for (ClientSelectedRole selectedRole : session.getOriginalRoles()) {
builder.addHeader(TRINO_HEADERS.requestOriginalRole(), selectedRole.toString());
}

Map<String, ClientSelectedRole> roles = session.getRoles();
for (Entry<String, ClientSelectedRole> entry : roles.entrySet()) {
builder.addHeader(TRINO_HEADERS.requestRole(), entry.getKey() + '=' + urlEncode(entry.getValue().toString()));
Expand Down Expand Up @@ -316,6 +322,12 @@ public boolean isResetAuthorizationUser()
return resetAuthorizationUser.get();
}

@Override
public Set<ClientSelectedRole> getSetOriginalRoles()
{
return ImmutableSet.copyOf(setOriginalRoles);
}

@Override
public Map<String, String> getSetSessionProperties()
{
Expand Down Expand Up @@ -475,6 +487,11 @@ private void processResponse(Headers headers, QueryResults results)
this.resetAuthorizationUser.set(Boolean.parseBoolean(resetAuthorizationUser));
}

setOriginalRoles.addAll(headers.values(TRINO_HEADERS.responseOriginalRole())
.stream()
.map(role -> ClientSelectedRole.valueOf(urlDecode(role)))
.collect(toImmutableSet()));

for (String setSession : headers.values(TRINO_HEADERS.responseSetSession())) {
List<String> keyValue = COLLECTION_HEADER_SPLITTER.splitToList(setSession);
if (keyValue.size() != 2) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import com.google.common.primitives.Ints;
import io.airlift.units.Duration;
import io.trino.client.ClientSelectedRole;
Expand Down Expand Up @@ -107,6 +108,7 @@ public class TrinoConnection
private final AtomicReference<String> schema = new AtomicReference<>();
private final AtomicReference<List<String>> path = new AtomicReference<>(ImmutableList.of());
private final AtomicReference<String> authorizationUser = new AtomicReference<>();
private final Set<ClientSelectedRole> originalRoles = Sets.newConcurrentHashSet();
private final AtomicReference<ZoneId> timeZoneId = new AtomicReference<>();
private final AtomicReference<Locale> locale = new AtomicReference<>();
private final AtomicReference<Integer> networkTimeoutMillis = new AtomicReference<>(Ints.saturatedCast(MINUTES.toMillis(2)));
Expand Down Expand Up @@ -855,6 +857,7 @@ StatementClient startQuery(String sql, Map<String, String> sessionPropertiesOver
.user(user)
.sessionUser(sessionUser.get())
.authorizationUser(Optional.ofNullable(authorizationUser.get()))
.originalRoles(ImmutableSet.copyOf(originalRoles))
.source(source)
.traceToken(Optional.ofNullable(clientInfo.get(TRACE_TOKEN)))
.clientTags(ImmutableSet.copyOf(clientTags))
Expand Down Expand Up @@ -891,10 +894,12 @@ void updateSession(StatementClient client)

if (client.getSetAuthorizationUser().isPresent()) {
authorizationUser.set(client.getSetAuthorizationUser().get());
originalRoles.addAll(client.getSetOriginalRoles());
roles.clear();
}
if (client.isResetAuthorizationUser()) {
authorizationUser.set(null);
originalRoles.clear();
roles.clear();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1241,7 +1241,7 @@ private static Properties toProperties(Map<String, String> map)
return properties;
}

private static String getCurrentUser(Connection connection)
public static String getCurrentUser(Connection connection)
throws SQLException
{
try (Statement statement = connection.createStatement();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,12 @@ public boolean isResetAuthorizationUser()
throw new UnsupportedOperationException();
}

@Override
public Set<ClientSelectedRole> getSetOriginalRoles()
{
throw new UnsupportedOperationException();
}

@Override
public Map<String, String> getSetSessionProperties()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,12 @@ public boolean isResetAuthorizationUser()
return true;
}

@Override
public Set<ClientSelectedRole> getSetOriginalRoles()
{
return Set.of();
}

@Override
public Map<String, String> getSetSessionProperties()
{
Expand Down
3 changes: 2 additions & 1 deletion core/trino-main/src/main/java/io/trino/Session.java
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,8 @@ public SessionRepresentation toSessionRepresentation()
clientTransactionSupport,
identity.getUser(),
originalIdentity.getUser(),
identity.getGroups(),
originalIdentity.getEnabledRoles(),
originalIdentity.getGroups(),
originalIdentity.getGroups(),
identity.getPrincipal().map(Principal::toString),
identity.getEnabledRoles(),
Expand Down
10 changes: 10 additions & 0 deletions core/trino-main/src/main/java/io/trino/SessionRepresentation.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public final class SessionRepresentation
private final boolean clientTransactionSupport;
private final String user;
private final String originalUser;
private final Set<String> originalRoles;
private final Set<String> groups;
private final Set<String> originalUserGroups;
private final Optional<String> principal;
Expand Down Expand Up @@ -81,6 +82,7 @@ public SessionRepresentation(
@JsonProperty("clientTransactionSupport") boolean clientTransactionSupport,
@JsonProperty("user") String user,
@JsonProperty("originalUser") String originalUser,
@JsonProperty("setOriginalRoles") Set<String> originalRoles,
@JsonProperty("groups") Set<String> groups,
@JsonProperty("originalUserGroups") Set<String> originalUserGroups,
@JsonProperty("principal") Optional<String> principal,
Expand Down Expand Up @@ -112,6 +114,7 @@ public SessionRepresentation(
this.clientTransactionSupport = clientTransactionSupport;
this.user = requireNonNull(user, "user is null");
this.originalUser = requireNonNull(originalUser, "originalUser is null");
this.originalRoles = requireNonNull(originalRoles, "setOriginalRoles is null");
this.groups = requireNonNull(groups, "groups is null");
this.originalUserGroups = requireNonNull(originalUserGroups, "originalUserGroups is null");
this.principal = requireNonNull(principal, "principal is null");
Expand Down Expand Up @@ -179,6 +182,12 @@ public String getOriginalUser()
return originalUser;
}

@JsonProperty
public Set<String> getOriginalRoles()
{
return originalRoles;
}

@JsonProperty
public Set<String> getGroups()
{
Expand Down Expand Up @@ -350,6 +359,7 @@ public Identity toOriginalIdentity(Map<String, String> extraCredentials)
return Identity.forUser(originalUser)
.withGroups(originalUserGroups)
.withPrincipal(principal.map(BasicPrincipal::new))
.withEnabledRoles(originalRoles)
.withExtraCredentials(extraCredentials)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ private static QueryInfo immediateFailureQueryInfo(
Optional.empty(),
Optional.empty(),
false,
ImmutableSet.of(),
ImmutableMap.of(),
ImmutableSet.of(),
ImmutableMap.of(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ private QueryContext createQueryContext(SessionRepresentation session, Optional<
return new QueryContext(
session.getUser(),
session.getOriginalUser(),
session.getOriginalRoles(),
session.getPrincipal(),
session.getEnabledRoles(),
session.getGroups(),
Expand Down
Loading
Loading