diff --git a/transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/TrinoTester.java b/transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/TrinoTester.java index 3c188360..2caf84c8 100644 --- a/transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/TrinoTester.java +++ b/transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/TrinoTester.java @@ -6,7 +6,6 @@ package com.linkedin.transport.test.trino; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.linkedin.transport.test.spi.Row; import com.linkedin.transport.test.spi.TestCase; @@ -23,7 +22,6 @@ import io.trino.spi.connector.ConnectorFactory; import io.trino.spi.connector.ConnectorMetadata; import io.trino.spi.function.BoundSignature; -import io.trino.metadata.FunctionBinding; import io.trino.spi.function.FunctionId; import com.linkedin.transport.api.StdFactory; import com.linkedin.transport.api.udf.StdUDF; @@ -103,12 +101,9 @@ public Connector create(String catalogName, Map config, Connecto @Override public StdFactory getStdFactory() { if (_stdFactory == null) { - FunctionBinding functionBinding = new FunctionBinding( - new FunctionId("test"), + _stdFactory = new TrinoFactory( new BoundSignature("test", UNKNOWN, ImmutableList.of()), - ImmutableMap.of(), - ImmutableMap.of()); - _stdFactory = new TrinoFactory(functionBinding, new TrinoTestFunctionDependencies(InternalTypeManager.TESTING_TYPE_MANAGER, _runner)); + new TrinoTestFunctionDependencies(InternalTypeManager.TESTING_TYPE_MANAGER, _runner)); } return _stdFactory; } diff --git a/transportable-udfs-trino-plugin/build.gradle b/transportable-udfs-trino-plugin/build.gradle index 3788510a..d3d25d4f 100644 --- a/transportable-udfs-trino-plugin/build.gradle +++ b/transportable-udfs-trino-plugin/build.gradle @@ -15,7 +15,9 @@ dependencies { implementation (group:'io.airlift', name: 'log', version: '221') implementation (group:'com.google.guava', name: 'guava', version: '24.1-jre') implementation (group:'io.trino', name: 'trino-plugin-toolkit', version: project.ext.'trino-version') - runtimeOnly (group:'io.trino', name: 'trino-main', version: project.ext.'trino-version') + runtimeOnly (group:'io.trino', name: 'trino-main', version: project.ext.'trino-version') { + exclude 'group': 'io.trino', 'module': 'trino-spi' + } compileOnly(group:'io.trino', name: 'trino-spi', version: project.ext.'trino-version') testImplementation (group:'io.trino', name: 'trino-main', version: project.ext.'trino-version') } diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUDFUtils.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUDFUtils.java index bd7e2579..f4a834c7 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUDFUtils.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUDFUtils.java @@ -8,6 +8,9 @@ import com.google.common.collect.ImmutableSet; import com.linkedin.transport.typesystem.TypeSignature; import com.linkedin.transport.typesystem.TypeSignatureElement; +import io.trino.spi.TrinoException; +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; import java.util.ArrayList; import java.util.List; import java.util.Locale; @@ -15,6 +18,7 @@ import java.util.stream.Collectors; import static com.linkedin.transport.typesystem.ConcreteTypeSignatureElement.*; +import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; /** @@ -49,6 +53,14 @@ static String quoteReservedKeywords(String signature) { return toTrinoTypeSignatureString(TypeSignature.parse(signature)); } + public static MethodHandle methodHandle(Class clazz, String name, Class... parameterTypes) { + try { + return MethodHandles.lookup().unreflect(clazz.getMethod(name, parameterTypes)); + } catch (IllegalAccessException | NoSuchMethodException e) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, e); + } + } + private static String toTrinoTypeSignatureString(TypeSignature typeSignature) { final TypeSignatureElement typeSignatureBase = typeSignature.getBase(); if (BOOLEAN.equals(typeSignatureBase)) { diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUdfWrapper.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUdfWrapper.java index 0e55f46b..2305f7e3 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUdfWrapper.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUdfWrapper.java @@ -23,17 +23,15 @@ import com.linkedin.transport.api.udf.StdUDF8; import com.linkedin.transport.api.udf.TopLevelStdUDF; import com.linkedin.transport.typesystem.GenericTypeSignatureElement; -import io.trino.metadata.FunctionBinding; -import io.trino.metadata.SignatureBinder; import io.trino.spi.function.BoundSignature; import io.trino.spi.function.FunctionDependencies; import io.trino.spi.function.FunctionDependencyDeclaration; import io.trino.spi.function.FunctionKind; import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.ScalarFunctionAdapter; import io.trino.spi.function.ScalarFunctionImplementation; import io.trino.spi.function.Signature; import io.trino.spi.function.TypeVariableConstraint; -import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction; import io.trino.spi.classloader.ThreadContextClassLoader; import io.trino.spi.function.InvocationConvention; import io.trino.spi.type.ArrayType; @@ -54,13 +52,15 @@ import java.util.stream.IntStream; import org.apache.commons.lang3.ClassUtils; +import static com.linkedin.transport.trino.StdUDFUtils.methodHandle; import static com.linkedin.transport.trino.StdUDFUtils.quoteReservedKeywords; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.*; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BOXED_NULLABLE; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; -import static io.trino.spi.function.OperatorType.*; -import static io.trino.spi.function.TypeVariableConstraint.*; +import static io.trino.spi.function.ScalarFunctionAdapter.NullAdaptationPolicy.RETURN_NULL_ON_NULL; +import static io.trino.spi.function.OperatorType.EQUAL; +import static io.trino.spi.function.TypeVariableConstraint.typeVariable; import static io.trino.sql.analyzer.TypeSignatureTranslator.parseTypeSignature; -import static io.trino.util.Reflection.*; // Suppressing argument naming convention for the evalInternal methods @SuppressWarnings({"checkstyle:regexpsinglelinejava"}) @@ -70,6 +70,7 @@ public abstract class StdUdfWrapper { private static final int JITTER_FACTOR = 50; // to calculate jitter from delay private final FunctionMetadata functionMetadata; + private final ScalarFunctionAdapter functionAdapter = new ScalarFunctionAdapter(RETURN_NULL_ON_NULL); public StdUdfWrapper(StdUDF stdUDF) { this.functionMetadata = FunctionMetadata.builder(FunctionKind.SCALAR) @@ -133,8 +134,7 @@ public FunctionDependencyDeclaration getFunctionDependencies(BoundSignature boun public ScalarFunctionImplementation getScalarFunctionImplementation(BoundSignature boundSignature, FunctionDependencies functionDependencies, InvocationConvention invocationConvention) { - FunctionBinding functionBinding = SignatureBinder.bindFunction(functionMetadata.getFunctionId(), functionMetadata.getSignature(), boundSignature); - StdFactory stdFactory = new TrinoFactory(functionBinding, functionDependencies); + StdFactory stdFactory = new TrinoFactory(boundSignature, functionDependencies); StdUDF stdUDF = getStdUDF(); stdUDF.init(stdFactory); // Subtract a small jitter value so that refresh is triggered on first call @@ -145,12 +145,25 @@ public ScalarFunctionImplementation getScalarFunctionImplementation(BoundSignatu - (new Random()).nextInt(initialJitterInt)); boolean[] nullableArguments = stdUDF.getAndCheckNullableArguments(); - ScalarFunctionImplementation res = new ChoicesSpecializedSqlScalarFunction( + return internalGetScalarFunctionImplementation( boundSignature, - NULLABLE_RETURN, + getMethodHandle(stdUDF, boundSignature, nullableArguments, requiredFilesNextRefreshTime), getNullConventionForArguments(nullableArguments), - getMethodHandle(stdUDF, boundSignature, nullableArguments, requiredFilesNextRefreshTime)).getScalarFunctionImplementation(invocationConvention); - return res; + invocationConvention + ); + } + + private ScalarFunctionImplementation internalGetScalarFunctionImplementation(BoundSignature boundSignature, MethodHandle methodHandle, + List nullableArguments, InvocationConvention invocationConvention) { + InvocationConvention actualConvention = new InvocationConvention(nullableArguments, NULLABLE_RETURN, false, false); + MethodHandle internalMethodHandle = functionAdapter.adapt( + methodHandle, + boundSignature.getArgumentTypes(), + actualConvention, + invocationConvention + ); + return ScalarFunctionImplementation.builder().methodHandle(internalMethodHandle) + .lambdaInterfaces(ImmutableList.of()).build(); } private MethodHandle getMethodHandle(StdUDF stdUDF, BoundSignature boundSignature, boolean[] nullableArguments, diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/TrinoFactory.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/TrinoFactory.java index 0fa4d89e..11d9555c 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/TrinoFactory.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/TrinoFactory.java @@ -6,7 +6,6 @@ package com.linkedin.transport.trino; import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableSet; import com.linkedin.transport.api.StdFactory; import com.linkedin.transport.api.data.StdArray; import com.linkedin.transport.api.data.StdBoolean; @@ -30,7 +29,7 @@ import com.linkedin.transport.trino.data.TrinoString; import com.linkedin.transport.trino.data.TrinoStruct; import io.airlift.slice.Slices; -import io.trino.metadata.FunctionBinding; +import io.trino.spi.function.BoundSignature; import io.trino.spi.function.FunctionDependencies; import io.trino.metadata.OperatorNotFoundException; import io.trino.spi.function.InvocationConvention; @@ -39,24 +38,19 @@ import io.trino.spi.type.MapType; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; -import io.trino.spi.type.TypeSignature; import java.lang.invoke.MethodHandle; import java.nio.ByteBuffer; import java.util.List; import java.util.stream.Collectors; -import static com.linkedin.transport.trino.StdUDFUtils.quoteReservedKeywords; -import static io.trino.metadata.SignatureBinder.*; -import static io.trino.sql.analyzer.TypeSignatureTranslator.*; - public class TrinoFactory implements StdFactory { - final FunctionBinding functionBinding; + final BoundSignature boundSignature; final FunctionDependencies functionDependencies; - public TrinoFactory(FunctionBinding functionBinding, FunctionDependencies functionDependencies) { - this.functionBinding = functionBinding; + public TrinoFactory(BoundSignature boundSignature, FunctionDependencies functionDependencies) { + this.boundSignature = boundSignature; this.functionDependencies = functionDependencies; } @@ -130,8 +124,7 @@ public StdStruct createStruct(StdType stdType) { @Override public StdType createStdType(String typeSignatureStr) { - TypeSignature typeSignature = applyBoundVariables(parseTypeSignature(quoteReservedKeywords(typeSignatureStr), ImmutableSet.of()), functionBinding); - return TrinoWrapper.createStdType(functionDependencies.getType(typeSignature)); + return TrinoWrapper.createStdType(boundSignature.getReturnType()); } public MethodHandle getOperatorHandle( diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoArray.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoArray.java index 4d0dfa5d..ea2ec710 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoArray.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoArray.java @@ -16,8 +16,7 @@ import io.trino.spi.type.Type; import java.util.Iterator; -import static io.trino.spi.type.TypeUtils.*; - +import static io.trino.spi.type.TypeUtils.readNativeValue; public class TrinoArray extends TrinoData implements StdArray { diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoBinary.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoBinary.java index 9fa7914b..18e56d49 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoBinary.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoBinary.java @@ -10,7 +10,7 @@ import io.trino.spi.block.BlockBuilder; import java.nio.ByteBuffer; -import static io.trino.spi.type.VarbinaryType.*; +import static io.trino.spi.type.VarbinaryType.VARBINARY; public class TrinoBinary extends TrinoData implements StdBinary { diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoBoolean.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoBoolean.java index 9b6c9e23..ecee5492 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoBoolean.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoBoolean.java @@ -8,7 +8,7 @@ import com.linkedin.transport.api.data.StdBoolean; import io.trino.spi.block.BlockBuilder; -import static io.trino.spi.type.BooleanType.*; +import static io.trino.spi.type.BooleanType.BOOLEAN; public class TrinoBoolean extends TrinoData implements StdBoolean { diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoDouble.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoDouble.java index 6e3567ec..2afe379c 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoDouble.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoDouble.java @@ -8,7 +8,7 @@ import com.linkedin.transport.api.data.StdDouble; import io.trino.spi.block.BlockBuilder; -import static io.trino.spi.type.DoubleType.*; +import static io.trino.spi.type.DoubleType.DOUBLE; public class TrinoDouble extends TrinoData implements StdDouble { diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoFloat.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoFloat.java index 16893bcc..a7e97cc8 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoFloat.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoFloat.java @@ -8,7 +8,7 @@ import com.linkedin.transport.api.data.StdFloat; import io.trino.spi.block.BlockBuilder; -import static java.lang.Float.*; +import static java.lang.Float.floatToIntBits; public class TrinoFloat extends TrinoData implements StdFloat { diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoInteger.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoInteger.java index bc52ad62..fbe28a0e 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoInteger.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoInteger.java @@ -8,7 +8,7 @@ import com.linkedin.transport.api.data.StdInteger; import io.trino.spi.block.BlockBuilder; -import static io.trino.spi.type.IntegerType.*; +import static io.trino.spi.type.IntegerType.INTEGER; public class TrinoInteger extends TrinoData implements StdInteger { diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoLong.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoLong.java index 5f842938..60034ab0 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoLong.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoLong.java @@ -8,7 +8,7 @@ import com.linkedin.transport.api.data.StdLong; import io.trino.spi.block.BlockBuilder; -import static io.trino.spi.type.BigintType.*; +import static io.trino.spi.type.BigintType.BIGINT; public class TrinoLong extends TrinoData implements StdLong { diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoMap.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoMap.java index 73c74637..6dcf1ee4 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoMap.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoMap.java @@ -27,11 +27,11 @@ import java.util.Iterator; import java.util.Set; -import static io.trino.spi.StandardErrorCode.*; +import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.trino.spi.function.InvocationConvention.simpleConvention; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; -import static io.trino.spi.type.TypeUtils.*; +import static io.trino.spi.type.TypeUtils.readNativeValue; public class TrinoMap extends TrinoData implements StdMap { diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoString.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoString.java index 5fc9e7f7..f6ea7875 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoString.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoString.java @@ -9,7 +9,7 @@ import io.airlift.slice.Slice; import io.trino.spi.block.BlockBuilder; -import static io.trino.spi.type.VarcharType.*; +import static io.trino.spi.type.VarcharType.VARCHAR; public class TrinoString extends TrinoData implements StdString { diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoStruct.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoStruct.java index c94ae335..026cd536 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoStruct.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoStruct.java @@ -21,7 +21,7 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; -import static io.trino.spi.type.TypeUtils.*; +import static io.trino.spi.type.TypeUtils.readNativeValue; public class TrinoStruct extends TrinoData implements StdStruct {