Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
public class TrinoWrapperGenerator implements WrapperGenerator {

private static final String TRINO_PACKAGE_SUFFIX = "trino";
private static final String GET_STD_UDF_METHOD = "getStdUDF";
private static final String GET_STATE_CLASS_NAME_METHOD = "getStateClassName";
private static final ClassName TRINO_STD_UDF_WRAPPER_CLASS_NAME =
ClassName.bestGuess("com.linkedin.transport.trino.StdUdfWrapper");
private static final ClassName TRINO_STD_UDF_WRAPPER_STATE_CLASS_NAME =
ClassName.bestGuess("com.linkedin.transport.trino.StdUdfWrapper.State");
private static final String SERVICE_FILE = "META-INF/services/io.trino.metadata.SqlScalarFunction";

@Override
Expand All @@ -33,7 +35,15 @@ public void generateWrappers(WrapperGeneratorContext context) {
TransportUDFMetadata udfMetadata = context.getTransportUdfMetadata();
for (String topLevelClass : context.getTransportUdfMetadata().getTopLevelClasses()) {
for (String implementationClass : udfMetadata.getStdUDFImplementations(topLevelClass)) {
generateWrapper(implementationClass, context.getSourcesOutputDir(), services);
ClassName implementationClassName = ClassName.bestGuess(implementationClass);
ClassName stateClassName =
ClassName.get(implementationClassName.packageName() + "." + TRINO_PACKAGE_SUFFIX,
implementationClassName.simpleName() + "State");
ClassName wrapperClassName =
ClassName.get(implementationClassName.packageName() + "." + TRINO_PACKAGE_SUFFIX,
implementationClassName.simpleName());
generateWrapperClass(wrapperClassName, implementationClassName, context.getSourcesOutputDir(), services);
generateStateClass(stateClassName, implementationClassName, context.getSourcesOutputDir(), services);
}
}
try {
Expand All @@ -43,12 +53,8 @@ public void generateWrappers(WrapperGeneratorContext context) {
}
}

private void generateWrapper(String implementationClass, File sourcesOutputDir, List<String> services) {
ClassName implementationClassName = ClassName.bestGuess(implementationClass);
ClassName wrapperClassName =
ClassName.get(implementationClassName.packageName() + "." + TRINO_PACKAGE_SUFFIX,
implementationClassName.simpleName());

private void generateWrapperClass(ClassName wrapperClassName, ClassName implementationClassName, File sourcesOutputDir,
List<String> services) {
/*
Generates constructor ->

Expand All @@ -65,15 +71,15 @@ private void generateWrapper(String implementationClass, File sourcesOutputDir,
Generates ->

@Override
protected StdUDF getStdUDF() {
protected String getStateClassName() {
return new ${implementationClassName}();
}
*/
MethodSpec getStdUDFMethod = MethodSpec.methodBuilder(GET_STD_UDF_METHOD)
MethodSpec getStateClassNameMethod = MethodSpec.methodBuilder(GET_STATE_CLASS_NAME_METHOD)
.addAnnotation(Override.class)
.returns(StdUDF.class)
.returns(String.class)
.addModifiers(Modifier.PROTECTED)
.addStatement("return new $T()", implementationClassName)
.addStatement("return \"" + implementationClassName.reflectionName() + "\"")
.build();

/*
Expand All @@ -91,7 +97,7 @@ public class ${wrapperClassName} extends StdUdfWrapper {
.addModifiers(Modifier.PUBLIC)
.superclass(TRINO_STD_UDF_WRAPPER_CLASS_NAME)
.addMethod(constructor)
.addMethod(getStdUDFMethod)
.addMethod(getStateClassNameMethod)
.build();

services.add(wrapperClassName.toString());
Expand All @@ -105,4 +111,49 @@ public class ${wrapperClassName} extends StdUdfWrapper {
throw new RuntimeException("Error writing wrapper to file: ", e);
}
}

private void generateStateClass(ClassName stateClassName, ClassName implementationClassName, File sourcesOutputDir,
List<String> services) {
/*
Generates constructor ->

public ${stateClassName}() {
super();
stdUDF = new ${implementationClassName};
}
*/
MethodSpec constructor = MethodSpec.constructorBuilder()
.addModifiers(Modifier.PUBLIC)
.addStatement("super()")
.addStatement("stdUDF = new $T()", implementationClassName)
.build();

/*
Generates ->

public class ${stateClassName} extends State {

.
.
.

}
*/
TypeSpec wrapperClass = TypeSpec.classBuilder(stateClassName)
.addModifiers(Modifier.PUBLIC)
.superclass(TRINO_STD_UDF_WRAPPER_STATE_CLASS_NAME)
.addMethod(constructor)
.build();

services.add(stateClassName.toString());
JavaFile javaFile = JavaFile.builder(stateClassName.packageName(), wrapperClass)
.skipJavaLangImports(true)
.build();

try {
javaFile.writeTo(sourcesOutputDir);
} catch (Exception e) {
throw new RuntimeException("Error writing wrapper to file: ", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import com.linkedin.transport.api.udf.StdUDF;
import com.linkedin.transport.trino.StdUdfWrapper;

import java.lang.reflect.InvocationTargetException;


Expand All @@ -26,8 +27,8 @@ public TrinoTestStdUDFWrapper(Class<? extends StdUDF> udfClass) {
}

@Override
protected StdUDF getStdUDF() {
return createInstance(_udfClass);
protected String getStateClassName() {
return TestState.class.getName();
}

private static <K extends StdUDF> K createInstance(Class<K> udfClass) {
Expand All @@ -37,4 +38,11 @@ private static <K extends StdUDF> K createInstance(Class<K> udfClass) {
throw new RuntimeException(e);
}
}

public class TestState extends State {
public TestState () {
super();
stdUDF = createInstance(_udfClass);
}
}
}
Loading