diff --git a/transportable-udfs-codegen/src/main/java/com/linkedin/transport/codegen/TrinoWrapperGenerator.java b/transportable-udfs-codegen/src/main/java/com/linkedin/transport/codegen/TrinoWrapperGenerator.java index 957b7741..bb21c8d2 100644 --- a/transportable-udfs-codegen/src/main/java/com/linkedin/transport/codegen/TrinoWrapperGenerator.java +++ b/transportable-udfs-codegen/src/main/java/com/linkedin/transport/codegen/TrinoWrapperGenerator.java @@ -22,9 +22,11 @@ public class TrinoWrapperGenerator implements WrapperGenerator { private static final String TRINO_PACKAGE_SUFFIX = "trino"; - private static final String GET_STD_UDF_METHOD = "getStdUDF"; + private static final String GET_STATE_CLASS_NAME_METHOD = "getStateClassName"; private static final ClassName TRINO_STD_UDF_WRAPPER_CLASS_NAME = ClassName.bestGuess("com.linkedin.transport.trino.StdUdfWrapper"); + private static final ClassName TRINO_STD_UDF_WRAPPER_STATE_CLASS_NAME = + ClassName.bestGuess("com.linkedin.transport.trino.StdUdfWrapper.State"); private static final String SERVICE_FILE = "META-INF/services/io.trino.metadata.SqlScalarFunction"; @Override @@ -33,7 +35,15 @@ public void generateWrappers(WrapperGeneratorContext context) { TransportUDFMetadata udfMetadata = context.getTransportUdfMetadata(); for (String topLevelClass : context.getTransportUdfMetadata().getTopLevelClasses()) { for (String implementationClass : udfMetadata.getStdUDFImplementations(topLevelClass)) { - generateWrapper(implementationClass, context.getSourcesOutputDir(), services); + ClassName implementationClassName = ClassName.bestGuess(implementationClass); + ClassName stateClassName = + ClassName.get(implementationClassName.packageName() + "." + TRINO_PACKAGE_SUFFIX, + implementationClassName.simpleName() + "State"); + ClassName wrapperClassName = + ClassName.get(implementationClassName.packageName() + "." + TRINO_PACKAGE_SUFFIX, + implementationClassName.simpleName()); + generateWrapperClass(wrapperClassName, implementationClassName, context.getSourcesOutputDir(), services); + generateStateClass(stateClassName, implementationClassName, context.getSourcesOutputDir(), services); } } try { @@ -43,12 +53,8 @@ public void generateWrappers(WrapperGeneratorContext context) { } } - private void generateWrapper(String implementationClass, File sourcesOutputDir, List services) { - ClassName implementationClassName = ClassName.bestGuess(implementationClass); - ClassName wrapperClassName = - ClassName.get(implementationClassName.packageName() + "." + TRINO_PACKAGE_SUFFIX, - implementationClassName.simpleName()); - + private void generateWrapperClass(ClassName wrapperClassName, ClassName implementationClassName, File sourcesOutputDir, + List services) { /* Generates constructor -> @@ -65,15 +71,15 @@ private void generateWrapper(String implementationClass, File sourcesOutputDir, Generates -> @Override - protected StdUDF getStdUDF() { + protected String getStateClassName() { return new ${implementationClassName}(); } */ - MethodSpec getStdUDFMethod = MethodSpec.methodBuilder(GET_STD_UDF_METHOD) + MethodSpec getStateClassNameMethod = MethodSpec.methodBuilder(GET_STATE_CLASS_NAME_METHOD) .addAnnotation(Override.class) - .returns(StdUDF.class) + .returns(String.class) .addModifiers(Modifier.PROTECTED) - .addStatement("return new $T()", implementationClassName) + .addStatement("return \"" + implementationClassName.reflectionName() + "\"") .build(); /* @@ -91,7 +97,7 @@ public class ${wrapperClassName} extends StdUdfWrapper { .addModifiers(Modifier.PUBLIC) .superclass(TRINO_STD_UDF_WRAPPER_CLASS_NAME) .addMethod(constructor) - .addMethod(getStdUDFMethod) + .addMethod(getStateClassNameMethod) .build(); services.add(wrapperClassName.toString()); @@ -105,4 +111,49 @@ public class ${wrapperClassName} extends StdUdfWrapper { throw new RuntimeException("Error writing wrapper to file: ", e); } } + + private void generateStateClass(ClassName stateClassName, ClassName implementationClassName, File sourcesOutputDir, + List services) { + /* + Generates constructor -> + + public ${stateClassName}() { + super(); + stdUDF = new ${implementationClassName}; + } + */ + MethodSpec constructor = MethodSpec.constructorBuilder() + .addModifiers(Modifier.PUBLIC) + .addStatement("super()") + .addStatement("stdUDF = new $T()", implementationClassName) + .build(); + + /* + Generates -> + + public class ${stateClassName} extends State { + + . + . + . + + } + */ + TypeSpec wrapperClass = TypeSpec.classBuilder(stateClassName) + .addModifiers(Modifier.PUBLIC) + .superclass(TRINO_STD_UDF_WRAPPER_STATE_CLASS_NAME) + .addMethod(constructor) + .build(); + + services.add(stateClassName.toString()); + JavaFile javaFile = JavaFile.builder(stateClassName.packageName(), wrapperClass) + .skipJavaLangImports(true) + .build(); + + try { + javaFile.writeTo(sourcesOutputDir); + } catch (Exception e) { + throw new RuntimeException("Error writing wrapper to file: ", e); + } + } } diff --git a/transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/TrinoTestStdUDFWrapper.java b/transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/TrinoTestStdUDFWrapper.java index 17f02eaf..feec3146 100644 --- a/transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/TrinoTestStdUDFWrapper.java +++ b/transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/TrinoTestStdUDFWrapper.java @@ -7,6 +7,7 @@ import com.linkedin.transport.api.udf.StdUDF; import com.linkedin.transport.trino.StdUdfWrapper; + import java.lang.reflect.InvocationTargetException; @@ -26,8 +27,8 @@ public TrinoTestStdUDFWrapper(Class udfClass) { } @Override - protected StdUDF getStdUDF() { - return createInstance(_udfClass); + protected String getStateClassName() { + return TestState.class.getName(); } private static K createInstance(Class udfClass) { @@ -37,4 +38,11 @@ private static K createInstance(Class udfClass) { throw new RuntimeException(e); } } + + public class TestState extends State { + public TestState () { + super(); + stdUDF = createInstance(_udfClass); + } + } } diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUdfWrapper.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUdfWrapper.java index 0f2d57af..841d2433 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUdfWrapper.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUdfWrapper.java @@ -36,6 +36,7 @@ import io.trino.operator.scalar.ChoicesScalarFunctionImplementation; import io.trino.operator.scalar.ScalarFunctionImplementation; import io.trino.spi.classloader.ThreadContextClassLoader; +import io.trino.spi.connector.ConnectorSession; import io.trino.spi.function.InvocationConvention; import io.trino.spi.type.ArrayType; import io.trino.spi.type.IntegerType; @@ -45,10 +46,7 @@ import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; -import java.util.HashSet; -import java.util.List; -import java.util.Random; -import java.util.Set; +import java.util.*; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.stream.Collectors; @@ -135,25 +133,20 @@ public FunctionDependencyDeclaration getFunctionDependencies(FunctionBinding fun @Override public ScalarFunctionImplementation specialize(FunctionBinding functionBinding, FunctionDependencies functionDependencies) { StdFactory stdFactory = new TrinoFactory(functionBinding, functionDependencies); - StdUDF stdUDF = getStdUDF(); - stdUDF.init(stdFactory); - // Subtract a small jitter value so that refresh is triggered on first call - // while ensuring subsequent calls do not happen at the same time across workers - long initialJitter = getRefreshIntervalMillis() / JITTER_FACTOR; - int initialJitterInt = initialJitter > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) initialJitter; - AtomicLong requiredFilesNextRefreshTime = new AtomicLong(System.currentTimeMillis() - - (new Random()).nextInt(initialJitterInt)); + + MethodHandle instanceFactory = constructorMethodHandle(getStateClass()); + StdUDF stdUDF = ((State) invokeMethodHandle(instanceFactory)).getStdUDF(); boolean[] nullableArguments = stdUDF.getAndCheckNullableArguments(); return new ChoicesScalarFunctionImplementation( functionBinding, NULLABLE_RETURN, getNullConventionForArguments(nullableArguments), - getMethodHandle(stdUDF, functionBinding, nullableArguments, requiredFilesNextRefreshTime)); + getMethodHandle(stdFactory, functionBinding, nullableArguments), + Optional.of(instanceFactory)); } - private MethodHandle getMethodHandle(StdUDF stdUDF, FunctionBinding functionBinding, boolean[] nullableArguments, - AtomicLong requiredFilesNextRefreshTime) { + private MethodHandle getMethodHandle(StdFactory stdFactory, FunctionBinding functionBinding, boolean[] nullableArguments) { Type[] inputTypes = functionBinding.getBoundSignature().getArgumentTypes().toArray(new Type[0]); Type outputType = functionBinding.getBoundSignature().getReturnType(); @@ -167,10 +160,10 @@ private MethodHandle getMethodHandle(StdUDF stdUDF, FunctionBinding functionBind MethodType specificMethodType = MethodType.methodType(specificMethodHandleReturnType, specificMethodHandleArgumentTypes); - // Specific MethodHandle required by trino where argument types map to the type signature + // Specific MethodHandle required by Trino where argument types map to the type signature MethodHandle specificMethodHandle = MethodHandles.explicitCastArguments(genericMethodHandle, specificMethodType); - return MethodHandles.insertArguments(specificMethodHandle, 0, stdUDF, inputTypes, - outputType instanceof IntegerType, requiredFilesNextRefreshTime); + return MethodHandles.insertArguments(specificMethodHandle, 1, stdFactory, inputTypes, + outputType instanceof IntegerType); } private List getNullConventionForArguments( @@ -192,12 +185,26 @@ private StdData[] wrapArguments(StdUDF stdUDF, Type[] types, Object[] arguments) return stdData; } - protected Object eval(StdUDF stdUDF, Type[] types, boolean isIntegerReturnType, - AtomicLong requiredFilesNextRefreshTime, Object... arguments) { + private Object invokeMethodHandle(MethodHandle methodHandle) { + try { + return methodHandle.invoke(); + } catch (Throwable e) { + throw new RuntimeException("Could not invoke MethodHandle " + methodHandle); + } + } + + protected Object eval(State state, ConnectorSession session, StdFactory stdFactory, Type[] types, boolean isIntegerReturnType, + Object... arguments) { + StdUDF stdUDF = state.getStdUDF(); + if (!state.isInitialized()) { + stdUDF.init(stdFactory); + state.setInitialized(); + } + long requiredFilesNextRefreshTime = state.getRequiredFilesNextRefreshTime(); StdData[] args = wrapArguments(stdUDF, types, arguments); - if (requiredFilesNextRefreshTime.get() <= System.currentTimeMillis()) { + if (requiredFilesNextRefreshTime <= System.currentTimeMillis()) { String[] requiredFiles = getRequiredFiles(stdUDF, args); - processRequiredFiles(stdUDF, requiredFiles, requiredFilesNextRefreshTime); + processRequiredFiles(state, requiredFiles); } StdData result; switch (args.length) { @@ -278,9 +285,10 @@ private String[] getRequiredFiles(StdUDF stdUDF, StdData[] args) { return requiredFiles; } - private synchronized void processRequiredFiles(StdUDF stdUDF, String[] requiredFiles, - AtomicLong requiredFilesNextRefreshTime) { - if (requiredFilesNextRefreshTime.get() <= System.currentTimeMillis()) { + private synchronized void processRequiredFiles(State state, String[] requiredFiles) { + long requiredFilesNextRefreshTime = state.getRequiredFilesNextRefreshTime(); + StdUDF stdUDF = state.getStdUDF(); + if (requiredFilesNextRefreshTime <= System.currentTimeMillis()) { try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(getClass().getClassLoader())) { String[] copiedFiles = new String[requiredFiles.length]; FileSystemClient client = new FileSystemClient(); @@ -291,8 +299,10 @@ private synchronized void processRequiredFiles(StdUDF stdUDF, String[] requiredF stdUDF.processRequiredFiles(copiedFiles); // Determine how many times _refreshIntervalMillis needs to be added to go above currentTimeMillis int refreshIntervalFactor = (int) Math.ceil( - (System.currentTimeMillis() - requiredFilesNextRefreshTime.get()) / (double) getRefreshIntervalMillis()); - requiredFilesNextRefreshTime.getAndAdd(getRefreshIntervalMillis() * Math.max(1, refreshIntervalFactor)); + (System.currentTimeMillis() - requiredFilesNextRefreshTime) / (double) getRefreshIntervalMillis()); + state.setRequiredFilesNextRefreshTime( + requiredFilesNextRefreshTime + getRefreshIntervalMillis() * Math.max(1, refreshIntervalFactor)) + ; } } } @@ -307,70 +317,115 @@ private Class getJavaTypeForNullability(Type trinoType, boolean nullableArgument private Class[] getMethodHandleArgumentTypes(Type[] argTypes, boolean[] nullableArguments, boolean useObjectForArgumentType) { - Class[] methodHandleArgumentTypes = new Class[argTypes.length + 4]; - methodHandleArgumentTypes[0] = StdUDF.class; - methodHandleArgumentTypes[1] = Type[].class; - methodHandleArgumentTypes[2] = boolean.class; - methodHandleArgumentTypes[3] = AtomicLong.class; + Class[] methodHandleArgumentTypes = new Class[argTypes.length + 5]; + methodHandleArgumentTypes[0] = State.class; + methodHandleArgumentTypes[1] = ConnectorSession.class; + methodHandleArgumentTypes[2] = StdFactory.class; + methodHandleArgumentTypes[3] = Type[].class; + methodHandleArgumentTypes[4] = boolean.class; for (int i = 0; i < argTypes.length; i++) { if (useObjectForArgumentType) { - methodHandleArgumentTypes[i + 4] = Object.class; + methodHandleArgumentTypes[i + 5] = Object.class; } else { - methodHandleArgumentTypes[i + 4] = getJavaTypeForNullability(argTypes[i], nullableArguments[i]); + methodHandleArgumentTypes[i + 5] = getJavaTypeForNullability(argTypes[i], nullableArguments[i]); } } return methodHandleArgumentTypes; } - protected abstract StdUDF getStdUDF(); + private Class getStateClass() { + try { + return Class.forName(getStateClassName()); + } catch (Exception e) { + throw new RuntimeException("Could not find class " + getStateClassName() + " on classpath"); + } + } + protected abstract String getStateClassName(); - public Object evalInternal(StdUDF stdUDF, Type[] types, boolean isIntegerReturnType, - AtomicLong requiredFilesNextRefreshTime) { - return eval(stdUDF, types, isIntegerReturnType, requiredFilesNextRefreshTime); + public Object evalInternal(State state, ConnectorSession session, StdFactory stdFactory, Type[] types, + boolean isIntegerReturnType) { + return eval(state, session, stdFactory, types, isIntegerReturnType); } - public Object evalInternal(StdUDF stdUDF, Type[] types, boolean isIntegerReturnType, - AtomicLong requiredFilesNextRefreshTime, Object arg1) { - return eval(stdUDF, types, isIntegerReturnType, requiredFilesNextRefreshTime, arg1); + public Object evalInternal(State state, ConnectorSession session, StdFactory stdFactory, Type[] types, boolean isIntegerReturnType, + Object arg1) { + return eval(state, session, stdFactory, types, isIntegerReturnType, arg1); } - public Object evalInternal(StdUDF stdUDF, Type[] types, boolean isIntegerReturnType, - AtomicLong requiredFilesNextRefreshTime, Object arg1, Object arg2) { - return eval(stdUDF, types, isIntegerReturnType, requiredFilesNextRefreshTime, arg1, arg2); + public Object evalInternal(State state, ConnectorSession session, StdFactory stdFactory, Type[] types, boolean isIntegerReturnType, + Object arg1, Object arg2) { + return eval(state, session, stdFactory, types, isIntegerReturnType, arg1, arg2); } - public Object evalInternal(StdUDF stdUDF, Type[] types, boolean isIntegerReturnType, - AtomicLong requiredFilesNextRefreshTime, Object arg1, Object arg2, Object arg3) { - return eval(stdUDF, types, isIntegerReturnType, requiredFilesNextRefreshTime, arg1, arg2, arg3); + public Object evalInternal(State state, ConnectorSession session, StdFactory stdFactory, Type[] types, boolean isIntegerReturnType, + Object arg1, Object arg2, Object arg3) { + return eval(state, session, stdFactory, types, isIntegerReturnType, arg1, arg2, arg3); } - public Object evalInternal(StdUDF stdUDF, Type[] types, boolean isIntegerReturnType, - AtomicLong requiredFilesNextRefreshTime, Object arg1, Object arg2, Object arg3, Object arg4) { - return eval(stdUDF, types, isIntegerReturnType, requiredFilesNextRefreshTime, arg1, arg2, arg3, arg4); + public Object evalInternal(State state, ConnectorSession session, StdFactory stdFactory, Type[] types, boolean isIntegerReturnType, + Object arg1, Object arg2, Object arg3, Object arg4) { + return eval(state, session, stdFactory, types, isIntegerReturnType, arg1, arg2, arg3, + arg4); } - public Object evalInternal(StdUDF stdUDF, Type[] types, boolean isIntegerReturnType, - AtomicLong requiredFilesNextRefreshTime, Object arg1, Object arg2, Object arg3, Object arg4, Object arg5) { - return eval(stdUDF, types, isIntegerReturnType, requiredFilesNextRefreshTime, arg1, arg2, arg3, arg4, arg5); + public Object evalInternal(State state, ConnectorSession session, StdFactory stdFactory, Type[] types, boolean isIntegerReturnType, + Object arg1, Object arg2, Object arg3, Object arg4, Object arg5) { + return eval(state, session, stdFactory, types, isIntegerReturnType, arg1, arg2, arg3, + arg4, arg5); } - public Object evalInternal(StdUDF stdUDF, Type[] types, boolean isIntegerReturnType, - AtomicLong requiredFilesNextRefreshTime, Object arg1, Object arg2, Object arg3, Object arg4, Object arg5, + public Object evalInternal(State state, ConnectorSession session, StdFactory stdFactory, Type[] types, boolean isIntegerReturnType, + Object arg1, Object arg2, Object arg3, Object arg4, Object arg5, Object arg6) { - return eval(stdUDF, types, isIntegerReturnType, requiredFilesNextRefreshTime, arg1, arg2, arg3, arg4, arg5, arg6); + return eval(state, session, stdFactory, types, isIntegerReturnType, arg1, arg2, arg3, + arg4, arg5, arg6); } - public Object evalInternal(StdUDF stdUDF, Type[] types, boolean isIntegerReturnType, - AtomicLong requiredFilesNextRefreshTime, Object arg1, Object arg2, Object arg3, Object arg4, Object arg5, + public Object evalInternal(State state, ConnectorSession session, StdFactory stdFactory, Type[] types, boolean isIntegerReturnType, + Object arg1, Object arg2, Object arg3, Object arg4, Object arg5, Object arg6, Object arg7) { - return eval(stdUDF, types, isIntegerReturnType, requiredFilesNextRefreshTime, arg1, arg2, arg3, arg4, arg5, arg6, + return eval(state, session, stdFactory, types, isIntegerReturnType, arg1, arg2, arg3, + arg4, arg5, arg6, arg7); } - public Object evalInternal(StdUDF stdUDF, Type[] types, boolean isIntegerReturnType, - AtomicLong requiredFilesNextRefreshTime, Object arg1, Object arg2, Object arg3, Object arg4, Object arg5, + public Object evalInternal(State state, ConnectorSession session, StdFactory stdFactory, Type[] types, boolean isIntegerReturnType, + Object arg1, Object arg2, Object arg3, Object arg4, Object arg5, Object arg6, Object arg7, Object arg8) { - return eval(stdUDF, types, isIntegerReturnType, requiredFilesNextRefreshTime, arg1, arg2, arg3, arg4, arg5, arg6, + return eval(state, session, stdFactory, types, isIntegerReturnType, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8); } + + public abstract static class State + { + private boolean initialized; + protected StdUDF stdUDF; + private long requiredFilesNextRefreshTime; + + public State() { + initialized = false; + requiredFilesNextRefreshTime = 0; + } + + public StdUDF getStdUDF() { + return stdUDF; + } + + public boolean isInitialized() { + return initialized; + } + + public void setInitialized() { + initialized = true; + } + + public long getRequiredFilesNextRefreshTime() { + return requiredFilesNextRefreshTime; + } + + public void setRequiredFilesNextRefreshTime(long requiredFilesNextRefreshTime) + { + this.requiredFilesNextRefreshTime = requiredFilesNextRefreshTime; + } + } }