diff --git a/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 b/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 index b840a8fa633c..881b8ed0dff3 100644 --- a/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 +++ b/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 @@ -606,6 +606,7 @@ primaryExpression | processingMode? qualifiedName '(' (setQuantifier? expression (',' expression)*)? orderBy? ')' filter? (nullTreatment? over)? #functionCall | qualifiedName '::' identifier '(' (expression (',' expression)*)? ')' #staticMethodCall + | primaryExpression '.' identifier '(' (expression (',' expression)*)? ')' #methodCall | identifier over #measure | identifier '->' expression #lambda | '(' (identifier (',' identifier)*)? ')' '->' expression #lambda diff --git a/core/trino-main/src/main/java/io/trino/metadata/FunctionResolver.java b/core/trino-main/src/main/java/io/trino/metadata/FunctionResolver.java index ac4afe12e528..f8c648635144 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/FunctionResolver.java +++ b/core/trino-main/src/main/java/io/trino/metadata/FunctionResolver.java @@ -146,7 +146,8 @@ public ResolvedFunction resolveStaticMethod( parameterTypes, catalogSchemaFunctionName -> filterCandidates( metadata.getFunctions(session, catalogSchemaFunctionName), - candidate -> candidate.functionMetadata().getReceiverType() + candidate -> !candidate.functionMetadata().isInstanceMethod() + && candidate.functionMetadata().getReceiverType() .map(TypeSignature::getBase).equals(Optional.of(receiver))), accessControl); @@ -158,6 +159,33 @@ public ResolvedFunction resolveStaticMethod( return resolve(session, catalogFunctionBinding, accessControl); } + public ResolvedFunction resolveInstanceMethod( + Session session, + TypeSignature receiverType, + QualifiedName methodName, + List parameterTypes, + AccessControl accessControl) + { + String receiver = receiverType.getBase(); + CatalogFunctionBinding catalogFunctionBinding = bindFunction( + session, + methodName, + parameterTypes, + catalogSchemaFunctionName -> filterCandidates( + metadata.getFunctions(session, catalogSchemaFunctionName), + candidate -> candidate.functionMetadata().isInstanceMethod() + && candidate.functionMetadata().getReceiverType() + .map(TypeSignature::getBase).equals(Optional.of(receiver))), + accessControl); + + FunctionMetadata functionMetadata = catalogFunctionBinding.boundFunctionMetadata(); + if (functionMetadata.isDeprecated()) { + warningCollector.add(new TrinoWarning(DEPRECATED_FUNCTION, "Use of deprecated function: %s.%s: %s".formatted(receiverType, methodName, functionMetadata.getDescription()))); + } + + return resolve(session, catalogFunctionBinding, accessControl); + } + private static Collection filterCandidates( Collection candidates, Predicate predicate) diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ParametricScalar.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ParametricScalar.java index da6b20bc1464..0bba2c120db7 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ParametricScalar.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ParametricScalar.java @@ -81,7 +81,14 @@ private static FunctionMetadata createFunctionMetadata(Signature signature, Scal if (deprecated) { functionMetadata.deprecated(); } - details.getReceiverType().ifPresent(functionMetadata::receiverType); + if (details.isInstanceMethod()) { + checkCondition(!signature.getArgumentTypes().isEmpty(), FUNCTION_IMPLEMENTATION_ERROR, "Instance method %s must declare a self argument", details.getName()); + functionMetadata.receiverType(signature.getArgumentTypes().getFirst()); + functionMetadata.instanceMethod(); + } + else { + details.getReceiverType().ifPresent(functionMetadata::receiverType); + } if (functionNullability.isReturnNullable()) { functionMetadata.nullable(); diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ScalarHeader.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ScalarHeader.java index 0b9fca4900b5..b00deb33d947 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ScalarHeader.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ScalarHeader.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import io.trino.spi.function.InstanceMethod; import io.trino.spi.function.OperatorType; import io.trino.spi.function.ScalarFunction; import io.trino.spi.function.ScalarOperator; @@ -45,13 +46,14 @@ public class ScalarHeader private final boolean deterministic; private final boolean neverFails; private final Optional receiverType; + private final boolean instanceMethod; public ScalarHeader(String name, Set aliases, Optional description, boolean hidden, boolean deterministic, boolean neverFails) { - this(name, aliases, description, hidden, deterministic, neverFails, Optional.empty()); + this(name, aliases, description, hidden, deterministic, neverFails, Optional.empty(), false); } - public ScalarHeader(String name, Set aliases, Optional description, boolean hidden, boolean deterministic, boolean neverFails, Optional receiverType) + public ScalarHeader(String name, Set aliases, Optional description, boolean hidden, boolean deterministic, boolean neverFails, Optional receiverType, boolean instanceMethod) { this.name = requireNonNull(name, "name is null"); checkArgument(!name.isEmpty()); @@ -63,6 +65,8 @@ public ScalarHeader(String name, Set aliases, Optional descripti this.deterministic = deterministic; this.neverFails = neverFails; this.receiverType = requireNonNull(receiverType, "receiverType is null"); + checkArgument(!instanceMethod || receiverType.isEmpty(), "instance method receiver type is inferred from the first argument"); + this.instanceMethod = instanceMethod; } public ScalarHeader(OperatorType operatorType, Optional description) @@ -75,6 +79,7 @@ public ScalarHeader(OperatorType operatorType, Optional description) this.deterministic = true; this.neverFails = false; this.receiverType = Optional.empty(); + this.instanceMethod = false; } public static List fromAnnotatedElement(AnnotatedElement annotated) @@ -82,11 +87,13 @@ public static List fromAnnotatedElement(AnnotatedElement annotated ScalarFunction scalarFunction = annotated.getAnnotation(ScalarFunction.class); ScalarOperator scalarOperator = annotated.getAnnotation(ScalarOperator.class); StaticMethod staticMethod = annotated.getAnnotation(StaticMethod.class); + InstanceMethod instanceMethod = annotated.getAnnotation(InstanceMethod.class); Optional description = parseDescription(annotated); ImmutableList.Builder builder = ImmutableList.builder(); if (scalarFunction != null) { + checkArgument(staticMethod == null || instanceMethod == null, "@StaticMethod and @InstanceMethod are mutually exclusive on %s", annotated); String baseName = scalarFunction.value().isEmpty() ? camelToSnake(annotatedName(annotated)) : scalarFunction.value(); Optional receiverType = Optional.empty(); if (staticMethod != null) { @@ -94,11 +101,14 @@ public static List fromAnnotatedElement(AnnotatedElement annotated checkArgument(parsed.getParameters().isEmpty(), "@StaticMethod receiver type must not have parameters: %s", staticMethod.value()); receiverType = Optional.of(parsed); } - builder.add(new ScalarHeader(baseName, ImmutableSet.copyOf(scalarFunction.alias()), description, scalarFunction.hidden(), scalarFunction.deterministic(), scalarFunction.neverFails(), receiverType)); + builder.add(new ScalarHeader(baseName, ImmutableSet.copyOf(scalarFunction.alias()), description, scalarFunction.hidden(), scalarFunction.deterministic(), scalarFunction.neverFails(), receiverType, instanceMethod != null)); } else if (staticMethod != null) { throw new IllegalArgumentException("@StaticMethod requires @ScalarFunction on " + annotated); } + else if (instanceMethod != null) { + throw new IllegalArgumentException("@InstanceMethod requires @ScalarFunction on " + annotated); + } if (scalarOperator != null) { builder.add(new ScalarHeader(scalarOperator.value(), description)); @@ -165,4 +175,9 @@ public Optional getReceiverType() { return receiverType; } + + public boolean isInstanceMethod() + { + return instanceMethod; + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java index e40a05920abb..9a9d607380fd 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java @@ -225,6 +225,7 @@ public class Analysis private final Map, ResolvedFunction> frameBoundCalculations = new LinkedHashMap<>(); private final Map, List> relationCoercions = new LinkedHashMap<>(); private final Map, RoutineEntry> resolvedFunctions = new LinkedHashMap<>(); + private final Map, Identifier> methodCallReceivers = new LinkedHashMap<>(); private final Map, LambdaArgumentDeclaration> lambdaArgumentReferences = new LinkedHashMap<>(); private final Map columns = new LinkedHashMap<>(); @@ -720,6 +721,16 @@ public void addResolvedFunction(Node node, ResolvedFunction function, String aut resolvedFunctions.put(NodeRef.of(node), new RoutineEntry(function, authorization)); } + public void addMethodCallReceiver(FunctionCall node, Identifier receiver) + { + methodCallReceivers.put(NodeRef.of(node), receiver); + } + + public Optional getMethodCallReceiver(FunctionCall node) + { + return Optional.ofNullable(methodCallReceivers.get(NodeRef.of(node))); + } + public Set> getColumnReferences() { return unmodifiableSet(columnReferences.keySet()); diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java index a8ee095a9530..bf407dce7737 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java @@ -129,6 +129,7 @@ import io.trino.sql.tree.LogicalExpression; import io.trino.sql.tree.LongLiteral; import io.trino.sql.tree.MeasureDefinition; +import io.trino.sql.tree.MethodCall; import io.trino.sql.tree.Node; import io.trino.sql.tree.NodeRef; import io.trino.sql.tree.NotExpression; @@ -329,6 +330,7 @@ public class ExpressionAnalyzer private final Cache varcharCastableTypeCache = buildNonEvictableCache(CacheBuilder.newBuilder().maximumSize(1000)); private final Map, ResolvedFunction> resolvedFunctions = new LinkedHashMap<>(); + private final Map, Identifier> methodCallReceivers = new LinkedHashMap<>(); private final Set> subqueries = new LinkedHashSet<>(); private final Set> existsSubqueries = new LinkedHashSet<>(); private final Map, Type> expressionCoercions = new LinkedHashMap<>(); @@ -437,6 +439,11 @@ public Map, ResolvedFunction> getResolvedFunctions() return unmodifiableMap(resolvedFunctions); } + public Map, Identifier> getMethodCallReceivers() + { + return unmodifiableMap(methodCallReceivers); + } + public Map, Type> getExpressionTypes() { return unmodifiableMap(expressionTypes); @@ -1309,6 +1316,14 @@ protected Type visitNullLiteral(NullLiteral node, Context context) @Override protected Type visitFunctionCall(FunctionCall node, Context context) { + // SQL:2023 6.3 Syntax Rule 2: a non-parenthesized value expression primary + // of the form A.B(args) is treated as a method invocation if it satisfies + // the rules for one; otherwise it is a routine invocation. + Optional asMethod = tryResolveAsInstanceMethod(node, context); + if (asMethod.isPresent()) { + return asMethod.get(); + } + boolean isAggregation = functionResolver.isAggregationFunction(session, node.getName(), accessControl); boolean isRowPatternCount = context.isPatternRecognition() && isAggregation && @@ -1460,6 +1475,123 @@ else if (isAggregation) { return setExpressionType(node, type); } + private Optional tryResolveAsInstanceMethod(FunctionCall node, Context context) + { + QualifiedName name = node.getName(); + if (name.getParts().size() != 2) { + return Optional.empty(); + } + if (node.isDistinct() + || node.getFilter().isPresent() + || node.getOrderBy().isPresent() + || node.getWindow().isPresent() + || node.getProcessingMode().isPresent() + || node.getNullTreatment().isPresent()) { + return Optional.empty(); + } + if (context.isPatternRecognition() || context.isInWindow()) { + return Optional.empty(); + } + + Identifier receiver = name.getOriginalParts().get(0); + Identifier method = name.getOriginalParts().get(1); + + // Method-call interpretation only applies when the receiver resolves + // as a field in the current scope. tryResolveField has no side effects + // so a non-match leaves the analyzer state untouched. + Optional resolvedReceiver = context.getScope() + .tryResolveField(receiver, QualifiedName.of(receiver.getValue())); + if (resolvedReceiver.isEmpty()) { + return Optional.empty(); + } + Type receiverType = resolvedReceiver.get().getField().getType(); + + MethodResolution resolution; + try { + resolution = resolveInstanceMethodCall(receiverType, method.getValue(), node.getArguments(), context); + } + catch (TrinoException e) { + return Optional.empty(); + } + + // Commit to method-call interpretation: record the receiver field reference. + process(receiver, context); + + Type result = analyzeInstanceMethodInvocation(node, receiver, receiverType, method.getValue(), node.getArguments(), resolution, context); + methodCallReceivers.put(NodeRef.of(node), receiver); + return Optional.of(result); + } + + @Override + protected Type visitMethodCall(MethodCall node, Context context) + { + Type receiverType = process(node.getReceiver(), context); + String methodName = node.getMethod().getValue(); + + MethodResolution resolution; + try { + resolution = resolveInstanceMethodCall(receiverType, methodName, node.getArguments(), context); + } + catch (TrinoException e) { + if (e.getLocation().isPresent()) { + throw e; + } + throw new TrinoException(e::getErrorCode, extractLocation(node), e.getMessage(), e); + } + + return analyzeInstanceMethodInvocation(node, node.getReceiver(), receiverType, methodName, node.getArguments(), resolution, context); + } + + private MethodResolution resolveInstanceMethodCall(Type receiverType, String methodName, List arguments, Context context) + { + List argumentTypes = ImmutableList.builder() + .add(new TypeSignatureProvider(receiverType.getTypeSignature())) + .addAll(getCallArgumentTypes(arguments, context)) + .build(); + ResolvedFunction function = functionResolver.resolveInstanceMethod( + session, + receiverType.getTypeSignature(), + QualifiedName.of(methodName), + argumentTypes, + accessControl); + return new MethodResolution(function, argumentTypes); + } + + private Type analyzeInstanceMethodInvocation( + Expression node, + Expression receiver, + Type receiverType, + String methodName, + List arguments, + MethodResolution resolution, + Context context) + { + if (arguments.size() + 1 > 127) { + throw semanticException(TOO_MANY_ARGUMENTS, node, "Too many arguments for method call .%s()", methodName); + } + + BoundSignature signature = resolution.function().signature(); + Type expectedReceiverType = signature.getArgumentTypes().getFirst(); + coerceType(receiver, receiverType, expectedReceiverType, format("Method .%s receiver", methodName)); + // Slot 0 of the signature is the receiver (self), so user-visible argument i maps to signature slot i + 1. + for (int i = 0; i < arguments.size(); i++) { + Expression expression = arguments.get(i); + Type expectedType = signature.getArgumentTypes().get(i + 1); + if (resolution.argumentTypes().get(i + 1).hasDependency()) { + FunctionType expectedFunctionType = (FunctionType) expectedType; + process(expression, context.expectingLambda(expectedFunctionType.getArgumentTypes())); + } + else { + Type actualType = plannerContext.getTypeManager().getType(resolution.argumentTypes().get(i + 1).getTypeSignature()); + coerceType(expression, actualType, expectedType, format("Method .%s argument %d", methodName, i)); + } + } + resolvedFunctions.put(NodeRef.of(node), resolution.function()); + return setExpressionType(node, signature.getReturnType()); + } + + private record MethodResolution(ResolvedFunction function, List argumentTypes) {} + @Override protected Type visitStaticMethodCall(StaticMethodCall node, Context context) { @@ -3948,6 +4080,7 @@ private static void updateAnalysis(Analysis analysis, ExpressionAnalyzer analyze analyzer.getSortKeyCoercionsForFrameBoundComparison()); analysis.addFrameBoundCalculations(analyzer.getFrameBoundCalculations()); analyzer.getResolvedFunctions().forEach((key, value) -> analysis.addResolvedFunction(key.getNode(), value, session.getUser())); + analyzer.getMethodCallReceivers().forEach((key, value) -> analysis.addMethodCallReceiver(key.getNode(), value)); analysis.addColumnReferences(analyzer.getColumnReferences()); analysis.addLambdaArgumentReferences(analyzer.getLambdaArgumentReferences()); analysis.addTableColumnReferences(accessControl, session.getIdentity(), analyzer.getTableColumnReferences()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java b/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java index 1cb3b0f32cbb..1de953de58e4 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java @@ -104,6 +104,7 @@ import io.trino.sql.tree.LocalTimestamp; import io.trino.sql.tree.LogicalExpression; import io.trino.sql.tree.LongLiteral; +import io.trino.sql.tree.MethodCall; import io.trino.sql.tree.NodeRef; import io.trino.sql.tree.NotExpression; import io.trino.sql.tree.NullIfExpression; @@ -322,6 +323,7 @@ private io.trino.sql.ir.Expression translate(Expression expr, boolean isRoot) case Identifier expression -> translate(expression); case FunctionCall expression -> translate(expression); case StaticMethodCall expression -> translate(expression); + case MethodCall expression -> translate(expression); case DereferenceExpression expression -> translate(expression); case Array expression -> translate(expression); case CurrentCatalog expression -> translate(expression); @@ -677,6 +679,18 @@ private io.trino.sql.ir.Expression translate(FunctionCall expression) Optional resolvedFunction = analysis.getResolvedFunction(expression); checkArgument(resolvedFunction.isPresent(), "Function has not been analyzed: %s", expression); + Optional methodReceiver = analysis.getMethodCallReceiver(expression); + if (methodReceiver.isPresent()) { + return new Call( + resolvedFunction.get(), + ImmutableList.builder() + .add(translateExpression(methodReceiver.get())) + .addAll(expression.getArguments().stream() + .map(this::translateExpression) + .collect(toImmutableList())) + .build()); + } + return new Call( resolvedFunction.get(), expression.getArguments().stream() @@ -696,6 +710,21 @@ private io.trino.sql.ir.Expression translate(StaticMethodCall expression) .collect(toImmutableList())); } + private io.trino.sql.ir.Expression translate(MethodCall expression) + { + Optional resolvedFunction = analysis.getResolvedFunction(expression); + checkArgument(resolvedFunction.isPresent(), "Method has not been analyzed: %s", expression); + + return new Call( + resolvedFunction.get(), + ImmutableList.builder() + .add(translateExpression(expression.getReceiver())) + .addAll(expression.getArguments().stream() + .map(this::translateExpression) + .collect(toImmutableList())) + .build()); + } + private io.trino.sql.ir.Expression translate(DereferenceExpression expression) { if (analysis.isColumnReference(expression)) { diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestMethodCall.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestMethodCall.java new file mode 100644 index 000000000000..3d6b06e44fc0 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestMethodCall.java @@ -0,0 +1,223 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.scalar; + +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.trino.metadata.InternalFunctionBundle; +import io.trino.spi.function.InstanceMethod; +import io.trino.spi.function.ScalarFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.function.StaticMethod; +import io.trino.spi.type.StandardTypes; +import io.trino.sql.query.QueryAssertions; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; + +import static io.trino.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; +import static io.trino.spi.StandardErrorCode.FUNCTION_NOT_FOUND; +import static io.trino.testing.assertions.TrinoExceptionAssert.assertTrinoExceptionThrownBy; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; + +@TestInstance(PER_CLASS) +@Execution(CONCURRENT) +public class TestMethodCall +{ + private QueryAssertions assertions; + + @BeforeAll + public void init() + { + assertions = new QueryAssertions(); + assertions.addFunctions(InternalFunctionBundle.builder() + .scalars(getClass()) + .build()); + } + + @AfterAll + public void teardown() + { + assertions.close(); + assertions = null; + } + + @ScalarFunction("char_length") + @InstanceMethod + @SqlType(StandardTypes.BIGINT) + public static long varcharCharLength(@SqlType(StandardTypes.VARCHAR) Slice self) + { + return self.toStringUtf8().length(); + } + + @ScalarFunction("repeat") + @InstanceMethod + @SqlType(StandardTypes.VARCHAR) + public static Slice varcharRepeat(@SqlType(StandardTypes.VARCHAR) Slice self, @SqlType(StandardTypes.BIGINT) long count) + { + return Slices.utf8Slice(self.toStringUtf8().repeat((int) count)); + } + + @ScalarFunction("from_string") + @StaticMethod(StandardTypes.BIGINT) + @SqlType(StandardTypes.BIGINT) + public static long bigintFromString(@SqlType(StandardTypes.VARCHAR) Slice value) + { + return Long.parseLong(value.toStringUtf8()); + } + + @Test + public void testReceiverInParens() + { + assertThat(assertions.expression("('hello').char_length()")) + .matches("BIGINT '5'"); + } + + @Test + public void testReceiverIsFunctionCall() + { + assertThat(assertions.expression("upper('ab').char_length()")) + .matches("BIGINT '2'"); + } + + @Test + public void testWithArguments() + { + assertThat(assertions.expression("('ab').repeat(3)")) + .matches("VARCHAR 'ababab'"); + } + + @Test + public void testBareReceiverResolvesAsMethod() + { + // SQL:2023 6.3 SR 2: A.B(args) is treated as a method invocation when applicable. + // Here `s` is a column of type VARCHAR, so s.char_length() resolves to the + // varchar method rather than a function named "s.char_length". + assertThat(assertions.query("SELECT s.char_length() FROM (VALUES VARCHAR 'hi') t(s)")) + .matches("VALUES BIGINT '2'"); + } + + @Test + public void testUnknownMethod() + { + assertTrinoExceptionThrownBy(() -> assertions.expression("('hello').nope()").evaluate()) + .hasErrorCode(FUNCTION_NOT_FOUND); + } + + @Test + public void testInstanceMethodNotResolvableAsFunction() + { + // The plain `char_length('hello')` form must NOT resolve to the instance method. + assertTrinoExceptionThrownBy(() -> assertions.expression("char_length('hello')").evaluate()) + .hasErrorCode(FUNCTION_NOT_FOUND); + } + + @Test + public void testInstanceMethodNotResolvableAsStaticMethod() + { + // `char_length` is an @InstanceMethod, so the static-method form must not find it. + assertTrinoExceptionThrownBy(() -> assertions.expression("varchar::char_length('hello')").evaluate()) + .hasErrorCode(FUNCTION_NOT_FOUND); + } + + @Test + public void testStaticMethodNotResolvableAsInstanceMethod() + { + // `from_string` is a @StaticMethod on bigint, so the instance-method form must not find it. + assertTrinoExceptionThrownBy(() -> assertions.expression("('42').from_string()").evaluate()) + .hasErrorCode(FUNCTION_NOT_FOUND); + } + + @Test + public void testReceiverCoercion() + { + // VARCHAR(2) coerces to the method's unbounded VARCHAR receiver type. + assertThat(assertions.expression("CAST('hi' AS VARCHAR(2)).char_length()")) + .matches("BIGINT '2'"); + } + + @Test + public void testReceiverNotCoercible() + { + // INTEGER has no implicit coercion to VARCHAR, so no instance method named `char_length` is found for it. + assertTrinoExceptionThrownBy(() -> assertions.expression("(42).char_length()").evaluate()) + .hasErrorCode(FUNCTION_NOT_FOUND); + } + + @Test + public void testArgumentCoercion() + { + // TINYINT coerces to BIGINT, matching the declared argument type. + assertThat(assertions.expression("('ab').repeat(TINYINT '3')")) + .matches("VARCHAR 'ababab'"); + } + + @Test + public void testArgumentNotCoercible() + { + // VARCHAR has no implicit coercion to BIGINT, so the call fails to resolve. + assertTrinoExceptionThrownBy(() -> assertions.expression("('ab').repeat('three')").evaluate()) + .hasErrorCode(FUNCTION_NOT_FOUND); + } + + @Test + public void testCaseInsensitiveMethodName() + { + assertThat(assertions.expression("('hello').CHAR_LENGTH()")) + .matches("BIGINT '5'"); + assertThat(assertions.expression("('hello').Char_Length()")) + .matches("BIGINT '5'"); + } + + @Test + public void testInstanceMethodRequiresScalarFunctionAnnotation() + { + assertTrinoExceptionThrownBy(() -> InternalFunctionBundle.builder().scalars(MissingScalarFunctionFixture.class).build()) + .hasErrorCode(FUNCTION_IMPLEMENTATION_ERROR) + .hasMessageContaining("missing @ScalarFunction or @ScalarOperator"); + } + + @Test + public void testInstanceMethodRequiresSelfArgument() + { + assertTrinoExceptionThrownBy(() -> InternalFunctionBundle.builder().scalars(MissingSelfFixture.class).build()) + .hasErrorCode(FUNCTION_IMPLEMENTATION_ERROR) + .hasMessageContaining("Instance method nothing must declare a self argument"); + } + + public static class MissingScalarFunctionFixture + { + @InstanceMethod + @SqlType(StandardTypes.BIGINT) + public static long length(@SqlType(StandardTypes.VARCHAR) Slice self) + { + return self.toStringUtf8().length(); + } + } + + public static class MissingSelfFixture + { + @ScalarFunction("nothing") + @InstanceMethod + @SqlType(StandardTypes.BIGINT) + public static long noSelf() + { + return 0; + } + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java b/core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java index 8d6d5bea10c0..b8136ba6e69d 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java +++ b/core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java @@ -77,6 +77,7 @@ import io.trino.sql.tree.LocalTimestamp; import io.trino.sql.tree.LogicalExpression; import io.trino.sql.tree.LongLiteral; +import io.trino.sql.tree.MethodCall; import io.trino.sql.tree.Node; import io.trino.sql.tree.NotExpression; import io.trino.sql.tree.NullIfExpression; @@ -512,6 +513,12 @@ protected String visitStaticMethodCall(StaticMethodCall node, Void context) return formatName(node.getType()) + "::" + formatExpression(node.getMethod()) + "(" + joinExpressions(node.getArguments()) + ")"; } + @Override + protected String visitMethodCall(MethodCall node, Void context) + { + return "(" + formatExpression(node.getReceiver()) + ")." + formatExpression(node.getMethod()) + "(" + joinExpressions(node.getArguments()) + ")"; + } + @Override protected String visitWindowOperation(WindowOperation node, Void context) { diff --git a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java index 4f7b2910f156..142e1006b8bc 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java +++ b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java @@ -185,6 +185,7 @@ import io.trino.sql.tree.MergeDelete; import io.trino.sql.tree.MergeInsert; import io.trino.sql.tree.MergeUpdate; +import io.trino.sql.tree.MethodCall; import io.trino.sql.tree.NaturalJoin; import io.trino.sql.tree.Nearest; import io.trino.sql.tree.NestedColumns; @@ -3172,6 +3173,16 @@ public Node visitStaticMethodCall(SqlBaseParser.StaticMethodCallContext context) visit(context.expression(), Expression.class)); } + @Override + public Node visitMethodCall(SqlBaseParser.MethodCallContext context) + { + return new MethodCall( + getLocation(context), + (Expression) visit(context.primaryExpression()), + (Identifier) visit(context.identifier()), + visit(context.expression(), Expression.class)); + } + @Override public Node visitMeasure(SqlBaseParser.MeasureContext context) { diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java b/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java index 4129c1070b96..8c9eef2ba7d2 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java @@ -342,6 +342,11 @@ protected R visitStaticMethodCall(StaticMethodCall node, C context) return visitExpression(node, context); } + protected R visitMethodCall(MethodCall node, C context) + { + return visitExpression(node, context); + } + protected R visitProcessingMode(ProcessingMode node, C context) { return visitNode(node, context); diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/DefaultTraversalVisitor.java b/core/trino-parser/src/main/java/io/trino/sql/tree/DefaultTraversalVisitor.java index fefc08bd65f6..226354e5a26e 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/DefaultTraversalVisitor.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/DefaultTraversalVisitor.java @@ -232,6 +232,16 @@ protected Void visitStaticMethodCall(StaticMethodCall node, C context) return null; } + @Override + protected Void visitMethodCall(MethodCall node, C context) + { + process(node.getReceiver(), context); + for (Expression argument : node.getArguments()) { + process(argument, context); + } + return null; + } + @Override protected Void visitWindowOperation(WindowOperation node, C context) { diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/ExpressionRewriter.java b/core/trino-parser/src/main/java/io/trino/sql/tree/ExpressionRewriter.java index 5d5899f0879f..50af5c36f7e7 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/ExpressionRewriter.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/ExpressionRewriter.java @@ -110,6 +110,11 @@ public Expression rewriteStaticMethodCall(StaticMethodCall node, C context, Expr return rewriteExpression(node, context, treeRewriter); } + public Expression rewriteMethodCall(MethodCall node, C context, ExpressionTreeRewriter treeRewriter) + { + return rewriteExpression(node, context, treeRewriter); + } + public Expression rewriteWindowOperation(WindowOperation node, C context, ExpressionTreeRewriter treeRewriter) { return rewriteExpression(node, context, treeRewriter); diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/ExpressionTreeRewriter.java b/core/trino-parser/src/main/java/io/trino/sql/tree/ExpressionTreeRewriter.java index cf3eb1b18bba..e44b0136c9f6 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/ExpressionTreeRewriter.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/ExpressionTreeRewriter.java @@ -537,6 +537,24 @@ public Expression visitStaticMethodCall(StaticMethodCall node, Context contex return node; } + @Override + public Expression visitMethodCall(MethodCall node, Context context) + { + if (!context.isDefaultRewrite()) { + Expression result = rewriter.rewriteMethodCall(node, context.get(), ExpressionTreeRewriter.this); + if (result != null) { + return result; + } + } + + Expression receiver = rewrite(node.getReceiver(), context.get()); + List arguments = rewrite(node.getArguments(), context); + if (receiver != node.getReceiver() || !sameElements(node.getArguments(), arguments)) { + return new MethodCall(node.getLocation().orElseThrow(), receiver, node.getMethod(), arguments); + } + return node; + } + // Since OrderBy contains list of SortItems, we want to process each SortItem's key, which is an expression private OrderBy rewriteOrderBy(OrderBy orderBy, Context context) { diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/MethodCall.java b/core/trino-parser/src/main/java/io/trino/sql/tree/MethodCall.java new file mode 100644 index 000000000000..8d27055bd4d0 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/MethodCall.java @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +public class MethodCall + extends Expression +{ + private final Expression receiver; + private final Identifier method; + private final List arguments; + + public MethodCall(NodeLocation location, Expression receiver, Identifier method, List arguments) + { + super(location); + this.receiver = requireNonNull(receiver, "receiver is null"); + this.method = requireNonNull(method, "method is null"); + this.arguments = ImmutableList.copyOf(requireNonNull(arguments, "arguments is null")); + } + + public Expression getReceiver() + { + return receiver; + } + + public Identifier getMethod() + { + return method; + } + + public List getArguments() + { + return arguments; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitMethodCall(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.builder() + .add(receiver) + .addAll(arguments) + .build(); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if ((obj == null) || (getClass() != obj.getClass())) { + return false; + } + MethodCall other = (MethodCall) obj; + return Objects.equals(receiver, other.receiver) && + Objects.equals(method, other.method) && + Objects.equals(arguments, other.arguments); + } + + @Override + public int hashCode() + { + return Objects.hash(receiver, method, arguments); + } + + @Override + public boolean shallowEquals(Node other) + { + if (!sameClass(this, other)) { + return false; + } + MethodCall otherInvocation = (MethodCall) other; + return method.equals(otherInvocation.method); + } +} diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java index 5f0946c9f5c1..31fd8ed19b66 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java @@ -134,6 +134,7 @@ import io.trino.sql.tree.MergeDelete; import io.trino.sql.tree.MergeInsert; import io.trino.sql.tree.MergeUpdate; +import io.trino.sql.tree.MethodCall; import io.trino.sql.tree.NaturalJoin; import io.trino.sql.tree.Nearest; import io.trino.sql.tree.NestedColumns; @@ -371,6 +372,49 @@ public void testStaticMethodCall() assertInvalidExpression("varchar(5)::parse('42')", "mismatched input '::'.*"); } + @Test + public void testMethodCall() + { + // Direct invocation on a parenthesized expression. + assertThat(expression("('hello').length()")) + .isEqualTo(new MethodCall( + location(1, 1), + new StringLiteral(location(1, 2), "hello"), + new Identifier(location(1, 11), "length", false), + ImmutableList.of())); + + // Receiver is a function call. + assertThat(expression("upper('a').length()")) + .isEqualTo(new MethodCall( + location(1, 1), + new FunctionCall( + location(1, 1), + QualifiedName.of(ImmutableList.of(new Identifier(location(1, 1), "upper", false))), + ImmutableList.of(new StringLiteral(location(1, 7), "a"))), + new Identifier(location(1, 12), "length", false), + ImmutableList.of())); + + // Method with arguments. + assertThat(expression("(x).contains(1, 2)")) + .isEqualTo(new MethodCall( + location(1, 1), + new Identifier(location(1, 2), "x", false), + new Identifier(location(1, 5), "contains", false), + ImmutableList.of( + new LongLiteral(location(1, 14), "1"), + new LongLiteral(location(1, 17), "2")))); + + // Bare two-part name still parses as a function call; method-call + // interpretation happens at semantic time per SQL:2023 6.3 SR 2. + assertThat(expression("x.length()")) + .isEqualTo(new FunctionCall( + location(1, 1), + QualifiedName.of(ImmutableList.of( + new Identifier(location(1, 1), "x", false), + new Identifier(location(1, 3), "length", false))), + ImmutableList.of())); + } + @Test public void testPossibleExponentialBacktracking() { diff --git a/core/trino-spi/pom.xml b/core/trino-spi/pom.xml index a628f44bad3c..436361620f66 100644 --- a/core/trino-spi/pom.xml +++ b/core/trino-spi/pom.xml @@ -283,8 +283,8 @@ java.method.numberOfParametersChanged method io.trino.spi.function.FunctionMetadata io.trino.spi.function.FunctionMetadata::fromJson(io.trino.spi.function.FunctionId, io.trino.spi.function.Signature, java.lang.String, java.util.Set<java.lang.String>, io.trino.spi.function.FunctionNullability, boolean, boolean, boolean, java.lang.String, io.trino.spi.function.FunctionKind, boolean) - method io.trino.spi.function.FunctionMetadata io.trino.spi.function.FunctionMetadata::fromJson(io.trino.spi.function.FunctionId, io.trino.spi.function.Signature, java.lang.String, java.util.Set<java.lang.String>, io.trino.spi.function.FunctionNullability, boolean, boolean, boolean, java.lang.String, io.trino.spi.function.FunctionKind, boolean, java.util.Optional<io.trino.spi.type.TypeSignature>) - Add receiverType to support SQL:2023 static method invocation + method io.trino.spi.function.FunctionMetadata io.trino.spi.function.FunctionMetadata::fromJson(io.trino.spi.function.FunctionId, io.trino.spi.function.Signature, java.lang.String, java.util.Set<java.lang.String>, io.trino.spi.function.FunctionNullability, boolean, boolean, boolean, java.lang.String, io.trino.spi.function.FunctionKind, boolean, java.util.Optional<io.trino.spi.type.TypeSignature>, boolean) + Add receiverType and instanceMethod to support SQL:2023 method invocation diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/FunctionMetadata.java b/core/trino-spi/src/main/java/io/trino/spi/function/FunctionMetadata.java index 4587cc892c00..4b417f5a978f 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/function/FunctionMetadata.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/FunctionMetadata.java @@ -48,6 +48,7 @@ public class FunctionMetadata private final FunctionKind kind; private final boolean deprecated; private final Optional receiverType; + private final boolean instanceMethod; private FunctionMetadata( FunctionId functionId, @@ -61,7 +62,8 @@ private FunctionMetadata( String description, FunctionKind kind, boolean deprecated, - Optional receiverType) + Optional receiverType, + boolean instanceMethod) { this.functionId = requireNonNull(functionId, "functionId is null"); this.signature = requireNonNull(signature, "signature is null"); @@ -82,6 +84,10 @@ private FunctionMetadata( this.kind = requireNonNull(kind, "kind is null"); this.deprecated = deprecated; this.receiverType = requireNonNull(receiverType, "receiverType is null"); + if (instanceMethod && receiverType.isEmpty()) { + throw new IllegalArgumentException("instance method must have a receiver type"); + } + this.instanceMethod = instanceMethod; } /** @@ -166,9 +172,11 @@ public boolean isDeprecated() } /** - * The receiver type when this function is a static method invocable as - * {@code T::method(args)}. A non-empty value implies the function is a - * static method; an empty value implies a regular function. + * The receiver type when this function is a method. For a static method + * (invocable as {@code T::method(args)}) this is the named type. For an + * instance method (invocable as {@code receiver.method(args)}) this is + * the type of the {@code self} parameter (the first declared argument). + * Empty for regular functions. */ @JsonProperty public Optional getReceiverType() @@ -176,6 +184,17 @@ public Optional getReceiverType() return receiverType; } + /** + * Whether this is an instance method (receiver passed as the first + * argument) rather than a static method. Only meaningful when + * {@link #getReceiverType()} is present. + */ + @JsonProperty + public boolean isInstanceMethod() + { + return instanceMethod; + } + @JsonCreator @DoNotCall // For JSON deserialization only public static FunctionMetadata fromJson( @@ -190,7 +209,8 @@ public static FunctionMetadata fromJson( @JsonProperty String description, @JsonProperty FunctionKind kind, @JsonProperty boolean deprecated, - @JsonProperty Optional receiverType) + @JsonProperty Optional receiverType, + @JsonProperty boolean instanceMethod) { return new FunctionMetadata( functionId, @@ -204,7 +224,8 @@ public static FunctionMetadata fromJson( description, kind, deprecated, - receiverType == null ? Optional.empty() : receiverType); + receiverType == null ? Optional.empty() : receiverType, + instanceMethod); } @Override @@ -259,6 +280,7 @@ public static final class Builder private FunctionId functionId; private boolean deprecated; private Optional receiverType = Optional.empty(); + private boolean instanceMethod; private Builder(String canonicalName, FunctionKind kind) { @@ -363,6 +385,12 @@ public Builder receiverType(TypeSignature receiverType) return this; } + public Builder instanceMethod() + { + this.instanceMethod = true; + return this; + } + public FunctionMetadata build() { FunctionId functionId = this.functionId; @@ -384,7 +412,8 @@ public FunctionMetadata build() description, kind, deprecated, - receiverType); + receiverType, + instanceMethod); } } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/InstanceMethod.java b/core/trino-spi/src/main/java/io/trino/spi/function/InstanceMethod.java new file mode 100644 index 000000000000..e74a0625e3f1 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/function/InstanceMethod.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.function; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.TYPE; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +/** + * Marks a scalar function as a non-static method, invocable via SQL:2023 + * {@code } syntax: {@code receiver.method(args)}. + * The receiver type is taken from the first {@code @SqlType} argument of + * the implementation, which becomes the {@code self} parameter. + */ +@Retention(RUNTIME) +@Target({METHOD, TYPE}) +public @interface InstanceMethod {}