Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,8 @@ primaryExpression
filter? over? #functionCall
| processingMode? qualifiedName '(' (setQuantifier? expression (',' expression)*)?
orderBy? ')' filter? (nullTreatment? over)? #functionCall
| qualifiedName '::' identifier '(' (expression (',' expression)*)? ')' #staticMethodCall
| primaryExpression '.' identifier '(' (expression (',' expression)*)? ')' #methodCall
| identifier over #measure
| identifier '->' expression #lambda
| '(' (identifier (',' identifier)*)? ')' '->' expression #lambda
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
Expand Down Expand Up @@ -118,7 +119,9 @@ public ResolvedFunction resolveFunction(Session session, QualifiedName name, Lis
session,
name,
parameterTypes,
catalogSchemaFunctionName -> metadata.getFunctions(session, catalogSchemaFunctionName),
catalogSchemaFunctionName -> filterCandidates(
metadata.getFunctions(session, catalogSchemaFunctionName),
candidate -> candidate.functionMetadata().getReceiverType().isEmpty()),
accessControl);

FunctionMetadata functionMetadata = catalogFunctionBinding.boundFunctionMetadata();
Expand All @@ -129,6 +132,69 @@ public ResolvedFunction resolveFunction(Session session, QualifiedName name, Lis
return resolve(session, catalogFunctionBinding, accessControl);
}

public ResolvedFunction resolveStaticMethod(
Session session,
TypeSignature receiverType,
QualifiedName methodName,
List<TypeSignatureProvider> parameterTypes,
AccessControl accessControl)
{
String receiver = receiverType.getBase();
CatalogFunctionBinding catalogFunctionBinding = bindFunction(
session,
methodName,
parameterTypes,
catalogSchemaFunctionName -> filterCandidates(
metadata.getFunctions(session, catalogSchemaFunctionName),
candidate -> !candidate.functionMetadata().isInstanceMethod()
&& candidate.functionMetadata().getReceiverType()
.map(TypeSignature::getBase).equals(Optional.of(receiver))),
accessControl);

FunctionMetadata functionMetadata = catalogFunctionBinding.boundFunctionMetadata();
if (functionMetadata.isDeprecated()) {
warningCollector.add(new TrinoWarning(DEPRECATED_FUNCTION, "Use of deprecated function: %s::%s: %s".formatted(receiverType, methodName, functionMetadata.getDescription())));
}

return resolve(session, catalogFunctionBinding, accessControl);
}

public ResolvedFunction resolveInstanceMethod(
Session session,
TypeSignature receiverType,
QualifiedName methodName,
List<TypeSignatureProvider> parameterTypes,
AccessControl accessControl)
{
String receiver = receiverType.getBase();
CatalogFunctionBinding catalogFunctionBinding = bindFunction(
session,
methodName,
parameterTypes,
catalogSchemaFunctionName -> filterCandidates(
metadata.getFunctions(session, catalogSchemaFunctionName),
candidate -> candidate.functionMetadata().isInstanceMethod()
&& candidate.functionMetadata().getReceiverType()
.map(TypeSignature::getBase).equals(Optional.of(receiver))),
accessControl);

FunctionMetadata functionMetadata = catalogFunctionBinding.boundFunctionMetadata();
if (functionMetadata.isDeprecated()) {
warningCollector.add(new TrinoWarning(DEPRECATED_FUNCTION, "Use of deprecated function: %s.%s: %s".formatted(receiverType, methodName, functionMetadata.getDescription())));
}

return resolve(session, catalogFunctionBinding, accessControl);
}

private static Collection<CatalogFunctionMetadata> filterCandidates(
Collection<CatalogFunctionMetadata> candidates,
Predicate<CatalogFunctionMetadata> predicate)
{
return candidates.stream()
.filter(predicate)
.collect(toImmutableList());
}

private ResolvedFunction resolve(Session session, CatalogFunctionBinding functionBinding, AccessControl accessControl)
{
if (isTrinoSqlLanguageFunction(functionBinding.functionBinding().getFunctionId())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ private static FunctionMetadata createFunctionMetadata(Signature signature, Scal
if (deprecated) {
functionMetadata.deprecated();
}
if (details.isInstanceMethod()) {
checkCondition(!signature.getArgumentTypes().isEmpty(), FUNCTION_IMPLEMENTATION_ERROR, "Instance method %s must declare a self argument", details.getName());
functionMetadata.receiverType(signature.getArgumentTypes().getFirst());
functionMetadata.instanceMethod();
}
else {
details.getReceiverType().ifPresent(functionMetadata::receiverType);
}

if (functionNullability.isReturnNullable()) {
functionMetadata.nullable();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.spi.function.InstanceMethod;
import io.trino.spi.function.OperatorType;
import io.trino.spi.function.ScalarFunction;
import io.trino.spi.function.ScalarOperator;
import io.trino.spi.function.StaticMethod;
import io.trino.spi.type.TypeSignature;

import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Method;
Expand All @@ -30,6 +33,7 @@
import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.metadata.OperatorNameUtil.mangleOperatorName;
import static io.trino.operator.annotations.FunctionsParserHelper.parseDescription;
import static io.trino.sql.analyzer.TypeSignatureTranslator.parseTypeSignature;
import static java.util.Objects.requireNonNull;

public class ScalarHeader
Expand All @@ -41,8 +45,15 @@ public class ScalarHeader
private final boolean hidden;
private final boolean deterministic;
private final boolean neverFails;
private final Optional<TypeSignature> receiverType;
private final boolean instanceMethod;

public ScalarHeader(String name, Set<String> aliases, Optional<String> description, boolean hidden, boolean deterministic, boolean neverFails)
{
this(name, aliases, description, hidden, deterministic, neverFails, Optional.empty(), false);
}

public ScalarHeader(String name, Set<String> aliases, Optional<String> description, boolean hidden, boolean deterministic, boolean neverFails, Optional<TypeSignature> receiverType, boolean instanceMethod)
{
this.name = requireNonNull(name, "name is null");
checkArgument(!name.isEmpty());
Expand All @@ -53,6 +64,9 @@ public ScalarHeader(String name, Set<String> aliases, Optional<String> descripti
this.hidden = hidden;
this.deterministic = deterministic;
this.neverFails = neverFails;
this.receiverType = requireNonNull(receiverType, "receiverType is null");
checkArgument(!instanceMethod || receiverType.isEmpty(), "instance method receiver type is inferred from the first argument");
this.instanceMethod = instanceMethod;
}

public ScalarHeader(OperatorType operatorType, Optional<String> description)
Expand All @@ -64,19 +78,36 @@ public ScalarHeader(OperatorType operatorType, Optional<String> description)
this.hidden = true;
this.deterministic = true;
this.neverFails = false;
this.receiverType = Optional.empty();
this.instanceMethod = false;
}

public static List<ScalarHeader> fromAnnotatedElement(AnnotatedElement annotated)
{
ScalarFunction scalarFunction = annotated.getAnnotation(ScalarFunction.class);
ScalarOperator scalarOperator = annotated.getAnnotation(ScalarOperator.class);
StaticMethod staticMethod = annotated.getAnnotation(StaticMethod.class);
InstanceMethod instanceMethod = annotated.getAnnotation(InstanceMethod.class);
Optional<String> description = parseDescription(annotated);

ImmutableList.Builder<ScalarHeader> builder = ImmutableList.builder();

if (scalarFunction != null) {
checkArgument(staticMethod == null || instanceMethod == null, "@StaticMethod and @InstanceMethod are mutually exclusive on %s", annotated);
String baseName = scalarFunction.value().isEmpty() ? camelToSnake(annotatedName(annotated)) : scalarFunction.value();
builder.add(new ScalarHeader(baseName, ImmutableSet.copyOf(scalarFunction.alias()), description, scalarFunction.hidden(), scalarFunction.deterministic(), scalarFunction.neverFails()));
Optional<TypeSignature> receiverType = Optional.empty();
if (staticMethod != null) {
TypeSignature parsed = parseTypeSignature(staticMethod.value(), ImmutableSet.of());
checkArgument(parsed.getParameters().isEmpty(), "@StaticMethod receiver type must not have parameters: %s", staticMethod.value());
receiverType = Optional.of(parsed);
}
builder.add(new ScalarHeader(baseName, ImmutableSet.copyOf(scalarFunction.alias()), description, scalarFunction.hidden(), scalarFunction.deterministic(), scalarFunction.neverFails(), receiverType, instanceMethod != null));
}
else if (staticMethod != null) {
throw new IllegalArgumentException("@StaticMethod requires @ScalarFunction on " + annotated);
}
else if (instanceMethod != null) {
throw new IllegalArgumentException("@InstanceMethod requires @ScalarFunction on " + annotated);
}

if (scalarOperator != null) {
Expand Down Expand Up @@ -139,4 +170,14 @@ public boolean neverFails()
{
return neverFails;
}

public Optional<TypeSignature> getReceiverType()
{
return receiverType;
}

public boolean isInstanceMethod()
{
return instanceMethod;
}
}
11 changes: 11 additions & 0 deletions core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ public class Analysis
private final Map<NodeRef<Expression>, ResolvedFunction> frameBoundCalculations = new LinkedHashMap<>();
private final Map<NodeRef<Relation>, List<Type>> relationCoercions = new LinkedHashMap<>();
private final Map<NodeRef<Node>, RoutineEntry> resolvedFunctions = new LinkedHashMap<>();
private final Map<NodeRef<FunctionCall>, Identifier> methodCallReceivers = new LinkedHashMap<>();
private final Map<NodeRef<Identifier>, LambdaArgumentDeclaration> lambdaArgumentReferences = new LinkedHashMap<>();

private final Map<Field, ColumnHandle> columns = new LinkedHashMap<>();
Expand Down Expand Up @@ -720,6 +721,16 @@ public void addResolvedFunction(Node node, ResolvedFunction function, String aut
resolvedFunctions.put(NodeRef.of(node), new RoutineEntry(function, authorization));
}

public void addMethodCallReceiver(FunctionCall node, Identifier receiver)
{
methodCallReceivers.put(NodeRef.of(node), receiver);
}

public Optional<Identifier> getMethodCallReceiver(FunctionCall node)
{
return Optional.ofNullable(methodCallReceivers.get(NodeRef.of(node)));
}

public Set<NodeRef<Expression>> getColumnReferences()
{
return unmodifiableSet(columnReferences.keySet());
Expand Down
Loading