diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..6eadc1d0 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,54 @@ +# +# CI build that assembles artifacts and runs tests. +# If validation is successful this workflow releases from the main dev branch. +# +# - skipping CI: add [skip ci] to the commit message +# - skipping release: add [skip release] to the commit message +# +name: CI + +on: + push: + branches: ['master'] + tags-ignore: [v*] # release tags are autogenerated after a successful CI, no need to run CI against them + pull_request: + branches: ['**'] + +jobs: + + build: + runs-on: ubuntu-latest + if: "! contains(toJSON(github.event.commits.*.message), '[skip ci]')" + + steps: + + - name: 1. Check out code + uses: actions/checkout@v2 # https://github.com/actions/checkout + with: + fetch-depth: '0' # https://github.com/shipkit/shipkit-changelog#fetch-depth-on-ci + + - name: 2. Setup Java JDK + uses: actions/setup-java@v2 + with: + distribution: 'adopt' + java-version: '8' + + - name: 3. Perform build + run: | + ./gradlew build + ./gradlew -p transportable-udfs-examples clean build -s + + - name: 4. Perform release + # Release job, only for pushes to the main development branch + if: github.event_name == 'push' + && github.ref == 'refs/heads/master' + && github.repository == 'linkedin/transport' + && !contains(toJSON(github.event.commits.*.message), '[skip release]') + + run: ./gradlew githubRelease publishToSonatype closeAndReleaseStagingRepository + env: + GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} + SONATYPE_USER: ${{secrets.SONATYPE_USER}} + SONATYPE_PWD: ${{secrets.SONATYPE_PWD}} + PGP_KEY: ${{secrets.PGP_KEY}} + PGP_PWD: ${{secrets.PGP_PWD}} diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 19e25c84..00000000 --- a/.travis.yml +++ /dev/null @@ -1,26 +0,0 @@ -# More details on how to configure the Travis build -# https://docs.travis-ci.com/user/customizing-the-build/ - -language: java - -jdk: - - openjdk8 - -#Skipping install step to avoid having Travis run arbitrary './gradlew assemble' task -# https://docs.travis-ci.com/user/customizing-the-build/#Skipping-the-Installation-Step -install: - - true - -#Don't build tags -branches: - except: - - /^v\d/ - -#Build and perform release (if needed) -script: - # Print output every minute to avoid travis timeout - - while sleep 1m; do echo "=====[ $SECONDS seconds elapsed -- still running ]====="; done & - # With the exception of release commands, all build logic goes in travis-build.sh - - ./travis-build.sh && ./gradlew ciPerformRelease -s - # Killing background sleep loop - - kill %1 diff --git a/README.md b/README.md index cba8da0b..7ec24da3 100644 --- a/README.md +++ b/README.md @@ -4,13 +4,13 @@ **Transport** is a framework for writing performant user-defined functions (UDFs) that are portable across a variety of engines including [Apache Spark](https://spark.apache.org/), [Apache Hive](https://hive.apache.org/), and -[Presto](https://prestodb.io/). Transport UDFs are also +[Trino](https://trino.io/). Transport UDFs are also capable of directly processing data stored in serialization formats such as Apache Avro. With Transport, developers only need to implement their UDF logic once using the Transport API. Transport then takes care of translating the UDF to native UDF version targeted at various engines or formats. Currently, Transport is capable of generating -engine-artifacts for Spark, Hive, and Presto, and format-artifacts for +engine-artifacts for Spark, Hive, and Trino, and format-artifacts for Avro. Further details on Transport can be found in this [LinkedIn Engineering blog post](https://engineering.linkedin.com/blog/2018/11/using-translatable-portable-UDFs). ## Documentation @@ -127,7 +127,7 @@ to familiarize yourself with the API, and how to write new UDFs. to find out how to write UDF tests in a unified testing API, but have the framework test them on multiple platforms. * Root [`build.gradle`](transportable-udfs-examples/build.gradle) file -to find out how to apply the `transport` plugin, which enables generating Hive, Spark, and Presto UDFs out of +to find out how to apply the `transport` plugin, which enables generating Hive, Spark, and Trino UDFs out of the transportable UDFs you define once you build your project. To see that in action: Change directory to `transportable-udfs-examples`: @@ -153,7 +153,7 @@ The results should be like: ``` transportable-udfs-example-udfs-hive.jar -transportable-udfs-example-udfs-presto.jar +transportable-udfs-example-udfs-trino.jar transportable-udfs-example-udfs-spark.jar transportable-udfs-example-udfs.jar ``` @@ -162,13 +162,13 @@ That is it! While only one version of the UDFs is implemented, multiple jars are Each of those jars uses native platform APIs and data models to implement the UDFs. So from an execution engine's perspective, there is no data transformation needed for interoperability or portability. Only suitable classes are used for each engine. -To call those jars from your SQL engine (i.e., Hive, Spark, or Presto), the standard process for deploying UDF jars is followed +To call those jars from your SQL engine (i.e., Hive, Spark, or Trino), the standard process for deploying UDF jars is followed for each engine. For example, in Hive, you add the jar to the classpath using the `ADD JAR` statement, and register the UDF using `CREATE FUNCTION` statement. -In Presto, the jar is deployed to the `plugin` directory. However, a small patch is required for the Presto -engine to recognize the jar as a plugin, since the generated Presto UDFs implement the `SqlScalarFunction` API, -which is currently not part of Presto's SPI architecture. You can find the patch [here](transportable-udfs-documentation/transport-udfs-presto.patch) and apply it - before deploying your UDFs jar to the Presto engine. +In Trino, the jar is deployed to the `plugin` directory. However, a small patch is required for the Trino +engine to recognize the jar as a plugin, since the generated Trino UDFs implement the `SqlScalarFunction` API, +which is currently not part of Trino's SPI architecture. You can find the patch [here](docs/transport-udfs-trino.patch) and apply it + before deploying your UDFs jar to the Trino engine. ## Contributing The project is under active development and we welcome contributions of different forms: diff --git a/build.gradle b/build.gradle index eb5d6ea7..fbe649ee 100644 --- a/build.gradle +++ b/build.gradle @@ -14,14 +14,18 @@ buildscript { classpath 'com.github.jengelman.gradle.plugins:shadow:2.0.4' classpath 'org.github.ngbinh.scalastyle:gradle-scalastyle-plugin_2.11:1.0.1' classpath 'gradle.plugin.nl.javadude.gradle.plugins:license-gradle-plugin:0.14.0' + classpath "io.github.gradle-nexus:publish-plugin:1.0.0" + classpath "org.shipkit:shipkit-auto-version:1.1.1" + classpath "org.shipkit:shipkit-changelog:1.1.10" } } plugins { - id "org.shipkit.java" version "2.3.4" id "checkstyle" } +apply from: "gradle/shipkit.gradle" + allprojects { group = 'com.linkedin.transport' apply plugin: 'idea' @@ -74,8 +78,11 @@ subprojects { } checkstyle { - configFile = file("${rootDir}/gradle/checkstyle/checkstyle.xml") - configProperties = ['config_loc' : "${rootDir}/gradle/checkstyle/"] + configFile = rootProject.file('gradle/checkstyle/checkstyle.xml') + configProperties = [ + 'configDir': rootProject.file('gradle/checkstyle'), + 'baseDir': rootDir + ] toolVersion '8.23' } } diff --git a/defaultEnvironment.gradle b/defaultEnvironment.gradle index c6b83602..3480cf21 100644 --- a/defaultEnvironment.gradle +++ b/defaultEnvironment.gradle @@ -10,8 +10,9 @@ subprojects { url "https://conjars.org/repo" } } - project.ext.setProperty('presto-version', '333') - project.ext.setProperty('airlift-slice-version', '0.38') + project.ext.setProperty('trino-version', '352') + project.ext.setProperty('airlift-slice-version', '0.39') project.ext.setProperty('spark-group', 'org.apache.spark') - project.ext.setProperty('spark-version', '2.3.0') + project.ext.setProperty('spark2-version', '2.3.0') + project.ext.setProperty('spark3-version', '3.1.1') } diff --git a/docs/required-trino-apis.md b/docs/required-trino-apis.md new file mode 100644 index 00000000..675c6854 --- /dev/null +++ b/docs/required-trino-apis.md @@ -0,0 +1,42 @@ +# Why is modifying the Trino SPI interface necessary for Transport to work? +Transport requires applying this [patch](transport-udfs-trino.patch) before being able to use Transport with Trino. +This patch makes some of the internal UDF classes be visible at the SPI layer. +Below we explain why some Transport APIs cannot leverage the APIs offered by the [public SPI UDF model](https://trino.io/docs/current/develop/functions.html). + +## [init() method](https://github.com/linkedin/transport/blob/09a89508296a2491f43cc8866d47952c911313ab/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF.java#L45) is hard to implement on top of Trino-SPI +The `init()` method allows users to perform necessary initializations for their Transport UDFs. +Conceptually, it is called once at the UDF initialization time before processing any records. It sets the [StdFactory](https://github.com/linkedin/transport/blob/d919f96dc1485ccb8b58e4faed3a5589a5966236/transportable-udfs-api/src/main/java/com/linkedin/transport/api/StdFactory.java#L36) to be used by the +`StdUDF`, and can be used to create Java types that correspond to the type signatures provided by the user. +Due to the lack of a similar API in the SPI UDF model, in the current approach, `init()` is called inside +overridden [specialize()](https://github.com/linkedin/transport/blob/d919f96dc1485ccb8b58e4faed3a5589a5966236/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUdfWrapper.java#L136) method in [StdUdfWrapper](https://github.com/linkedin/transport/blob/d919f96dc1485ccb8b58e4faed3a5589a5966236/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUdfWrapper.java#L72) +which extends [SqlScalarFunction](https://github.com/trinodb/trino/blob/54d8154037dfe5f6f65709dbafeb92f5506af2ac/core/trino-main/src/main/java/io/trino/metadata/SqlScalarFunction.java#L18). +That way, we can implement the + semantics of init(): + +## [TrinoFactory](https://github.com/linkedin/transport/blob/92dfbbfd989367418bdd14f9ac4cc2bcf1e7c777/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/TrinoFactory.java#L52) requires `FunctionBinding` and `FunctionDependencies` which are not provided by the Trino-SPI +[TrinoFactory](https://github.com/linkedin/transport/blob/92dfbbfd989367418bdd14f9ac4cc2bcf1e7c777/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/TrinoFactory.java#L52) +is designed to convert Transport data types and their required operators (e.g., the equals function of map keys) +to Trino native data type and operators. This serves implementing the + [createStdType()](https://github.com/linkedin/transport/blob/92dfbbfd989367418bdd14f9ac4cc2bcf1e7c777/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/TrinoFactory.java#L139) +in [StdFactory](https://github.com/linkedin/transport/blob/d919f96dc1485ccb8b58e4faed3a5589a5966236/transportable-udfs-api/src/main/java/com/linkedin/transport/api/StdFactory.java#L36), which is a standard +API across all engines. +The TrinoFactory factory implementaiton of the StdFactory requires Trino classes [FunctionBinding](https://github.com/trinodb/trino/blob/54d8154037dfe5f6f65709dbafeb92f5506af2ac/core/trino-main/src/main/java/io/trino/metadata/FunctionBinding.java#L26) +and [FunctionDependencies](https://github.com/trinodb/trino/blob/0b1a1b9fa036bac132c80c990166096abc1b2552/core/trino-main/src/main/java/io/trino/metadata/FunctionDependencies.java#L47) +to implement its basic functionality; however those classes are not provided by the Trino SPI UDF model. +In the current integration approach, TrinoFactory is initialized inside the overridden [specialize()](https://github.com/linkedin/transport/blob/d919f96dc1485ccb8b58e4faed3a5589a5966236/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUdfWrapper.java#L136) method +in [StdUdfWrapper](https://github.com/linkedin/transport/blob/d919f96dc1485ccb8b58e4faed3a5589a5966236/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUdfWrapper.java#L72) +which extends [SqlScalarFunction](https://github.com/trinodb/trino/blob/54d8154037dfe5f6f65709dbafeb92f5506af2ac/core/trino-main/src/main/java/io/trino/metadata/SqlScalarFunction.java#L18) +, and gets access to those two classes from there. + +The snippet below shows how the Transport Trino implementation uses the `SqlScalarFunction#specialize()` method +to call `StdUF#init()` and pass the `FunctionDependencies` and `FunctionBinding` objects to the TrinoFactory. +```java +@Override +public ScalarFunctionImplementation specialize(FunctionBinding functionBinding, FunctionDependencies functionDependencies) { + StdFactory stdFactory = new TrinoFactory(functionBinding, functionDependencies); + StdUDF stdUDF = getStdUDF(); + stdUDF.init(stdFactory); + ... +} +``` + diff --git a/docs/transport-udfs-presto.patch b/docs/transport-udfs-trino.patch similarity index 58% rename from docs/transport-udfs-presto.patch rename to docs/transport-udfs-trino.patch index c29b1dd6..5ed54f59 100644 --- a/docs/transport-udfs-presto.patch +++ b/docs/transport-udfs-trino.patch @@ -1,24 +1,24 @@ -diff --git a/presto-main/src/main/java/io/prestosql/server/PluginManager.java b/presto-main/src/main/java/io/prestosql/server/PluginManager.java -index abcd001031..053c17aeed 100644 ---- a/presto-main/src/main/java/io/prestosql/server/PluginManager.java -+++ b/presto-main/src/main/java/io/prestosql/server/PluginManager.java -@@ -23,6 +23,7 @@ import io.prestosql.connector.ConnectorManager; - import io.prestosql.eventlistener.EventListenerManager; - import io.prestosql.execution.resourcegroups.ResourceGroupManager; - import io.prestosql.metadata.MetadataManager; -+import io.prestosql.metadata.SqlScalarFunction; - import io.prestosql.security.AccessControlManager; - import io.prestosql.security.GroupProviderManager; - import io.prestosql.server.security.PasswordAuthenticatorManager; -@@ -54,6 +55,7 @@ import java.util.ServiceLoader; +diff --git a/core/trino-main/src/main/java/io/trino/server/PluginManager.java b/core/trino-main/src/main/java/io/trino/server/PluginManager.java +index 76cc04ca9d..483e609c86 100644 +--- a/core/trino-main/src/main/java/io/trino/server/PluginManager.java ++++ b/core/trino-main/src/main/java/io/trino/server/PluginManager.java +@@ -23,6 +23,7 @@ import io.trino.connector.ConnectorManager; + import io.trino.eventlistener.EventListenerManager; + import io.trino.execution.resourcegroups.ResourceGroupManager; + import io.trino.metadata.MetadataManager; ++import io.trino.metadata.SqlScalarFunction; + import io.trino.security.AccessControlManager; + import io.trino.security.GroupProviderManager; + import io.trino.server.security.CertificateAuthenticatorManager; +@@ -55,6 +56,7 @@ import java.util.ServiceLoader; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Supplier; +import java.util.stream.Collectors; import static com.google.common.base.Preconditions.checkState; - import static io.prestosql.metadata.FunctionExtractor.extractFunctions; -@@ -64,8 +66,22 @@ import static java.util.Objects.requireNonNull; + import static io.trino.metadata.FunctionExtractor.extractFunctions; +@@ -65,8 +67,27 @@ import static java.util.Objects.requireNonNull; @ThreadSafe public class PluginManager { @@ -29,19 +29,24 @@ index abcd001031..053c17aeed 100644 + // as it is the case with vanilla plugins. + // JIRA: https://jira01.corp.linkedin.com:8443/browse/LIHADOOP-34269 private static final ImmutableList SPI_PACKAGES = ImmutableList.builder() -+ // io.prestosql.metadata is required for SqlScalarFunction and FunctionRegistry classes -+ .add("io.prestosql.metadata.") -+ // io.prestosql.operator. is required for ScalarFunctionImplementation and TypeSignatureParser -+ .add("io.prestosql.operator.") - .add("io.prestosql.spi.") -+ // io.prestosql.type is required for TypeManager, and all supported types -+ .add("io.prestosql.type.") -+ // io.prestosql.util is required for Reflection -+ .add("io.prestosql.util.") ++ // io.trino.metadata is required for SqlScalarFunction, Metadata, MetadataManager, FunctionBinding, ++ // FunctionDependencies, TypeVariableConstraint, FunctionArgumentDefinition, FunctionKind, FunctionMetadata, ++ // Signature and SignatureBinder classes ++ .add("io.trino.metadata.") ++ // io.trino.operator. is required for AbstractTestFunctions, ScalarFunctionImplementation ++ // & ChoicesScalarFunctionImplementation ++ .add("io.trino.operator.") ++ // io.trino.sql.analyzer.TypeSignatureTranslator. is required for parseTypeSignature ++ .add("io.trino.sql.analyzer.TypeSignatureTranslator.") + .add("io.trino.spi.") ++ // io.trino.type is required for UnknownType ++ .add("io.trino.type.") ++ // io.trino.util is required for Reflection ++ .add("io.trino.util.") .add("com.fasterxml.jackson.annotation.") .add("io.airlift.slice.") .add("org.openjdk.jol.") -@@ -159,11 +175,22 @@ public class PluginManager +@@ -163,11 +184,26 @@ public class PluginManager { ServiceLoader serviceLoader = ServiceLoader.load(Plugin.class, pluginClassLoader); List plugins = ImmutableList.copyOf(serviceLoader); diff --git a/docs/using-transport-udfs.md b/docs/using-transport-udfs.md index 9b4e5f97..687ece14 100644 --- a/docs/using-transport-udfs.md +++ b/docs/using-transport-udfs.md @@ -11,19 +11,19 @@ The Transport framework automatically generates UDF artifacts for each supported - [Using the UDF artifacts](#using-the-udf-artifacts) - [Hive](#hive) - [Spark](#spark) - - [Presto](#presto) + - [Trino](#trino) ## Identifying platform-specific UDF artifacts ### Platform-specific artifact file -As mentioned above, the Transport Plugin will automatically generate artifacts for each platform. Once these artifacts are published to a ivy repository, you can consume them using the corresponding ivy coordinates using the platform name as a maven classifier. E.g. if the UDF has an ivy coordinate `com.linkedin.transport-example:example-udf:1.0.0`, then the coordinate for the platform-specific UDF would be `com.linkedin.transport-example:example-udf:1.0.0?classifier=PLATFORM-NAME` where `PLATFORM-NAME` is `hive`, `presto` or `spark`. +As mentioned above, the Transport Plugin will automatically generate artifacts for each platform. Once these artifacts are published to a ivy repository, you can consume them using the corresponding ivy coordinates using the platform name as a maven classifier. E.g. if the UDF has an ivy coordinate `com.linkedin.transport-example:example-udf:1.0.0`, then the coordinate for the platform-specific UDF would be `com.linkedin.transport-example:example-udf:1.0.0?classifier=PLATFORM-NAME` where `PLATFORM-NAME` is `hive`, `trino` or `spark`. -If you are building the UDF project locally, the platform-specific artifacts are built alongside the UDF artifact in the output directory with the platform name as a file suffix. If the built UDF is located at `/path/to/example-udf.ext` then the platform-specific artifact is located at `/path/to/example-udf-PLATFORM-NAME.ext` where `PLATFORM-NAME` is `hive`, `presto` or `spark`. +If you are building the UDF project locally, the platform-specific artifacts are built alongside the UDF artifact in the output directory with the platform name as a file suffix. If the built UDF is located at `/path/to/example-udf.ext` then the platform-specific artifact is located at `/path/to/example-udf-PLATFORM-NAME.ext` where `PLATFORM-NAME` is `hive`, `trino` or `spark`. ### Platform-specific UDF class -If the UDF class is `com.linkedin.transport.example.ExampleUDF` then the platform-specific UDF class will be `com.linkedin.transport.example.PLATFORM-NAME.ExampleUDF` where `PLATFORM-NAME` is `hive`, `presto` or `spark`. +If the UDF class is `com.linkedin.transport.example.ExampleUDF` then the platform-specific UDF class will be `com.linkedin.transport.example.PLATFORM-NAME.ExampleUDF` where `PLATFORM-NAME` is `hive`, `trino` or `spark`. ## Using the UDF artifacts @@ -80,16 +80,16 @@ If the UDF class is `com.linkedin.transport.example.ExampleUDF` then the platfor ) ``` -### Presto +### Trino -1. Add the UDF to the Presto installation -Unlike Hive and Spark, Presto currently does not allow dynamically loading jar files once the Presto server has started. -In Presto, the jar is deployed to the `plugin` directory. -However, a small patch is required for the Presto engine to recognize the jar as a plugin, since the generated Presto UDFs implement the `SqlScalarFunction` API, which is currently not part of Presto's SPI architecture. -You can find the patch [here](transport-udfs-presto.patch) and apply it before deploying your UDFs jar to the Presto engine. +1. Add the UDF to the Trino installation +Unlike Hive and Spark, Trino currently does not allow dynamically loading jar files once the Trino server has started. +In Trino, the jar is deployed to the `plugin` directory. +However, a small patch is required for the Trino engine to recognize the jar as a plugin, since the generated Trino UDFs implement the `SqlScalarFunction` API, which is currently not part of Trino's SPI architecture. +You can find the patch [here](transport-udfs-trino.patch) and apply it before deploying your UDFs jar to the Trino engine ([Why is this patch needed?](required-trino-apis.md)). 2. Call the UDF in a query To call the UDF, you will need to use the function name defined in the Transport UDF definition. ``` - presto-cli> SELECT example_udf(some_column, 'some_constant'); + trino-cli> SELECT example_udf(some_column, 'some_constant'); ``` diff --git a/gradle/checkstyle/checkstyle.xml b/gradle/checkstyle/checkstyle.xml index a205c87e..5260d332 100644 --- a/gradle/checkstyle/checkstyle.xml +++ b/gradle/checkstyle/checkstyle.xml @@ -191,7 +191,8 @@ LinkedIn Java style. Before uncommenting this please read the "Suppression File" section of http://go/checkstyle to prevent error events in IntelliJ IDEA. --> - + + diff --git a/gradle/java-publication.gradle b/gradle/java-publication.gradle new file mode 100644 index 00000000..ae68d9ac --- /dev/null +++ b/gradle/java-publication.gradle @@ -0,0 +1,84 @@ +def licenseSpec = copySpec { + from project.rootDir + include "LICENSE" +} + +task sourcesJar(type: Jar, dependsOn: classes) { + classifier 'sources' + from sourceSets.main.allSource + with licenseSpec +} + +task javadocJar(type: Jar, dependsOn: javadoc) { + classifier 'javadoc' + from tasks.javadoc + with licenseSpec +} + +jar { + with licenseSpec +} + +artifacts { + archives sourcesJar + archives javadocJar +} + +apply plugin: "maven-publish" //https://docs.gradle.org/current/userguide/publishing_maven.html +publishing { + publications { + javaLibrary(MavenPublication) { + from components.java + artifact sourcesJar + artifact javadocJar + + artifactId = project.archivesBaseName + + pom { + name = artifactId + description = "A library for analyzing, processing, and rewriting views defined in the Hive Metastore, and sharing them across multiple execution engines" + + url = "https://github.com/linkedin/transport" + licenses { + license { + name = 'BSD 2-CLAUSE LICENSE' + url = 'https://github.com/linkedin/transport/blob/master/LICENSE' + distribution = 'repo' + } + } + developers { + developer { + id = 'wmoustafa' + name = 'Walaa Eldin Moustafa' + } + developer { + id = 'shardulm94' + name = 'Shardul Mahadik' + } + } + scm { + url = 'https://github.com/linkedin/transport.git' + } + issueManagement { + url = 'https://github.com/linkedin/transport/issues' + system = 'GitHub issues' + } + ciManagement { + url = 'https://travis-ci.com/linkedin/transport' + system = 'Travis CI' + } + } + } + } + + //useful for testing - running "publish" will create artifacts/pom in a local dir + repositories { maven { url = "$rootProject.buildDir/repo" } } +} + +apply plugin: 'signing' //https://docs.gradle.org/current/userguide/signing_plugin.html +signing { + if (System.getenv("PGP_KEY")) { + useInMemoryPgpKeys(System.getenv("PGP_KEY"), System.getenv("PGP_PWD")) + sign publishing.publications.javaLibrary + } +} \ No newline at end of file diff --git a/gradle/shipkit.gradle b/gradle/shipkit.gradle index a5d979d8..6f3e828d 100644 --- a/gradle/shipkit.gradle +++ b/gradle/shipkit.gradle @@ -1,34 +1,38 @@ -shipkit { - gitHub.repository = "linkedin/transport" +//Plugin jars are added to the buildscript classpath in the root build.gradle file +apply plugin: "org.shipkit.shipkit-auto-version" //https://github.com/shipkit/shipkit-auto-version - gitHub.readOnlyAuthToken = "361a43a2b351e61e2243c5ea15792f33a3c9b467" - - // The GitHub write token is required for committing release notes and bumping up project version - // Ensure that the release machine or Travis CI has this env variable exported - gitHub.writeAuthToken = System.getenv("GH_WRITE_TOKEN") - - git.releasableBranchRegex = "master|release/.+" +apply plugin: "org.shipkit.shipkit-changelog" //https://github.com/shipkit/shipkit-changelog +tasks.named("generateChangelog") { + previousRevision = project.ext.'shipkit-auto-version.previous-tag' + githubToken = System.getenv("GITHUB_TOKEN") + repository = "linkedin/transport" } -allprojects { - plugins.withId("org.shipkit.bintray") { - - //Bintray configuration is handled by JFrog Bintray Gradle Plugin - //For reference see the official documentation: https://github.com/bintray/gradle-bintray-plugin - bintray { - - // The Bintray API token is required to publish artifacts to Bintray - // Ensure that the release machine or Travis CI has this env variable exported - key = System.getenv("BINTRAY_API_KEY") +apply plugin: "org.shipkit.shipkit-github-release" //https://github.com/shipkit/shipkit-changelog +tasks.named("githubRelease") { + def genTask = tasks.named("generateChangelog").get() + dependsOn genTask + repository = genTask.repository + changelog = genTask.outputFile + githubToken = System.getenv("GITHUB_TOKEN") + newTagRevision = System.getenv("GITHUB_SHA") +} - pkg { - repo = 'maven' - user = 'smahadik' - userOrg = 'linkedin-transport' - name = 'transport' - licenses = ['BSD 2-Clause'] - labels = ['transport', 'UDF', 'user defined functions', 'portable'] +apply plugin: "io.github.gradle-nexus.publish-plugin" //https://github.com/gradle-nexus/publish-plugin/ +nexusPublishing { + repositories { + if (System.getenv("SONATYPE_PWD")) { + sonatype { + username = System.getenv("SONATYPE_USER") + password = System.getenv("SONATYPE_PWD") } } } } + +// we need to exclude the plugin module for its specific gradle configuration +configure(allprojects - project(':transportable-udfs-plugin')) { p -> + plugins.withId('java') { + p.apply from: "$rootDir/gradle/java-publication.gradle" + } +} diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 75b8c7c8..14e30f74 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-5.0-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-6.7-all.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/settings.gradle b/settings.gradle index 65480879..a775c86e 100644 --- a/settings.gradle +++ b/settings.gradle @@ -12,14 +12,16 @@ def modules = [ 'transportable-udfs-compile-utils', 'transportable-udfs-hive', 'transportable-udfs-plugin', - 'transportable-udfs-presto', - 'transportable-udfs-spark', + 'transportable-udfs-spark_2.11', + 'transportable-udfs-spark_2.12', + 'transportable-udfs-trino', 'transportable-udfs-test:transportable-udfs-test-api', 'transportable-udfs-test:transportable-udfs-test-generic', 'transportable-udfs-test:transportable-udfs-test-hive', - 'transportable-udfs-test:transportable-udfs-test-presto', - 'transportable-udfs-test:transportable-udfs-test-spark', + 'transportable-udfs-test:transportable-udfs-test-spark_2.11', + 'transportable-udfs-test:transportable-udfs-test-spark_2.12', 'transportable-udfs-test:transportable-udfs-test-spi', + 'transportable-udfs-test:transportable-udfs-test-trino', 'transportable-udfs-type-system', 'transportable-udfs-utils' ] diff --git a/transportable-udfs-annotation-processor/src/main/java/com/linkedin/transport/processor/TransportProcessor.java b/transportable-udfs-annotation-processor/src/main/java/com/linkedin/transport/processor/TransportProcessor.java index 12e25002..30a110a8 100644 --- a/transportable-udfs-annotation-processor/src/main/java/com/linkedin/transport/processor/TransportProcessor.java +++ b/transportable-udfs-annotation-processor/src/main/java/com/linkedin/transport/processor/TransportProcessor.java @@ -130,10 +130,19 @@ private void processUDFClass(TypeElement udfClassElement) { udfClassElement ); } else { - String topLevelStdUdfClassName = - elementsOverridingTopLevelStdUDFMethods.iterator().next().getQualifiedName().toString(); + TypeElement topLevelStdUdfTypeElement = elementsOverridingTopLevelStdUDFMethods.iterator().next(); + String topLevelStdUdfClassName = topLevelStdUdfTypeElement.getQualifiedName().toString(); debug(String.format("TopLevelStdUDF class found: %s", topLevelStdUdfClassName)); + String udfClassName = udfClassElement.getQualifiedName().toString(); _transportUdfMetadata.addUDF(topLevelStdUdfClassName, udfClassElement.getQualifiedName().toString()); + _transportUdfMetadata.setClassNumberOfTypeParameters( + topLevelStdUdfClassName, + topLevelStdUdfTypeElement.getTypeParameters().size() + ); + _transportUdfMetadata.setClassNumberOfTypeParameters( + udfClassName, + udfClassElement.getTypeParameters().size() + ); } } diff --git a/transportable-udfs-annotation-processor/src/test/java/com/linkedin/transport/processor/TransportProcessorTest.java b/transportable-udfs-annotation-processor/src/test/java/com/linkedin/transport/processor/TransportProcessorTest.java index 0322f306..dd3c25c5 100644 --- a/transportable-udfs-annotation-processor/src/test/java/com/linkedin/transport/processor/TransportProcessorTest.java +++ b/transportable-udfs-annotation-processor/src/test/java/com/linkedin/transport/processor/TransportProcessorTest.java @@ -81,7 +81,7 @@ public void shouldNotContainMultipleOverridingsOfTopLevelStdUDFMethods1() throws .withErrorCount(1) .withErrorContaining(Constants.MORE_THAN_ONE_TYPE_OVERRIDING_ERROR) .in(forResource("udfs/UDFWithMultipleInterfaces1.java")) - .onLine(14) + .onLine(13) .atColumn(8); } @@ -96,7 +96,7 @@ public void shouldNotContainMultipleOverridingsOfTopLevelStdUDFMethods2() throws .withErrorCount(1) .withErrorContaining(Constants.MORE_THAN_ONE_TYPE_OVERRIDING_ERROR) .in(forResource("udfs/UDFWithMultipleInterfaces2.java")) - .onLine(13) + .onLine(12) .atColumn(8); } @@ -110,7 +110,7 @@ public void udfShouldNotOverrideInterfaceMethods() throws IOException { .withErrorCount(1) .withErrorContaining(Constants.MORE_THAN_ONE_TYPE_OVERRIDING_ERROR) .in(forResource("udfs/UDFOverridingInterfaceMethod.java")) - .onLine(14) + .onLine(13) .atColumn(8); } @@ -123,7 +123,7 @@ public void udfShouldImplementTopLevelStdUDF() throws IOException { .withErrorCount(1) .withErrorContaining(Constants.INTERFACE_NOT_IMPLEMENTED_ERROR) .in(forResource("udfs/UDFNotImplementingTopLevelStdUDF.java")) - .onLine(14) + .onLine(13) .atColumn(8); } diff --git a/transportable-udfs-annotation-processor/src/test/resources/outputs/empty.json b/transportable-udfs-annotation-processor/src/test/resources/outputs/empty.json index 7836d3b6..4e9cfcb3 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/outputs/empty.json +++ b/transportable-udfs-annotation-processor/src/test/resources/outputs/empty.json @@ -1,3 +1,4 @@ { - "udfs": [] + "udfs": {}, + "classToNumberOfTypeParameters": {} } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/outputs/overloadedUDF.json b/transportable-udfs-annotation-processor/src/test/resources/outputs/overloadedUDF.json index 9f6a2450..4f0526f0 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/outputs/overloadedUDF.json +++ b/transportable-udfs-annotation-processor/src/test/resources/outputs/overloadedUDF.json @@ -1,11 +1,13 @@ { - "udfs": [ - { - "topLevelClass": "udfs.OverloadedUDF1", - "stdUDFImplementations": [ - "udfs.OverloadedUDFInt", - "udfs.OverloadedUDFString" - ] - } - ] + "udfs": { + "udfs.OverloadedUDF1": [ + "udfs.OverloadedUDFInt", + "udfs.OverloadedUDFString" + ] + }, + "classToNumberOfTypeParameters": { + "udfs.OverloadedUDFString": 0, + "udfs.OverloadedUDF1": 0, + "udfs.OverloadedUDFInt": 0 + } } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/outputs/simpleUDF.json b/transportable-udfs-annotation-processor/src/test/resources/outputs/simpleUDF.json index 9323ddbd..34c7cee2 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/outputs/simpleUDF.json +++ b/transportable-udfs-annotation-processor/src/test/resources/outputs/simpleUDF.json @@ -1,10 +1,10 @@ { - "udfs": [ - { - "topLevelClass": "udfs.SimpleUDF", - "stdUDFImplementations": [ - "udfs.SimpleUDF" - ] - } - ] + "udfs": { + "udfs.SimpleUDF": [ + "udfs.SimpleUDF" + ] + }, + "classToNumberOfTypeParameters": { + "udfs.SimpleUDF": 0 + } } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/outputs/udfExtendingAbstractUDF.json b/transportable-udfs-annotation-processor/src/test/resources/outputs/udfExtendingAbstractUDF.json index ab58d4d8..5b72a274 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/outputs/udfExtendingAbstractUDF.json +++ b/transportable-udfs-annotation-processor/src/test/resources/outputs/udfExtendingAbstractUDF.json @@ -1,10 +1,10 @@ { - "udfs": [ - { - "topLevelClass": "udfs.UDFExtendingAbstractUDF", - "stdUDFImplementations": [ - "udfs.UDFExtendingAbstractUDF" - ] - } - ] + "udfs": { + "udfs.UDFExtendingAbstractUDF": [ + "udfs.UDFExtendingAbstractUDF" + ] + }, + "classToNumberOfTypeParameters": { + "udfs.UDFExtendingAbstractUDF": 0 + } } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/outputs/udfExtendingAbstractUDFImplementingInterface.json b/transportable-udfs-annotation-processor/src/test/resources/outputs/udfExtendingAbstractUDFImplementingInterface.json index d2551a77..b75531e1 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/outputs/udfExtendingAbstractUDFImplementingInterface.json +++ b/transportable-udfs-annotation-processor/src/test/resources/outputs/udfExtendingAbstractUDFImplementingInterface.json @@ -1,10 +1,11 @@ { - "udfs": [ - { - "topLevelClass": "udfs.AbstractUDFImplementingInterface", - "stdUDFImplementations": [ - "udfs.UDFExtendingAbstractUDFImplementingInterface" - ] - } - ] + "udfs": { + "udfs.AbstractUDFImplementingInterface": [ + "udfs.UDFExtendingAbstractUDFImplementingInterface" + ] + }, + "classToNumberOfTypeParameters": { + "udfs.UDFExtendingAbstractUDFImplementingInterface": 0, + "udfs.AbstractUDFImplementingInterface": 0 + } } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/udfs/AbstractUDF.java b/transportable-udfs-annotation-processor/src/test/resources/udfs/AbstractUDF.java index 4a482115..06536aa2 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/udfs/AbstractUDF.java +++ b/transportable-udfs-annotation-processor/src/test/resources/udfs/AbstractUDF.java @@ -5,11 +5,10 @@ */ package udfs; -import com.linkedin.transport.api.data.StdString; import com.linkedin.transport.api.udf.StdUDF0; import com.linkedin.transport.api.udf.TopLevelStdUDF; -public abstract class AbstractUDF extends StdUDF0 implements TopLevelStdUDF { +public abstract class AbstractUDF extends StdUDF0 implements TopLevelStdUDF { } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/udfs/AbstractUDFImplementingInterface.java b/transportable-udfs-annotation-processor/src/test/resources/udfs/AbstractUDFImplementingInterface.java index 4078d7bc..7d85fb36 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/udfs/AbstractUDFImplementingInterface.java +++ b/transportable-udfs-annotation-processor/src/test/resources/udfs/AbstractUDFImplementingInterface.java @@ -5,12 +5,11 @@ */ package udfs; -import com.linkedin.transport.api.data.StdString; import com.linkedin.transport.api.udf.StdUDF0; import com.linkedin.transport.api.udf.TopLevelStdUDF; -public abstract class AbstractUDFImplementingInterface extends StdUDF0 implements TopLevelStdUDF { +public abstract class AbstractUDFImplementingInterface extends StdUDF0 implements TopLevelStdUDF { @Override public String getFunctionName() { diff --git a/transportable-udfs-annotation-processor/src/test/resources/udfs/OuterClassForInnerUDF.java b/transportable-udfs-annotation-processor/src/test/resources/udfs/OuterClassForInnerUDF.java index d5d2551d..a84ee746 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/udfs/OuterClassForInnerUDF.java +++ b/transportable-udfs-annotation-processor/src/test/resources/udfs/OuterClassForInnerUDF.java @@ -6,14 +6,13 @@ package udfs; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdString; import com.linkedin.transport.api.udf.StdUDF0; import com.linkedin.transport.api.udf.TopLevelStdUDF; import java.util.List; public class OuterClassForInnerUDF { - public class InnerUDF extends StdUDF0 implements TopLevelStdUDF { + public class InnerUDF extends StdUDF0 implements TopLevelStdUDF { @Override public String getFunctionName() { @@ -36,7 +35,7 @@ public String getOutputParameterSignature() { } @Override - public StdString eval() { + public String eval() { return null; } } diff --git a/transportable-udfs-annotation-processor/src/test/resources/udfs/OverloadedUDFInt.java b/transportable-udfs-annotation-processor/src/test/resources/udfs/OverloadedUDFInt.java index 3f130d9d..292c8606 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/udfs/OverloadedUDFInt.java +++ b/transportable-udfs-annotation-processor/src/test/resources/udfs/OverloadedUDFInt.java @@ -6,12 +6,11 @@ package udfs; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdInteger; import com.linkedin.transport.api.udf.StdUDF0; import java.util.List; -public class OverloadedUDFInt extends StdUDF0 implements OverloadedUDF1 { +public class OverloadedUDFInt extends StdUDF0 implements OverloadedUDF1 { @Override public List getInputParameterSignatures() { @@ -24,7 +23,7 @@ public String getOutputParameterSignature() { } @Override - public StdInteger eval() { + public Integer eval() { return null; } } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/udfs/OverloadedUDFString.java b/transportable-udfs-annotation-processor/src/test/resources/udfs/OverloadedUDFString.java index 9782d683..d0855f55 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/udfs/OverloadedUDFString.java +++ b/transportable-udfs-annotation-processor/src/test/resources/udfs/OverloadedUDFString.java @@ -6,12 +6,11 @@ package udfs; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdString; import com.linkedin.transport.api.udf.StdUDF0; import java.util.List; -public class OverloadedUDFString extends StdUDF0 implements OverloadedUDF1 { +public class OverloadedUDFString extends StdUDF0 implements OverloadedUDF1 { @Override public List getInputParameterSignatures() { @@ -24,7 +23,7 @@ public String getOutputParameterSignature() { } @Override - public StdString eval() { + public String eval() { return null; } } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/udfs/SimpleUDF.java b/transportable-udfs-annotation-processor/src/test/resources/udfs/SimpleUDF.java index 15e749ec..46231c63 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/udfs/SimpleUDF.java +++ b/transportable-udfs-annotation-processor/src/test/resources/udfs/SimpleUDF.java @@ -6,13 +6,12 @@ package udfs; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdString; import com.linkedin.transport.api.udf.StdUDF0; import com.linkedin.transport.api.udf.TopLevelStdUDF; import java.util.List; -public class SimpleUDF extends StdUDF0 implements TopLevelStdUDF { +public class SimpleUDF extends StdUDF0 implements TopLevelStdUDF { @Override public String getFunctionName() { @@ -35,7 +34,7 @@ public String getOutputParameterSignature() { } @Override - public StdString eval() { + public String eval() { return null; } } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFExtendingAbstractUDF.java b/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFExtendingAbstractUDF.java index 99fb068c..564fcd7e 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFExtendingAbstractUDF.java +++ b/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFExtendingAbstractUDF.java @@ -6,7 +6,6 @@ package udfs; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdString; import com.linkedin.transport.api.udf.TopLevelStdUDF; import java.util.List; @@ -34,7 +33,7 @@ public String getOutputParameterSignature() { } @Override - public StdString eval() { + public String eval() { return null; } } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFExtendingAbstractUDFImplementingInterface.java b/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFExtendingAbstractUDFImplementingInterface.java index 2fc6d5ce..db8e3bc7 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFExtendingAbstractUDFImplementingInterface.java +++ b/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFExtendingAbstractUDFImplementingInterface.java @@ -6,7 +6,6 @@ package udfs; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdString; import java.util.List; @@ -23,7 +22,7 @@ public String getOutputParameterSignature() { } @Override - public StdString eval() { + public String eval() { return null; } } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFNotImplementingTopLevelStdUDF.java b/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFNotImplementingTopLevelStdUDF.java index 43e8ffa9..862403e4 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFNotImplementingTopLevelStdUDF.java +++ b/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFNotImplementingTopLevelStdUDF.java @@ -6,12 +6,11 @@ package udfs; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdString; import com.linkedin.transport.api.udf.StdUDF0; import java.util.List; -public class UDFNotImplementingTopLevelStdUDF extends StdUDF0 { +public class UDFNotImplementingTopLevelStdUDF extends StdUDF0 { @Override public List getInputParameterSignatures() { @@ -24,7 +23,7 @@ public String getOutputParameterSignature() { } @Override - public StdString eval() { + public String eval() { return null; } } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFOverridingInterfaceMethod.java b/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFOverridingInterfaceMethod.java index 346ff553..2d97547e 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFOverridingInterfaceMethod.java +++ b/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFOverridingInterfaceMethod.java @@ -6,12 +6,11 @@ package udfs; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdBoolean; import com.linkedin.transport.api.udf.StdUDF0; import java.util.List; -public class UDFOverridingInterfaceMethod extends StdUDF0 implements OverloadedUDF1 { +public class UDFOverridingInterfaceMethod extends StdUDF0 implements OverloadedUDF1 { @Override public String getFunctionName() { @@ -29,7 +28,7 @@ public String getOutputParameterSignature() { } @Override - public StdBoolean eval() { + public Boolean eval() { return null; } } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFWithMultipleInterfaces1.java b/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFWithMultipleInterfaces1.java index 54e57c16..83a38c4d 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFWithMultipleInterfaces1.java +++ b/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFWithMultipleInterfaces1.java @@ -6,12 +6,11 @@ package udfs; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdBoolean; import com.linkedin.transport.api.udf.StdUDF0; import java.util.List; -public class UDFWithMultipleInterfaces1 extends StdUDF0 implements OverloadedUDF1, OverloadedUDF2 { +public class UDFWithMultipleInterfaces1 extends StdUDF0 implements OverloadedUDF1, OverloadedUDF2 { @Override public String getFunctionName() { @@ -34,7 +33,7 @@ public String getOutputParameterSignature() { } @Override - public StdBoolean eval() { + public Boolean eval() { return null; } } \ No newline at end of file diff --git a/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFWithMultipleInterfaces2.java b/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFWithMultipleInterfaces2.java index e8e62bb1..f2fbe270 100644 --- a/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFWithMultipleInterfaces2.java +++ b/transportable-udfs-annotation-processor/src/test/resources/udfs/UDFWithMultipleInterfaces2.java @@ -6,7 +6,6 @@ package udfs; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdString; import java.util.List; @@ -33,7 +32,7 @@ public String getOutputParameterSignature() { } @Override - public StdString eval() { + public String eval() { return null; } } \ No newline at end of file diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/StdFactory.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/StdFactory.java index 3e28b64a..76b9e239 100644 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/StdFactory.java +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/StdFactory.java @@ -5,166 +5,101 @@ */ package com.linkedin.transport.api; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdBoolean; -import com.linkedin.transport.api.data.StdBinary; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdDouble; -import com.linkedin.transport.api.data.StdFloat; -import com.linkedin.transport.api.data.StdInteger; -import com.linkedin.transport.api.data.StdLong; -import com.linkedin.transport.api.data.StdMap; -import com.linkedin.transport.api.data.StdString; -import com.linkedin.transport.api.data.StdStruct; +import com.linkedin.transport.api.data.ArrayData; +import com.linkedin.transport.api.data.MapData; +import com.linkedin.transport.api.data.RowData; import com.linkedin.transport.api.types.StdArrayType; import com.linkedin.transport.api.types.StdMapType; -import com.linkedin.transport.api.types.StdStructType; import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.api.udf.StdUDF; import java.io.Serializable; -import java.nio.ByteBuffer; import java.util.List; /** - * {@link StdFactory} is used to create {@link StdData} and {@link StdType} objects inside Standard UDFs. + * {@link StdFactory} is used to create containter types (e.g., {@link ArrayData}, {@link MapData}, {@link RowData}) + * and {@link StdType} objects inside Standard UDFs. * - * Specific APIs of {@link StdFactory} are implemented by each target platform (e.g., Spark, Presto, Hive) individually. + * Specific APIs of {@link StdFactory} are implemented by each target platform (e.g., Spark, Trino, Hive) individually. * A {@link StdFactory} object is available inside Standard UDFs using {@link StdUDF#getStdFactory()}. * The Standard UDF framework is responsible for providing the correct platform specific implementation at runtime. */ public interface StdFactory extends Serializable { /** - * Creates a {@link StdInteger} representing a given integer value. - * - * @param value the input integer value - * @return {@link StdInteger} with the given integer value - */ - StdInteger createInteger(int value); - - /** - * Creates a {@link StdLong} representing a given long value. - * - * @param value the input long value - * @return {@link StdLong} with the given long value - */ - StdLong createLong(long value); - - /** - * Creates a {@link StdBoolean} representing a given boolean value. - * - * @param value the input boolean value - * @return {@link StdBoolean} with the given boolean value - */ - StdBoolean createBoolean(boolean value); - - /** - * Creates a {@link StdString} representing a given {@link String} value. - * - * @param value the input {@link String} value - * @return {@link StdString} with the given {@link String} value - */ - StdString createString(String value); - - /** - * Creates a {@link StdFloat} representing a given float value. - * - * @param value the input float value - * @return {@link StdFloat} with the given float value - */ - StdFloat createFloat(float value); - - /** - * Creates a {@link StdDouble} representing a given double value. - * - * @param value the input double value - * @return {@link StdDouble} with the given double value - */ - StdDouble createDouble(double value); - - /** - * Creates a {@link StdBinary} representing a given {@link ByteBuffer} value. - * - * @param value the input {@link ByteBuffer} value - * @return {@link StdBinary} with the given {@link ByteBuffer} value - */ - StdBinary createBinary(ByteBuffer value); - - /** - * Creates an empty {@link StdArray} whose type is given by the given {@link StdType}. + * Creates an empty {@link ArrayData} whose type is given by the given {@link StdType}. * * It is expected that the top-level {@link StdType} is a {@link StdArrayType}. * * @param stdType type of the array to be created * @param expectedSize expected number of entries in the array - * @return an empty {@link StdArray} + * @return an empty {@link ArrayData} */ - StdArray createArray(StdType stdType, int expectedSize); + ArrayData createArray(StdType stdType, int expectedSize); /** - * Creates an empty {@link StdArray} whose type is given by the given {@link StdType}. + * Creates an empty {@link ArrayData} whose type is given by the given {@link StdType}. * * It is expected that the top-level {@link StdType} is a {@link StdArrayType}. * * @param stdType type of the array to be created - * @return an empty {@link StdArray} + * @return an empty {@link ArrayData} */ - StdArray createArray(StdType stdType); + ArrayData createArray(StdType stdType); /** - * Creates an empty {@link StdMap} whose type is given by the given {@link StdType}. + * Creates an empty {@link MapData} whose type is given by the given {@link StdType}. * * It is expected that the top-level {@link StdType} is a {@link StdMapType}. * * @param stdType type of the map to be created - * @return an empty {@link StdMap} + * @return an empty {@link MapData} */ - StdMap createMap(StdType stdType); + MapData createMap(StdType stdType); /** - * Creates a {@link StdStruct} with the given field names and types. + * Creates a {@link RowData} with the given field names and types. * * @param fieldNames names of the struct fields * @param fieldTypes types of the struct fields - * @return a {@link StdStruct} with all fields initialized to null + * @return a {@link RowData} with all fields initialized to null */ - StdStruct createStruct(List fieldNames, List fieldTypes); + RowData createStruct(List fieldNames, List fieldTypes); /** - * Creates a {@link StdStruct} with the given field types. Field names will be field0, field1, field2... + * Creates a {@link RowData} with the given field types. Field names will be field0, field1, field2... * * @param fieldTypes types of the struct fields - * @return a {@link StdStruct} with all fields initialized to null + * @return a {@link RowData} with all fields initialized to null */ - StdStruct createStruct(List fieldTypes); + RowData createStruct(List fieldTypes); /** - * Creates a {@link StdStruct} whose type is given by the given {@link StdType}. + * Creates a {@link RowData} whose type is given by the given {@link StdType}. * - * It is expected that the top-level {@link StdType} is a {@link StdStructType}. + * It is expected that the top-level {@link StdType} is a {@link com.linkedin.transport.api.types.RowType}. * * @param stdType type of the struct to be created - * @return a {@link StdStruct} with all fields initialized to null + * @return a {@link RowData} with all fields initialized to null */ - StdStruct createStruct(StdType stdType); + RowData createStruct(StdType stdType); /** * Creates a {@link StdType} representing the given type signature. * * The following are considered valid type signatures: *
    - *
  • {@code "varchar"} - Represents SQL varchar type. Corresponding standard type is {@link StdString}
  • - *
  • {@code "integer"} - Represents SQL int type. Corresponding standard type is {@link StdInteger}
  • - *
  • {@code "bigint"} - Represents SQL bigint/long type. Corresponding standard type is {@link StdLong}
  • - *
  • {@code "boolean"} - Represents SQL boolean type. Corresponding standard type is {@link StdBoolean}
  • + *
  • {@code "varchar"} - Represents SQL varchar type. Corresponding Transport type is {@link String}
  • + *
  • {@code "integer"} - Represents SQL int type. Corresponding Transport type is {@link Integer}
  • + *
  • {@code "bigint"} - Represents SQL bigint/long type. Corresponding Transport type is {@link Long}
  • + *
  • {@code "boolean"} - Represents SQL boolean type. Corresponding Transport type is {@link Boolean}
  • *
  • {@code "array(T)"} - Represents SQL array type, where {@code T} is type signature of array element. - * Corresponding standard type is {@link StdArray}
  • + * Corresponding Transport type is {@link ArrayData} *
  • {@code "map(K,V)"} - Represents SQL map type, where {@code K} and {@code V} are type signatures of the map - * keys and values respectively. array element. Corresponding standard type is {@link StdMap}
  • + * keys and values respectively. Corresponding Transport type is {@link MapData} *
  • {@code "row(f0 T0, f1 T1,... fn Tn)"} - Represents SQL struct type, where {@code f0}...{@code fn} are field * names and {@code T0}...{@code Tn} are type signatures for the fields. Field names are optional; if not - * specified they default to {@code field0}...{@code fieldn}. Corresponding standard type is {@link StdStruct}
  • + * specified they default to {@code field0}...{@code fieldn}. Corresponding Transport type is {@link RowData} *
* * Generic type parameters can also be used as part of the type signatures; e.g., The type signature {@code "map(K,V)"} diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdArray.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/ArrayData.java similarity index 76% rename from transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdArray.java rename to transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/ArrayData.java index ac698ae8..65a6b9a6 100644 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdArray.java +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/ArrayData.java @@ -5,8 +5,8 @@ */ package com.linkedin.transport.api.data; -/** A Standard UDF data type for representing arrays. */ -public interface StdArray extends StdData, Iterable { +/** A Transport UDF data type for representing arrays. */ +public interface ArrayData extends Iterable { /** Returns the number of elements in the array. */ int size(); @@ -16,12 +16,12 @@ public interface StdArray extends StdData, Iterable { * * @param idx the index of the element to be retrieved */ - StdData get(int idx); + E get(int idx); /** * Adds an element to the end of the array. * * @param e the element to append to the array */ - void add(StdData e); + void add(E e); } diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdMap.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/MapData.java similarity index 78% rename from transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdMap.java rename to transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/MapData.java index 8e67500e..39bd6965 100644 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdMap.java +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/MapData.java @@ -9,8 +9,8 @@ import java.util.Set; -/** A Standard UDF data type for representing maps. */ -public interface StdMap extends StdData { +/** A Transport UDF data type for representing maps. */ +public interface MapData { /** Returns the number of key-value pairs in the map. */ int size(); @@ -20,7 +20,7 @@ public interface StdMap extends StdData { * * @param key the key whose value is to be returned */ - StdData get(StdData key); + V get(K key); /** * Adds the given value to the map against the given key. @@ -28,18 +28,18 @@ public interface StdMap extends StdData { * @param key the key to which the value is to be associated * @param value the value to be associated with the key */ - void put(StdData key, StdData value); + void put(K key, V value); /** Returns a {@link Set} of all the keys in the map. */ - Set keySet(); + Set keySet(); /** Returns a {@link Collection} of all the values in the map. */ - Collection values(); + Collection values(); /** * Returns true if the map contains the given key, false otherwise. * * @param key the key to be checked */ - boolean containsKey(StdData key); + boolean containsKey(K key); } diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/PlatformData.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/PlatformData.java index 75df0518..41fc7bae 100644 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/PlatformData.java +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/PlatformData.java @@ -5,7 +5,7 @@ */ package com.linkedin.transport.api.data; -/** An interface for all platform-specific implementations of {@link StdData}. */ +/** An interface to handle platform-specific container types. */ public interface PlatformData { /** Returns the underlying platform-specific object holding the data. */ diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdStruct.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/RowData.java similarity index 50% rename from transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdStruct.java rename to transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/RowData.java index 14ccff80..2d8f1ce0 100644 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdStruct.java +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/RowData.java @@ -8,39 +8,39 @@ import java.util.List; -/** A Standard UDF data type for representing structs. */ -public interface StdStruct extends StdData { +/** A Transport UDF data type for representing SQL ROW/STRUCT data type. */ +public interface RowData { /** - * Returns the value of the field at the given position in the struct. + * Returns the value of the field at the given position in the row. * - * @param index the position of the field in the struct + * @param index the position of the field in the row */ - StdData getField(int index); + Object getField(int index); /** - * Returns the value of the given field from the struct. + * Returns the value of the given field from the row. * * @param name the name of the field */ - StdData getField(String name); + Object getField(String name); /** - * Sets the value of the field at the given position in the struct. + * Sets the value of the field at the given position in the row. * - * @param index the position of the field in the struct + * @param index the position of the field in the row * @param value the value to be set */ - void setField(int index, StdData value); + void setField(int index, Object value); /** - * Sets the value of the given field in the struct. + * Sets the value of the given field in the row. * * @param name the name of the field * @param value the value to be set */ - void setField(String name, StdData value); + void setField(String name, Object value); /** Returns a {@link List} of all fields in the struct. */ - List fields(); + List fields(); } diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdBinary.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdBinary.java deleted file mode 100644 index d1fc4acb..00000000 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdBinary.java +++ /dev/null @@ -1,15 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.api.data; - -import java.nio.ByteBuffer; - -/** A Standard UDF data type for representing binary objects. */ -public interface StdBinary extends StdData { - - /** Returns the underlying {@link ByteBuffer} value. */ - ByteBuffer get(); -} diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdBoolean.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdBoolean.java deleted file mode 100644 index ff230bc1..00000000 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdBoolean.java +++ /dev/null @@ -1,13 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.api.data; - -/** A Standard UDF data type for representing booleans. */ -public interface StdBoolean extends StdData { - - /** Returns the underlying boolean value. */ - boolean get(); -} diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdData.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdData.java deleted file mode 100644 index 77b3d1d7..00000000 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdData.java +++ /dev/null @@ -1,19 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.api.data; - -import com.linkedin.transport.api.StdFactory; - - -/** - * An interface for all data types in Standard UDFs. - * - * {@link StdData} is the main interface through which StdUDFs receive input data and return output data. All Standard - * UDF data types (e.g., {@link StdInteger}, {@link StdArray}, {@link StdMap}) must extend {@link StdData}. Methods - * inside {@link StdFactory} can be used to create {@link StdData} objects. - */ -public interface StdData { -} diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdDouble.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdDouble.java deleted file mode 100644 index a96fcc0e..00000000 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdDouble.java +++ /dev/null @@ -1,13 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.api.data; - -/** A Standard UDF data type for representing doubles. */ -public interface StdDouble extends StdData { - - /** Returns the underlying double value. */ - double get(); -} diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdFloat.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdFloat.java deleted file mode 100644 index da76dd28..00000000 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdFloat.java +++ /dev/null @@ -1,13 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.api.data; - -/** A Standard UDF data type for representing floats. */ -public interface StdFloat extends StdData { - - /** Returns the underlying float value. */ - float get(); -} diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdInteger.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdInteger.java deleted file mode 100644 index c74a92dd..00000000 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdInteger.java +++ /dev/null @@ -1,13 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.api.data; - -/** A Standard UDF data type for representing integers. */ -public interface StdInteger extends StdData { - - /** Returns the underlying int value. */ - int get(); -} diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdLong.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdLong.java deleted file mode 100644 index 84f322f7..00000000 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdLong.java +++ /dev/null @@ -1,13 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.api.data; - -/** A Standard UDF data type for representing longs. */ -public interface StdLong extends StdData { - - /** Returns the underlying long value. */ - long get(); -} diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdString.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdString.java deleted file mode 100644 index 7ccd8385..00000000 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdString.java +++ /dev/null @@ -1,13 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.api.data; - -/** A Standard UDF data type for representing strings. */ -public interface StdString extends StdData { - - /** Returns the underlying {@link String} value. */ - String get(); -} diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdTimestamp.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdTimestamp.java deleted file mode 100644 index 18ff9bc1..00000000 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdTimestamp.java +++ /dev/null @@ -1,13 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.api.data; - -/** A Standard UDF data type for representing timestamps. */ -public interface StdTimestamp extends StdData { - - /** Returns the number of milliseconds elapsed from epoch for the {@link StdTimestamp}. */ - long toEpoch(); -} diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/types/StdStructType.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/types/RowType.java similarity index 89% rename from transportable-udfs-api/src/main/java/com/linkedin/transport/api/types/StdStructType.java rename to transportable-udfs-api/src/main/java/com/linkedin/transport/api/types/RowType.java index 521ec2d7..59938208 100644 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/types/StdStructType.java +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/types/RowType.java @@ -9,7 +9,7 @@ /** A {@link StdType} representing a struct type. */ -public interface StdStructType extends StdType { +public interface RowType extends StdType { /** Returns a {@link List} of the types of all the struct fields. */ List fieldTypes(); diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF.java index a4e29220..f90b91ca 100644 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF.java +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF.java @@ -6,8 +6,6 @@ package com.linkedin.transport.api.udf; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.types.StdType; import java.util.List; @@ -19,8 +17,7 @@ * abstract class for UDFs expecting {@code i} arguments. Similar to lambda expressions, StdUDF(i) abstract classes are * type-parameterized by the input types and output type of the eval function. Each class is type-parameterized by * {@code (i+1)} type parameters; {@code i} type parameters for the UDF input types, and one type parameter for the - * output type. All types (both input and output types) must extend the {@link StdData} - * interface. + * output type. */ public abstract class StdUDF { private StdFactory _stdFactory; @@ -40,7 +37,7 @@ public abstract class StdUDF { * of contained UDF. * * @param stdFactory a {@link StdFactory} object which can be used to create - * {@link StdData} and {@link StdType} objects + * data and type objects */ public void init(StdFactory stdFactory) { _stdFactory = stdFactory; @@ -85,8 +82,8 @@ public final boolean[] getAndCheckNullableArguments() { protected abstract int numberOfArguments(); /** - * Returns a {@link StdFactory} object which can be used to create {@link StdData} and - * {@link StdType} objects + * Returns a {@link StdFactory} object which can be used to create data and + * type objects */ public StdFactory getStdFactory() { return _stdFactory; diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF0.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF0.java index b62fe95f..d3558fc3 100644 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF0.java +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF0.java @@ -5,15 +5,13 @@ */ package com.linkedin.transport.api.udf; -import com.linkedin.transport.api.data.StdData; - /** * A Standard UDF with zero input arguments. * * @param the type of the return value of the {@link StdUDF} */ -public abstract class StdUDF0 extends StdUDF { +public abstract class StdUDF0 extends StdUDF { /** * Returns the output of the {@link StdUDF} given the input arguments. diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF1.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF1.java index 28d0ad71..18e8e769 100644 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF1.java +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF1.java @@ -5,8 +5,6 @@ */ package com.linkedin.transport.api.udf; -import com.linkedin.transport.api.data.StdData; - /** * A Standard UDF with one input argument. @@ -17,7 +15,7 @@ // Suppressing class parameter type parameter name and arg naming style checks since this naming convention is more // suitable to Standard UDFs, and the code is more readable this way. @SuppressWarnings({"checkstyle:classtypeparametername", "checkstyle:regexpsinglelinejava"}) -public abstract class StdUDF1 extends StdUDF { +public abstract class StdUDF1 extends StdUDF { /** * Returns the output of the {@link StdUDF} given the input arguments. @@ -38,7 +36,7 @@ public abstract class StdUDF1 extends Std * hence obtaining the most recent version of a file. * Example: 'hdfs:///data/derived/dwh/prop/testMemberId/#LATEST/testMemberId.txt' * - * The arguments passed to {@link #eval(StdData)} are passed to this method as well to allow users to construct + * The arguments passed to {@link #eval(Object)} are passed to this method as well to allow users to construct * required file paths from arguments passed to the UDF. Since this method is called before any rows are processed, * only constant UDF arguments should be used to construct the file paths. Values of non-constant arguments are not * deterministic, and are null for most platforms. (Constant arguments are arguments whose literal values are given diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF2.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF2.java index 3e020ae1..3eb293ba 100644 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF2.java +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/udf/StdUDF2.java @@ -5,8 +5,6 @@ */ package com.linkedin.transport.api.udf; -import com.linkedin.transport.api.data.StdData; - /** * A Standard UDF with three input arguments. @@ -18,7 +16,7 @@ // Suppressing class parameter type parameter name and arg naming style checks since this naming convention is more // suitable to Standard UDFs, and the code is more readable this way. @SuppressWarnings({"checkstyle:classtypeparametername", "checkstyle:regexpsinglelinejava"}) -public abstract class StdUDF2 extends StdUDF { +public abstract class StdUDF2 extends StdUDF { /** * Returns the output of the {@link StdUDF} given the input arguments. @@ -40,7 +38,7 @@ public abstract class StdUDF2 +public abstract class StdUDF3 extends StdUDF { /** @@ -43,7 +41,7 @@ public abstract class StdUDF3 +public abstract class StdUDF4 extends StdUDF { /** @@ -45,7 +43,7 @@ public abstract class StdUDF4 extends StdUDF { +public abstract class StdUDF5 extends StdUDF { /** * Returns the output of the {@link StdUDF} given the input arguments. @@ -47,7 +45,7 @@ public abstract class StdUDF5 extends StdUDF { +public abstract class StdUDF6 extends StdUDF { /** * Returns the output of the {@link StdUDF} given the input arguments. @@ -49,7 +47,7 @@ public abstract class StdUDF6 extends StdUDF { +public abstract class StdUDF7 extends StdUDF { /** * Returns the output of the {@link StdUDF} given the input arguments. @@ -51,7 +49,7 @@ public abstract class StdUDF7 extends StdUDF { +public abstract class StdUDF8 extends StdUDF { /** * Returns the output of the {@link StdUDF} given the input arguments. @@ -53,7 +51,7 @@ public abstract class StdUDF8 boundVariables) { } @Override - public StdInteger createInteger(int value) { - return new AvroInteger(value); + public ArrayData createArray(StdType stdType, int size) { + return new AvroArrayData((Schema) stdType.underlyingType(), size); } @Override - public StdLong createLong(long value) { - return new AvroLong(value); - } - - @Override - public StdBoolean createBoolean(boolean value) { - return new AvroBoolean(value); - } - - @Override - public StdString createString(String value) { - return new AvroString(new Utf8(value)); - } - - @Override - public StdFloat createFloat(float value) { - return new AvroFloat(value); - } - - @Override - public StdDouble createDouble(double value) { - return new AvroDouble(value); - } - - @Override - public StdBinary createBinary(ByteBuffer value) { - return new AvroBinary(value); - } - - @Override - public StdArray createArray(StdType stdType, int size) { - return new AvroArray((Schema) stdType.underlyingType(), size); - } - - @Override - public StdArray createArray(StdType stdType) { + public ArrayData createArray(StdType stdType) { return createArray(stdType, 0); } @Override - public StdMap createMap(StdType stdType) { - return new AvroMap((Schema) stdType.underlyingType()); + public MapData createMap(StdType stdType) { + return new AvroMapData((Schema) stdType.underlyingType()); } @Override - public StdStruct createStruct(List fieldNames, List fieldTypes) { + public RowData createStruct(List fieldNames, List fieldTypes) { if (fieldNames.size() != fieldTypes.size()) { throw new RuntimeException( "Field names and types are of different lengths: " + "Field names length is " + fieldNames.size() + ". " @@ -112,18 +61,18 @@ public StdStruct createStruct(List fieldNames, List fieldTypes) for (int i = 0; i < fieldTypes.size(); i++) { fields.add(new Field(fieldNames.get(i), (Schema) fieldTypes.get(i).underlyingType(), null, null)); } - return new AvroStruct(Schema.createRecord(fields)); + return new AvroRowData(Schema.createRecord(fields)); } @Override - public StdStruct createStruct(List fieldTypes) { + public RowData createStruct(List fieldTypes) { return createStruct(IntStream.range(0, fieldTypes.size()).mapToObj(i -> "field" + i).collect(Collectors.toList()), fieldTypes); } @Override - public StdStruct createStruct(StdType stdType) { - return new AvroStruct((Schema) stdType.underlyingType()); + public RowData createStruct(StdType stdType) { + return new AvroRowData((Schema) stdType.underlyingType()); } @Override diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/AvroWrapper.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/AvroWrapper.java index 1f7657e5..d8149724 100644 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/AvroWrapper.java +++ b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/AvroWrapper.java @@ -5,18 +5,11 @@ */ package com.linkedin.transport.avro; -import com.linkedin.transport.api.data.StdData; +import com.linkedin.transport.api.data.PlatformData; import com.linkedin.transport.api.types.StdType; -import com.linkedin.transport.avro.data.AvroArray; -import com.linkedin.transport.avro.data.AvroBinary; -import com.linkedin.transport.avro.data.AvroBoolean; -import com.linkedin.transport.avro.data.AvroDouble; -import com.linkedin.transport.avro.data.AvroFloat; -import com.linkedin.transport.avro.data.AvroInteger; -import com.linkedin.transport.avro.data.AvroLong; -import com.linkedin.transport.avro.data.AvroMap; -import com.linkedin.transport.avro.data.AvroString; -import com.linkedin.transport.avro.data.AvroStruct; +import com.linkedin.transport.avro.data.AvroArrayData; +import com.linkedin.transport.avro.data.AvroMapData; +import com.linkedin.transport.avro.data.AvroRowData; import com.linkedin.transport.avro.types.AvroArrayType; import com.linkedin.transport.avro.types.AvroBinaryType; import com.linkedin.transport.avro.types.AvroBooleanType; @@ -26,13 +19,12 @@ import com.linkedin.transport.avro.types.AvroLongType; import com.linkedin.transport.avro.types.AvroMapType; import com.linkedin.transport.avro.types.AvroStringType; -import com.linkedin.transport.avro.types.AvroStructType; +import com.linkedin.transport.avro.types.AvroRowType; import java.nio.ByteBuffer; import java.util.List; import java.util.Map; import org.apache.avro.Schema; import org.apache.avro.generic.GenericArray; -import org.apache.avro.generic.GenericEnumSymbol; import org.apache.avro.generic.GenericRecord; import org.apache.avro.util.Utf8; @@ -42,50 +34,28 @@ public class AvroWrapper { private AvroWrapper() { } - public static StdData createStdData(Object avroData, Schema avroSchema) { + public static Object createStdData(Object avroData, Schema avroSchema) { switch (avroSchema.getType()) { case INT: - return new AvroInteger((Integer) avroData); case LONG: - return new AvroLong((Long) avroData); case BOOLEAN: - return new AvroBoolean((Boolean) avroData); - case ENUM: { - if (avroData == null) { - return new AvroString(null); - } - - if (avroData instanceof String) { - return new AvroString(new Utf8((String) avroData)); - } else if (avroData instanceof GenericEnumSymbol) { - return new AvroString(new Utf8(((GenericEnumSymbol) avroData).toString())); - } - throw new IllegalArgumentException("Unsupported type for Avro enum: " + avroData.getClass()); - } - case STRING: { - if (avroData == null) { - return new AvroString(null); - } - - if (avroData instanceof Utf8) { - return new AvroString((Utf8) avroData); - } else if (avroData instanceof String) { - return new AvroString(new Utf8((String) avroData)); - } - throw new IllegalArgumentException("Unsupported type for Avro string: " + avroData.getClass()); - } case FLOAT: - return new AvroFloat((Float) avroData); case DOUBLE: - return new AvroDouble((Double) avroData); case BYTES: - return new AvroBinary((ByteBuffer) avroData); + return avroData; + case STRING: + case ENUM: + if (avroData == null) { + return null; + } else { + return avroData.toString(); + } case ARRAY: - return new AvroArray((GenericArray) avroData, avroSchema); + return new AvroArrayData((GenericArray) avroData, avroSchema); case MAP: - return new AvroMap((Map) avroData, avroSchema); + return new AvroMapData((Map) avroData, avroSchema); case RECORD: - return new AvroStruct((GenericRecord) avroData, avroSchema); + return new AvroRowData((GenericRecord) avroData, avroSchema); case UNION: { Schema nonNullableType = getNonNullComponent(avroSchema); if (avroData == null) { @@ -100,6 +70,18 @@ public static StdData createStdData(Object avroData, Schema avroSchema) { } } + public static Object getPlatformData(Object transportData) { + if (transportData instanceof Integer || transportData instanceof Long || transportData instanceof Double + || transportData instanceof Boolean || transportData instanceof ByteBuffer) { + return transportData; + } else if (transportData instanceof String) { + return transportData == null ? null : new Utf8((String) transportData); + } else { + return transportData == null ? null : ((PlatformData) transportData).getUnderlyingData(); + } + } + + /** * Returns a non null component of a simple union schema. The supported union schema must have * only two fields where one of them is null type, the other is returned. @@ -139,7 +121,7 @@ public static StdType createStdType(Schema avroSchema) { case MAP: return new AvroMapType(avroSchema); case RECORD: - return new AvroStructType(avroSchema); + return new AvroRowType(avroSchema); case UNION: { Schema nonNullableType = getNonNullComponent(avroSchema); return createStdType(nonNullableType); diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/StdUdfWrapper.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/StdUdfWrapper.java index a1ea3e0d..a75a8e6a 100644 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/StdUdfWrapper.java +++ b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/StdUdfWrapper.java @@ -7,7 +7,6 @@ import com.linkedin.transport.api.StdFactory; import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdData; import com.linkedin.transport.api.udf.StdUDF; import com.linkedin.transport.api.udf.StdUDF0; import com.linkedin.transport.api.udf.StdUDF1; @@ -36,7 +35,7 @@ public abstract class StdUdfWrapper { protected boolean _requiredFilesProcessed; protected StdFactory _stdFactory; private boolean[] _nullableArguments; - private StdData[] _args; + private Object[] _args; /** * Given input schemas, this method matches them to the expected type signatures, and finds bindings to the @@ -68,12 +67,27 @@ protected boolean containsNullValuedNonNullableArgument(Object[] arguments) { return false; } - protected StdData wrap(Object avroObject, StdData stdData) { - if (avroObject != null) { - ((PlatformData) stdData).setUnderlyingData(avroObject); - return stdData; - } else { - return null; + protected Object wrap(Object avroObject, Schema inputSchema, Object stdData) { + switch (inputSchema.getType()) { + case INT: + case LONG: + case BOOLEAN: + return avroObject; + case STRING: + return avroObject == null ? null : avroObject.toString(); + case ARRAY: + case MAP: + case RECORD: + if (avroObject != null) { + ((PlatformData) stdData).setUnderlyingData(avroObject); + return stdData; + } else { + return null; + } + case NULL: + return null; + default: + throw new RuntimeException("Unrecognized Avro Schema: " + inputSchema.getClass()); } } @@ -82,22 +96,24 @@ protected StdData wrap(Object avroObject, StdData stdData) { protected abstract Class getTopLevelUdfClass(); protected void createStdData() { - _args = new StdData[_inputSchemas.length]; + _args = new Object[_inputSchemas.length]; for (int i = 0; i < _inputSchemas.length; i++) { _args[i] = AvroWrapper.createStdData(null, _inputSchemas[i]); } } - private StdData[] wrapArguments(Object[] arguments) { - return IntStream.range(0, _args.length).mapToObj(i -> wrap(arguments[i], _args[i])).toArray(StdData[]::new); + private Object[] wrapArguments(Object[] arguments) { + return IntStream.range(0, _args.length).mapToObj( + i -> wrap(arguments[i], _inputSchemas[i], _args[i]) + ).toArray(Object[]::new); } public Object evaluate(Object[] arguments) { if (containsNullValuedNonNullableArgument(arguments)) { return null; } - StdData[] args = wrapArguments(arguments); - StdData result; + Object[] args = wrapArguments(arguments); + Object result; switch (args.length) { case 0: result = ((StdUDF0) _stdUdf).eval(); @@ -129,6 +145,6 @@ public Object evaluate(Object[] arguments) { default: throw new UnsupportedOperationException("eval not yet supported for StdUDF" + args.length); } - return result == null ? null : ((PlatformData) result).getUnderlyingData(); + return AvroWrapper.getPlatformData(result); } } diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroArray.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroArrayData.java similarity index 65% rename from transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroArray.java rename to transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroArrayData.java index 1557ed6c..3cb2e6f6 100644 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroArray.java +++ b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroArrayData.java @@ -6,8 +6,7 @@ package com.linkedin.transport.avro.data; import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdData; +import com.linkedin.transport.api.data.ArrayData; import com.linkedin.transport.avro.AvroWrapper; import java.util.Iterator; import org.apache.avro.Schema; @@ -15,16 +14,16 @@ import org.apache.avro.generic.GenericData; -public class AvroArray implements StdArray, PlatformData { +public class AvroArrayData implements ArrayData, PlatformData { private final Schema _elementSchema; private GenericArray _genericArray; - public AvroArray(GenericArray genericArray, Schema arraySchema) { + public AvroArrayData(GenericArray genericArray, Schema arraySchema) { _genericArray = genericArray; _elementSchema = arraySchema.getElementType(); } - public AvroArray(Schema arraySchema, int size) { + public AvroArrayData(Schema arraySchema, int size) { _elementSchema = arraySchema.getElementType(); _genericArray = new GenericData.Array(size, arraySchema); } @@ -35,18 +34,18 @@ public int size() { } @Override - public StdData get(int idx) { - return AvroWrapper.createStdData(_genericArray.get(idx), _elementSchema); + public E get(int idx) { + return (E) AvroWrapper.createStdData(_genericArray.get(idx), _elementSchema); } @Override - public void add(StdData e) { - _genericArray.add(((PlatformData) e).getUnderlyingData()); + public void add(E e) { + _genericArray.add(AvroWrapper.getPlatformData(e)); } @Override - public Iterator iterator() { - return new Iterator() { + public Iterator iterator() { + return new Iterator() { private final Iterator _iterator = _genericArray.iterator(); @Override @@ -55,8 +54,8 @@ public boolean hasNext() { } @Override - public StdData next() { - return AvroWrapper.createStdData(_iterator.next(), _elementSchema); + public E next() { + return (E) AvroWrapper.createStdData(_iterator.next(), _elementSchema); } }; } diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroBinary.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroBinary.java deleted file mode 100644 index 902e610d..00000000 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroBinary.java +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.avro.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdBinary; -import java.nio.ByteBuffer; - - -public class AvroBinary implements StdBinary, PlatformData { - private ByteBuffer _byteBuffer; - - public AvroBinary(ByteBuffer aByteBuffer) { - _byteBuffer = aByteBuffer; - } - - @Override - public Object getUnderlyingData() { - return _byteBuffer; - } - - @Override - public void setUnderlyingData(Object value) { - _byteBuffer = (ByteBuffer) value; - } - - @Override - public ByteBuffer get() { - return _byteBuffer; - } -} diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroBoolean.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroBoolean.java deleted file mode 100644 index 99f83738..00000000 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroBoolean.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.avro.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdBoolean; - - -public class AvroBoolean implements StdBoolean, PlatformData { - private Boolean _boolean; - - public AvroBoolean(Boolean aBoolean) { - _boolean = aBoolean; - } - - @Override - public boolean get() { - return _boolean; - } - - @Override - public Object getUnderlyingData() { - return _boolean; - } - - @Override - public void setUnderlyingData(Object value) { - _boolean = (Boolean) value; - } -} diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroDouble.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroDouble.java deleted file mode 100644 index 214443ae..00000000 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroDouble.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.avro.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdDouble; - - -public class AvroDouble implements StdDouble, PlatformData { - private Double _double; - - public AvroDouble(Double aDouble) { - _double = aDouble; - } - - @Override - public Object getUnderlyingData() { - return _double; - } - - @Override - public void setUnderlyingData(Object value) { - _double = (Double) value; - } - - @Override - public double get() { - return _double; - } -} diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroFloat.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroFloat.java deleted file mode 100644 index c4547d81..00000000 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroFloat.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.avro.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdFloat; - - -public class AvroFloat implements StdFloat, PlatformData { - private Float _float; - - public AvroFloat(Float aFloat) { - _float = aFloat; - } - - @Override - public Object getUnderlyingData() { - return _float; - } - - @Override - public void setUnderlyingData(Object value) { - _float = (Float) value; - } - - @Override - public float get() { - return _float; - } -} diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroInteger.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroInteger.java deleted file mode 100644 index 5a170f3b..00000000 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroInteger.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.avro.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdInteger; - - -public class AvroInteger implements StdInteger, PlatformData { - private Integer _integer; - - public AvroInteger(Integer integer) { - _integer = integer; - } - - @Override - public int get() { - return _integer; - } - - @Override - public Object getUnderlyingData() { - return _integer; - } - - @Override - public void setUnderlyingData(Object value) { - _integer = (Integer) value; - } -} diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroLong.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroLong.java deleted file mode 100644 index a56af06c..00000000 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroLong.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.avro.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdLong; - - -public class AvroLong implements StdLong, PlatformData { - private Long _long; - - public AvroLong(Long aLong) { - _long = aLong; - } - - @Override - public long get() { - return _long; - } - - @Override - public Object getUnderlyingData() { - return _long; - } - - @Override - public void setUnderlyingData(Object value) { - _long = (Long) value; - } -} diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroMap.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroMapData.java similarity index 59% rename from transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroMap.java rename to transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroMapData.java index d0913d53..2b95796c 100644 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroMap.java +++ b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroMapData.java @@ -6,8 +6,7 @@ package com.linkedin.transport.avro.data; import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdMap; +import com.linkedin.transport.api.data.MapData; import com.linkedin.transport.avro.AvroWrapper; import java.util.AbstractSet; import java.util.Collection; @@ -21,18 +20,18 @@ import static org.apache.avro.Schema.Type.*; -public class AvroMap implements StdMap, PlatformData { +public class AvroMapData implements MapData, PlatformData { private Map _map; private final Schema _keySchema; private final Schema _valueSchema; - public AvroMap(Map map, Schema mapSchema) { + public AvroMapData(Map map, Schema mapSchema) { _map = map; _keySchema = Schema.create(STRING); _valueSchema = mapSchema.getValueType(); } - public AvroMap(Schema mapSchema) { + public AvroMapData(Schema mapSchema) { _map = new LinkedHashMap<>(); _keySchema = Schema.create(STRING); _valueSchema = mapSchema.getValueType(); @@ -54,21 +53,21 @@ public int size() { } @Override - public StdData get(StdData key) { - return AvroWrapper.createStdData(_map.get(((PlatformData) key).getUnderlyingData()), _valueSchema); + public V get(K key) { + return (V) AvroWrapper.createStdData(_map.get(AvroWrapper.getPlatformData(key)), _valueSchema); } @Override - public void put(StdData key, StdData value) { - _map.put(((PlatformData) key).getUnderlyingData(), ((PlatformData) value).getUnderlyingData()); + public void put(K key, V value) { + _map.put(AvroWrapper.getPlatformData(key), AvroWrapper.getPlatformData(value)); } @Override - public Set keySet() { - return new AbstractSet() { + public Set keySet() { + return new AbstractSet() { @Override - public Iterator iterator() { - return new Iterator() { + public Iterator iterator() { + return new Iterator() { Iterator keySet = _map.keySet().iterator(); @Override public boolean hasNext() { @@ -76,8 +75,8 @@ public boolean hasNext() { } @Override - public StdData next() { - return AvroWrapper.createStdData(keySet.next(), _keySchema); + public K next() { + return (K) AvroWrapper.createStdData(keySet.next(), _keySchema); } }; } @@ -90,12 +89,12 @@ public int size() { } @Override - public Collection values() { - return _map.values().stream().map(v -> AvroWrapper.createStdData(v, _valueSchema)).collect(Collectors.toList()); + public Collection values() { + return _map.values().stream().map(v -> (V) AvroWrapper.createStdData(v, _valueSchema)).collect(Collectors.toList()); } @Override - public boolean containsKey(StdData key) { - return _map.containsKey(((PlatformData) key).getUnderlyingData()); + public boolean containsKey(K key) { + return _map.containsKey(AvroWrapper.getPlatformData(key)); } } diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroStruct.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroRowData.java similarity index 68% rename from transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroStruct.java rename to transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroRowData.java index f018d5bc..64fa5e4c 100644 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroStruct.java +++ b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroRowData.java @@ -6,8 +6,7 @@ package com.linkedin.transport.avro.data; import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdStruct; +import com.linkedin.transport.api.data.RowData; import com.linkedin.transport.avro.AvroWrapper; import java.util.List; import java.util.stream.Collectors; @@ -17,17 +16,17 @@ import org.apache.avro.generic.GenericRecord; -public class AvroStruct implements StdStruct, PlatformData { +public class AvroRowData implements RowData, PlatformData { private final Schema _recordSchema; private GenericRecord _genericRecord; - public AvroStruct(GenericRecord genericRecord, Schema recordSchema) { + public AvroRowData(GenericRecord genericRecord, Schema recordSchema) { _genericRecord = genericRecord; _recordSchema = recordSchema; } - public AvroStruct(Schema recordSchema) { + public AvroRowData(Schema recordSchema) { _genericRecord = new Record(recordSchema); _recordSchema = recordSchema; } @@ -43,27 +42,27 @@ public void setUnderlyingData(Object value) { } @Override - public StdData getField(int index) { + public Object getField(int index) { return AvroWrapper.createStdData(_genericRecord.get(index), _recordSchema.getFields().get(index).schema()); } @Override - public StdData getField(String name) { + public Object getField(String name) { return AvroWrapper.createStdData(_genericRecord.get(name), _recordSchema.getField(name).schema()); } @Override - public void setField(int index, StdData value) { - _genericRecord.put(index, ((PlatformData) value).getUnderlyingData()); + public void setField(int index, Object value) { + _genericRecord.put(index, AvroWrapper.getPlatformData(value)); } @Override - public void setField(String name, StdData value) { - _genericRecord.put(name, ((PlatformData) value).getUnderlyingData()); + public void setField(String name, Object value) { + _genericRecord.put(name, AvroWrapper.getPlatformData(value)); } @Override - public List fields() { + public List fields() { return IntStream.range(0, _recordSchema.getFields().size()).mapToObj(i -> getField(i)).collect(Collectors.toList()); } } diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroString.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroString.java deleted file mode 100644 index 745df05e..00000000 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroString.java +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.avro.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdString; -import org.apache.avro.util.Utf8; - - -public class AvroString implements StdString, PlatformData { - private Utf8 _string; - - public AvroString(Utf8 string) { - _string = string; - } - - @Override - public String get() { - return _string.toString(); - } - - @Override - public Object getUnderlyingData() { - return _string; - } - - @Override - public void setUnderlyingData(Object value) { - _string = (Utf8) value; - } -} diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/types/AvroStructType.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/types/AvroRowType.java similarity index 82% rename from transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/types/AvroStructType.java rename to transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/types/AvroRowType.java index 2c97b39c..3923b3f5 100644 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/types/AvroStructType.java +++ b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/types/AvroRowType.java @@ -5,7 +5,7 @@ */ package com.linkedin.transport.avro.types; -import com.linkedin.transport.api.types.StdStructType; +import com.linkedin.transport.api.types.RowType; import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.avro.AvroWrapper; import java.util.List; @@ -13,10 +13,10 @@ import org.apache.avro.Schema; -public class AvroStructType implements StdStructType { +public class AvroRowType implements RowType { final private Schema _schema; - public AvroStructType(Schema schema) { + public AvroRowType(Schema schema) { _schema = schema; } diff --git a/transportable-udfs-avro/src/test/java/com/linkedin/transport/avro/TestAvroWrapper.java b/transportable-udfs-avro/src/test/java/com/linkedin/transport/avro/TestAvroWrapper.java index d81a452a..2257e977 100644 --- a/transportable-udfs-avro/src/test/java/com/linkedin/transport/avro/TestAvroWrapper.java +++ b/transportable-udfs-avro/src/test/java/com/linkedin/transport/avro/TestAvroWrapper.java @@ -7,37 +7,20 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdData; import com.linkedin.transport.api.types.StdType; -import com.linkedin.transport.avro.data.AvroArray; -import com.linkedin.transport.avro.data.AvroBinary; -import com.linkedin.transport.avro.data.AvroBoolean; -import com.linkedin.transport.avro.data.AvroDouble; -import com.linkedin.transport.avro.data.AvroFloat; -import com.linkedin.transport.avro.data.AvroInteger; -import com.linkedin.transport.avro.data.AvroLong; -import com.linkedin.transport.avro.data.AvroMap; -import com.linkedin.transport.avro.data.AvroString; -import com.linkedin.transport.avro.data.AvroStruct; +import com.linkedin.transport.avro.data.AvroArrayData; +import com.linkedin.transport.avro.data.AvroMapData; +import com.linkedin.transport.avro.data.AvroRowData; import com.linkedin.transport.avro.types.AvroArrayType; -import com.linkedin.transport.avro.types.AvroBinaryType; -import com.linkedin.transport.avro.types.AvroBooleanType; -import com.linkedin.transport.avro.types.AvroDoubleType; -import com.linkedin.transport.avro.types.AvroFloatType; -import com.linkedin.transport.avro.types.AvroIntegerType; import com.linkedin.transport.avro.types.AvroLongType; import com.linkedin.transport.avro.types.AvroMapType; -import com.linkedin.transport.avro.types.AvroStringType; -import com.linkedin.transport.avro.types.AvroStructType; -import java.nio.ByteBuffer; +import com.linkedin.transport.avro.types.AvroRowType; import java.util.Arrays; import java.util.Map; import org.apache.avro.Schema; import org.apache.avro.generic.GenericArray; import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericRecord; -import org.apache.avro.util.Utf8; import org.testng.annotations.Test; import static org.testng.Assert.*; @@ -54,61 +37,6 @@ private Schema createSchema(String fieldName, String typeName) { String.format("{\"name\": \"%s\",\"type\": %s}", fieldName, typeName)); } - private void testSimpleType(String typeName, Class expectedAvroTypeClass, - Object testData, Class expectedDataClass) { - Schema avroSchema = createSchema(String.format("\"%s\"", typeName)); - - StdType stdType = AvroWrapper.createStdType(avroSchema); - assertTrue(expectedAvroTypeClass.isAssignableFrom(stdType.getClass())); - assertEquals(avroSchema, stdType.underlyingType()); - - StdData stdData = AvroWrapper.createStdData(testData, avroSchema); - assertNotNull(stdData); - assertTrue(expectedDataClass.isAssignableFrom(stdData.getClass())); - if ("string".equals(typeName)) { - // Use String values for equality assertion as we support both Utf8 and String input types - assertEquals(testData.toString(), ((PlatformData) stdData).getUnderlyingData().toString()); - } else { - assertEquals(testData, ((PlatformData) stdData).getUnderlyingData()); - } - } - - @Test - public void testBooleanType() { - testSimpleType("boolean", AvroBooleanType.class, true, AvroBoolean.class); - } - - @Test - public void testIntegerType() { - testSimpleType("int", AvroIntegerType.class, 1, AvroInteger.class); - } - - @Test - public void testLongType() { - testSimpleType("long", AvroLongType.class, 1L, AvroLong.class); - } - - @Test - public void testFloatType() { - testSimpleType("float", AvroFloatType.class, 1.0f, AvroFloat.class); - } - - @Test - public void testDoubleType() { - testSimpleType("double", AvroDoubleType.class, 1.0, AvroDouble.class); - } - - @Test - public void testStringType() { - testSimpleType("string", AvroStringType.class, new Utf8("foo"), AvroString.class); - testSimpleType("string", AvroStringType.class, "foo", AvroString.class); - } - - @Test - public void testBinaryType() { - testSimpleType("bytes", AvroBinaryType.class, ByteBuffer.wrap("bar".getBytes()), AvroBinary.class); - } - @Test public void testEnumType() { Schema field1 = createSchema("field1", "" @@ -122,17 +50,17 @@ public void testEnumType() { GenericRecord record1 = new GenericData.Record(structSchema); record1.put("field1", "A"); - StdData stdEnumData1 = AvroWrapper.createStdData(record1.get("field1"), + Object stdEnumData1 = AvroWrapper.createStdData(record1.get("field1"), Schema.createEnum("SampleEnum", "", "", Arrays.asList("A", "B"))); - assertTrue(stdEnumData1 instanceof AvroString); - assertEquals("A", ((AvroString) stdEnumData1).get()); + assertTrue(stdEnumData1 instanceof String); + assertEquals("A", ((String) stdEnumData1)); GenericRecord record2 = new GenericData.Record(structSchema); record1.put("field1", new GenericData.EnumSymbol(field1, "A")); - StdData stdEnumData2 = AvroWrapper.createStdData(record1.get("field1"), + Object stdEnumData2 = AvroWrapper.createStdData(record1.get("field1"), Schema.createEnum("SampleEnum", "", "", Arrays.asList("A", "B"))); - assertTrue(stdEnumData2 instanceof AvroString); - assertEquals("A", ((AvroString) stdEnumData2).get()); + assertTrue(stdEnumData2 instanceof String); + assertEquals("A", ((String) stdEnumData2)); } @Test @@ -142,14 +70,14 @@ public void testArrayType() { StdType stdArrayType = AvroWrapper.createStdType(arraySchema); assertTrue(stdArrayType instanceof AvroArrayType); - assertEquals(arraySchema, stdArrayType.underlyingType()); + assertEquals(arraySchema, ((AvroArrayType) stdArrayType).underlyingType()); assertEquals(elementType, ((AvroArrayType) stdArrayType).elementType().underlyingType()); GenericArray value = new GenericData.Array<>(arraySchema, Arrays.asList(1, 2)); - StdData stdArrayData = AvroWrapper.createStdData(value, arraySchema); - assertTrue(stdArrayData instanceof AvroArray); - assertEquals(2, ((AvroArray) stdArrayData).size()); - assertEquals(value, ((AvroArray) stdArrayData).getUnderlyingData()); + Object stdArrayData = AvroWrapper.createStdData(value, arraySchema); + assertTrue(stdArrayData instanceof AvroArrayData); + assertEquals(2, ((AvroArrayData) stdArrayData).size()); + assertEquals(value, ((AvroArrayData) stdArrayData).getUnderlyingData()); } @Test @@ -163,10 +91,10 @@ public void testMapType() { assertEquals(valueType, ((AvroMapType) stdMapType).valueType().underlyingType()); Map value = ImmutableMap.of("foo", 1L, "bar", 2L); - StdData stdMapData = AvroWrapper.createStdData(value, mapSchema); - assertTrue(stdMapData instanceof AvroMap); - assertEquals(2, ((AvroMap) stdMapData).size()); - assertEquals(value, ((AvroMap) stdMapData).getUnderlyingData()); + Object stdMapData = AvroWrapper.createStdData(value, mapSchema); + assertTrue(stdMapData instanceof AvroMapData); + assertEquals(2, ((AvroMapData) stdMapData).size()); + assertEquals(value, ((AvroMapData) stdMapData).getUnderlyingData()); } @Test @@ -179,21 +107,21 @@ public void testRecordType() { )); StdType stdStructType = AvroWrapper.createStdType(structSchema); - assertTrue(stdStructType instanceof AvroStructType); + assertTrue(stdStructType instanceof AvroRowType); assertEquals(structSchema, stdStructType.underlyingType()); - assertEquals(field1, ((AvroStructType) stdStructType).fieldTypes().get(0).underlyingType()); - assertEquals(field2, ((AvroStructType) stdStructType).fieldTypes().get(1).underlyingType()); + assertEquals(field1, ((AvroRowType) stdStructType).fieldTypes().get(0).underlyingType()); + assertEquals(field2, ((AvroRowType) stdStructType).fieldTypes().get(1).underlyingType()); GenericRecord value = new GenericData.Record(structSchema); value.put("field1", 1); value.put("field2", 2.0); - StdData stdStructData = AvroWrapper.createStdData(value, structSchema); - assertTrue(stdStructData instanceof AvroStruct); - AvroStruct avroStruct = (AvroStruct) stdStructData; + Object stdStructData = AvroWrapper.createStdData(value, structSchema); + assertTrue(stdStructData instanceof AvroRowData); + AvroRowData avroStruct = (AvroRowData) stdStructData; assertEquals(2, avroStruct.fields().size()); assertEquals(value, avroStruct.getUnderlyingData()); - assertEquals(1, ((PlatformData) avroStruct.getField("field1")).getUnderlyingData()); - assertEquals(2.0, ((PlatformData) avroStruct.getField("field2")).getUnderlyingData()); + assertEquals(1, avroStruct.getField("field1")); + assertEquals(2.0, avroStruct.getField("field2")); } @Test @@ -205,11 +133,11 @@ public void testValidUnionType() { assertTrue(stdLongType instanceof AvroLongType); assertEquals(nonNullType, stdLongType.underlyingType()); - StdData stdLongData = AvroWrapper.createStdData(1L, unionSchema); - assertTrue(stdLongData instanceof AvroLong); - assertEquals(1L, ((AvroLong) stdLongData).get()); + Object stdLongData = AvroWrapper.createStdData(1L, unionSchema); + assertTrue(stdLongData instanceof Long); + assertEquals(1L, stdLongData); - StdData stdNullData = AvroWrapper.createStdData(null, unionSchema); + Object stdNullData = AvroWrapper.createStdData(null, unionSchema); assertNull(stdNullData); } @@ -242,21 +170,21 @@ public void testStructWithSimpleUnionField() { GenericRecord record1 = new GenericData.Record(structSchema); record1.put("field1", 1); record1.put("field2", 3.0); - AvroStruct avroStruct1 = (AvroStruct) AvroWrapper.createStdData(record1, structSchema); + AvroRowData avroStruct1 = (AvroRowData) AvroWrapper.createStdData(record1, structSchema); assertEquals(2, avroStruct1.fields().size()); - assertEquals(3.0, ((PlatformData) avroStruct1.getField("field2")).getUnderlyingData()); + assertEquals(3.0, avroStruct1.getField("field2")); GenericRecord record2 = new GenericData.Record(structSchema); record2.put("field1", 1); record2.put("field2", null); - AvroStruct avroStruct2 = (AvroStruct) AvroWrapper.createStdData(record2, structSchema); + AvroRowData avroStruct2 = (AvroRowData) AvroWrapper.createStdData(record2, structSchema); assertEquals(2, avroStruct2.fields().size()); assertNull(avroStruct2.getField("field2")); assertNull(avroStruct2.fields().get(1)); GenericRecord record3 = new GenericData.Record(structSchema); record3.put("field1", 1); - AvroStruct avroStruct3 = (AvroStruct) AvroWrapper.createStdData(record3, structSchema); + AvroRowData avroStruct3 = (AvroRowData) AvroWrapper.createStdData(record3, structSchema); assertEquals(2, avroStruct3.fields().size()); assertNull(avroStruct3.getField("field2")); assertNull(avroStruct3.fields().get(1)); diff --git a/transportable-udfs-codegen/src/main/java/com/linkedin/transport/codegen/SparkWrapperGenerator.java b/transportable-udfs-codegen/src/main/java/com/linkedin/transport/codegen/SparkWrapperGenerator.java index bec08d8e..1bc19ded 100644 --- a/transportable-udfs-codegen/src/main/java/com/linkedin/transport/codegen/SparkWrapperGenerator.java +++ b/transportable-udfs-codegen/src/main/java/com/linkedin/transport/codegen/SparkWrapperGenerator.java @@ -11,7 +11,9 @@ import java.io.File; import java.io.IOException; import java.io.InputStream; +import java.util.Arrays; import java.util.Collection; +import java.util.Map; import java.util.stream.Collectors; import org.apache.commons.io.IOUtils; import org.apache.commons.text.StringSubstitutor; @@ -23,6 +25,7 @@ public class SparkWrapperGenerator implements WrapperGenerator { private static final String SPARK_WRAPPER_TEMPLATE_RESOURCE_PATH = "wrapper-templates/spark"; private static final String SUBSTITUTOR_KEY_WRAPPER_PACKAGE = "wrapperPackage"; private static final String SUBSTITUTOR_KEY_WRAPPER_CLASS = "wrapperClass"; + private static final String SUBSTITUTOR_KEY_WRAPPER_CLASS_PARAMERTERS = "wrapperClassParameters"; private static final String SUBSTITUTOR_KEY_UDF_TOP_LEVEL_CLASS = "udfTopLevelClass"; private static final String SUBSTITUTOR_KEY_UDF_IMPLEMENTATIONS = "udfImplementations"; @@ -30,12 +33,16 @@ public class SparkWrapperGenerator implements WrapperGenerator { public void generateWrappers(WrapperGeneratorContext context) { TransportUDFMetadata udfMetadata = context.getTransportUdfMetadata(); for (String topLevelClass : udfMetadata.getTopLevelClasses()) { - generateWrapper(topLevelClass, udfMetadata.getStdUDFImplementations(topLevelClass), + generateWrapper( + topLevelClass, + udfMetadata.getStdUDFImplementations(topLevelClass), + udfMetadata.getClassToNumberOfTypeParameters(), context.getSourcesOutputDir()); } } - private void generateWrapper(String topLevelClass, Collection implementationClasses, File outputDir) { + private void generateWrapper(String topLevelClass, Collection implementationClasses, + Map classToNumberOfTypeParameters, File outputDir) { final String wrapperTemplate; try (InputStream wrapperTemplateStream = Thread.currentThread() .getContextClassLoader() @@ -49,13 +56,15 @@ private void generateWrapper(String topLevelClass, Collection implementa ClassName wrapperClass = ClassName.get(topLevelClassName.packageName() + "." + SPARK_PACKAGE_SUFFIX, topLevelClassName.simpleName()); String udfImplementationInstantiations = implementationClasses.stream() - .map(clazz -> "new " + clazz + "()") + .map(clazz -> "new " + clazz + parameters(clazz, classToNumberOfTypeParameters) + "()") .collect(Collectors.joining(", ")); + String topLevelClassNameString = topLevelClassName.toString(); ImmutableMap substitutionMap = ImmutableMap.of( SUBSTITUTOR_KEY_WRAPPER_PACKAGE, wrapperClass.packageName(), SUBSTITUTOR_KEY_WRAPPER_CLASS, wrapperClass.simpleName(), - SUBSTITUTOR_KEY_UDF_TOP_LEVEL_CLASS, topLevelClassName.toString(), + SUBSTITUTOR_KEY_UDF_TOP_LEVEL_CLASS, topLevelClassNameString + + parameters(topLevelClassNameString, classToNumberOfTypeParameters), SUBSTITUTOR_KEY_UDF_IMPLEMENTATIONS, udfImplementationInstantiations ); @@ -69,4 +78,11 @@ private void generateWrapper(String topLevelClass, Collection implementa throw new RuntimeException("Error writing wrapper to file", e); } } + + private static String parameters(String clazz, Map classToNumberOfTypeParameters) { + int numberOfTypeParameters = classToNumberOfTypeParameters.get(clazz); + String[] objectTypes = new String[numberOfTypeParameters]; + Arrays.fill(objectTypes, "Object"); + return numberOfTypeParameters > 0 ? "[" + String.join(", ", objectTypes) + "]" : ""; + } } diff --git a/transportable-udfs-codegen/src/main/java/com/linkedin/transport/codegen/PrestoWrapperGenerator.java b/transportable-udfs-codegen/src/main/java/com/linkedin/transport/codegen/TrinoWrapperGenerator.java similarity index 89% rename from transportable-udfs-codegen/src/main/java/com/linkedin/transport/codegen/PrestoWrapperGenerator.java rename to transportable-udfs-codegen/src/main/java/com/linkedin/transport/codegen/TrinoWrapperGenerator.java index 4f0ce5c5..957b7741 100644 --- a/transportable-udfs-codegen/src/main/java/com/linkedin/transport/codegen/PrestoWrapperGenerator.java +++ b/transportable-udfs-codegen/src/main/java/com/linkedin/transport/codegen/TrinoWrapperGenerator.java @@ -19,13 +19,13 @@ import javax.lang.model.element.Modifier; -public class PrestoWrapperGenerator implements WrapperGenerator { +public class TrinoWrapperGenerator implements WrapperGenerator { - private static final String PRESTO_PACKAGE_SUFFIX = "presto"; + private static final String TRINO_PACKAGE_SUFFIX = "trino"; private static final String GET_STD_UDF_METHOD = "getStdUDF"; - private static final ClassName PRESTO_STD_UDF_WRAPPER_CLASS_NAME = - ClassName.bestGuess("com.linkedin.transport.presto.StdUdfWrapper"); - private static final String SERVICE_FILE = "META-INF/services/io.prestosql.metadata.SqlScalarFunction"; + private static final ClassName TRINO_STD_UDF_WRAPPER_CLASS_NAME = + ClassName.bestGuess("com.linkedin.transport.trino.StdUdfWrapper"); + private static final String SERVICE_FILE = "META-INF/services/io.trino.metadata.SqlScalarFunction"; @Override public void generateWrappers(WrapperGeneratorContext context) { @@ -46,7 +46,7 @@ public void generateWrappers(WrapperGeneratorContext context) { private void generateWrapper(String implementationClass, File sourcesOutputDir, List services) { ClassName implementationClassName = ClassName.bestGuess(implementationClass); ClassName wrapperClassName = - ClassName.get(implementationClassName.packageName() + "." + PRESTO_PACKAGE_SUFFIX, + ClassName.get(implementationClassName.packageName() + "." + TRINO_PACKAGE_SUFFIX, implementationClassName.simpleName()); /* @@ -89,7 +89,7 @@ public class ${wrapperClassName} extends StdUdfWrapper { */ TypeSpec wrapperClass = TypeSpec.classBuilder(wrapperClassName) .addModifiers(Modifier.PUBLIC) - .superclass(PRESTO_STD_UDF_WRAPPER_CLASS_NAME) + .superclass(TRINO_STD_UDF_WRAPPER_CLASS_NAME) .addMethod(constructor) .addMethod(getStdUDFMethod) .build(); diff --git a/transportable-udfs-codegen/src/test/java/com/linkedin/transport/codegen/TestPrestoWrapperGenerator.java b/transportable-udfs-codegen/src/test/java/com/linkedin/transport/codegen/TestTrinoWrapperGenerator.java similarity index 59% rename from transportable-udfs-codegen/src/test/java/com/linkedin/transport/codegen/TestPrestoWrapperGenerator.java rename to transportable-udfs-codegen/src/test/java/com/linkedin/transport/codegen/TestTrinoWrapperGenerator.java index 3c2fafbf..2815de69 100644 --- a/transportable-udfs-codegen/src/test/java/com/linkedin/transport/codegen/TestPrestoWrapperGenerator.java +++ b/transportable-udfs-codegen/src/test/java/com/linkedin/transport/codegen/TestTrinoWrapperGenerator.java @@ -8,16 +8,16 @@ import org.testng.annotations.Test; -public class TestPrestoWrapperGenerator extends AbstractTestWrapperGenerator { +public class TestTrinoWrapperGenerator extends AbstractTestWrapperGenerator { @Override WrapperGenerator getWrapperGenerator() { - return new PrestoWrapperGenerator(); + return new TrinoWrapperGenerator(); } @Test - public void testPrestoWrapperGenerator() { - testWrapperGenerator("inputs/sample-udf-metadata.json", "outputs/sample-udf-metadata/presto/sources", - "outputs/sample-udf-metadata/presto/resources"); + public void testTrinoWrapperGenerator() { + testWrapperGenerator("inputs/sample-udf-metadata.json", "outputs/sample-udf-metadata/trino/sources", + "outputs/sample-udf-metadata/trino/resources"); } } diff --git a/transportable-udfs-codegen/src/test/resources/inputs/sample-udf-metadata.json b/transportable-udfs-codegen/src/test/resources/inputs/sample-udf-metadata.json index 4d6fb8ae..6da584f2 100644 --- a/transportable-udfs-codegen/src/test/resources/inputs/sample-udf-metadata.json +++ b/transportable-udfs-codegen/src/test/resources/inputs/sample-udf-metadata.json @@ -1,17 +1,17 @@ { - "udfs": [ - { - "topLevelClass": "udfs.OverloadedUDF", - "stdUDFImplementations": [ - "udfs.OverloadedUDFInt", - "udfs.OverloadedUDFString" - ] - }, - { - "topLevelClass": "udfs.SimpleUDF", - "stdUDFImplementations": [ - "udfs.SimpleUDF" - ] - } - ] -} \ No newline at end of file + "udfs": { + "udfs.OverloadedUDF": [ + "udfs.OverloadedUDFInt", + "udfs.OverloadedUDFString" + ], + "udfs.SimpleUDF": [ + "udfs.SimpleUDF" + ] + }, + "classToNumberOfTypeParameters": { + "udfs.OverloadedUDFString": 0, + "udfs.OverloadedUDF": 0, + "udfs.OverloadedUDFInt": 0, + "udfs.SimpleUDF": 0 + } +} diff --git a/transportable-udfs-codegen/src/test/resources/outputs/sample-udf-metadata/presto/resources/META-INF/services/io.prestosql.metadata.SqlScalarFunction b/transportable-udfs-codegen/src/test/resources/outputs/sample-udf-metadata/presto/resources/META-INF/services/io.prestosql.metadata.SqlScalarFunction deleted file mode 100644 index b7fd5cdf..00000000 --- a/transportable-udfs-codegen/src/test/resources/outputs/sample-udf-metadata/presto/resources/META-INF/services/io.prestosql.metadata.SqlScalarFunction +++ /dev/null @@ -1,3 +0,0 @@ -udfs.presto.OverloadedUDFInt -udfs.presto.OverloadedUDFString -udfs.presto.SimpleUDF diff --git a/transportable-udfs-codegen/src/test/resources/outputs/sample-udf-metadata/trino/resources/META-INF/services/io.trino.metadata.SqlScalarFunction b/transportable-udfs-codegen/src/test/resources/outputs/sample-udf-metadata/trino/resources/META-INF/services/io.trino.metadata.SqlScalarFunction new file mode 100644 index 00000000..8e1bf706 --- /dev/null +++ b/transportable-udfs-codegen/src/test/resources/outputs/sample-udf-metadata/trino/resources/META-INF/services/io.trino.metadata.SqlScalarFunction @@ -0,0 +1,3 @@ +udfs.trino.OverloadedUDFInt +udfs.trino.OverloadedUDFString +udfs.trino.SimpleUDF diff --git a/transportable-udfs-codegen/src/test/resources/outputs/sample-udf-metadata/presto/sources/udfs/presto/OverloadedUDFInt.java b/transportable-udfs-codegen/src/test/resources/outputs/sample-udf-metadata/trino/sources/udfs/trino/OverloadedUDFInt.java similarity index 78% rename from transportable-udfs-codegen/src/test/resources/outputs/sample-udf-metadata/presto/sources/udfs/presto/OverloadedUDFInt.java rename to transportable-udfs-codegen/src/test/resources/outputs/sample-udf-metadata/trino/sources/udfs/trino/OverloadedUDFInt.java index f534f7d2..0b042d38 100644 --- a/transportable-udfs-codegen/src/test/resources/outputs/sample-udf-metadata/presto/sources/udfs/presto/OverloadedUDFInt.java +++ b/transportable-udfs-codegen/src/test/resources/outputs/sample-udf-metadata/trino/sources/udfs/trino/OverloadedUDFInt.java @@ -1,7 +1,7 @@ -package udfs.presto; +package udfs.trino; import com.linkedin.transport.api.udf.StdUDF; -import com.linkedin.transport.presto.StdUdfWrapper; +import com.linkedin.transport.trino.StdUdfWrapper; public class OverloadedUDFInt extends StdUdfWrapper { public OverloadedUDFInt() { diff --git a/transportable-udfs-codegen/src/test/resources/outputs/sample-udf-metadata/presto/sources/udfs/presto/OverloadedUDFString.java b/transportable-udfs-codegen/src/test/resources/outputs/sample-udf-metadata/trino/sources/udfs/trino/OverloadedUDFString.java similarity index 79% rename from transportable-udfs-codegen/src/test/resources/outputs/sample-udf-metadata/presto/sources/udfs/presto/OverloadedUDFString.java rename to transportable-udfs-codegen/src/test/resources/outputs/sample-udf-metadata/trino/sources/udfs/trino/OverloadedUDFString.java index 6295a5e0..6bb81781 100644 --- a/transportable-udfs-codegen/src/test/resources/outputs/sample-udf-metadata/presto/sources/udfs/presto/OverloadedUDFString.java +++ b/transportable-udfs-codegen/src/test/resources/outputs/sample-udf-metadata/trino/sources/udfs/trino/OverloadedUDFString.java @@ -1,7 +1,7 @@ -package udfs.presto; +package udfs.trino; import com.linkedin.transport.api.udf.StdUDF; -import com.linkedin.transport.presto.StdUdfWrapper; +import com.linkedin.transport.trino.StdUdfWrapper; public class OverloadedUDFString extends StdUdfWrapper { public OverloadedUDFString() { diff --git a/transportable-udfs-codegen/src/test/resources/outputs/sample-udf-metadata/presto/sources/udfs/presto/SimpleUDF.java b/transportable-udfs-codegen/src/test/resources/outputs/sample-udf-metadata/trino/sources/udfs/trino/SimpleUDF.java similarity index 76% rename from transportable-udfs-codegen/src/test/resources/outputs/sample-udf-metadata/presto/sources/udfs/presto/SimpleUDF.java rename to transportable-udfs-codegen/src/test/resources/outputs/sample-udf-metadata/trino/sources/udfs/trino/SimpleUDF.java index 67ea1c7e..eda7c528 100644 --- a/transportable-udfs-codegen/src/test/resources/outputs/sample-udf-metadata/presto/sources/udfs/presto/SimpleUDF.java +++ b/transportable-udfs-codegen/src/test/resources/outputs/sample-udf-metadata/trino/sources/udfs/trino/SimpleUDF.java @@ -1,7 +1,7 @@ -package udfs.presto; +package udfs.trino; import com.linkedin.transport.api.udf.StdUDF; -import com.linkedin.transport.presto.StdUdfWrapper; +import com.linkedin.transport.trino.StdUdfWrapper; public class SimpleUDF extends StdUdfWrapper { public SimpleUDF() { diff --git a/transportable-udfs-compile-utils/src/main/java/com/linkedin/transport/compile/TransportUDFMetadata.java b/transportable-udfs-compile-utils/src/main/java/com/linkedin/transport/compile/TransportUDFMetadata.java index 48db80f5..c10cc44b 100644 --- a/transportable-udfs-compile-utils/src/main/java/com/linkedin/transport/compile/TransportUDFMetadata.java +++ b/transportable-udfs-compile-utils/src/main/java/com/linkedin/transport/compile/TransportUDFMetadata.java @@ -9,14 +9,18 @@ import com.google.common.collect.Multimap; import com.google.gson.Gson; import com.google.gson.GsonBuilder; +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.io.Reader; import java.io.Writer; import java.util.Collection; -import java.util.LinkedList; -import java.util.List; +import java.util.HashMap; +import java.util.Map; import java.util.Set; @@ -26,6 +30,7 @@ public class TransportUDFMetadata { private static final Gson GSON; private Multimap _udfs; + private Map _classToNumberOfTypeParameters; static { GSON = new GsonBuilder().setPrettyPrinting().create(); @@ -33,14 +38,15 @@ public class TransportUDFMetadata { public TransportUDFMetadata() { _udfs = LinkedHashMultimap.create(); + _classToNumberOfTypeParameters = new HashMap<>(); } public void addUDF(String topLevelClass, String stdUDFImplementation) { _udfs.put(topLevelClass, stdUDFImplementation); } - public void addUDF(String topLevelClass, Collection stdUDFImplementations) { - _udfs.putAll(topLevelClass, stdUDFImplementations); + public void setClassNumberOfTypeParameters(String clazz, int numberOfTypeParameters) { + _classToNumberOfTypeParameters.put(clazz, numberOfTypeParameters); } public Set getTopLevelClasses() { @@ -51,8 +57,12 @@ public Collection getStdUDFImplementations(String topLevelClass) { return _udfs.get(topLevelClass); } + public Map getClassToNumberOfTypeParameters() { + return _classToNumberOfTypeParameters; + } + public void toJson(Writer writer) { - GSON.toJson(TransportUDFMetadataSerDe.fromUDFMetadata(this), writer); + GSON.toJson(TransportUDFMetadataSerDe.serialize(this), writer); } public static TransportUDFMetadata fromJsonFile(File jsonFile) { @@ -64,50 +74,49 @@ public static TransportUDFMetadata fromJsonFile(File jsonFile) { } public static TransportUDFMetadata fromJson(Reader reader) { - return TransportUDFMetadataSerDe.toUDFMetadata(GSON.fromJson(reader, TransportUDFMetadataJson.class)); + return TransportUDFMetadataSerDe.deserialize(new JsonParser().parse(reader)); } - /** - * Represents the JSON object structure of the Transport UDF metadata resource file - */ - private static class TransportUDFMetadataJson { - private List udfs; + private static class TransportUDFMetadataSerDe { - TransportUDFMetadataJson() { - this.udfs = new LinkedList<>(); + public static TransportUDFMetadata deserialize(JsonElement json) { + TransportUDFMetadata metadata = new TransportUDFMetadata(); + JsonObject root = json.getAsJsonObject(); + + // Deserialize udfs + JsonObject udfs = root.getAsJsonObject("udfs"); + udfs.keySet().forEach(topLevelClass -> { + JsonArray stdUdfImplementations = udfs.getAsJsonArray(topLevelClass); + for (int i = 0; i < stdUdfImplementations.size(); i++) { + metadata.addUDF(topLevelClass, stdUdfImplementations.get(i).getAsString()); + } + }); + + // Deserialize classToNumberOfTypeParameters + JsonObject classToNumberOfTypeParameters = root.getAsJsonObject("classToNumberOfTypeParameters"); + classToNumberOfTypeParameters.entrySet().forEach( + e -> metadata.setClassNumberOfTypeParameters(e.getKey(), e.getValue().getAsInt()) + ); + return metadata; } - static class UDFInfo { - private String topLevelClass; - private Collection stdUDFImplementations; - - UDFInfo(String topLevelClass, Collection stdUDFImplementations) { - this.topLevelClass = topLevelClass; - this.stdUDFImplementations = stdUDFImplementations; + public static JsonElement serialize(TransportUDFMetadata metadata) { + // Serialzie _udfs + JsonObject udfs = new JsonObject(); + for (Map.Entry> entry : metadata._udfs.asMap().entrySet()) { + JsonArray stdUdfImplementations = new JsonArray(); + entry.getValue().forEach(f -> stdUdfImplementations.add(f)); + udfs.add(entry.getKey(), stdUdfImplementations); } - } - } - /** - * Converts objects between {@link TransportUDFMetadata} and {@link TransportUDFMetadataJson} - */ - private static class TransportUDFMetadataSerDe { - - private static TransportUDFMetadataJson fromUDFMetadata(TransportUDFMetadata metadata) { - TransportUDFMetadataJson metadataJson = new TransportUDFMetadataJson(); - for (String topLevelClass : metadata.getTopLevelClasses()) { - metadataJson.udfs.add( - new TransportUDFMetadataJson.UDFInfo(topLevelClass, metadata.getStdUDFImplementations(topLevelClass))); - } - return metadataJson; - } + // Serialize _classToNumberOfTypeParameters + JsonObject classToNumberOfTypeParameters = new JsonObject(); + metadata._classToNumberOfTypeParameters.forEach((clazz, n) -> classToNumberOfTypeParameters.addProperty(clazz, n)); - private static TransportUDFMetadata toUDFMetadata(TransportUDFMetadataJson metadataJson) { - TransportUDFMetadata metadata = new TransportUDFMetadata(); - for (TransportUDFMetadataJson.UDFInfo udf : metadataJson.udfs) { - metadata.addUDF(udf.topLevelClass, udf.stdUDFImplementations); - } - return metadata; + JsonObject root = new JsonObject(); + root.add("udfs", udfs); + root.add("classToNumberOfTypeParameters", classToNumberOfTypeParameters); + return root; } } } diff --git a/transportable-udfs-examples/build.gradle b/transportable-udfs-examples/build.gradle index 8714ba39..93be10ba 100644 --- a/transportable-udfs-examples/build.gradle +++ b/transportable-udfs-examples/build.gradle @@ -33,10 +33,6 @@ subprojects { url "https://conjars.org/repo" } } - project.ext.setProperty('presto-version', '333') - project.ext.setProperty('airlift-slice-version', '0.38') - project.ext.setProperty('spark-group', 'org.apache.spark') - project.ext.setProperty('spark-version', '2.3.0') } subprojects { @@ -62,8 +58,11 @@ subprojects { } checkstyle { - configFile = file("${rootDir}/../gradle/checkstyle/checkstyle.xml") - configProperties = ['config_loc' : "${rootDir}/../gradle/checkstyle/"] + configFile = rootProject.file('../gradle/checkstyle/checkstyle.xml') + configProperties = [ + 'configDir': rootProject.file('../gradle/checkstyle'), + 'baseDir': "${rootDir}/.." + ] toolVersion '8.23' } } diff --git a/transportable-udfs-examples/gradle/wrapper/gradle-wrapper.properties b/transportable-udfs-examples/gradle/wrapper/gradle-wrapper.properties index a8cec85d..4167e4da 100644 --- a/transportable-udfs-examples/gradle/wrapper/gradle-wrapper.properties +++ b/transportable-udfs-examples/gradle/wrapper/gradle-wrapper.properties @@ -3,4 +3,4 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-5.0-all.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-6.7-all.zip diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/build.gradle b/transportable-udfs-examples/transportable-udfs-example-udfs/build.gradle index 40d7f387..31e488f5 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/build.gradle +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/build.gradle @@ -15,9 +15,10 @@ dependencies { // If the license plugin is applied, disable license checks for the autogenerated source sets plugins.withId('com.github.hierynomus.license') { - licenseHive.enabled = false - licensePresto.enabled = false - licenseSpark.enabled = false + tasks.getByName('licenseTrino').enabled = false + tasks.getByName('licenseHive').enabled = false + tasks.getByName('licenseSpark_2.11').enabled = false + tasks.getByName('licenseSpark_2.12').enabled = false } // TODO: Add a debugPlatform flag to allow debugging specific test methods in IntelliJ diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/ArrayElementAtFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/ArrayElementAtFunction.java index 2697f8db..e9ee2b22 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/ArrayElementAtFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/ArrayElementAtFunction.java @@ -6,15 +6,26 @@ package com.linkedin.transport.examples; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdInteger; +import com.linkedin.transport.api.data.ArrayData; import com.linkedin.transport.api.udf.StdUDF2; import com.linkedin.transport.api.udf.TopLevelStdUDF; import java.util.List; -public class ArrayElementAtFunction extends StdUDF2 implements TopLevelStdUDF { +/** + * Another way to define this class using generics can look like this + * + * public class ArrayElementAtFunction extends StdUDF2, Integer, K> implements TopLevelStdUDF { + * + * @Override + * public K eval(ArrayData a1, Integer idx) { + * return a1.get(idx); + * } + * + * } + * + */ +public class ArrayElementAtFunction extends StdUDF2 implements TopLevelStdUDF { @Override public String getFunctionName() { @@ -40,7 +51,7 @@ public String getOutputParameterSignature() { } @Override - public StdData eval(StdArray a1, StdInteger idx) { - return a1.get(idx.get()); + public Object eval(ArrayData a1, Integer idx) { + return a1.get(idx); } } diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/ArrayFillFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/ArrayFillFunction.java index ae5a9ac1..a9cff404 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/ArrayFillFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/ArrayFillFunction.java @@ -7,16 +7,14 @@ import com.google.common.collect.ImmutableList; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdLong; +import com.linkedin.transport.api.data.ArrayData; import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.api.udf.StdUDF2; import com.linkedin.transport.api.udf.TopLevelStdUDF; import java.util.List; -public class ArrayFillFunction extends StdUDF2 implements TopLevelStdUDF { +public class ArrayFillFunction extends StdUDF2> implements TopLevelStdUDF { private StdType _arrayType; @@ -40,9 +38,9 @@ public void init(StdFactory stdFactory) { } @Override - public StdArray eval(StdData a, StdLong length) { - StdArray array = getStdFactory().createArray(_arrayType); - for (int i = 0; i < length.get(); i++) { + public ArrayData eval(K a, Long length) { + ArrayData array = getStdFactory().createArray(_arrayType); + for (int i = 0; i < length; i++) { array.add(a); } return array; diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/BinaryDuplicateFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/BinaryDuplicateFunction.java index 26a63111..8252b816 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/BinaryDuplicateFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/BinaryDuplicateFunction.java @@ -6,24 +6,22 @@ package com.linkedin.transport.examples; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdBinary; import com.linkedin.transport.api.udf.StdUDF1; import com.linkedin.transport.api.udf.TopLevelStdUDF; import java.nio.ByteBuffer; import java.util.List; -public class BinaryDuplicateFunction extends StdUDF1 implements TopLevelStdUDF { +public class BinaryDuplicateFunction extends StdUDF1 implements TopLevelStdUDF { @Override - public StdBinary eval(StdBinary binaryObject) { - ByteBuffer byteBuffer = binaryObject.get(); + public ByteBuffer eval(ByteBuffer byteBuffer) { ByteBuffer results = ByteBuffer.allocate(2 * byteBuffer.array().length); for (int i = 0; i < 2; i++) { for (byte b : byteBuffer.array()) { results.put(b); } } - return getStdFactory().createBinary(results); + return results; } @Override diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/BinaryObjectSizeFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/BinaryObjectSizeFunction.java index 0f4b538a..39b56cd4 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/BinaryObjectSizeFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/BinaryObjectSizeFunction.java @@ -6,17 +6,16 @@ package com.linkedin.transport.examples; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdBinary; -import com.linkedin.transport.api.data.StdInteger; import com.linkedin.transport.api.udf.StdUDF1; import com.linkedin.transport.api.udf.TopLevelStdUDF; +import java.nio.ByteBuffer; import java.util.List; -public class BinaryObjectSizeFunction extends StdUDF1 implements TopLevelStdUDF { +public class BinaryObjectSizeFunction extends StdUDF1 implements TopLevelStdUDF { @Override - public StdInteger eval(StdBinary binaryObject) { - return getStdFactory().createInteger(binaryObject.get().array().length); + public Integer eval(ByteBuffer byteBuffer) { + return byteBuffer.array().length; } @Override diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/FileLookupFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/FileLookupFunction.java index 8112e443..e9ed378f 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/FileLookupFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/FileLookupFunction.java @@ -7,9 +7,6 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdBoolean; -import com.linkedin.transport.api.data.StdInteger; -import com.linkedin.transport.api.data.StdString; import com.linkedin.transport.api.udf.StdUDF2; import com.linkedin.transport.api.udf.TopLevelStdUDF; import java.io.BufferedReader; @@ -21,14 +18,14 @@ import org.apache.commons.io.IOUtils; -public class FileLookupFunction extends StdUDF2 implements TopLevelStdUDF { +public class FileLookupFunction extends StdUDF2 implements TopLevelStdUDF { private Set ids; @Override - public StdBoolean eval(StdString filename, StdInteger intToCheck) { + public Boolean eval(String filename, Integer intToCheck) { Preconditions.checkNotNull(intToCheck, "Integer to check should not be null"); - return getStdFactory().createBoolean(ids.contains(intToCheck.get())); + return ids.contains(intToCheck); } @Override @@ -57,8 +54,8 @@ public String getFunctionDescription() { } @Override - public String[] getRequiredFiles(StdString filename, StdInteger intToCheck) { - return new String[]{filename.get()}; + public String[] getRequiredFiles(String filename, Integer intToCheck) { + return new String[]{filename}; } public void processRequiredFiles(String[] localPaths) { diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/MapFromTwoArraysFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/MapFromTwoArraysFunction.java index 6fd99981..d6002a50 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/MapFromTwoArraysFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/MapFromTwoArraysFunction.java @@ -7,15 +7,16 @@ import com.google.common.collect.ImmutableList; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdMap; +import com.linkedin.transport.api.data.ArrayData; +import com.linkedin.transport.api.data.MapData; import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.api.udf.StdUDF2; import com.linkedin.transport.api.udf.TopLevelStdUDF; import java.util.List; -public class MapFromTwoArraysFunction extends StdUDF2 implements TopLevelStdUDF { +public class MapFromTwoArraysFunction extends StdUDF2, ArrayData, MapData> + implements TopLevelStdUDF { private StdType _mapType; @@ -35,16 +36,16 @@ public String getOutputParameterSignature() { @Override public void init(StdFactory stdFactory) { super.init(stdFactory); - // Note: we create the _mapType once in init() and then reuse it to create StdMap objects + // Note: we create the _mapType once in init() and then reuse it to create MapData objects _mapType = getStdFactory().createStdType(getOutputParameterSignature()); } @Override - public StdMap eval(StdArray a1, StdArray a2) { + public MapData eval(ArrayData a1, ArrayData a2) { if (a1.size() != a2.size()) { return null; } - StdMap map = getStdFactory().createMap(_mapType); + MapData map = getStdFactory().createMap(_mapType); for (int i = 0; i < a1.size(); i++) { map.put(a1.get(i), a2.get(i)); } diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/MapKeySetFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/MapKeySetFunction.java index a76c4403..af81e024 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/MapKeySetFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/MapKeySetFunction.java @@ -7,16 +7,15 @@ import com.google.common.collect.ImmutableList; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdMap; +import com.linkedin.transport.api.data.ArrayData; +import com.linkedin.transport.api.data.MapData; import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.api.udf.StdUDF1; import com.linkedin.transport.api.udf.TopLevelStdUDF; import java.util.List; -public class MapKeySetFunction extends StdUDF1 implements TopLevelStdUDF { +public class MapKeySetFunction extends StdUDF1, ArrayData> implements TopLevelStdUDF { private StdType _mapType; @@ -39,9 +38,9 @@ public void init(StdFactory stdFactory) { } @Override - public StdArray eval(StdMap map) { - StdArray result = getStdFactory().createArray(_mapType); - for (StdData key : map.keySet()) { + public ArrayData eval(MapData map) { + ArrayData result = getStdFactory().createArray(_mapType); + for (K key : map.keySet()) { result.add(key); } return result; diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/MapValuesFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/MapValuesFunction.java index f22ff7f7..82b34ef1 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/MapValuesFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/MapValuesFunction.java @@ -7,16 +7,15 @@ import com.google.common.collect.ImmutableList; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdMap; +import com.linkedin.transport.api.data.ArrayData; +import com.linkedin.transport.api.data.MapData; import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.api.udf.StdUDF1; import com.linkedin.transport.api.udf.TopLevelStdUDF; import java.util.List; -public class MapValuesFunction extends StdUDF1 implements TopLevelStdUDF { +public class MapValuesFunction extends StdUDF1, ArrayData> implements TopLevelStdUDF { private StdType _mapType; @@ -39,9 +38,9 @@ public void init(StdFactory stdFactory) { } @Override - public StdArray eval(StdMap map) { - StdArray result = getStdFactory().createArray(_mapType); - for (StdData value : map.values()) { + public ArrayData eval(MapData map) { + ArrayData result = getStdFactory().createArray(_mapType); + for (V value : map.values()) { result.add(value); } return result; diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NestedMapFromTwoArraysFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NestedMapFromTwoArraysFunction.java new file mode 100644 index 00000000..5e244a11 --- /dev/null +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NestedMapFromTwoArraysFunction.java @@ -0,0 +1,88 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.examples; + +import com.google.common.collect.ImmutableList; +import com.linkedin.transport.api.StdFactory; +import com.linkedin.transport.api.data.ArrayData; +import com.linkedin.transport.api.data.MapData; +import com.linkedin.transport.api.data.RowData; +import com.linkedin.transport.api.types.StdType; +import com.linkedin.transport.api.udf.StdUDF1; +import com.linkedin.transport.api.udf.TopLevelStdUDF; +import java.util.List; + + +public class NestedMapFromTwoArraysFunction extends StdUDF1 implements TopLevelStdUDF { + + private StdType _arrayType; + private StdType _mapType; + private StdType _rowType; + + @Override + public List getInputParameterSignatures() { + return ImmutableList.of( + "array(row(array(K),array(V)))" + ); + } + + @Override + public String getOutputParameterSignature() { + return "array(row(map(K,V)))"; + } + + @Override + public void init(StdFactory stdFactory) { + super.init(stdFactory); + _arrayType = getStdFactory().createStdType(getOutputParameterSignature()); + _rowType = getStdFactory().createStdType("row(map(K,V))"); + _mapType = getStdFactory().createStdType("map(K,V)"); + } + + @Override + public ArrayData eval(ArrayData a1) { + ArrayData result = getStdFactory().createArray(_arrayType); + + for (int i = 0; i < a1.size(); i++) { + if (a1.get(i) == null) { + return null; + } + RowData inputRow = (RowData) a1.get(i); + + if (inputRow.getField(0) == null || inputRow.getField(1) == null) { + return null; + } + ArrayData kValues = (ArrayData) inputRow.getField(0); + ArrayData vValues = (ArrayData) inputRow.getField(1); + + if (kValues.size() != vValues.size()) { + return null; + } + + MapData map = getStdFactory().createMap(_mapType); + for (int j = 0; j < kValues.size(); j++) { + map.put(kValues.get(j), vValues.get(j)); + } + + RowData outputRow = getStdFactory().createStruct(_rowType); + outputRow.setField(0, map); + + result.add(outputRow); + } + + return result; + } + + @Override + public String getFunctionName() { + return "nested_map_from_two_arrays"; + } + + @Override + public String getFunctionDescription() { + return "Create a nested map from the 2 nested arrays"; + } +} diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddDoubleFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddDoubleFunction.java index 6ee9c918..80b0fb2b 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddDoubleFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddDoubleFunction.java @@ -6,15 +6,14 @@ package com.linkedin.transport.examples; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdDouble; import com.linkedin.transport.api.udf.StdUDF2; import java.util.List; -public class NumericAddDoubleFunction extends StdUDF2 implements NumericAddFunction { +public class NumericAddDoubleFunction extends StdUDF2 implements NumericAddFunction { @Override - public StdDouble eval(StdDouble first, StdDouble second) { - return getStdFactory().createDouble(first.get() + second.get()); + public Double eval(Double first, Double second) { + return first + second; } @Override diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddFloatFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddFloatFunction.java index 643b558b..a2a0ab47 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddFloatFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddFloatFunction.java @@ -6,15 +6,14 @@ package com.linkedin.transport.examples; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdFloat; import com.linkedin.transport.api.udf.StdUDF2; import java.util.List; -public class NumericAddFloatFunction extends StdUDF2 implements NumericAddFunction { +public class NumericAddFloatFunction extends StdUDF2 implements NumericAddFunction { @Override - public StdFloat eval(StdFloat first, StdFloat second) { - return getStdFactory().createFloat(first.get() + second.get()); + public Float eval(Float first, Float second) { + return first + second; } @Override diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddIntFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddIntFunction.java index cc5fb900..bcdb3696 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddIntFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddIntFunction.java @@ -6,16 +6,15 @@ package com.linkedin.transport.examples; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdInteger; import com.linkedin.transport.api.udf.StdUDF2; import java.util.List; -public class NumericAddIntFunction extends StdUDF2 +public class NumericAddIntFunction extends StdUDF2 implements NumericAddFunction { @Override - public StdInteger eval(StdInteger first, StdInteger second) { - return getStdFactory().createInteger(first.get() + second.get()); + public Integer eval(Integer first, Integer second) { + return first + second; } @Override diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddLongFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddLongFunction.java index a530e586..c24d2148 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddLongFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddLongFunction.java @@ -6,15 +6,14 @@ package com.linkedin.transport.examples; import com.google.common.collect.ImmutableList; -import com.linkedin.transport.api.data.StdLong; import com.linkedin.transport.api.udf.StdUDF2; import java.util.List; -public class NumericAddLongFunction extends StdUDF2 implements NumericAddFunction { +public class NumericAddLongFunction extends StdUDF2 implements NumericAddFunction { @Override - public StdLong eval(StdLong first, StdLong second) { - return getStdFactory().createLong(first.get() + second.get()); + public Long eval(Long first, Long second) { + return first + second; } @Override diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/StructCreateByIndexFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/StructCreateByIndexFunction.java index 5a23283d..ffa78ba7 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/StructCreateByIndexFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/StructCreateByIndexFunction.java @@ -7,15 +7,14 @@ import com.google.common.collect.ImmutableList; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdStruct; +import com.linkedin.transport.api.data.RowData; import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.api.udf.StdUDF2; import com.linkedin.transport.api.udf.TopLevelStdUDF; import java.util.List; -public class StructCreateByIndexFunction extends StdUDF2 implements TopLevelStdUDF { +public class StructCreateByIndexFunction extends StdUDF2 implements TopLevelStdUDF { private StdType _field1Type; private StdType _field2Type; @@ -41,11 +40,11 @@ public void init(StdFactory stdFactory) { } @Override - public StdStruct eval(StdData field1Value, StdData field2Value) { - StdStruct struct = getStdFactory().createStruct(ImmutableList.of(_field1Type, _field2Type)); - struct.setField(0, field1Value); - struct.setField(1, field2Value); - return struct; + public RowData eval(Object field1Value, Object field2Value) { + RowData row = getStdFactory().createStruct(ImmutableList.of(_field1Type, _field2Type)); + row.setField(0, field1Value); + row.setField(1, field2Value); + return row; } @Override diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/StructCreateByNameFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/StructCreateByNameFunction.java index 36ca3472..b4f2a0c0 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/StructCreateByNameFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/StructCreateByNameFunction.java @@ -7,16 +7,14 @@ import com.google.common.collect.ImmutableList; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdString; -import com.linkedin.transport.api.data.StdStruct; +import com.linkedin.transport.api.data.RowData; import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.api.udf.StdUDF4; import com.linkedin.transport.api.udf.TopLevelStdUDF; import java.util.List; -public class StructCreateByNameFunction extends StdUDF4 implements TopLevelStdUDF { +public class StructCreateByNameFunction extends StdUDF4 implements TopLevelStdUDF { private StdType _field1Type; private StdType _field2Type; @@ -44,13 +42,13 @@ public void init(StdFactory stdFactory) { } @Override - public StdStruct eval(StdString field1Name, StdData field1Value, StdString field2Name, StdData field2Value) { - StdStruct struct = getStdFactory().createStruct( - ImmutableList.of(field1Name.get(), field2Name.get()), + public RowData eval(String field1Name, Object field1Value, String field2Name, Object field2Value) { + RowData struct = getStdFactory().createStruct( + ImmutableList.of(field1Name, field2Name), ImmutableList.of(_field1Type, _field2Type) ); - struct.setField(field1Name.get(), field1Value); - struct.setField(field2Name.get(), field2Value); + struct.setField(field1Name, field1Value); + struct.setField(field2Name, field2Value); return struct; } diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestNestedMapFromTwoArraysFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestNestedMapFromTwoArraysFunction.java new file mode 100644 index 00000000..da8e75ae --- /dev/null +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestNestedMapFromTwoArraysFunction.java @@ -0,0 +1,49 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.examples; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.linkedin.transport.api.udf.StdUDF; +import com.linkedin.transport.api.udf.TopLevelStdUDF; +import com.linkedin.transport.test.AbstractStdUDFTest; +import com.linkedin.transport.test.spi.StdTester; +import java.util.List; +import java.util.Map; +import org.testng.annotations.Test; + + +public class TestNestedMapFromTwoArraysFunction extends AbstractStdUDFTest { + + @Override + protected Map, List>> getTopLevelStdUDFClassesAndImplementations() { + return ImmutableMap.of(NestedMapFromTwoArraysFunction.class, ImmutableList.of(NestedMapFromTwoArraysFunction.class)); + } + + @Test + public void testNestedMapUnionFunction() { + StdTester tester = getTester(); + tester.check( + functionCall("nested_map_from_two_arrays", array(row(array(1, 2), array("a", "b")))), + array(row(map(1, "a", 2, "b"))), + "array(row(map(integer,varchar)))"); + tester.check( + functionCall("nested_map_from_two_arrays", array(row(array(1, 2), array("a", "b")), row(array(11, 12), array("aa", "bb")))), + array(row(map(1, "a", 2, "b")), row(map(11, "aa", 12, "bb"))), + "array(row(map(integer,varchar)))"); + tester.check( + functionCall("nested_map_from_two_arrays", + array(row(array(array(1), array(2)), array(array("a"), array("b"))))), + array(row(map(array(1), array("a"), array(2), array("b")))), + "array(row(map(array(integer),array(varchar))))"); + tester.check( + functionCall("nested_map_from_two_arrays", array(row(array(1), array("a", "b")))), + null, "array(row(map(integer,varchar)))"); + tester.check( + functionCall("nested_map_from_two_arrays", array(row(null, array("a", "b")))), + null, "array(row(map(unknown,varchar)))"); + } +} diff --git a/transportable-udfs-hive/build.gradle b/transportable-udfs-hive/build.gradle index ce458dc5..7e193eaf 100644 --- a/transportable-udfs-hive/build.gradle +++ b/transportable-udfs-hive/build.gradle @@ -7,10 +7,12 @@ dependencies { compile('org.apache.hadoop:hadoop-common:2.7.4') compileOnly('org.apache.hive:hive-exec:1.2.2') { exclude group: 'org.apache.avro' + exclude group: 'org.apache.calcite' } testCompile project(path: ':transportable-udfs-type-system', configuration: 'tests') testCompile('org.apache.hive:hive-exec:1.2.2') { exclude group: 'org.apache.avro' + exclude group: 'org.apache.calcite' } } diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/HiveFactory.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/HiveFactory.java index e0373b63..c058e56b 100644 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/HiveFactory.java +++ b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/HiveFactory.java @@ -5,34 +5,18 @@ */ package com.linkedin.transport.hive; -import com.google.common.base.Preconditions; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdBoolean; -import com.linkedin.transport.api.data.StdBinary; -import com.linkedin.transport.api.data.StdDouble; -import com.linkedin.transport.api.data.StdFloat; -import com.linkedin.transport.api.data.StdInteger; -import com.linkedin.transport.api.data.StdLong; -import com.linkedin.transport.api.data.StdMap; -import com.linkedin.transport.api.data.StdString; -import com.linkedin.transport.api.data.StdStruct; +import com.linkedin.transport.api.data.ArrayData; +import com.linkedin.transport.api.data.MapData; +import com.linkedin.transport.api.data.RowData; import com.linkedin.transport.api.types.StdType; -import com.linkedin.transport.hive.data.HiveArray; -import com.linkedin.transport.hive.data.HiveBoolean; -import com.linkedin.transport.hive.data.HiveBinary; -import com.linkedin.transport.hive.data.HiveDouble; -import com.linkedin.transport.hive.data.HiveFloat; -import com.linkedin.transport.hive.data.HiveInteger; -import com.linkedin.transport.hive.data.HiveLong; -import com.linkedin.transport.hive.data.HiveMap; -import com.linkedin.transport.hive.data.HiveString; -import com.linkedin.transport.hive.data.HiveStruct; +import com.linkedin.transport.hive.data.HiveArrayData; +import com.linkedin.transport.hive.data.HiveMapData; +import com.linkedin.transport.hive.data.HiveRowData; import com.linkedin.transport.hive.types.objectinspector.CacheableObjectInspectorConverters; import com.linkedin.transport.hive.typesystem.HiveTypeFactory; import com.linkedin.transport.typesystem.AbstractBoundVariables; import com.linkedin.transport.typesystem.TypeSignature; -import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -45,7 +29,6 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; public class HiveFactory implements StdFactory { @@ -61,59 +44,23 @@ public HiveFactory(AbstractBoundVariables boundVariables) { } @Override - public StdInteger createInteger(int value) { - return new HiveInteger(value, PrimitiveObjectInspectorFactory.javaIntObjectInspector, this); - } - - @Override - public StdLong createLong(long value) { - return new HiveLong(value, PrimitiveObjectInspectorFactory.javaLongObjectInspector, this); - } - - @Override - public StdBoolean createBoolean(boolean value) { - return new HiveBoolean(value, PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, this); - } - - @Override - public StdString createString(String value) { - Preconditions.checkNotNull(value, "Cannot create a null StdString"); - return new HiveString(value, PrimitiveObjectInspectorFactory.javaStringObjectInspector, this); - } - - @Override - public StdFloat createFloat(float value) { - return new HiveFloat(value, PrimitiveObjectInspectorFactory.javaFloatObjectInspector, this); - } - - @Override - public StdDouble createDouble(double value) { - return new HiveDouble(value, PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, this); - } - - @Override - public StdBinary createBinary(ByteBuffer value) { - return new HiveBinary(value.array(), PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector, this); - } - - @Override - public StdArray createArray(StdType stdType, int expectedSize) { + public ArrayData createArray(StdType stdType, int expectedSize) { ListObjectInspector listObjectInspector = (ListObjectInspector) stdType.underlyingType(); - return new HiveArray( + return new HiveArrayData( new ArrayList(expectedSize), ObjectInspectorFactory.getStandardListObjectInspector(listObjectInspector.getListElementObjectInspector()), this); } @Override - public StdArray createArray(StdType stdType) { + public ArrayData createArray(StdType stdType) { return createArray(stdType, 0); } @Override - public StdMap createMap(StdType stdType) { + public MapData createMap(StdType stdType) { MapObjectInspector mapObjectInspector = (MapObjectInspector) stdType.underlyingType(); - return new HiveMap( + return new HiveMapData( new HashMap(), ObjectInspectorFactory.getStandardMapObjectInspector( mapObjectInspector.getMapKeyObjectInspector(), @@ -122,8 +69,8 @@ public StdMap createMap(StdType stdType) { } @Override - public StdStruct createStruct(List fieldNames, List fieldTypes) { - return new HiveStruct( + public RowData createStruct(List fieldNames, List fieldTypes) { + return new HiveRowData( new ArrayList(Arrays.asList(new Object[fieldTypes.size()])), ObjectInspectorFactory.getStandardStructObjectInspector( fieldNames, @@ -133,16 +80,16 @@ public StdStruct createStruct(List fieldNames, List fieldTypes) } @Override - public StdStruct createStruct(List fieldTypes) { + public RowData createStruct(List fieldTypes) { List fieldNames = IntStream.range(0, fieldTypes.size()).mapToObj(i -> "field" + i).collect(Collectors.toList()); return createStruct(fieldNames, fieldTypes); } @Override - public StdStruct createStruct(StdType stdType) { + public RowData createStruct(StdType stdType) { StructObjectInspector structObjectInspector = (StructObjectInspector) stdType.underlyingType(); - return new HiveStruct( + return new HiveRowData( new ArrayList(Arrays.asList(new Object[structObjectInspector.getAllStructFieldRefs().size()])), ObjectInspectorFactory.getStandardStructObjectInspector( structObjectInspector.getAllStructFieldRefs() diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/HiveWrapper.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/HiveWrapper.java index 3b9daa43..2b06d7a4 100644 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/HiveWrapper.java +++ b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/HiveWrapper.java @@ -6,18 +6,11 @@ package com.linkedin.transport.hive; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdData; import com.linkedin.transport.api.types.StdType; -import com.linkedin.transport.hive.data.HiveArray; -import com.linkedin.transport.hive.data.HiveBoolean; -import com.linkedin.transport.hive.data.HiveBinary; -import com.linkedin.transport.hive.data.HiveDouble; -import com.linkedin.transport.hive.data.HiveFloat; -import com.linkedin.transport.hive.data.HiveInteger; -import com.linkedin.transport.hive.data.HiveLong; -import com.linkedin.transport.hive.data.HiveMap; -import com.linkedin.transport.hive.data.HiveString; -import com.linkedin.transport.hive.data.HiveStruct; +import com.linkedin.transport.hive.data.HiveArrayData; +import com.linkedin.transport.hive.data.HiveData; +import com.linkedin.transport.hive.data.HiveMapData; +import com.linkedin.transport.hive.data.HiveRowData; import com.linkedin.transport.hive.types.HiveArrayType; import com.linkedin.transport.hive.types.HiveBooleanType; import com.linkedin.transport.hive.types.HiveBinaryType; @@ -27,11 +20,13 @@ import com.linkedin.transport.hive.types.HiveLongType; import com.linkedin.transport.hive.types.HiveMapType; import com.linkedin.transport.hive.types.HiveStringType; -import com.linkedin.transport.hive.types.HiveStructType; +import com.linkedin.transport.hive.types.HiveRowType; import com.linkedin.transport.hive.types.HiveUnknownType; +import java.nio.ByteBuffer; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector; @@ -39,6 +34,14 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.FloatObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.SettableBinaryObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.SettableBooleanObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.SettableDoubleObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.SettableFloatObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.SettableIntObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.SettableLongObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.SettableStringObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.VoidObjectInspector; @@ -48,28 +51,23 @@ public final class HiveWrapper { private HiveWrapper() { } - public static StdData createStdData(Object hiveData, ObjectInspector hiveObjectInspector, StdFactory stdFactory) { - if (hiveObjectInspector instanceof IntObjectInspector) { - return new HiveInteger(hiveData, (IntObjectInspector) hiveObjectInspector, stdFactory); - } else if (hiveObjectInspector instanceof LongObjectInspector) { - return new HiveLong(hiveData, (LongObjectInspector) hiveObjectInspector, stdFactory); - } else if (hiveObjectInspector instanceof BooleanObjectInspector) { - return new HiveBoolean(hiveData, (BooleanObjectInspector) hiveObjectInspector, stdFactory); - } else if (hiveObjectInspector instanceof StringObjectInspector) { - return new HiveString(hiveData, (StringObjectInspector) hiveObjectInspector, stdFactory); - } else if (hiveObjectInspector instanceof FloatObjectInspector) { - return new HiveFloat(hiveData, (FloatObjectInspector) hiveObjectInspector, stdFactory); - } else if (hiveObjectInspector instanceof DoubleObjectInspector) { - return new HiveDouble(hiveData, (DoubleObjectInspector) hiveObjectInspector, stdFactory); + public static Object createStdData(Object hiveData, ObjectInspector hiveObjectInspector, StdFactory stdFactory) { + if (hiveObjectInspector instanceof IntObjectInspector || hiveObjectInspector instanceof LongObjectInspector + || hiveObjectInspector instanceof FloatObjectInspector || hiveObjectInspector instanceof DoubleObjectInspector + || hiveObjectInspector instanceof BooleanObjectInspector + || hiveObjectInspector instanceof StringObjectInspector) { + return ((PrimitiveObjectInspector) hiveObjectInspector).getPrimitiveJavaObject(hiveData); } else if (hiveObjectInspector instanceof BinaryObjectInspector) { - return new HiveBinary(hiveData, (BinaryObjectInspector) hiveObjectInspector, stdFactory); + BinaryObjectInspector binaryObjectInspector = (BinaryObjectInspector) hiveObjectInspector; + return hiveData == null ? null : ByteBuffer.wrap(binaryObjectInspector.getPrimitiveJavaObject(hiveData)); } else if (hiveObjectInspector instanceof ListObjectInspector) { ListObjectInspector listObjectInspector = (ListObjectInspector) hiveObjectInspector; - return new HiveArray(hiveData, listObjectInspector, stdFactory); + return new HiveArrayData(hiveData, listObjectInspector, stdFactory); } else if (hiveObjectInspector instanceof MapObjectInspector) { - return new HiveMap(hiveData, hiveObjectInspector, stdFactory); + return new HiveMapData(hiveData, hiveObjectInspector, stdFactory); } else if (hiveObjectInspector instanceof StructObjectInspector) { - return new HiveStruct(hiveData, hiveObjectInspector, stdFactory); + return new HiveRowData(((StructObjectInspector) hiveObjectInspector).getStructFieldsDataAsList(hiveData).toArray(), + hiveObjectInspector, stdFactory); } else if (hiveObjectInspector instanceof VoidObjectInspector) { return null; } @@ -97,11 +95,57 @@ public static StdType createStdType(ObjectInspector hiveObjectInspector) { } else if (hiveObjectInspector instanceof MapObjectInspector) { return new HiveMapType((MapObjectInspector) hiveObjectInspector); } else if (hiveObjectInspector instanceof StructObjectInspector) { - return new HiveStructType((StructObjectInspector) hiveObjectInspector); + return new HiveRowType((StructObjectInspector) hiveObjectInspector); } else if (hiveObjectInspector instanceof VoidObjectInspector) { return new HiveUnknownType((VoidObjectInspector) hiveObjectInspector); } assert false : "Unrecognized Hive ObjectInspector: " + hiveObjectInspector.getClass(); return null; } + + public static Object getPlatformDataForObjectInspector(Object transportData, ObjectInspector oi) { + if (transportData == null) { + return null; + } else if (oi instanceof IntObjectInspector) { + return ((SettableIntObjectInspector) oi).create((Integer) transportData); + } else if (oi instanceof LongObjectInspector) { + return ((SettableLongObjectInspector) oi).create((Long) transportData); + } else if (oi instanceof FloatObjectInspector) { + return ((SettableFloatObjectInspector) oi).create((Float) transportData); + } else if (oi instanceof DoubleObjectInspector) { + return ((SettableDoubleObjectInspector) oi).create((Double) transportData); + } else if (oi instanceof BooleanObjectInspector) { + return ((SettableBooleanObjectInspector) oi).create((Boolean) transportData); + } else if (oi instanceof StringObjectInspector) { + return ((SettableStringObjectInspector) oi).create((String) transportData); + } else if (oi instanceof BinaryObjectInspector) { + return ((SettableBinaryObjectInspector) oi).create(((ByteBuffer) transportData).array()); + } else { + return ((HiveData) transportData).getUnderlyingDataForObjectInspector(oi); + } + } + + public static Object getStandardObject(Object transportData) { + if (transportData == null) { + return null; + } else if (transportData instanceof Integer) { + return PrimitiveObjectInspectorFactory.writableIntObjectInspector.create((Integer) transportData); + } else if (transportData instanceof Long) { + return PrimitiveObjectInspectorFactory.writableLongObjectInspector.create((Long) transportData); + } else if (transportData instanceof Float) { + return PrimitiveObjectInspectorFactory.writableFloatObjectInspector.create((Float) transportData); + } else if (transportData instanceof Double) { + return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.create((Double) transportData); + } else if (transportData instanceof Boolean) { + return PrimitiveObjectInspectorFactory.writableBooleanObjectInspector.create((Boolean) transportData); + } else if (transportData instanceof String) { + return PrimitiveObjectInspectorFactory.writableStringObjectInspector.create((String) transportData); + } else if (transportData instanceof ByteBuffer) { + return PrimitiveObjectInspectorFactory.writableBinaryObjectInspector.create(((ByteBuffer) transportData).array()); + } else { + return ((HiveData) transportData).getUnderlyingDataForObjectInspector( + ((HiveData) transportData).getUnderlyingObjectInspector() + ); + } + } } diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/StdUdfWrapper.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/StdUdfWrapper.java index bfa5cb6d..b932bf77 100644 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/StdUdfWrapper.java +++ b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/StdUdfWrapper.java @@ -7,7 +7,6 @@ import com.linkedin.transport.api.StdFactory; import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdData; import com.linkedin.transport.api.udf.StdUDF; import com.linkedin.transport.api.udf.StdUDF0; import com.linkedin.transport.api.udf.StdUDF1; @@ -23,6 +22,7 @@ import com.linkedin.transport.utils.FileSystemUtils; import java.io.FileNotFoundException; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.List; import java.util.stream.IntStream; @@ -35,7 +35,8 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; - +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector; /** * Base class for all Hive Standard UDFs. It provides a standard way of type validation, binding, and output type @@ -49,7 +50,8 @@ public abstract class StdUdfWrapper extends GenericUDF { protected StdFactory _stdFactory; private boolean[] _nullableArguments; private String[] _distributedCacheFiles; - private StdData[] _args; + private Object[] _args; + private ObjectInspector _outputObjectInspector; /** * Given input object inspectors, this method matches them to the expected type signatures, and finds bindings to the @@ -70,7 +72,8 @@ public ObjectInspector initialize(ObjectInspector[] arguments) { _stdUdf.init(_stdFactory); _requiredFilesProcessed = false; createStdData(); - return hiveTypeInference.getOutputDataType(); + _outputObjectInspector = hiveTypeInference.getOutputDataType(); + return _outputObjectInspector; } @Override @@ -108,14 +111,23 @@ protected boolean containsNullValuedNonNullableConstants() { return false; } - protected StdData wrap(DeferredObject hiveDeferredObject, StdData stdData) { + protected Object wrap(DeferredObject hiveDeferredObject, ObjectInspector inputObjectInspector, Object stdData) { try { Object hiveObject = hiveDeferredObject.get(); - if (hiveObject != null) { - ((PlatformData) stdData).setUnderlyingData(hiveObject); - return stdData; + if (inputObjectInspector instanceof BinaryObjectInspector) { + return hiveObject == null ? null : ByteBuffer.wrap( + ((BinaryObjectInspector) inputObjectInspector).getPrimitiveJavaObject(hiveObject) + ); + } + if (inputObjectInspector instanceof PrimitiveObjectInspector) { + return ((PrimitiveObjectInspector) inputObjectInspector).getPrimitiveJavaObject(hiveObject); } else { - return null; + if (hiveObject != null) { + ((PlatformData) stdData).setUnderlyingData(hiveObject); + return stdData; + } else { + return null; + } } } catch (HiveException e) { throw new RuntimeException("Cannot extract Hive Object from Deferred Object"); @@ -127,21 +139,35 @@ protected StdData wrap(DeferredObject hiveDeferredObject, StdData stdData) { protected abstract Class getTopLevelUdfClass(); protected void createStdData() { - _args = new StdData[_inputObjectInspectors.length]; + _args = new Object[_inputObjectInspectors.length]; for (int i = 0; i < _inputObjectInspectors.length; i++) { _args[i] = HiveWrapper.createStdData(null, _inputObjectInspectors[i], _stdFactory); } } - private StdData[] wrapArguments(DeferredObject[] deferredObjects) { - return IntStream.range(0, _args.length).mapToObj(i -> wrap(deferredObjects[i], _args[i])).toArray(StdData[]::new); + private Object getPlatformData(Object transportData) { + if (transportData == null) { + return null; + } else if (transportData instanceof Integer || transportData instanceof Long || transportData instanceof Boolean + || transportData instanceof String || transportData instanceof Float || transportData instanceof Double + || transportData instanceof ByteBuffer) { + return HiveWrapper.getPlatformDataForObjectInspector(transportData, _outputObjectInspector); + } else { + return ((PlatformData) transportData).getUnderlyingData(); + } + } + + private Object[] wrapArguments(DeferredObject[] deferredObjects) { + return IntStream.range(0, _args.length).mapToObj( + i -> wrap(deferredObjects[i], _inputObjectInspectors[i], _args[i]) + ).toArray(Object[]::new); } - private StdData[] wrapConstants() { + private Object[] wrapConstants() { return Arrays.stream(_inputObjectInspectors) .map(oi -> (oi instanceof ConstantObjectInspector) ? HiveWrapper.createStdData( ((ConstantObjectInspector) oi).getWritableConstantValue(), oi, _stdFactory) : null) - .toArray(StdData[]::new); + .toArray(Object[]::new); } @Override @@ -152,8 +178,8 @@ public Object evaluate(DeferredObject[] arguments) throws HiveException { if (!_requiredFilesProcessed) { processRequiredFiles(); } - StdData[] args = wrapArguments(arguments); - StdData result; + Object[] args = wrapArguments(arguments); + Object result; switch (args.length) { case 0: result = ((StdUDF0) _stdUdf).eval(); @@ -185,7 +211,7 @@ public Object evaluate(DeferredObject[] arguments) throws HiveException { default: throw new UnsupportedOperationException("eval not yet supported for StdUDF" + args.length); } - return result == null ? null : ((PlatformData) result).getUnderlyingData(); + return getPlatformData(result); } @Override @@ -193,7 +219,7 @@ public String[] getRequiredFiles() { if (containsNullValuedNonNullableConstants()) { return new String[]{}; } - StdData[] args = wrapConstants(); + Object[] args = wrapConstants(); String[] requiredFiles; switch (args.length) { case 0: diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveArray.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveArrayData.java similarity index 72% rename from transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveArray.java rename to transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveArrayData.java index 57cb0e8c..d0bf8ec4 100644 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveArray.java +++ b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveArrayData.java @@ -6,8 +6,7 @@ package com.linkedin.transport.hive.data; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdData; +import com.linkedin.transport.api.data.ArrayData; import com.linkedin.transport.hive.HiveWrapper; import java.util.Iterator; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; @@ -15,12 +14,12 @@ import org.apache.hadoop.hive.serde2.objectinspector.SettableListObjectInspector; -public class HiveArray extends HiveData implements StdArray { +public class HiveArrayData extends HiveData implements ArrayData { final ListObjectInspector _listObjectInspector; final ObjectInspector _elementObjectInspector; - public HiveArray(Object object, ObjectInspector objectInspector, StdFactory stdFactory) { + public HiveArrayData(Object object, ObjectInspector objectInspector, StdFactory stdFactory) { super(stdFactory); _object = object; _listObjectInspector = (ListObjectInspector) objectInspector; @@ -33,19 +32,21 @@ public int size() { } @Override - public StdData get(int idx) { - return HiveWrapper.createStdData(_listObjectInspector.getListElement(_object, idx), _elementObjectInspector, + public E get(int idx) { + return (E) HiveWrapper.createStdData( + _listObjectInspector.getListElement(_object, idx), + _elementObjectInspector, _stdFactory); } @Override - public void add(StdData e) { + public void add(E e) { if (_listObjectInspector instanceof SettableListObjectInspector) { SettableListObjectInspector settableListObjectInspector = (SettableListObjectInspector) _listObjectInspector; int originalSize = size(); settableListObjectInspector.resize(_object, originalSize + 1); settableListObjectInspector.set(_object, originalSize, - ((HiveData) e).getUnderlyingDataForObjectInspector(_elementObjectInspector)); + HiveWrapper.getPlatformDataForObjectInspector(e, _elementObjectInspector)); _isObjectModified = true; } else { throw new RuntimeException("Attempt to modify an immutable Hive object of type: " @@ -59,8 +60,8 @@ public ObjectInspector getUnderlyingObjectInspector() { } @Override - public Iterator iterator() { - return new Iterator() { + public Iterator iterator() { + return new Iterator() { int size = size(); int currentIndex = 0; @@ -70,8 +71,8 @@ public boolean hasNext() { } @Override - public StdData next() { - StdData element = HiveWrapper.createStdData(_listObjectInspector.getListElement(_object, currentIndex), + public E next() { + E element = (E) HiveWrapper.createStdData(_listObjectInspector.getListElement(_object, currentIndex), _elementObjectInspector, _stdFactory); currentIndex++; return element; diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveBinary.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveBinary.java deleted file mode 100644 index c5c14e40..00000000 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveBinary.java +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.hive.data; - -import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdBinary; -import java.nio.ByteBuffer; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector; - - -public class HiveBinary extends HiveData implements StdBinary { - - private final BinaryObjectInspector _binaryObjectInspector; - - public HiveBinary(Object object, BinaryObjectInspector binaryObjectInspector, StdFactory stdFactory) { - super(stdFactory); - _object = object; - _binaryObjectInspector = binaryObjectInspector; - } - - @Override - public ByteBuffer get() { - return ByteBuffer.wrap(_binaryObjectInspector.getPrimitiveJavaObject(_object)); - } - - @Override - public ObjectInspector getUnderlyingObjectInspector() { - return _binaryObjectInspector; - } -} diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveBoolean.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveBoolean.java deleted file mode 100644 index b4537170..00000000 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveBoolean.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.hive.data; - -import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdBoolean; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector; - - -public class HiveBoolean extends HiveData implements StdBoolean { - - final BooleanObjectInspector _booleanObjectInspector; - - public HiveBoolean(Object object, BooleanObjectInspector booleanObjectInspector, StdFactory stdFactory) { - super(stdFactory); - _object = object; - _booleanObjectInspector = booleanObjectInspector; - } - - @Override - public boolean get() { - return _booleanObjectInspector.get(_object); - } - - @Override - public ObjectInspector getUnderlyingObjectInspector() { - return _booleanObjectInspector; - } -} diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveData.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveData.java index 51beb456..94337266 100644 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveData.java +++ b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveData.java @@ -59,10 +59,6 @@ public ObjectInspector getStandardObjectInspector() { getUnderlyingObjectInspector(), ObjectInspectorUtils.ObjectInspectorCopyOption.WRITABLE); } - public Object getStandardObject() { - return getUnderlyingDataForObjectInspector(getStandardObjectInspector()); - } - private Object getObjectFromCache(ObjectInspector oi) { if (_isObjectModified) { _cachedObjectsForObjectInspectors.clear(); diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveDouble.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveDouble.java deleted file mode 100644 index e5447f00..00000000 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveDouble.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.hive.data; - -import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdDouble; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector; - - -public class HiveDouble extends HiveData implements StdDouble { - - private final DoubleObjectInspector _doubleObjectInspector; - - public HiveDouble(Object object, DoubleObjectInspector floatObjectInspector, StdFactory stdFactory) { - super(stdFactory); - _object = object; - _doubleObjectInspector = floatObjectInspector; - } - - @Override - public double get() { - return _doubleObjectInspector.get(_object); - } - - @Override - public ObjectInspector getUnderlyingObjectInspector() { - return _doubleObjectInspector; - } -} diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveFloat.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveFloat.java deleted file mode 100644 index a630d73b..00000000 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveFloat.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.hive.data; - -import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdFloat; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.FloatObjectInspector; - - -public class HiveFloat extends HiveData implements StdFloat { - - private final FloatObjectInspector _floatObjectInspector; - - public HiveFloat(Object object, FloatObjectInspector floatObjectInspector, StdFactory stdFactory) { - super(stdFactory); - _object = object; - _floatObjectInspector = floatObjectInspector; - } - - @Override - public float get() { - return _floatObjectInspector.get(_object); - } - - @Override - public ObjectInspector getUnderlyingObjectInspector() { - return _floatObjectInspector; - } -} diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveInteger.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveInteger.java deleted file mode 100644 index a1d2e38f..00000000 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveInteger.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.hive.data; - -import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdInteger; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector; - - -public class HiveInteger extends HiveData implements StdInteger { - - final IntObjectInspector _intObjectInspector; - - public HiveInteger(Object object, IntObjectInspector intObjectInspector, StdFactory stdFactory) { - super(stdFactory); - _object = object; - _intObjectInspector = intObjectInspector; - } - - @Override - public int get() { - return _intObjectInspector.get(_object); - } - - @Override - public ObjectInspector getUnderlyingObjectInspector() { - return _intObjectInspector; - } -} diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveLong.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveLong.java deleted file mode 100644 index 0b662b59..00000000 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveLong.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.hive.data; - -import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdLong; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector; - - -public class HiveLong extends HiveData implements StdLong { - - final LongObjectInspector _longObjectInspector; - - public HiveLong(Object object, LongObjectInspector longObjectInspector, StdFactory stdFactory) { - super(stdFactory); - _object = object; - _longObjectInspector = longObjectInspector; - } - - @Override - public long get() { - return _longObjectInspector.get(_object); - } - - @Override - public ObjectInspector getUnderlyingObjectInspector() { - return _longObjectInspector; - } -} diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveMap.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveMapData.java similarity index 66% rename from transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveMap.java rename to transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveMapData.java index 70f5132b..54da6042 100644 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveMap.java +++ b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveMapData.java @@ -6,8 +6,7 @@ package com.linkedin.transport.hive.data; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdMap; +import com.linkedin.transport.api.data.MapData; import com.linkedin.transport.hive.HiveWrapper; import java.util.AbstractCollection; import java.util.AbstractSet; @@ -20,13 +19,13 @@ import org.apache.hadoop.hive.serde2.objectinspector.SettableMapObjectInspector; -public class HiveMap extends HiveData implements StdMap { +public class HiveMapData extends HiveData implements MapData { final MapObjectInspector _mapObjectInspector; final ObjectInspector _keyObjectInspector; final ObjectInspector _valueObjectInspector; - public HiveMap(Object object, ObjectInspector objectInspector, StdFactory stdFactory) { + public HiveMapData(Object object, ObjectInspector objectInspector, StdFactory stdFactory) { super(stdFactory); _object = object; _mapObjectInspector = (MapObjectInspector) objectInspector; @@ -40,30 +39,30 @@ public int size() { } @Override - public StdData get(StdData key) { + public V get(K key) { MapObjectInspector mapOI = _mapObjectInspector; Object mapObj = _object; Object keyObj; try { - keyObj = ((HiveData) key).getUnderlyingDataForObjectInspector(_keyObjectInspector); + keyObj = HiveWrapper.getPlatformDataForObjectInspector(key, _keyObjectInspector); } catch (RuntimeException e) { // Cannot convert key argument to Map's KeyOI. So convert both the map and the key arg to // objects having standard OIs mapOI = (MapObjectInspector) getStandardObjectInspector(); - mapObj = getStandardObject(); - keyObj = ((HiveData) key).getStandardObject(); + mapObj = HiveWrapper.getStandardObject(this); + keyObj = HiveWrapper.getStandardObject(key); } - return HiveWrapper.createStdData( + return (V) HiveWrapper.createStdData( mapOI.getMapValueElement(mapObj, keyObj), mapOI.getMapValueObjectInspector(), _stdFactory); } @Override - public void put(StdData key, StdData value) { + public void put(K key, V value) { if (_mapObjectInspector instanceof SettableMapObjectInspector) { - Object keyObj = ((HiveData) key).getUnderlyingDataForObjectInspector(_keyObjectInspector); - Object valueObj = ((HiveData) value).getUnderlyingDataForObjectInspector(_valueObjectInspector); + Object keyObj = HiveWrapper.getPlatformDataForObjectInspector(key, _keyObjectInspector); + Object valueObj = HiveWrapper.getPlatformDataForObjectInspector(value, _valueObjectInspector); ((SettableMapObjectInspector) _mapObjectInspector).put( _object, @@ -79,11 +78,11 @@ public void put(StdData key, StdData value) { //TODO: Cache the result of .getMap(_object) below for subsequent calls. @Override - public Set keySet() { - return new AbstractSet() { + public Set keySet() { + return new AbstractSet() { @Override - public Iterator iterator() { - return new Iterator() { + public Iterator iterator() { + return new Iterator() { Iterator mapKeyIterator = _mapObjectInspector.getMap(_object).keySet().iterator(); @Override @@ -92,26 +91,26 @@ public boolean hasNext() { } @Override - public StdData next() { - return HiveWrapper.createStdData(mapKeyIterator.next(), _keyObjectInspector, _stdFactory); + public K next() { + return (K) HiveWrapper.createStdData(mapKeyIterator.next(), _keyObjectInspector, _stdFactory); } }; } @Override public int size() { - return HiveMap.this.size(); + return HiveMapData.this.size(); } }; } //TODO: Cache the result of .getMap(_object) below for subsequent calls. @Override - public Collection values() { - return new AbstractCollection() { + public Collection values() { + return new AbstractCollection() { @Override - public Iterator iterator() { - return new Iterator() { + public Iterator iterator() { + return new Iterator() { Iterator mapValueIterator = _mapObjectInspector.getMap(_object).values().iterator(); @Override @@ -120,30 +119,30 @@ public boolean hasNext() { } @Override - public StdData next() { - return HiveWrapper.createStdData(mapValueIterator.next(), _valueObjectInspector, _stdFactory); + public V next() { + return (V) HiveWrapper.createStdData(mapValueIterator.next(), _valueObjectInspector, _stdFactory); } }; } @Override public int size() { - return HiveMap.this.size(); + return HiveMapData.this.size(); } }; } @Override - public boolean containsKey(StdData key) { + public boolean containsKey(K key) { Object mapObj = _object; Object keyObj; try { - keyObj = ((HiveData) key).getUnderlyingDataForObjectInspector(_keyObjectInspector); + keyObj = HiveWrapper.getPlatformDataForObjectInspector(key, _keyObjectInspector); } catch (RuntimeException e) { // Cannot convert key argument to Map's KeyOI. So convertboth the map and the key arg to // objects having standard OIs - mapObj = getStandardObject(); - keyObj = ((HiveData) key).getStandardObject(); + mapObj = HiveWrapper.getStandardObject(this); + keyObj = HiveWrapper.getStandardObject(key); } return ((Map) mapObj).containsKey(keyObj); diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveStruct.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveRowData.java similarity index 79% rename from transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveStruct.java rename to transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveRowData.java index 80872eff..5704374e 100644 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveStruct.java +++ b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveRowData.java @@ -6,8 +6,7 @@ package com.linkedin.transport.hive.data; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdStruct; +import com.linkedin.transport.api.data.RowData; import com.linkedin.transport.hive.HiveWrapper; import java.util.List; import java.util.stream.Collectors; @@ -18,18 +17,18 @@ import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; -public class HiveStruct extends HiveData implements StdStruct { +public class HiveRowData extends HiveData implements RowData { StructObjectInspector _structObjectInspector; - public HiveStruct(Object object, ObjectInspector objectInspector, StdFactory stdFactory) { + public HiveRowData(Object object, ObjectInspector objectInspector, StdFactory stdFactory) { super(stdFactory); _object = object; _structObjectInspector = (StructObjectInspector) objectInspector; } @Override - public StdData getField(int index) { + public Object getField(int index) { StructField structField = _structObjectInspector.getAllStructFieldRefs().get(index); return HiveWrapper.createStdData( _structObjectInspector.getStructFieldData(_object, structField), @@ -38,7 +37,7 @@ public StdData getField(int index) { } @Override - public StdData getField(String name) { + public Object getField(String name) { StructField structField = _structObjectInspector.getStructFieldRef(name); return HiveWrapper.createStdData( _structObjectInspector.getStructFieldData(_object, structField), @@ -47,11 +46,11 @@ public StdData getField(String name) { } @Override - public void setField(int index, StdData value) { + public void setField(int index, Object value) { if (_structObjectInspector instanceof SettableStructObjectInspector) { StructField field = _structObjectInspector.getAllStructFieldRefs().get(index); ((SettableStructObjectInspector) _structObjectInspector).setStructFieldData(_object, - field, ((HiveData) value).getUnderlyingDataForObjectInspector(field.getFieldObjectInspector()) + field, HiveWrapper.getPlatformDataForObjectInspector(value, field.getFieldObjectInspector()) ); _isObjectModified = true; } else { @@ -61,11 +60,11 @@ public void setField(int index, StdData value) { } @Override - public void setField(String name, StdData value) { + public void setField(String name, Object value) { if (_structObjectInspector instanceof SettableStructObjectInspector) { StructField field = _structObjectInspector.getStructFieldRef(name); ((SettableStructObjectInspector) _structObjectInspector).setStructFieldData(_object, - field, ((HiveData) value).getUnderlyingDataForObjectInspector(field.getFieldObjectInspector())); + field, HiveWrapper.getPlatformDataForObjectInspector(value, field.getFieldObjectInspector())); _isObjectModified = true; } else { throw new RuntimeException("Attempt to modify an immutable Hive object of type: " @@ -74,7 +73,7 @@ public void setField(String name, StdData value) { } @Override - public List fields() { + public List fields() { return IntStream.range(0, _structObjectInspector.getAllStructFieldRefs().size()).mapToObj(i -> getField(i)) .collect(Collectors.toList()); } diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveString.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveString.java deleted file mode 100644 index 83310309..00000000 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveString.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.hive.data; - -import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdString; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; - - -public class HiveString extends HiveData implements StdString { - - final StringObjectInspector _stringObjectInspector; - - public HiveString(Object object, StringObjectInspector stringObjectInspector, StdFactory stdFactory) { - super(stdFactory); - _object = object; - _stringObjectInspector = stringObjectInspector; - } - - @Override - public String get() { - return _stringObjectInspector.getPrimitiveJavaObject(_object); - } - - @Override - public ObjectInspector getUnderlyingObjectInspector() { - return _stringObjectInspector; - } -} diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/types/HiveStructType.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/types/HiveRowType.java similarity index 83% rename from transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/types/HiveStructType.java rename to transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/types/HiveRowType.java index f4393776..c9ceef43 100644 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/types/HiveStructType.java +++ b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/types/HiveRowType.java @@ -5,7 +5,7 @@ */ package com.linkedin.transport.hive.types; -import com.linkedin.transport.api.types.StdStructType; +import com.linkedin.transport.api.types.RowType; import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.hive.HiveWrapper; import java.util.List; @@ -13,11 +13,11 @@ import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; -public class HiveStructType implements StdStructType { +public class HiveRowType implements RowType { final StructObjectInspector _structObjectInspector; - public HiveStructType(StructObjectInspector structObjectInspector) { + public HiveRowType(StructObjectInspector structObjectInspector) { _structObjectInspector = structObjectInspector; } diff --git a/transportable-udfs-plugin/build.gradle b/transportable-udfs-plugin/build.gradle index 6f4ade36..f241efbe 100644 --- a/transportable-udfs-plugin/build.gradle +++ b/transportable-udfs-plugin/build.gradle @@ -1,15 +1,7 @@ plugins { - id 'java' id 'java-gradle-plugin' -} - -gradlePlugin { - plugins { - simplePlugin { - id = 'com.linkedin.transport.plugin' - implementationClass = 'com.linkedin.transport.plugin.TransportPlugin' - } - } + id 'maven-publish' + id 'signing' } dependencies { @@ -27,8 +19,9 @@ def writeVersionInfo = { file -> ant.propertyfile(file: file) { entry(key: "transport-version", value: version) entry(key: "hive-version", value: '1.2.2') - entry(key: "presto-version", value: '333') - entry(key: "spark-version", value: '2.3.0') + entry(key: "trino-version", value: '352') + entry(key: "spark_2.11-version", value: '2.3.0') + entry(key: "spark_2.12-version", value: '3.1.1') entry(key: "scala-version", value: '2.11.8') } } @@ -36,3 +29,96 @@ def writeVersionInfo = { file -> processResources.doLast { writeVersionInfo(new File(sourceSets.main.output.resourcesDir, "version-info.properties")) } + +def licenseSpec = copySpec { + from project.rootDir + include "LICENSE" +} + +task sourcesJar(type: Jar, dependsOn: classes) { + classifier 'sources' + from sourceSets.main.allSource + with licenseSpec +} + +task javadocJar(type: Jar, dependsOn: javadoc) { + classifier 'javadoc' + from tasks.javadoc + with licenseSpec +} + +jar { + with licenseSpec +} + +artifacts { + archives sourcesJar + archives javadocJar +} + +signing { + if (System.getenv("PGP_KEY")) { + useInMemoryPgpKeys(System.getenv("PGP_KEY"), System.getenv("PGP_PWD")) + sign publishing.publications + } +} + +gradlePlugin { + plugins { + simplePlugin { + id = 'com.linkedin.transport.plugin' + implementationClass = 'com.linkedin.transport.plugin.TransportPlugin' + } + } +} + +publishing { + // afterEvaluate is necessary because java-gradle-plugin + // creates its publications in an afterEvaluate callback + afterEvaluate { + publications { + withType(MavenPublication) { + artifact sourcesJar + artifact javadocJar + + pom { + name = artifactId + description = "A library for analyzing, processing, and rewriting views defined in the Hive Metastore, and sharing them across multiple execution engines" + + url = "https://github.com/linkedin/transport" + licenses { + license { + name = 'BSD 2-CLAUSE LICENSE' + url = 'https://github.com/linkedin/transport/blob/master/LICENSE' + distribution = 'repo' + } + } + developers { + developer { + id = 'wmoustafa' + name = 'Walaa Eldin Moustafa' + } + developer { + id = 'shardulm94' + name = 'Shardul Mahadik' + } + } + scm { + url = 'https://github.com/linkedin/transport.git' + } + issueManagement { + url = 'https://github.com/linkedin/transport/issues' + system = 'GitHub issues' + } + ciManagement { + url = 'https://travis-ci.com/linkedin/transport' + system = 'Travis CI' + } + } + } + } + } + + //useful for testing - running "publish" will create artifacts/pom in a local dir + repositories { maven { url = "$rootProject.buildDir/repo" } } +} diff --git a/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/Defaults.java b/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/Defaults.java index 64a967f4..b4269b1b 100644 --- a/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/Defaults.java +++ b/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/Defaults.java @@ -7,8 +7,8 @@ import com.google.common.collect.ImmutableList; import com.linkedin.transport.codegen.HiveWrapperGenerator; -import com.linkedin.transport.codegen.PrestoWrapperGenerator; import com.linkedin.transport.codegen.SparkWrapperGenerator; +import com.linkedin.transport.codegen.TrinoWrapperGenerator; import com.linkedin.transport.plugin.packaging.DistributionPackaging; import com.linkedin.transport.plugin.packaging.ShadedJarPackaging; import com.linkedin.transport.plugin.packaging.ThinJarPackaging; @@ -16,6 +16,7 @@ import java.io.InputStream; import java.util.List; import java.util.Properties; +import org.gradle.jvm.toolchain.JavaLanguageVersion; import static com.linkedin.transport.plugin.ConfigurationType.*; @@ -42,79 +43,99 @@ private static Properties loadDefaultVersions() { } } + private static final String getVersion(final String platform) { + return DEFAULT_VERSIONS.getProperty(platform + "-version"); + } + private static final String HIVE = "hive"; + private static final String SPARK_2_11 = "spark_2.11"; + private static final String SPARK_2_12 = "spark_2.12"; + private static final String TRINO = "trino"; + + private static final String TRANSPORT_VERSION = getVersion("transport"); + private static final String SCALA_VERSION = getVersion("scala"); + private static final String HIVE_VERSION = getVersion(HIVE); + private static final String SPARK_2_11_VERSION = getVersion(SPARK_2_11); + private static final String SPARK_2_12_VERSION = getVersion(SPARK_2_12); + private static final String TRINO_VERSION = getVersion(TRINO); + static final List MAIN_SOURCE_SET_DEPENDENCY_CONFIGURATIONS = ImmutableList.of( - getDependencyConfiguration(IMPLEMENTATION, "com.linkedin.transport:transportable-udfs-api", "transport"), - getDependencyConfiguration(ANNOTATION_PROCESSOR, "com.linkedin.transport:transportable-udfs-annotation-processor", - "transport"), + DependencyConfiguration.builder(IMPLEMENTATION, "com.linkedin.transport:transportable-udfs-api", TRANSPORT_VERSION).build(), + DependencyConfiguration.builder(ANNOTATION_PROCESSOR, "com.linkedin.transport:transportable-udfs-annotation-processor", TRANSPORT_VERSION).build(), // the idea plugin needs a scala-library on the classpath when the scala plugin is applied even when there are no // scala sources - getDependencyConfiguration(COMPILE_ONLY, "org.scala-lang:scala-library", "scala") + DependencyConfiguration.builder(COMPILE_ONLY, "org.scala-lang:scala-library", SCALA_VERSION).build() ); static final List TEST_SOURCE_SET_DEPENDENCY_CONFIGURATIONS = ImmutableList.of( - getDependencyConfiguration(IMPLEMENTATION, "com.linkedin.transport:transportable-udfs-test-api", "transport"), - getDependencyConfiguration(RUNTIME_ONLY, "com.linkedin.transport:transportable-udfs-test-generic", "transport") + DependencyConfiguration.builder(IMPLEMENTATION, "com.linkedin.transport:transportable-udfs-test-api", TRANSPORT_VERSION).build(), + DependencyConfiguration.builder(RUNTIME_ONLY, "com.linkedin.transport:transportable-udfs-test-generic", TRANSPORT_VERSION).build() ); static final List DEFAULT_PLATFORMS = ImmutableList.of( - new Platform( - "presto", + new Platform(TRINO, Language.JAVA, - PrestoWrapperGenerator.class, + TrinoWrapperGenerator.class, + JavaLanguageVersion.of(11), ImmutableList.of( - getDependencyConfiguration(IMPLEMENTATION, "com.linkedin.transport:transportable-udfs-presto", - "transport"), - getDependencyConfiguration(COMPILE_ONLY, "io.prestosql:presto-main", "presto") + DependencyConfiguration.builder(IMPLEMENTATION, "com.linkedin.transport:transportable-udfs-trino", TRANSPORT_VERSION).build(), + DependencyConfiguration.builder(COMPILE_ONLY, "io.trino:trino-main", TRINO_VERSION).build() ), ImmutableList.of( - getDependencyConfiguration(RUNTIME_ONLY, "com.linkedin.transport:transportable-udfs-test-presto", - "transport"), - // presto-main:tests is a transitive dependency of transportable-udfs-test-presto, but some POM -> IVY + DependencyConfiguration.builder(RUNTIME_ONLY, "com.linkedin.transport:transportable-udfs-test-trino", TRANSPORT_VERSION).build(), + // trino-main:tests is a transitive dependency of transportable-udfs-test-trino, but some POM -> IVY // converters drop dependencies with classifiers, so we apply this dependency explicitly - getDependencyConfiguration(RUNTIME_ONLY, "io.prestosql:presto-main", "presto", "tests") + DependencyConfiguration.builder(RUNTIME_ONLY, "io.trino:trino-main", TRINO_VERSION).classifier("tests").build() ), ImmutableList.of(new ThinJarPackaging(), new DistributionPackaging())), - new Platform( - "hive", + new Platform(HIVE, Language.JAVA, HiveWrapperGenerator.class, + JavaLanguageVersion.of(8), ImmutableList.of( - getDependencyConfiguration(IMPLEMENTATION, "com.linkedin.transport:transportable-udfs-hive", "transport"), - getDependencyConfiguration(COMPILE_ONLY, "org.apache.hive:hive-exec", "hive") + DependencyConfiguration.builder(IMPLEMENTATION, "com.linkedin.transport:transportable-udfs-hive", TRANSPORT_VERSION).build(), + DependencyConfiguration.builder(COMPILE_ONLY, "org.apache.hive:hive-exec", HIVE_VERSION).exclude("org.apache.calcite").build() ), ImmutableList.of( - getDependencyConfiguration(RUNTIME_ONLY, "com.linkedin.transport:transportable-udfs-test-hive", - "transport") + DependencyConfiguration.builder(RUNTIME_ONLY, "com.linkedin.transport:transportable-udfs-test-hive", TRANSPORT_VERSION).build() ), ImmutableList.of(new ShadedJarPackaging(ImmutableList.of("org.apache.hadoop", "org.apache.hive"), null))), - new Platform( - "spark", + new Platform(SPARK_2_11, Language.SCALA, SparkWrapperGenerator.class, + JavaLanguageVersion.of(8), ImmutableList.of( - getDependencyConfiguration(IMPLEMENTATION, "com.linkedin.transport:transportable-udfs-spark", - "transport"), - getDependencyConfiguration(COMPILE_ONLY, "org.apache.spark:spark-sql_2.11", "spark") + DependencyConfiguration.builder(IMPLEMENTATION, "com.linkedin.transport:transportable-udfs-spark_2.11", TRANSPORT_VERSION).build(), + DependencyConfiguration.builder(COMPILE_ONLY, "org.apache.spark:spark-sql_2.11", SPARK_2_11_VERSION).build() ), ImmutableList.of( - getDependencyConfiguration(RUNTIME_ONLY, "com.linkedin.transport:transportable-udfs-test-spark", - "transport") + DependencyConfiguration.builder(RUNTIME_ONLY, "com.linkedin.transport:transportable-udfs-test-spark_2.11", TRANSPORT_VERSION).build() ), ImmutableList.of(new ShadedJarPackaging( ImmutableList.of("org.apache.hadoop", "org.apache.spark"), - ImmutableList.of("com.linkedin.transport.spark.**"))) + ImmutableList.of( + "com.linkedin.transport.spark.stdUDFRegistration", + "com.linkedin.transport.spark.SparkStdUDF" + ) + )) + ), + new Platform(SPARK_2_12, + Language.SCALA, + SparkWrapperGenerator.class, + JavaLanguageVersion.of(8), + ImmutableList.of( + DependencyConfiguration.builder(IMPLEMENTATION, "com.linkedin.transport:transportable-udfs-spark_2.12", TRANSPORT_VERSION).build(), + DependencyConfiguration.builder(COMPILE_ONLY, "org.apache.spark:spark-sql_2.12", SPARK_2_12_VERSION).build() + ), + ImmutableList.of( + DependencyConfiguration.builder(RUNTIME_ONLY, "com.linkedin.transport:transportable-udfs-test-spark_2.12", TRANSPORT_VERSION).build() + ), + ImmutableList.of(new ShadedJarPackaging( + ImmutableList.of("org.apache.hadoop", "org.apache.spark"), + ImmutableList.of( + "com.linkedin.transport.spark.stdUDFRegistration", + "com.linkedin.transport.spark.SparkStdUDF" + ) + )) ) ); - - private static DependencyConfiguration getDependencyConfiguration(ConfigurationType configurationType, - String module, String platform) { - return getDependencyConfiguration(configurationType, module, platform, null); - } - - private static DependencyConfiguration getDependencyConfiguration(ConfigurationType configurationType, - String module, String platform, String classifier) { - return new DependencyConfiguration(configurationType, module - + ":" + DEFAULT_VERSIONS.getProperty(platform + "-version") - + (classifier != null ? (":" + classifier) : "")); - } } diff --git a/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/DependencyConfiguration.java b/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/DependencyConfiguration.java index 078ca0cf..be8f0a0e 100644 --- a/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/DependencyConfiguration.java +++ b/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/DependencyConfiguration.java @@ -5,17 +5,34 @@ */ package com.linkedin.transport.plugin; +import com.google.common.collect.ImmutableMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + + /** * Represents a dependency to be applied to a certain sourceset configuration (e.g. implementation, compileOnly, etc.) * In the future can expand to incorporate exclude rules, dependency substitutions, etc. */ public class DependencyConfiguration { - private ConfigurationType _configurationType; - private String _dependencyString; + private static final String GROUP_KEY = "group"; + private static final String MODULE_KEY = "module"; - public DependencyConfiguration(ConfigurationType configurationType, String dependencyString) { - _configurationType = configurationType; - _dependencyString = dependencyString; + private final ConfigurationType _configurationType; + private final String _module; + private final String _version; + private final String _classifier; + private final Set> _excludedProperties; + + private DependencyConfiguration(Builder builder) { + this._configurationType = builder._configurationType; + this._module = builder._module; + this._version = builder._version; + this._classifier = builder._classifier; + this._excludedProperties = builder._excludedProperties; } public ConfigurationType getConfigurationType() { @@ -23,6 +40,57 @@ public ConfigurationType getConfigurationType() { } public String getDependencyString() { - return _dependencyString; + return _module + ":" + _version + Optional.ofNullable(_classifier).map(v -> ":" + v).orElse(""); + } + + public Set> getExcludedProperties() { + return _excludedProperties; + } + + public static Builder builder(final ConfigurationType configurationType, final String module, final String version) { + return new Builder(configurationType, module, version); + } + + public static class Builder { + private final ConfigurationType _configurationType; + private final String _module; + private String _version; + private String _classifier; + private Set> _excludedProperties; + + public Builder(final ConfigurationType configurationType, final String module, final String version) { + Objects.requireNonNull(configurationType); + Objects.requireNonNull(module); + Objects.requireNonNull(version); + this._configurationType = configurationType; + this._module = module; + this._version = version; + } + + public Builder classifier(final String classifier) { + Objects.requireNonNull(classifier); + _classifier = classifier; + return this; + } + + public Builder exclude(final String group) { + return exclude(group, null); + } + + public Builder exclude(final String group, final String module) { + Objects.requireNonNull(group); + if (_excludedProperties == null) { + _excludedProperties = new HashSet<>(); + } + _excludedProperties.add((module == null) + ? ImmutableMap.of(GROUP_KEY, group) + : ImmutableMap.of(GROUP_KEY, group, MODULE_KEY, module) + ); + return this; + } + + public DependencyConfiguration build() { + return new DependencyConfiguration(this); + } } } \ No newline at end of file diff --git a/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/Platform.java b/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/Platform.java index b3d87679..7acf6847 100644 --- a/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/Platform.java +++ b/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/Platform.java @@ -8,6 +8,7 @@ import com.linkedin.transport.codegen.WrapperGenerator; import com.linkedin.transport.plugin.packaging.Packaging; import java.util.List; +import org.gradle.jvm.toolchain.JavaLanguageVersion; /** @@ -21,12 +22,14 @@ public class Platform { private final List _defaultWrapperDependencyConfigurations; private final List _defaultTestDependencyConfigurations; private final List _packaging; + private final JavaLanguageVersion _javaLanguageVersion; public Platform(String name, Language language, Class wrapperGeneratorClass, - List defaultWrapperDependencyConfigurations, + JavaLanguageVersion javaLanguageVersion, List defaultWrapperDependencyConfigurations, List defaultTestDependencyConfigurations, List packaging) { _name = name; _language = language; + _javaLanguageVersion = javaLanguageVersion; _wrapperGeneratorClass = wrapperGeneratorClass; _defaultWrapperDependencyConfigurations = defaultWrapperDependencyConfigurations; _defaultTestDependencyConfigurations = defaultTestDependencyConfigurations; @@ -56,4 +59,8 @@ public List getDefaultTestDependencyConfigurations() { public List getPackaging() { return _packaging; } + + public JavaLanguageVersion getJavaLanguageVersion() { + return _javaLanguageVersion; + } } diff --git a/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/SourceSetUtils.java b/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/SourceSetUtils.java index 8a87fefe..aee37c28 100644 --- a/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/SourceSetUtils.java +++ b/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/SourceSetUtils.java @@ -6,9 +6,15 @@ package com.linkedin.transport.plugin; import java.util.Collection; +import java.util.Map; +import java.util.Set; +import javax.annotation.Nullable; import org.codehaus.groovy.runtime.InvokerHelper; import org.gradle.api.Project; import org.gradle.api.artifacts.Configuration; +import org.gradle.api.artifacts.Dependency; +import org.gradle.api.artifacts.ModuleDependency; +import org.gradle.api.artifacts.dsl.DependencyHandler; import org.gradle.api.file.SourceDirectorySet; import org.gradle.api.plugins.Convention; import org.gradle.api.tasks.ScalaSourceSet; @@ -75,7 +81,30 @@ private static String getConfigurationNameForSourceSet(SourceSet sourceSet, Conf * Adds the provided dependency to the given {@link Configuration} */ static void addDependencyToConfiguration(Project project, Configuration configuration, Object dependency) { - configuration.withDependencies(dependencySet -> dependencySet.add(project.getDependencies().create(dependency))); + addDependencyToConfiguration(configuration, createDependency(project, dependency), null); + } + + /** + * Adds the provided dependency {@link Dependency} to the given {@link Configuration}, + * excluding the elements in the excludeProperties + */ + static void addDependencyToConfiguration(final Configuration configuration, final Dependency dependency, + final @Nullable Set> excludeProperties) { + configuration.withDependencies(dependencySet -> { + if (excludeProperties != null) { + if (dependency instanceof ModuleDependency) { + excludeProperties.stream().forEach(((ModuleDependency) dependency)::exclude); + } + } + dependencySet.add(dependency); + }); + } + + /** + * Create {@link Dependency} by {@link Project}'s {@link DependencyHandler} + */ + static Dependency createDependency(final Project project, Object dependency) { + return project.getDependencies().create(dependency); } /** @@ -83,9 +112,11 @@ static void addDependencyToConfiguration(Project project, Configuration configur */ static void addDependencyConfigurationToSourceSet(Project project, SourceSet sourceSet, DependencyConfiguration dependencyConfiguration) { - addDependencyToConfiguration(project, - SourceSetUtils.getConfigurationForSourceSet(project, sourceSet, dependencyConfiguration.getConfigurationType()), - dependencyConfiguration.getDependencyString()); + addDependencyToConfiguration( + getConfigurationForSourceSet(project, sourceSet, dependencyConfiguration.getConfigurationType()), + createDependency(project, dependencyConfiguration.getDependencyString()), + dependencyConfiguration.getExcludedProperties() + ); } /** diff --git a/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/TransportPlugin.java b/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/TransportPlugin.java index 2f8e2984..ec911c5f 100644 --- a/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/TransportPlugin.java +++ b/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/TransportPlugin.java @@ -24,12 +24,15 @@ import org.gradle.api.plugins.scala.ScalaPlugin; import org.gradle.api.tasks.SourceSet; import org.gradle.api.tasks.TaskProvider; +import org.gradle.api.tasks.compile.JavaCompile; import org.gradle.api.tasks.testing.Test; +import org.gradle.jvm.toolchain.JavaToolchainService; import org.gradle.language.base.plugins.LifecycleBasePlugin; import org.gradle.testing.jacoco.plugins.JacocoPlugin; import org.gradle.testing.jacoco.plugins.JacocoTaskExtension; import static com.linkedin.transport.plugin.ConfigurationType.*; +import static com.linkedin.transport.plugin.Language.*; import static com.linkedin.transport.plugin.SourceSetUtils.*; @@ -64,7 +67,7 @@ public void apply(Project project) { Defaults.DEFAULT_PLATFORMS.forEach( platform -> configurePlatform(project, platform, mainSourceSet, testSourceSet, extension.outputDirFile)); }); - // Disable Jacoco for platform test tasks as it is known to cause issues with Presto and Hive tests + // Disable Jacoco for platform test tasks as it is known to cause issues with Trino and Hive tests project.getPlugins().withType(JacocoPlugin.class, (jacocoPlugin) -> { Defaults.DEFAULT_PLATFORMS.forEach(platform -> { project.getTasksByName(testTaskName(platform), true).forEach(task -> { @@ -120,9 +123,9 @@ private SourceSet configureSourceSet(Project project, Platform platform, SourceS return javaConvention.getSourceSets().create(platform.getName(), sourceSet -> { /* - Creates a SourceSet and set the source directories for a given platform. E.g. For the Presto platform, + Creates a SourceSet and set the source directories for a given platform. E.g. For the Trino platform, - presto { + trino { java.srcDirs = ["${buildDir}/generatedWrappers/sources"] resources.srcDirs = ["${buildDir}/generatedWrappers/resources"] } @@ -131,11 +134,11 @@ private SourceSet configureSourceSet(Project project, Platform platform, SourceS sourceSet.getResources().setSrcDirs(ImmutableList.of(wrapperResourceOutputDir)); /* - Sets up the configuration for the platform's wrapper SourceSet. E.g. For the Presto platform, + Sets up the configuration for the platform's wrapper SourceSet. E.g. For the Trino platform, configurations { - prestoImplementation.extendsFrom mainImplementation - prestoRuntimeOnly.extendsFrom mainRuntimeOnly + trinoImplementation.extendsFrom mainImplementation + trinoRuntimeOnly.extendsFrom mainRuntimeOnly } */ getConfigurationForSourceSet(project, sourceSet, IMPLEMENTATION).extendsFrom( @@ -144,12 +147,12 @@ private SourceSet configureSourceSet(Project project, Platform platform, SourceS getConfigurationForSourceSet(project, mainSourceSet, RUNTIME_ONLY)); /* - Adds the default dependencies for the platform. E.g For the Presto platform, + Adds the default dependencies for the platform. E.g For the Trino platform, dependencies { - prestoImplementation project.files(project.tasks.jar) - prestoImplementation 'com.linkedin.transport:transportable-udfs-presto:$version' - prestoCompileOnly 'io.prestosql:presto-main:$version' + trinoImplementation project.files(project.tasks.jar) + trinoImplementation 'com.linkedin.transport:transportable-udfs-trino:$version' + trinoCompileOnly 'io.trino:trino-main:$version' } */ addDependencyToConfiguration(project, getConfigurationForSourceSet(project, sourceSet, IMPLEMENTATION), @@ -165,17 +168,17 @@ private TaskProvider configureGenerateWrappersTask(Project SourceSet inputSourceSet, SourceSet outputSourceSet) { /* - Creates a generateWrapper task for a given platform. E.g For the Presto platform, + Creates a generateWrapper task for a given platform. E.g For the Trino platform, - task generatePrestoWrappers { - generatorClass = 'com.linkedin.transport.codegen.PrestoWrapperGenerator' + task generateTrinoWrappers { + generatorClass = 'com.linkedin.transport.codegen.TrinoWrapperGenerator' inputClassesDirs = sourceSets.main.output.classesDirs - sourcesOutputDir = sourceSets.presto.java.srcDirs[0] - resourcesOutputDir = sourceSets.presto.resources.srcDirs[0] + sourcesOutputDir = sourceSets.trino.java.srcDirs[0] + resourcesOutputDir = sourceSets.trino.resources.srcDirs[0] dependsOn classes } - prestoClasses.dependsOn(generatePrestoWrappers) + trinoClasses.dependsOn(generateTrinoWrappers) */ String taskName = outputSourceSet.getTaskName("generate", "Wrappers"); File sourcesOutputDir = @@ -192,6 +195,18 @@ private TaskProvider configureGenerateWrappersTask(Project task.dependsOn(project.getTasks().named(inputSourceSet.getClassesTaskName())); }); + // Configure Java compile tasks to run with platform specific jdk + // TODO: set platform specific jdks/toolchain for scala tasks when support is available + if (platform.getLanguage() == JAVA) { + project.getTasks() + .named(outputSourceSet.getCompileTaskName(platform.getLanguage().toString()), JavaCompile.class, task -> { + JavaToolchainService javaToolchains = project.getExtensions().getByType(JavaToolchainService.class); + task.getJavaCompiler().set(javaToolchains.compilerFor(toolChainSpec -> { + toolChainSpec.getLanguageVersion().set(platform.getJavaLanguageVersion()); + })); + }); + } + project.getTasks() .named(outputSourceSet.getCompileTaskName(platform.getLanguage().toString())) .configure(task -> task.dependsOn(generateWrappersTask)); @@ -216,17 +231,17 @@ private TaskProvider configureTestTask(Project project, Platform platform, SourceSet testSourceSet) { /* - Configures the classpath configuration to run platform-specific tests. E.g. For the Presto platform, + Configures the classpath configuration to run platform-specific tests. E.g. For the Trino platform, configurations { - prestoTestClasspath { + trinoTestClasspath { extendsFrom testImplementation } } dependencies { - prestoTestClasspath sourceSets.main.output, sourceSets.test.output - prestoTestClasspath 'com.linkedin.transport:transportable-udfs-test-presto' + trinoTestClasspath sourceSets.main.output, sourceSets.test.output + trinoTestClasspath 'com.linkedin.transport:transportable-udfs-test-trino' } */ Configuration testClasspath = project.getConfigurations() @@ -239,13 +254,13 @@ private TaskProvider configureTestTask(Project project, Platform platform, dependencyConfiguration.getDependencyString())); /* - Creates the test task for a given platform. E.g. For the Presto platform, + Creates the test task for a given platform. E.g. For the Trino platform, - task prestoTest(type: Test, dependsOn: test) { + task trinoTest(type: Test, dependsOn: test) { group 'Verification' - description 'Runs the Presto tests.' + description 'Runs the Trino tests.' testClassesDirs = sourceSets.test.output.classesDirs - classpath = configurations.prestoTestClasspath + classpath = configurations.trinoTestClasspath useTestNG() } */ @@ -257,6 +272,12 @@ task prestoTest(type: Test, dependsOn: test) { task.setClasspath(testClasspath); task.useTestNG(); task.mustRunAfter(project.getTasks().named("test")); + + // configure test task to run with platform specific jdk + JavaToolchainService javaToolchains = project.getExtensions().getByType(JavaToolchainService.class); + task.getJavaLauncher().set(javaToolchains.launcherFor(toolChainSpec -> { + toolChainSpec.getLanguageVersion().set(platform.getJavaLanguageVersion()); + })); }); } diff --git a/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/packaging/DistributionPackaging.java b/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/packaging/DistributionPackaging.java index 13de71db..26c43fc5 100644 --- a/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/packaging/DistributionPackaging.java +++ b/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/packaging/DistributionPackaging.java @@ -66,7 +66,7 @@ public List> configurePackagingTasks(Project projec */ private TaskProvider createThinJarTask(Project project, SourceSet sourceSet, String platformName) { /* - task DistThinJar(type: Jar, dependsOn: prestoClasses) { + task DistThinJar(type: Jar, dependsOn: trinoClasses) { classifier '-dist-thin' from sourceSets..output from sourceSets..resources diff --git a/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/packaging/ThinJarPackaging.java b/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/packaging/ThinJarPackaging.java index 7367733c..87bd8e12 100644 --- a/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/packaging/ThinJarPackaging.java +++ b/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/packaging/ThinJarPackaging.java @@ -25,7 +25,7 @@ public class ThinJarPackaging implements Packaging { public List> configurePackagingTasks(Project project, Platform platform, SourceSet platformSourceSet, SourceSet mainSourceSet) { /* - task ThinJar(type: Jar, dependsOn: prestoClasses) { + task ThinJar(type: Jar, dependsOn: Classes) { classifier '-thin' from sourceSets..output from sourceSets..resources diff --git a/transportable-udfs-presto/build.gradle b/transportable-udfs-presto/build.gradle deleted file mode 100644 index 4141cb79..00000000 --- a/transportable-udfs-presto/build.gradle +++ /dev/null @@ -1,46 +0,0 @@ -apply plugin: 'java' - -buildscript { - repositories { - mavenCentral() - } - dependencies { - classpath group:'io.prestosql', name: 'presto-main', version: project.ext.'presto-version' - } -} - -import com.google.common.base.StandardSystemProperty -import io.prestosql.server.JavaVersion -task verifyPrestoJvmRequirements(type:Exec) { - String javaVersion = StandardSystemProperty.JAVA_VERSION.value() - if (javaVersion == null) { - throw new GradleException("Java version not defined") - } - JavaVersion version = JavaVersion.parse(javaVersion) - if (!(version.getMajor() == 8 && version.getUpdate().isPresent() && version.getUpdate().getAsInt() >= 151) - || (version.getMajor() >= 9)) { - throw new GradleException(String.format("Presto requires Java 8u151+ (found %s)", version)) - } -} - -dependencies { - compile project(':transportable-udfs-api') - compile project(':transportable-udfs-type-system') - compile project(':transportable-udfs-utils') - compileOnly(group:'io.prestosql', name: 'presto-main', version: project.ext.'presto-version') { - exclude 'group': 'com.google.collections', 'module': 'google-collections' - } - testCompile(group:'io.prestosql', name: 'presto-main', version: project.ext.'presto-version') { - exclude 'group': 'com.google.collections', 'module': 'google-collections' - } - testCompile(group:'io.prestosql', name: 'presto-main', version: project.ext.'presto-version', classifier: 'tests') { - exclude 'group': 'com.google.collections', 'module': 'google-collections' - } - compileOnly(group:'io.prestosql', name: 'presto-spi', version: project.ext.'presto-version') - compile('org.apache.hadoop:hadoop-hdfs:2.7.4') - compile('org.apache.hadoop:hadoop-common:2.7.4') - testCompile('io.airlift:testing:0.142') - // The io.airlift.slice dependency below has to match its counterpart in presto-root's pom.xml file - // If not specified, an older version is picked up transitively from another dependency - testCompile(group: 'io.airlift', name: 'slice', version: project.ext.'airlift-slice-version') -} diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/PrestoFactory.java b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/PrestoFactory.java deleted file mode 100644 index ae3605cb..00000000 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/PrestoFactory.java +++ /dev/null @@ -1,141 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.presto; - -import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableSet; -import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdBoolean; -import com.linkedin.transport.api.data.StdBinary; -import com.linkedin.transport.api.data.StdDouble; -import com.linkedin.transport.api.data.StdFloat; -import com.linkedin.transport.api.data.StdInteger; -import com.linkedin.transport.api.data.StdLong; -import com.linkedin.transport.api.data.StdMap; -import com.linkedin.transport.api.data.StdString; -import com.linkedin.transport.api.data.StdStruct; -import com.linkedin.transport.api.types.StdType; -import com.linkedin.transport.presto.data.PrestoArray; -import com.linkedin.transport.presto.data.PrestoBoolean; -import com.linkedin.transport.presto.data.PrestoBinary; -import com.linkedin.transport.presto.data.PrestoDouble; -import com.linkedin.transport.presto.data.PrestoFloat; -import com.linkedin.transport.presto.data.PrestoInteger; -import com.linkedin.transport.presto.data.PrestoLong; -import com.linkedin.transport.presto.data.PrestoMap; -import com.linkedin.transport.presto.data.PrestoString; -import com.linkedin.transport.presto.data.PrestoStruct; -import io.airlift.slice.Slices; -import io.prestosql.metadata.BoundVariables; -import io.prestosql.metadata.Metadata; -import io.prestosql.metadata.OperatorNotFoundException; -import io.prestosql.metadata.ResolvedFunction; -import io.prestosql.operator.scalar.ScalarFunctionImplementation; -import io.prestosql.spi.function.OperatorType; -import io.prestosql.spi.type.ArrayType; -import io.prestosql.spi.type.MapType; -import io.prestosql.spi.type.RowType; -import io.prestosql.spi.type.Type; -import java.nio.ByteBuffer; -import java.util.List; -import java.util.stream.Collectors; - -import static io.prestosql.metadata.SignatureBinder.*; -import static io.prestosql.operator.TypeSignatureParser.*; - -public class PrestoFactory implements StdFactory { - - final BoundVariables boundVariables; - final Metadata metadata; - - public PrestoFactory(BoundVariables boundVariables, Metadata metadata) { - this.boundVariables = boundVariables; - this.metadata = metadata; - } - - @Override - public StdInteger createInteger(int value) { - return new PrestoInteger(value); - } - - @Override - public StdLong createLong(long value) { - return new PrestoLong(value); - } - - @Override - public StdBoolean createBoolean(boolean value) { - return new PrestoBoolean(value); - } - - @Override - public StdString createString(String value) { - Preconditions.checkNotNull(value, "Cannot create a null StdString"); - return new PrestoString(Slices.utf8Slice(value)); - } - - @Override - public StdFloat createFloat(float value) { - return new PrestoFloat(value); - } - - @Override - public StdDouble createDouble(double value) { - return new PrestoDouble(value); - } - - @Override - public StdBinary createBinary(ByteBuffer value) { - return new PrestoBinary(Slices.wrappedBuffer(value.array())); - } - - @Override - public StdArray createArray(StdType stdType, int expectedSize) { - return new PrestoArray((ArrayType) stdType.underlyingType(), expectedSize, this); - } - - @Override - public StdArray createArray(StdType stdType) { - return createArray(stdType, 0); - } - - @Override - public StdMap createMap(StdType stdType) { - return new PrestoMap((MapType) stdType.underlyingType(), this); - } - - @Override - public PrestoStruct createStruct(List fieldNames, List fieldTypes) { - return new PrestoStruct(fieldNames, - fieldTypes.stream().map(stdType -> (Type) stdType.underlyingType()).collect(Collectors.toList()), this); - } - - @Override - public PrestoStruct createStruct(List fieldTypes) { - return new PrestoStruct( - fieldTypes.stream().map(stdType -> (Type) stdType.underlyingType()).collect(Collectors.toList()), this); - } - - @Override - public StdStruct createStruct(StdType stdType) { - return new PrestoStruct((RowType) stdType.underlyingType(), this); - } - - @Override - public StdType createStdType(String typeSignature) { - return PrestoWrapper.createStdType( - metadata.getType(applyBoundVariables(parseTypeSignature(typeSignature, ImmutableSet.of()), boundVariables))); - } - - public ScalarFunctionImplementation getScalarFunctionImplementation(ResolvedFunction resolvedFunction) { - return metadata.getScalarFunctionImplementation(resolvedFunction); - } - - public ResolvedFunction resolveOperator(OperatorType operatorType, List argumentTypes) throws OperatorNotFoundException { - return metadata.resolveOperator(operatorType, argumentTypes); - } -} diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/PrestoWrapper.java b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/PrestoWrapper.java deleted file mode 100644 index 7f561b96..00000000 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/PrestoWrapper.java +++ /dev/null @@ -1,140 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.presto; - -import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.types.StdType; -import com.linkedin.transport.presto.data.PrestoArray; -import com.linkedin.transport.presto.data.PrestoBoolean; -import com.linkedin.transport.presto.data.PrestoBinary; -import com.linkedin.transport.presto.data.PrestoDouble; -import com.linkedin.transport.presto.data.PrestoFloat; -import com.linkedin.transport.presto.data.PrestoInteger; -import com.linkedin.transport.presto.data.PrestoLong; -import com.linkedin.transport.presto.data.PrestoMap; -import com.linkedin.transport.presto.data.PrestoString; -import com.linkedin.transport.presto.data.PrestoStruct; -import com.linkedin.transport.presto.types.PrestoArrayType; -import com.linkedin.transport.presto.types.PrestoBooleanType; -import com.linkedin.transport.presto.types.PrestoBinaryType; -import com.linkedin.transport.presto.types.PrestoDoubleType; -import com.linkedin.transport.presto.types.PrestoFloatType; -import com.linkedin.transport.presto.types.PrestoIntegerType; -import com.linkedin.transport.presto.types.PrestoLongType; -import com.linkedin.transport.presto.types.PrestoMapType; -import com.linkedin.transport.presto.types.PrestoStringType; -import com.linkedin.transport.presto.types.PrestoStructType; -import com.linkedin.transport.presto.types.PrestoUnknownType; -import io.airlift.slice.Slice; -import io.prestosql.spi.PrestoException; -import io.prestosql.spi.block.Block; -import io.prestosql.spi.type.ArrayType; -import io.prestosql.spi.type.BigintType; -import io.prestosql.spi.type.BooleanType; -import io.prestosql.spi.type.DoubleType; -import io.prestosql.spi.type.IntegerType; -import io.prestosql.spi.type.MapType; -import io.prestosql.spi.type.RealType; -import io.prestosql.spi.type.RowType; -import io.prestosql.spi.type.Type; -import io.prestosql.spi.type.VarbinaryType; -import io.prestosql.spi.type.VarcharType; -import io.prestosql.type.UnknownType; - -import static io.prestosql.spi.StandardErrorCode.*; -import static java.lang.Float.*; -import static java.lang.Math.*; -import static java.lang.String.*; - - -public final class PrestoWrapper { - - private PrestoWrapper() { - } - - public static StdData createStdData(Object prestoData, Type prestoType, StdFactory stdFactory) { - if (prestoData == null) { - return null; - } - if (prestoType instanceof IntegerType) { - // Presto represents SQL Integers (i.e., corresponding to IntegerType above) as long or Long - // Therefore, to pass it to the PrestoInteger class, we first cast it to Long, then extract - // the int value. - return new PrestoInteger(((Long) prestoData).intValue()); - } else if (prestoType instanceof BigintType) { - return new PrestoLong((long) prestoData); - } else if (prestoType.getJavaType() == boolean.class) { - return new PrestoBoolean((boolean) prestoData); - } else if (prestoType instanceof VarcharType) { - return new PrestoString((Slice) prestoData); - } else if (prestoType instanceof RealType) { - // Presto represents SQL Reals (i.e., corresponding to RealType above) as long or Long - // Therefore, to pass it to the PrestoFloat class, we first cast it to Long, extract - // the int value and convert it the int bits to float. - long value = (long) prestoData; - int floatValue; - try { - floatValue = toIntExact(value); - } catch (ArithmeticException e) { - throw new PrestoException(GENERIC_INTERNAL_ERROR, - format("Value (%sb) is not a valid single-precision float", Long.toBinaryString(value))); - } - return new PrestoFloat(intBitsToFloat(floatValue)); - } else if (prestoType instanceof DoubleType) { - return new PrestoDouble((double) prestoData); - } else if (prestoType instanceof VarbinaryType) { - return new PrestoBinary((Slice) prestoData); - } else if (prestoType instanceof ArrayType) { - return new PrestoArray((Block) prestoData, (ArrayType) prestoType, stdFactory); - } else if (prestoType instanceof MapType) { - return new PrestoMap((Block) prestoData, prestoType, stdFactory); - } else if (prestoType instanceof RowType) { - return new PrestoStruct((Block) prestoData, prestoType, stdFactory); - } - assert false : "Unrecognized Presto Type: " + prestoType.getClass(); - return null; - } - - public static StdType createStdType(Object prestoType) { - if (prestoType instanceof IntegerType) { - return new PrestoIntegerType((IntegerType) prestoType); - } else if (prestoType instanceof BigintType) { - return new PrestoLongType((BigintType) prestoType); - } else if (prestoType instanceof BooleanType) { - return new PrestoBooleanType((BooleanType) prestoType); - } else if (prestoType instanceof VarcharType) { - return new PrestoStringType((VarcharType) prestoType); - } else if (prestoType instanceof RealType) { - return new PrestoFloatType((RealType) prestoType); - } else if (prestoType instanceof DoubleType) { - return new PrestoDoubleType((DoubleType) prestoType); - } else if (prestoType instanceof VarbinaryType) { - return new PrestoBinaryType((VarbinaryType) prestoType); - } else if (prestoType instanceof ArrayType) { - return new PrestoArrayType((ArrayType) prestoType); - } else if (prestoType instanceof MapType) { - return new PrestoMapType((MapType) prestoType); - } else if (prestoType instanceof RowType) { - return new PrestoStructType(((RowType) prestoType)); - } else if (prestoType instanceof UnknownType) { - return new PrestoUnknownType(((UnknownType) prestoType)); - } - assert false : "Unrecognized Presto Type: " + prestoType.getClass(); - return null; - } - - /** - * @return index if the index is in range, -1 otherwise. - */ - public static int checkedIndexToBlockPosition(Block block, long index) { - int blockLength = block.getPositionCount(); - if (index >= 0 && index < blockLength) { - return toIntExact(index); - } - return -1; // -1 indicates that the element is out of range and the calling function should return null - } -} diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoBinary.java b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoBinary.java deleted file mode 100644 index bc201cde..00000000 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoBinary.java +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.presto.data; - -import com.linkedin.transport.api.data.StdBinary; -import io.airlift.slice.Slice; -import io.prestosql.spi.block.BlockBuilder; -import java.nio.ByteBuffer; - -import static io.prestosql.spi.type.VarbinaryType.*; - -public class PrestoBinary extends PrestoData implements StdBinary { - - private Slice _slice; - - public PrestoBinary(Slice slice) { - _slice = slice; - } - - @Override - public ByteBuffer get() { - return _slice.toByteBuffer(); - } - - @Override - public Object getUnderlyingData() { - return _slice; - } - - @Override - public void setUnderlyingData(Object value) { - _slice = (Slice) value; - } - - @Override - public void writeToBlock(BlockBuilder blockBuilder) { - VARBINARY.writeSlice(blockBuilder, _slice); - } -} diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoBoolean.java b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoBoolean.java deleted file mode 100644 index 408fc9be..00000000 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoBoolean.java +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.presto.data; - -import com.linkedin.transport.api.data.StdBoolean; -import io.prestosql.spi.block.BlockBuilder; - -import static io.prestosql.spi.type.BooleanType.*; - - -public class PrestoBoolean extends PrestoData implements StdBoolean { - - boolean _value; - - public PrestoBoolean(boolean value) { - _value = value; - } - - @Override - public boolean get() { - return _value; - } - - @Override - public Object getUnderlyingData() { - return _value; - } - - @Override - public void setUnderlyingData(Object value) { - _value = (boolean) value; - } - - @Override - public void writeToBlock(BlockBuilder blockBuilder) { - BOOLEAN.writeBoolean(blockBuilder, _value); - } -} diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoDouble.java b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoDouble.java deleted file mode 100644 index 0ab9fe6f..00000000 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoDouble.java +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.presto.data; - -import com.linkedin.transport.api.data.StdDouble; -import io.prestosql.spi.block.BlockBuilder; - -import static io.prestosql.spi.type.DoubleType.*; - - -public class PrestoDouble extends PrestoData implements StdDouble { - - private double _double; - - public PrestoDouble(double aDouble) { - _double = aDouble; - } - - @Override - public double get() { - return _double; - } - - @Override - public Object getUnderlyingData() { - return _double; - } - - @Override - public void setUnderlyingData(Object value) { - _double = (double) value; - } - - @Override - public void writeToBlock(BlockBuilder blockBuilder) { - DOUBLE.writeDouble(blockBuilder, _double); - } -} \ No newline at end of file diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoFloat.java b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoFloat.java deleted file mode 100644 index 11328cef..00000000 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoFloat.java +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.presto.data; - -import com.linkedin.transport.api.data.StdFloat; -import io.prestosql.spi.block.BlockBuilder; - -import static java.lang.Float.*; - - -public class PrestoFloat extends PrestoData implements StdFloat { - - private float _float; - - public PrestoFloat(float aFloat) { - _float = aFloat; - } - - @Override - public float get() { - return _float; - } - - @Override - public Object getUnderlyingData() { - return (long) floatToIntBits(_float); - } - - @Override - public void setUnderlyingData(Object value) { - _float = (float) value; - } - - @Override - public void writeToBlock(BlockBuilder blockBuilder) { - blockBuilder.writeInt(floatToIntBits(_float)); - } -} diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoInteger.java b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoInteger.java deleted file mode 100644 index 06ef9a3b..00000000 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoInteger.java +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.presto.data; - -import com.linkedin.transport.api.data.StdInteger; -import io.prestosql.spi.block.BlockBuilder; - -import static io.prestosql.spi.type.IntegerType.*; - - -public class PrestoInteger extends PrestoData implements StdInteger { - - int _integer; - - public PrestoInteger(int integer) { - _integer = integer; - } - - @Override - public int get() { - return _integer; - } - - @Override - public Object getUnderlyingData() { - return _integer; - } - - @Override - public void setUnderlyingData(Object value) { - _integer = ((Long) value).intValue(); - } - - @Override - public void writeToBlock(BlockBuilder blockBuilder) { - // It looks a bit strange, but the call to writeLong is correct here. INTEGER does not have a writeInt method for - // some reason. It uses BlockBuilder.writeInt internally. - INTEGER.writeLong(blockBuilder, _integer); - } -} diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoLong.java b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoLong.java deleted file mode 100644 index 29832b4a..00000000 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoLong.java +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.presto.data; - -import com.linkedin.transport.api.data.StdLong; -import io.prestosql.spi.block.BlockBuilder; - -import static io.prestosql.spi.type.BigintType.*; - - -public class PrestoLong extends PrestoData implements StdLong { - - long _value; - - public PrestoLong(long value) { - _value = value; - } - - @Override - public long get() { - return _value; - } - - @Override - public Object getUnderlyingData() { - return _value; - } - - @Override - public void setUnderlyingData(Object value) { - _value = (long) value; - } - - @Override - public void writeToBlock(BlockBuilder blockBuilder) { - BIGINT.writeLong(blockBuilder, _value); - } -} diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoString.java b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoString.java deleted file mode 100644 index 6691da3f..00000000 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoString.java +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.presto.data; - -import com.linkedin.transport.api.data.StdString; -import io.airlift.slice.Slice; -import io.prestosql.spi.block.BlockBuilder; - -import static io.prestosql.spi.type.VarcharType.*; - - -public class PrestoString extends PrestoData implements StdString { - - Slice _slice; - - public PrestoString(Slice slice) { - _slice = slice; - } - - @Override - public String get() { - return _slice.toStringUtf8(); - } - - @Override - public Object getUnderlyingData() { - return _slice; - } - - @Override - public void setUnderlyingData(Object value) { - _slice = (Slice) value; - } - - @Override - public void writeToBlock(BlockBuilder blockBuilder) { - VARCHAR.writeSlice(blockBuilder, _slice); - } -} diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/SparkWrapper.scala b/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/SparkWrapper.scala deleted file mode 100644 index 29e935db..00000000 --- a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/SparkWrapper.scala +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.spark - -import java.nio.ByteBuffer - -import com.linkedin.transport.api.data.StdData -import com.linkedin.transport.api.types.StdType -import com.linkedin.transport.spark.data._ -import com.linkedin.transport.spark.types._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -object SparkWrapper { - - def createStdData(data: Any, dataType: DataType): StdData = { // scalastyle:ignore cyclomatic.complexity - if (data == null) { - null - } else { - dataType match { - case _: IntegerType => SparkInteger(data.asInstanceOf[Integer]) - case _: LongType => SparkLong(data.asInstanceOf[java.lang.Long]) - case _: BooleanType => SparkBoolean(data.asInstanceOf[java.lang.Boolean]) - case _: StringType => SparkString(data.asInstanceOf[UTF8String]) - case _: FloatType => SparkFloat(data.asInstanceOf[java.lang.Float]) - case _: DoubleType => SparkDouble(data.asInstanceOf[java.lang.Double]) - case _: BinaryType => SparkBinary(data.asInstanceOf[Array[Byte]]) - case _: ArrayType => SparkArray(data.asInstanceOf[ArrayData], dataType.asInstanceOf[ArrayType]) - case _: MapType => SparkMap(data.asInstanceOf[MapData], dataType.asInstanceOf[MapType]) - case _: StructType => SparkStruct(data.asInstanceOf[InternalRow], dataType.asInstanceOf[StructType]) - case _: NullType => null - case _ => throw new UnsupportedOperationException("Unrecognized Spark Type: " + dataType.getClass) - } - } - } - - def createStdType(dataType: DataType): StdType = dataType match { - case _: IntegerType => SparkIntegerType(dataType.asInstanceOf[IntegerType]) - case _: LongType => SparkLongType(dataType.asInstanceOf[LongType]) - case _: BooleanType => SparkBooleanType(dataType.asInstanceOf[BooleanType]) - case _: StringType => SparkStringType(dataType.asInstanceOf[StringType]) - case _: FloatType => SparkFloatType(dataType.asInstanceOf[FloatType]) - case _: DoubleType => SparkDoubleType(dataType.asInstanceOf[DoubleType]) - case _: BinaryType => SparkBinaryType(dataType.asInstanceOf[BinaryType]) - case _: ArrayType => SparkArrayType(dataType.asInstanceOf[ArrayType]) - case _: MapType => SparkMapType(dataType.asInstanceOf[MapType]) - case _: StructType => SparkStructType(dataType.asInstanceOf[StructType]) - case _: NullType => SparkUnknownType(dataType.asInstanceOf[NullType]) - case _ => throw new UnsupportedOperationException("Unrecognized Spark Type: " + dataType.getClass) - } -} diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkBinary.scala b/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkBinary.scala deleted file mode 100644 index bd402530..00000000 --- a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkBinary.scala +++ /dev/null @@ -1,19 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.spark.data - -import java.nio.ByteBuffer - -import com.linkedin.transport.api.data.{PlatformData, StdBinary} - -case class SparkBinary(private var _bytes: Array[Byte]) extends StdBinary with PlatformData { - - override def get(): ByteBuffer = ByteBuffer.wrap(_bytes) - - override def getUnderlyingData: AnyRef = _bytes - - override def setUnderlyingData(value: scala.Any): Unit = _bytes = value.asInstanceOf[ByteBuffer].array() -} diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkBoolean.scala b/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkBoolean.scala deleted file mode 100644 index 2477eef2..00000000 --- a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkBoolean.scala +++ /dev/null @@ -1,17 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.spark.data - -import com.linkedin.transport.api.data.{PlatformData, StdBoolean} - -case class SparkBoolean(private var _bool: java.lang.Boolean) extends StdBoolean with PlatformData { - - override def get(): Boolean = _bool.booleanValue() - - override def getUnderlyingData: AnyRef = _bool - - override def setUnderlyingData(value: scala.Any): Unit = _bool = value.asInstanceOf[java.lang.Boolean] -} diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkDouble.scala b/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkDouble.scala deleted file mode 100644 index 6a4820e3..00000000 --- a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkDouble.scala +++ /dev/null @@ -1,18 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.spark.data - -import com.linkedin.transport.api.data.{PlatformData, StdDouble} - -case class SparkDouble(private var _double: java.lang.Double) extends StdDouble with PlatformData { - - override def get(): Double = _double.doubleValue() - - override def getUnderlyingData: AnyRef = _double - - override def setUnderlyingData(value: scala.Any): Unit = _double = value.asInstanceOf[java.lang.Double] -} - diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkFloat.scala b/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkFloat.scala deleted file mode 100644 index d9842b51..00000000 --- a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkFloat.scala +++ /dev/null @@ -1,17 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.spark.data - -import com.linkedin.transport.api.data.{PlatformData, StdFloat} - -case class SparkFloat(private var _float: java.lang.Float) extends StdFloat with PlatformData { - - override def get(): Float = _float.floatValue() - - override def getUnderlyingData: AnyRef = _float - - override def setUnderlyingData(value: scala.Any): Unit = _float = value.asInstanceOf[java.lang.Float] -} diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkInteger.scala b/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkInteger.scala deleted file mode 100644 index b7c0db9e..00000000 --- a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkInteger.scala +++ /dev/null @@ -1,17 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.spark.data - -import com.linkedin.transport.api.data.{PlatformData, StdInteger} - -case class SparkInteger(private var _int: Integer) extends StdInteger with PlatformData { - - override def get(): Int = _int.intValue() - - override def getUnderlyingData: AnyRef = _int - - override def setUnderlyingData(value: scala.Any): Unit = _int = value.asInstanceOf[Integer] -} diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkLong.scala b/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkLong.scala deleted file mode 100644 index 5a534290..00000000 --- a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkLong.scala +++ /dev/null @@ -1,18 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.spark.data - -import com.linkedin.transport.api.data.{PlatformData, StdLong} - -case class SparkLong(private var _long: java.lang.Long) extends StdLong with PlatformData { - - override def get(): Long = _long.longValue() - - override def getUnderlyingData: AnyRef = _long - - override def setUnderlyingData(value: scala.Any): Unit = _long = value.asInstanceOf[java.lang.Long] - -} diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkMap.scala b/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkMap.scala deleted file mode 100644 index d200be8c..00000000 --- a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkMap.scala +++ /dev/null @@ -1,134 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.spark.data - -import java.util - -import com.linkedin.transport.api.data.{PlatformData, StdData, StdMap} -import com.linkedin.transport.spark.SparkWrapper -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, MapData} -import org.apache.spark.sql.types.MapType - -import scala.collection.mutable.Map - - -case class SparkMap(private var _mapData: MapData, - private val _mapType: MapType) extends StdMap with PlatformData { - - private val _keyType = _mapType.keyType - private val _valueType = _mapType.valueType - private var _mutableMap: Map[Any, Any] = if (_mapData == null) createMutableMap() else null - - override def put(key: StdData, value: StdData): Unit = { - // TODO: Does not support inserting nulls. Should we? - if (_mutableMap == null) { - _mutableMap = createMutableMap() - } - _mutableMap.put(key.asInstanceOf[PlatformData].getUnderlyingData, value.asInstanceOf[PlatformData].getUnderlyingData) - } - - override def keySet(): util.Set[StdData] = { - val keysIterator: Iterator[Any] = if (_mutableMap == null) { - new Iterator[Any] { - var offset : Int = 0 - - override def next(): Any = { - offset += 1 - _mapData.keyArray().get(offset - 1, _keyType) - } - - override def hasNext: Boolean = { - offset < SparkMap.this.size() - } - } - } else { - _mutableMap.keysIterator - } - - new util.AbstractSet[StdData] { - - override def iterator(): util.Iterator[StdData] = new util.Iterator[StdData] { - - override def next(): StdData = SparkWrapper.createStdData(keysIterator.next(), _keyType) - - override def hasNext: Boolean = keysIterator.hasNext - } - - override def size(): Int = SparkMap.this.size() - } - } - - override def size(): Int = { - if (_mutableMap == null) { - _mapData.numElements() - } else { - _mutableMap.size - } - } - - override def values(): util.Collection[StdData] = { - val valueIterator: Iterator[Any] = if (_mutableMap == null) { - new Iterator[Any] { - var offset : Int = 0 - - override def next(): Any = { - offset += 1 - _mapData.valueArray().get(offset - 1, _valueType) - } - - override def hasNext: Boolean = { - offset < SparkMap.this.size() - } - } - } else { - _mutableMap.valuesIterator - } - - new util.AbstractCollection[StdData] { - - override def iterator(): util.Iterator[StdData] = new util.Iterator[StdData] { - - override def next(): StdData = SparkWrapper.createStdData(valueIterator.next(), _valueType) - - override def hasNext: Boolean = valueIterator.hasNext - } - - override def size(): Int = SparkMap.this.size() - } - } - - override def containsKey(key: StdData): Boolean = get(key) != null - - override def get(key: StdData): StdData = { - // Spark's complex data types (MapData, ArrayData, InternalRow) do not implement equals/hashcode - // If the key is of the above complex data types, get() will return null - if (_mutableMap == null) { - _mutableMap = createMutableMap() - } - SparkWrapper.createStdData(_mutableMap.get(key.asInstanceOf[PlatformData].getUnderlyingData).orNull, _valueType) - } - - private def createMutableMap(): Map[Any, Any] = { - val mutableMap = Map.empty[Any, Any] - if (_mapData != null) { - _mapData.foreach(_keyType, _valueType, (k, v) => mutableMap.put(k, v)) - } - mutableMap - } - - override def getUnderlyingData: AnyRef = { - if (_mutableMap == null) { - _mapData - } else { - ArrayBasedMapData(_mutableMap) - } - } - - override def setUnderlyingData(value: scala.Any): Unit = { - _mapData = value.asInstanceOf[MapData] - _mutableMap = null - } -} diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkString.scala b/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkString.scala deleted file mode 100644 index bd089dd5..00000000 --- a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkString.scala +++ /dev/null @@ -1,18 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.spark.data - -import com.linkedin.transport.api.data.{PlatformData, StdString} -import org.apache.spark.unsafe.types.UTF8String - -case class SparkString(private var _str: UTF8String) extends StdString with PlatformData { - - override def get(): String = _str.toString - - override def getUnderlyingData: AnyRef = _str - - override def setUnderlyingData(value: scala.Any): Unit = _str = value.asInstanceOf[UTF8String] -} diff --git a/transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/data/TestSparkPrimitives.scala b/transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/data/TestSparkPrimitives.scala deleted file mode 100644 index 21b88c8e..00000000 --- a/transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/data/TestSparkPrimitives.scala +++ /dev/null @@ -1,80 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.spark.data - -import java.lang -import java.nio.ByteBuffer -import java.nio.charset.Charset - -import com.linkedin.transport.api.data._ -import com.linkedin.transport.spark.{SparkFactory, SparkWrapper} -import org.apache.spark.sql.types.DataTypes -import org.apache.spark.unsafe.types.UTF8String -import org.testng.Assert.{assertEquals, assertSame} -import org.testng.annotations.Test - - -class TestSparkPrimitives { - - val stdFactory = new SparkFactory(null) - - @Test - def testCreateSparkInteger(): Unit = { - val intData = 123 - val stdInteger = SparkWrapper.createStdData(intData, DataTypes.IntegerType).asInstanceOf[StdInteger] - assertEquals(stdInteger.get(), intData) - assertSame(stdInteger.asInstanceOf[PlatformData].getUnderlyingData, intData) - } - - @Test - def testCreateSparkLong(): Unit = { - val longData = new lang.Long(1234L) // scalastyle:ignore magic.number - val stdLong = SparkWrapper.createStdData(longData, DataTypes.LongType).asInstanceOf[StdLong] - assertEquals(stdLong.get(), longData) - assertSame(stdLong.asInstanceOf[PlatformData].getUnderlyingData, longData) - } - - @Test - def testCreateSparkBoolean(): Unit = { - val booleanData = new lang.Boolean(true) - val stdBoolean = SparkWrapper.createStdData(booleanData, DataTypes.BooleanType).asInstanceOf[StdBoolean] - assertEquals(stdBoolean.get(), true) - assertSame(stdBoolean.asInstanceOf[PlatformData].getUnderlyingData, booleanData) - } - - @Test - def testCreateSparkString(): Unit = { - val stringData = UTF8String.fromString("test") - val stdString = SparkWrapper.createStdData(stringData, DataTypes.StringType).asInstanceOf[StdString] - assertEquals(stdString.get(), "test") - assertSame(stdString.asInstanceOf[PlatformData].getUnderlyingData, stringData) - } - - @Test - def testCreateSparkFloat(): Unit = { - val floatData = new lang.Float(1.0f) - val stdFloat = SparkWrapper.createStdData(floatData, DataTypes.FloatType).asInstanceOf[StdFloat] - assertEquals(stdFloat.get(), 1.0f) - assertSame(stdFloat.asInstanceOf[PlatformData].getUnderlyingData, floatData) - } - - @Test - def testCreateSparkDouble(): Unit = { - val doubleData = new lang.Double(2.0) - val stdDouble = SparkWrapper.createStdData(doubleData, DataTypes.DoubleType).asInstanceOf[StdDouble] - assertEquals(stdDouble.get(), 2.0) - assertSame(stdDouble.asInstanceOf[PlatformData].getUnderlyingData, doubleData) - } - - @Test - def testCreateSparkBinary(): Unit = { - val bytesData = ByteBuffer.wrap("foo".getBytes(Charset.forName("UTF-8"))) - val stdByte = SparkWrapper.createStdData(bytesData.array(), DataTypes.BinaryType).asInstanceOf[StdBinary] - assertEquals(stdByte.get(), bytesData) - assertSame(stdByte.asInstanceOf[PlatformData].getUnderlyingData, bytesData.array()) - } - -} diff --git a/transportable-udfs-spark/build.gradle b/transportable-udfs-spark_2.11/build.gradle similarity index 86% rename from transportable-udfs-spark/build.gradle rename to transportable-udfs-spark_2.11/build.gradle index 1f00d4f5..048a3dac 100644 --- a/transportable-udfs-spark/build.gradle +++ b/transportable-udfs-spark_2.11/build.gradle @@ -6,17 +6,17 @@ dependencies { compile project(':transportable-udfs-utils') // For spark-core and spark-sql dependencies, we exclude transitive dependency on 'jackson-module-paranamer', // since this is required for the LinkedIn version of spark-core and spark-sql. - compileOnly(group: project.ext.'spark-group', name: 'spark-core_2.11', version: project.ext.'spark-version') { + compileOnly(group: project.ext.'spark-group', name: 'spark-core_2.11', version: project.ext.'spark2-version') { exclude module: 'jackson-module-paranamer' } - compileOnly(group: project.ext.'spark-group', name: 'spark-sql_2.11', version: project.ext.'spark-version') { + compileOnly(group: project.ext.'spark-group', name: 'spark-sql_2.11', version: project.ext.'spark2-version') { exclude module: 'jackson-module-paranamer' } compileOnly('com.fasterxml.jackson.module:jackson-module-paranamer:2.6.7') - testCompile(group: project.ext.'spark-group', name: 'spark-core_2.11', version: project.ext.'spark-version') { + testCompile(group: project.ext.'spark-group', name: 'spark-core_2.11', version: project.ext.'spark2-version') { exclude module: 'jackson-module-paranamer' } - testCompile(group: project.ext.'spark-group', name: 'spark-sql_2.11', version: project.ext.'spark-version') { + testCompile(group: project.ext.'spark-group', name: 'spark-sql_2.11', version: project.ext.'spark2-version') { exclude module: 'jackson-module-paranamer' } testCompile('com.fasterxml.jackson.module:jackson-module-paranamer:2.6.7') diff --git a/transportable-udfs-spark/config/scalastyle/scalastyle-config.xml b/transportable-udfs-spark_2.11/config/scalastyle/scalastyle-config.xml similarity index 100% rename from transportable-udfs-spark/config/scalastyle/scalastyle-config.xml rename to transportable-udfs-spark_2.11/config/scalastyle/scalastyle-config.xml diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/SparkFactory.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/SparkFactory.scala similarity index 55% rename from transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/SparkFactory.scala rename to transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/SparkFactory.scala index 07b61ba2..87d5625d 100644 --- a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/SparkFactory.scala +++ b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/SparkFactory.scala @@ -8,58 +8,36 @@ package com.linkedin.transport.spark import java.nio.ByteBuffer import java.util.{List => JavaList} -import com.google.common.base.Preconditions import com.linkedin.transport.api.StdFactory import com.linkedin.transport.api.data._ import com.linkedin.transport.api.types.StdType import com.linkedin.transport.spark.data._ import com.linkedin.transport.spark.typesystem.SparkTypeFactory import com.linkedin.transport.typesystem.{AbstractBoundVariables, TypeSignature} -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType} class SparkFactory(private val _boundVariables: AbstractBoundVariables[DataType]) extends StdFactory { private val _sparkTypeFactory: SparkTypeFactory = new SparkTypeFactory - override def createInteger(value: Int): StdInteger = SparkInteger(value) - - override def createLong(value: Long): StdLong = SparkLong(value) - - override def createBoolean(value: Boolean): StdBoolean = SparkBoolean(value) - - override def createString(value: String): StdString = { - Preconditions.checkNotNull(value, "Cannot create a null StdString".asInstanceOf[Any]) - SparkString(UTF8String.fromString(value)) - } - - override def createFloat(value: Float): StdFloat = SparkFloat(value) - - override def createDouble(value: Double): StdDouble = SparkDouble(value) - - override def createBinary(value: ByteBuffer): StdBinary = { - Preconditions.checkNotNull(value, "Cannot create a null StdBinary".asInstanceOf[Any]) - SparkBinary(value.array()) - } - - override def createArray(stdType: StdType): StdArray = createArray(stdType, 0) + override def createArray(stdType: StdType): ArrayData[_] = createArray(stdType, 0) // we do not pass size to `new Array()` as the size argument of createArray is supposed to be just a hint about - // the expected number of entries in the StdArray. `new Array(size)` will create an array with null entries - override def createArray(stdType: StdType, size: Int): StdArray = SparkArray( + // the expected number of entries in the ArrayData. `new Array(size)` will create an array with null entries + override def createArray(stdType: StdType, size: Int): ArrayData[_] = SparkArrayData( null, stdType.underlyingType().asInstanceOf[ArrayType] ) - override def createMap(stdType: StdType): StdMap = SparkMap( + override def createMap(stdType: StdType): MapData[_, _] = SparkMapData( //TODO: make these as separate mutable standard spark types null, stdType.underlyingType().asInstanceOf[MapType] ) - override def createStruct(fieldTypes: JavaList[StdType]): StdStruct = { + override def createStruct(fieldTypes: JavaList[StdType]): RowData = { createStruct(null, fieldTypes) } - override def createStruct(fieldNames: JavaList[String], fieldTypes: JavaList[StdType]): StdStruct = { + override def createStruct(fieldNames: JavaList[String], fieldTypes: JavaList[StdType]): RowData = { val structFields = new Array[StructField](fieldTypes.size()) (0 until fieldTypes.size()).foreach({ idx => { @@ -69,13 +47,13 @@ class SparkFactory(private val _boundVariables: AbstractBoundVariables[DataType] ) } }) - SparkStruct(null, StructType(structFields)) + SparkRowData(null, StructType(structFields)) } - override def createStruct(stdType: StdType): StdStruct = { + override def createStruct(stdType: StdType): RowData = { //TODO: make these as separate mutable standard spark types val structType: StructType = stdType.underlyingType().asInstanceOf[StructType] - SparkStruct(null, structType) + SparkRowData(null, structType) } override def createStdType(typeSignature: String): StdType = SparkWrapper.createStdType( diff --git a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/SparkWrapper.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/SparkWrapper.scala new file mode 100644 index 00000000..b365f716 --- /dev/null +++ b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/SparkWrapper.scala @@ -0,0 +1,76 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.spark + +import java.nio.ByteBuffer + +import com.linkedin.transport.api.data.PlatformData +import com.linkedin.transport.api.types.StdType +import com.linkedin.transport.spark.data._ +import com.linkedin.transport.spark.types._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +object SparkWrapper { + + def createStdData(data: Any, dataType: DataType): Object = { // scalastyle:ignore cyclomatic.complexity + if (data == null) { + null + } else { + dataType match { + case _: IntegerType => data.asInstanceOf[Object] + case _: LongType => data.asInstanceOf[Object] + case _: BooleanType => data.asInstanceOf[Object] + case _: StringType => data.asInstanceOf[UTF8String].toString + case _: FloatType => data.asInstanceOf[Object] + case _: DoubleType => data.asInstanceOf[Object] + case _: BinaryType => ByteBuffer.wrap(data.asInstanceOf[Array[Byte]]) + case _: ArrayType => SparkArrayData( + data.asInstanceOf[org.apache.spark.sql.catalyst.util.ArrayData], dataType.asInstanceOf[ArrayType] + ) + case _: MapType => SparkMapData( + data.asInstanceOf[org.apache.spark.sql.catalyst.util.MapData], dataType.asInstanceOf[MapType] + ) + case _: StructType => SparkRowData(data.asInstanceOf[InternalRow], dataType.asInstanceOf[StructType]) + case _: NullType => null + case _ => throw new UnsupportedOperationException("Unrecognized Spark Type: " + dataType.getClass) + } + } + } + + def getPlatformData(transportData: Object): Object = { + if (transportData == null) { + null + } else { + transportData match { + case _: java.lang.Integer => transportData + case _: java.lang.Long => transportData + case _: java.lang.Float => transportData + case _: java.lang.Double => transportData + case _: java.lang.Boolean => transportData + case _: java.lang.String => UTF8String.fromString(transportData.asInstanceOf[String]) + case _: ByteBuffer => transportData.asInstanceOf[ByteBuffer].array() + case _ => transportData.asInstanceOf[PlatformData].getUnderlyingData + } + } + } + + def createStdType(dataType: DataType): StdType = dataType match { + case _: IntegerType => SparkIntegerType(dataType.asInstanceOf[IntegerType]) + case _: LongType => SparkLongType(dataType.asInstanceOf[LongType]) + case _: BooleanType => SparkBooleanType(dataType.asInstanceOf[BooleanType]) + case _: StringType => SparkStringType(dataType.asInstanceOf[StringType]) + case _: FloatType => SparkFloatType(dataType.asInstanceOf[FloatType]) + case _: DoubleType => SparkDoubleType(dataType.asInstanceOf[DoubleType]) + case _: BinaryType => SparkBinaryType(dataType.asInstanceOf[BinaryType]) + case _: ArrayType => SparkArrayType(dataType.asInstanceOf[ArrayType]) + case _: MapType => SparkMapType(dataType.asInstanceOf[MapType]) + case _: StructType => SparkRowType(dataType.asInstanceOf[StructType]) + case _: NullType => SparkUnknownType(dataType.asInstanceOf[NullType]) + case _ => throw new UnsupportedOperationException("Unrecognized Spark Type: " + dataType.getClass) + } +} diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/StdUDFRegistration.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/StdUDFRegistration.scala similarity index 100% rename from transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/StdUDFRegistration.scala rename to transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/StdUDFRegistration.scala diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/StdUdfWrapper.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/StdUdfWrapper.scala similarity index 77% rename from transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/StdUdfWrapper.scala rename to transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/StdUdfWrapper.scala index f1cca7d4..5eca65a1 100644 --- a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/StdUdfWrapper.scala +++ b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/StdUdfWrapper.scala @@ -10,7 +10,6 @@ import java.nio.file.Paths import java.util.List import com.linkedin.transport.api.StdFactory -import com.linkedin.transport.api.data.{PlatformData, StdData} import com.linkedin.transport.api.udf._ import com.linkedin.transport.spark.typesystem.SparkTypeInference import com.linkedin.transport.utils.FileSystemUtils @@ -64,29 +63,29 @@ abstract class StdUdfWrapper(_expressions: Seq[Expression]) extends Expression if (wrappedConstants != null) { val requiredFiles = wrappedConstants.length match { case 0 => - _stdUdf.asInstanceOf[StdUDF0[StdData]].getRequiredFiles() + _stdUdf.asInstanceOf[StdUDF0[Object]].getRequiredFiles() case 1 => - _stdUdf.asInstanceOf[StdUDF1[StdData, StdData]].getRequiredFiles(wrappedConstants(0)) + _stdUdf.asInstanceOf[StdUDF1[Object, Object]].getRequiredFiles(wrappedConstants(0)) case 2 => - _stdUdf.asInstanceOf[StdUDF2[StdData, StdData, StdData]].getRequiredFiles(wrappedConstants(0), + _stdUdf.asInstanceOf[StdUDF2[Object, Object, Object]].getRequiredFiles(wrappedConstants(0), wrappedConstants(1)) case 3 => - _stdUdf.asInstanceOf[StdUDF3[StdData, StdData, StdData, StdData]].getRequiredFiles(wrappedConstants(0), + _stdUdf.asInstanceOf[StdUDF3[Object, Object, Object, Object]].getRequiredFiles(wrappedConstants(0), wrappedConstants(1), wrappedConstants(2)) case 4 => - _stdUdf.asInstanceOf[StdUDF4[StdData, StdData, StdData, StdData, StdData]].getRequiredFiles(wrappedConstants(0), + _stdUdf.asInstanceOf[StdUDF4[Object, Object, Object, Object, Object]].getRequiredFiles(wrappedConstants(0), wrappedConstants(1), wrappedConstants(2), wrappedConstants(3)) case 5 => - _stdUdf.asInstanceOf[StdUDF5[StdData, StdData, StdData, StdData, StdData, StdData]].getRequiredFiles(wrappedConstants(0), + _stdUdf.asInstanceOf[StdUDF5[Object, Object, Object, Object, Object, Object]].getRequiredFiles(wrappedConstants(0), wrappedConstants(1), wrappedConstants(2), wrappedConstants(3), wrappedConstants(4)) case 6 => - _stdUdf.asInstanceOf[StdUDF6[StdData, StdData, StdData, StdData, StdData, StdData, StdData]].getRequiredFiles(wrappedConstants(0), + _stdUdf.asInstanceOf[StdUDF6[Object, Object, Object, Object, Object, Object, Object]].getRequiredFiles(wrappedConstants(0), wrappedConstants(1), wrappedConstants(2), wrappedConstants(3), wrappedConstants(4), wrappedConstants(5)) case 7 => - _stdUdf.asInstanceOf[StdUDF7[StdData, StdData, StdData, StdData, StdData, StdData, StdData, StdData]].getRequiredFiles(wrappedConstants(0), + _stdUdf.asInstanceOf[StdUDF7[Object, Object, Object, Object, Object, Object, Object, Object]].getRequiredFiles(wrappedConstants(0), wrappedConstants(1), wrappedConstants(2), wrappedConstants(3), wrappedConstants(4), wrappedConstants(5), wrappedConstants(6)) case 8 => - _stdUdf.asInstanceOf[StdUDF8[StdData, StdData, StdData, StdData, StdData, StdData, StdData, StdData, StdData]].getRequiredFiles(wrappedConstants(0), + _stdUdf.asInstanceOf[StdUDF8[Object, Object, Object, Object, Object, Object, Object, Object, Object]].getRequiredFiles(wrappedConstants(0), wrappedConstants(1), wrappedConstants(2), wrappedConstants(3), wrappedConstants(4), wrappedConstants(5), wrappedConstants(6), wrappedConstants(7)) case _ => throw new UnsupportedOperationException("getRequiredFiles not yet supported for StdUDF" + _expressions.length) @@ -108,8 +107,8 @@ abstract class StdUdfWrapper(_expressions: Seq[Expression]) extends Expression } } // scalastyle:on magic.number - private final def checkNullsAndWrapConstants(): Array[StdData] = { - val wrappedConstants = new Array[StdData](_expressions.length) + private final def checkNullsAndWrapConstants(): Array[Object] = { + val wrappedConstants = new Array[Object](_expressions.length) for (i <- _expressions.indices) { val constantValue = if (_expressions(i).foldable) _expressions(i).eval() else null if (!_nullableArguments(i) && _expressions(i).foldable && constantValue == null) { @@ -135,42 +134,41 @@ abstract class StdUdfWrapper(_expressions: Seq[Expression]) extends Expression } val stdResult = wrappedArguments.length match { case 0 => - _stdUdf.asInstanceOf[StdUDF0[StdData]].eval() + _stdUdf.asInstanceOf[StdUDF0[Object]].eval() case 1 => - _stdUdf.asInstanceOf[StdUDF1[StdData, StdData]].eval(wrappedArguments(0)) + _stdUdf.asInstanceOf[StdUDF1[Object, Object]].eval(wrappedArguments(0)) case 2 => - _stdUdf.asInstanceOf[StdUDF2[StdData, StdData, StdData]].eval(wrappedArguments(0), wrappedArguments(1)) + _stdUdf.asInstanceOf[StdUDF2[Object, Object, Object]].eval(wrappedArguments(0), wrappedArguments(1)) case 3 => - _stdUdf.asInstanceOf[StdUDF3[StdData, StdData, StdData, StdData]].eval(wrappedArguments(0), wrappedArguments(1), + _stdUdf.asInstanceOf[StdUDF3[Object, Object, Object, Object]].eval(wrappedArguments(0), wrappedArguments(1), wrappedArguments(2)) case 4 => - _stdUdf.asInstanceOf[StdUDF4[StdData, StdData, StdData, StdData, StdData]].eval(wrappedArguments(0), + _stdUdf.asInstanceOf[StdUDF4[Object, Object, Object, Object, Object]].eval(wrappedArguments(0), wrappedArguments(1), wrappedArguments(2), wrappedArguments(3)) case 5 => - _stdUdf.asInstanceOf[StdUDF5[StdData, StdData, StdData, StdData, StdData, StdData]].eval(wrappedArguments(0), + _stdUdf.asInstanceOf[StdUDF5[Object, Object, Object, Object, Object, Object]].eval(wrappedArguments(0), wrappedArguments(1), wrappedArguments(2), wrappedArguments(3), wrappedArguments(4)) case 6 => - _stdUdf.asInstanceOf[StdUDF6[StdData, StdData, StdData, StdData, StdData, StdData, StdData]].eval(wrappedArguments(0), + _stdUdf.asInstanceOf[StdUDF6[Object, Object, Object, Object, Object, Object, Object]].eval(wrappedArguments(0), wrappedArguments(1), wrappedArguments(2), wrappedArguments(3), wrappedArguments(4), wrappedArguments(5)) case 7 => - _stdUdf.asInstanceOf[StdUDF7[StdData, StdData, StdData, StdData, StdData, StdData, StdData, StdData]].eval(wrappedArguments(0), + _stdUdf.asInstanceOf[StdUDF7[Object, Object, Object, Object, Object, Object, Object, Object]].eval(wrappedArguments(0), wrappedArguments(1), wrappedArguments(2), wrappedArguments(3), wrappedArguments(4), wrappedArguments(5), wrappedArguments(6)) case 8 => - _stdUdf.asInstanceOf[StdUDF8[StdData, StdData, StdData, StdData, StdData, StdData, StdData, StdData, StdData]].eval(wrappedArguments(0), + _stdUdf.asInstanceOf[StdUDF8[Object, Object, Object, Object, Object, Object, Object, Object, Object]].eval(wrappedArguments(0), wrappedArguments(1), wrappedArguments(2), wrappedArguments(3), wrappedArguments(4), wrappedArguments(5), wrappedArguments(6), wrappedArguments(7)) case _ => throw new UnsupportedOperationException("eval not yet supported for StdUDF" + _expressions.length) } - if (stdResult == null) null else stdResult.asInstanceOf[PlatformData].getUnderlyingData + SparkWrapper.getPlatformData(stdResult) } } // scalastyle:on magic.number - - private final def checkNullsAndWrapArguments(input: InternalRow): Array[StdData] = { - val wrappedArguments = new Array[StdData](_expressions.length) + private final def checkNullsAndWrapArguments(input: InternalRow): Array[Object] = { + val wrappedArguments = new Array[Object](_expressions.length) for (i <- _expressions.indices) { val evaluatedExpression = _expressions(i).eval(input) if(!_nullableArguments(i) && evaluatedExpression == null) { diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkArray.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkArrayData.scala similarity index 75% rename from transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkArray.scala rename to transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkArrayData.scala index 9fe91cab..e98ef069 100644 --- a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkArray.scala +++ b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkArrayData.scala @@ -7,20 +7,19 @@ package com.linkedin.transport.spark.data import java.util -import com.linkedin.transport.api.data.{PlatformData, StdArray, StdData} +import com.linkedin.transport.api.data.{ArrayData, PlatformData} import com.linkedin.transport.spark.SparkWrapper -import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types.{ArrayType, DataType} import scala.collection.mutable.ArrayBuffer -case class SparkArray(private var _arrayData: ArrayData, - private val _arrayType: DataType) extends StdArray with PlatformData { +case class SparkArrayData[E](private var _arrayData: org.apache.spark.sql.catalyst.util.ArrayData, + private val _arrayType: DataType) extends ArrayData[E] with PlatformData { private val _elementType = _arrayType.asInstanceOf[ArrayType].elementType private var _mutableBuffer: ArrayBuffer[Any] = if (_arrayData == null) createMutableArray() else null - override def add(e: StdData): Unit = { + override def add(e: E): Unit = { // Once add is called, we cannot use Spark's readonly ArrayData API // we have to add elements to a mutable buffer and start using that // always instead of the readonly stdType @@ -29,7 +28,7 @@ case class SparkArray(private var _arrayData: ArrayData, _mutableBuffer = createMutableArray() } // TODO: Does not support inserting nulls. Should we? - _mutableBuffer.append(e.asInstanceOf[PlatformData].getUnderlyingData) + _mutableBuffer.append(SparkWrapper.getPlatformData(e.asInstanceOf[Object])) } private def createMutableArray(): ArrayBuffer[Any] = { @@ -47,20 +46,20 @@ case class SparkArray(private var _arrayData: ArrayData, if (_mutableBuffer == null) { _arrayData } else { - ArrayData.toArrayData(_mutableBuffer) + org.apache.spark.sql.catalyst.util.ArrayData.toArrayData(_mutableBuffer) } } override def setUnderlyingData(value: scala.Any): Unit = { - _arrayData = value.asInstanceOf[ArrayData] + _arrayData = value.asInstanceOf[org.apache.spark.sql.catalyst.util.ArrayData] _mutableBuffer = null } - override def iterator(): util.Iterator[StdData] = { - new util.Iterator[StdData] { + override def iterator(): util.Iterator[E] = { + new util.Iterator[E] { private var idx = 0 - override def next(): StdData = { + override def next(): E = { val e = get(idx) idx += 1 e @@ -78,11 +77,11 @@ case class SparkArray(private var _arrayData: ArrayData, } } - override def get(idx: Int): StdData = { + override def get(idx: Int): E = { if (_mutableBuffer == null) { - SparkWrapper.createStdData(_arrayData.get(idx, _elementType), _elementType) + SparkWrapper.createStdData(_arrayData.get(idx, _elementType), _elementType).asInstanceOf[E] } else { - SparkWrapper.createStdData(_mutableBuffer(idx), _elementType) + SparkWrapper.createStdData(_mutableBuffer(idx), _elementType).asInstanceOf[E] } } } diff --git a/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkMapData.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkMapData.scala new file mode 100644 index 00000000..cd9679c8 --- /dev/null +++ b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkMapData.scala @@ -0,0 +1,106 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.spark.data + +import java.util + +import com.linkedin.transport.api.data.{MapData, PlatformData} +import com.linkedin.transport.spark.SparkWrapper +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData +import org.apache.spark.sql.types.MapType + +import scala.collection.mutable.Map + + +case class SparkMapData[K, V](private var _mapData: org.apache.spark.sql.catalyst.util.MapData, + private val _mapType: MapType) extends MapData[K, V] with PlatformData { + + private val _keyType = _mapType.keyType + private val _valueType = _mapType.valueType + private var _mutableMap: Map[Any, Any] = if (_mapData == null) createMutableMap() else null + + override def put(key: K, value: V): Unit = { + // TODO: Does not support inserting nulls. Should we? + if (_mutableMap == null) { + _mutableMap = createMutableMap() + } + _mutableMap.put( + SparkWrapper.getPlatformData(key.asInstanceOf[Object]), + SparkWrapper.getPlatformData(value.asInstanceOf[Object]) + ) + } + + override def keySet(): util.Set[K] = { + new util.AbstractSet[K] { + + override def iterator(): util.Iterator[K] = new util.Iterator[K] { + private val keysIterator = if (_mutableMap == null) _mapData.keyArray().array.iterator else _mutableMap.keysIterator + + override def next(): K = SparkWrapper.createStdData(keysIterator.next(), _keyType).asInstanceOf[K] + + override def hasNext: Boolean = keysIterator.hasNext + } + + override def size(): Int = SparkMapData.this.size() + } + } + + override def size(): Int = { + if (_mutableMap == null) { + _mapData.numElements() + } else { + _mutableMap.size + } + } + + override def values(): util.Collection[V] = { + new util.AbstractCollection[V] { + + override def iterator(): util.Iterator[V] = new util.Iterator[V] { + private val valueIterator = if (_mutableMap == null) _mapData.valueArray().array.iterator else _mutableMap.valuesIterator + + override def next(): V = SparkWrapper.createStdData(valueIterator.next(), _valueType).asInstanceOf[V] + + override def hasNext: Boolean = valueIterator.hasNext + } + + override def size(): Int = SparkMapData.this.size() + } + } + + override def containsKey(key: K): Boolean = get(key) != null + + override def get(key: K): V = { + // Spark's complex data types (MapData, ArrayData, InternalRow) do not implement equals/hashcode + // If the key is of the above complex data types, get() will return null + if (_mutableMap == null) { + _mutableMap = createMutableMap() + } + SparkWrapper.createStdData(_mutableMap.get(SparkWrapper.getPlatformData(key.asInstanceOf[Object])).orNull, _valueType) + .asInstanceOf[V] + } + + private def createMutableMap(): Map[Any, Any] = { + val mutableMap = Map.empty[Any, Any] + if (_mapData != null) { + _mapData.foreach(_keyType, _valueType, (k, v) => mutableMap.put(k, v)) + } + mutableMap + } + + override def getUnderlyingData: AnyRef = { + if (_mutableMap == null) { + _mapData + } else { + ArrayBasedMapData(_mutableMap) + } + } + + override def setUnderlyingData(value: scala.Any): Unit = { + _mapData = value.asInstanceOf[org.apache.spark.sql.catalyst.util.MapData] + _mutableMap = null + } +} diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkStruct.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkRowData.scala similarity index 74% rename from transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkStruct.scala rename to transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkRowData.scala index ba432905..9cbc883e 100644 --- a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkStruct.scala +++ b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/data/SparkRowData.scala @@ -7,7 +7,7 @@ package com.linkedin.transport.spark.data import java.util.{List => JavaList} -import com.linkedin.transport.api.data.{PlatformData, StdData, StdStruct} +import com.linkedin.transport.api.data.{PlatformData, RowData} import com.linkedin.transport.spark.SparkWrapper import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.StructType @@ -16,14 +16,14 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -case class SparkStruct(private var _row: InternalRow, - private val _structType: StructType) extends StdStruct with PlatformData { +case class SparkRowData(private var _row: InternalRow, + private val _structType: StructType) extends RowData with PlatformData { private var _mutableBuffer: ArrayBuffer[Any] = if (_row == null) createMutableStruct() else null - override def getField(name: String): StdData = getField(_structType.fieldIndex(name)) + override def getField(name: String): Object = getField(_structType.fieldIndex(name)) - override def getField(index: Int): StdData = { + override def getField(index: Int): Object = { val fieldDataType = _structType(index).dataType if (_mutableBuffer == null) { SparkWrapper.createStdData(_row.get(index, fieldDataType), fieldDataType) @@ -32,15 +32,15 @@ case class SparkStruct(private var _row: InternalRow, } } - override def setField(name: String, value: StdData): Unit = { + override def setField(name: String, value: Object): Unit = { setField(_structType.fieldIndex(name), value) } - override def setField(index: Int, value: StdData): Unit = { + override def setField(index: Int, value: Object): Unit = { if (_mutableBuffer == null) { _mutableBuffer = createMutableStruct() } - _mutableBuffer(index) = value.asInstanceOf[PlatformData].getUnderlyingData + _mutableBuffer(index) = SparkWrapper.getPlatformData(value) } private def createMutableStruct() = { @@ -51,7 +51,7 @@ case class SparkStruct(private var _row: InternalRow, } } - override def fields(): JavaList[StdData] = { + override def fields(): JavaList[Object] = { _structType.indices.map(getField).asJava } diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/types/SparkTypes.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/types/SparkTypes.scala similarity index 96% rename from transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/types/SparkTypes.scala rename to transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/types/SparkTypes.scala index 45fdc5c5..554a282d 100644 --- a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/types/SparkTypes.scala +++ b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/types/SparkTypes.scala @@ -70,7 +70,7 @@ case class SparkMapType(mapType: MapType) extends StdMapType { override def valueType(): StdType = SparkWrapper.createStdType(mapType.valueType) } -case class SparkStructType(structType: StructType) extends StdStructType { +case class SparkRowType(structType: StructType) extends RowType { override def underlyingType(): DataType = structType diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/typesystem/SparkBoundVariables.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/typesystem/SparkBoundVariables.scala similarity index 100% rename from transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/typesystem/SparkBoundVariables.scala rename to transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/typesystem/SparkBoundVariables.scala diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/typesystem/SparkTypeFactory.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/typesystem/SparkTypeFactory.scala similarity index 100% rename from transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/typesystem/SparkTypeFactory.scala rename to transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/typesystem/SparkTypeFactory.scala diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/typesystem/SparkTypeInference.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/typesystem/SparkTypeInference.scala similarity index 100% rename from transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/typesystem/SparkTypeInference.scala rename to transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/typesystem/SparkTypeInference.scala diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/typesystem/SparkTypeSystem.scala b/transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/typesystem/SparkTypeSystem.scala similarity index 100% rename from transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/typesystem/SparkTypeSystem.scala rename to transportable-udfs-spark_2.11/src/main/scala/com/linkedin/transport/spark/typesystem/SparkTypeSystem.scala diff --git a/transportable-udfs-spark/src/main/scala/org/apache/spark/sql/StdUDFUtils.scala b/transportable-udfs-spark_2.11/src/main/scala/org/apache/spark/sql/StdUDFUtils.scala similarity index 100% rename from transportable-udfs-spark/src/main/scala/org/apache/spark/sql/StdUDFUtils.scala rename to transportable-udfs-spark_2.11/src/main/scala/org/apache/spark/sql/StdUDFUtils.scala diff --git a/transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/TestSparkFactory.scala b/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/TestSparkFactory.scala similarity index 88% rename from transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/TestSparkFactory.scala rename to transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/TestSparkFactory.scala index e9c6304a..20221e47 100644 --- a/transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/TestSparkFactory.scala +++ b/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/TestSparkFactory.scala @@ -23,18 +23,6 @@ class TestSparkFactory { val typeFactory: SparkTypeFactory = new SparkTypeFactory val stdFactory = new SparkFactory(new SparkBoundVariables) - @Test - def testCreatePrimitives(): Unit = { - assertEquals(stdFactory.createInteger(1).get(), 1) - assertEquals(stdFactory.createLong(1L).get(), 1L) - assertEquals(stdFactory.createBoolean(true).get(), true) - assertEquals(stdFactory.createString("").get(), "") - assertEquals(stdFactory.createFloat(2.0f).get(), 2.0f) - assertEquals(stdFactory.createDouble(3.0).get(), 3.0) - val byteArray = "foo".getBytes(Charset.forName("UTF-8")) - assertEquals(stdFactory.createBinary(ByteBuffer.wrap(byteArray)).get().array(), byteArray) - } - @Test def testCreateArray(): Unit = { var stdArray = stdFactory.createArray(stdFactory.createStdType("array(integer)")) diff --git a/transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/common/AssertSparkExpression.scala b/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/common/AssertSparkExpression.scala similarity index 100% rename from transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/common/AssertSparkExpression.scala rename to transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/common/AssertSparkExpression.scala diff --git a/transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/data/TestSparkArray.scala b/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkArray.scala similarity index 78% rename from transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/data/TestSparkArray.scala rename to transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkArray.scala index 00d70d88..dfc024ac 100644 --- a/transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/data/TestSparkArray.scala +++ b/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkArray.scala @@ -5,7 +5,8 @@ */ package com.linkedin.transport.spark.data -import com.linkedin.transport.api.data.{PlatformData, StdArray} +import com.linkedin.transport.api.data +import com.linkedin.transport.api.data.{PlatformData} import com.linkedin.transport.spark.{SparkFactory, SparkWrapper} import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types.{ArrayType, DataTypes} @@ -20,35 +21,33 @@ class TestSparkArray { @Test def testCreateSparkArray(): Unit = { - val stdArray = SparkWrapper.createStdData(arrayData, arrayType).asInstanceOf[StdArray] + val stdArray = SparkWrapper.createStdData(arrayData, arrayType).asInstanceOf[data.ArrayData[Integer]] assertEquals(stdArray.size(), arrayData.numElements()) assertSame(stdArray.asInstanceOf[PlatformData].getUnderlyingData, arrayData) } @Test def testSparkArrayGet(): Unit = { - val stdArray = SparkWrapper.createStdData(arrayData, arrayType).asInstanceOf[StdArray] + val stdArray = SparkWrapper.createStdData(arrayData, arrayType).asInstanceOf[data.ArrayData[Integer]] (0 until stdArray.size).foreach(idx => { - assertEquals(stdArray.get(idx).asInstanceOf[SparkInteger].get(), idx) + assertEquals(stdArray.get(idx), idx) }) } @Test def testSparkArrayAdd(): Unit = { - val stdArray = SparkWrapper.createStdData(arrayData, arrayType).asInstanceOf[StdArray] - val insert = stdFactory.createInteger(5) // scalastyle:ignore magic.number - stdArray.add(insert) + val stdArray = SparkWrapper.createStdData(arrayData, arrayType).asInstanceOf[data.ArrayData[Integer]] + stdArray.add(5) // Since original ArrayData is immutable, a mutable ArrayBuffer should be created and set as the underlying object assertNotSame(stdArray.asInstanceOf[PlatformData].getUnderlyingData, arrayData) assertEquals(stdArray.size(), arrayData.numElements() + 1) - assertEquals(stdArray.get(stdArray.size() - 1), insert) + assertEquals(stdArray.get(stdArray.size() - 1), 5) } @Test def testSparkArrayMutabilityReset(): Unit = { - val stdArray = SparkWrapper.createStdData(arrayData, arrayType).asInstanceOf[StdArray] - val insert = stdFactory.createInteger(5) // scalastyle:ignore magic.number - stdArray.add(insert) + val stdArray = SparkWrapper.createStdData(arrayData, arrayType).asInstanceOf[data.ArrayData[Integer]] + stdArray.add(5) stdArray.asInstanceOf[PlatformData].setUnderlyingData(arrayData) // After underlying data is explicitly set, mutuable buffer should be removed assertSame(stdArray.asInstanceOf[PlatformData].getUnderlyingData, arrayData) diff --git a/transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/data/TestSparkMap.scala b/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkMap.scala similarity index 67% rename from transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/data/TestSparkMap.scala rename to transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkMap.scala index 608c027f..12675eb1 100644 --- a/transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/data/TestSparkMap.scala +++ b/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkMap.scala @@ -5,7 +5,7 @@ */ package com.linkedin.transport.spark.data -import com.linkedin.transport.api.data.{PlatformData, StdMap, StdString} +import com.linkedin.transport.api.data.{MapData, PlatformData} import com.linkedin.transport.spark.{SparkFactory, SparkWrapper} import org.apache.spark.sql.catalyst.util.ArrayBasedMapData import org.apache.spark.sql.types.{DataTypes, MapType} @@ -23,58 +23,54 @@ class TestSparkMap { @Test def testCreateSparkMap(): Unit = { - val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[StdMap] + val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[MapData[String, String]] assertEquals(stdMap.size(), mapData.numElements()) assertSame(stdMap.asInstanceOf[PlatformData].getUnderlyingData, mapData) } @Test def testSparkMapKeySet(): Unit = { - val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[StdMap] - assertEqualsNoOrder(stdMap.keySet().toArray, mapData.keyArray.array.map(s => stdFactory.createString(s.toString))) + val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[MapData[String, String]] + assertEqualsNoOrder(stdMap.keySet().toArray, mapData.keyArray.array.map(s => s.toString)) } @Test def testSparkMapValues(): Unit = { - val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[StdMap] - assertEqualsNoOrder(stdMap.values().toArray, mapData.valueArray.array.map(s => stdFactory.createString(s.toString))) + val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[MapData[String, String]] + assertEqualsNoOrder(stdMap.values().toArray, mapData.valueArray.array.map(s => s.toString)) } @Test def testSparkMapGet(): Unit = { - val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[StdMap] + val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[MapData[String, String]] mapData.keyArray.foreach(mapType.keyType, (idx, key) => { - assertEquals(stdMap.get(stdFactory.createString(key.toString)).asInstanceOf[StdString].get, + assertEquals(stdMap.get(key.toString), mapData.valueArray.array(idx).toString) }) - assertEquals(stdMap.containsKey(stdFactory.createString("nonExistentKey")), false) - // Even for a get in SparkMap we create mutable Map since Spark's Impl is based of arrays. So underlying object should change + assertEquals(stdMap.containsKey("nonExistentKey"), false) + // Even for a get in SparkMapData we create mutable Map since Spark's Impl is based of arrays. So underlying object should change assertNotSame(stdMap.asInstanceOf[PlatformData].getUnderlyingData, mapData) } @Test def testSparkMapContainsKey(): Unit = { - val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[StdMap] - assertEquals(stdMap.containsKey(stdFactory.createString("k3")), true) - assertEquals(stdMap.containsKey(stdFactory.createString("k4")), false) + val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[MapData[String, String]] + assertEquals(stdMap.containsKey("k3"), true) + assertEquals(stdMap.containsKey("k4"), false) } @Test def testSparkMapPut(): Unit = { - val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[StdMap] - val insertKey = stdFactory.createString("k4") - val insertVal = stdFactory.createString("v4") - stdMap.put(insertKey, insertVal) + val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[MapData[String, String]] + stdMap.put("k4", "v4") assertEquals(stdMap.size(), mapData.numElements() + 1) - assertEquals(stdMap.get(stdFactory.createString("k4")), insertVal) + assertEquals(stdMap.get("k4"), "v4") } @Test def testSparkMapMutabilityReset(): Unit = { - val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[StdMap] - val insertKey = stdFactory.createString("k4") - val insertVal = stdFactory.createString("v4") - stdMap.put(insertKey, insertVal) + val stdMap = SparkWrapper.createStdData(mapData, mapType).asInstanceOf[MapData[String, String]] + stdMap.put("k4", "v4") stdMap.asInstanceOf[PlatformData].setUnderlyingData(mapData) // After underlying data is explicitly set, mutuable map should be removed assertSame(stdMap.asInstanceOf[PlatformData].getUnderlyingData, mapData) diff --git a/transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/data/TestSparkStruct.scala b/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkRowData.scala similarity index 70% rename from transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/data/TestSparkStruct.scala rename to transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkRowData.scala index 9a911af2..df7def17 100644 --- a/transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/data/TestSparkStruct.scala +++ b/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/data/TestSparkRowData.scala @@ -5,7 +5,7 @@ */ package com.linkedin.transport.spark.data -import com.linkedin.transport.api.data.{PlatformData, StdStruct} +import com.linkedin.transport.api.data.{PlatformData, RowData} import com.linkedin.transport.spark.{SparkFactory, SparkWrapper} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.ArrayData @@ -14,7 +14,7 @@ import org.apache.spark.unsafe.types.UTF8String import org.testng.Assert.{assertEquals, assertNotSame, assertSame} import org.testng.annotations.Test -class TestSparkStruct { +class TestSparkRowData { val stdFactory = new SparkFactory(null) val dataArray = Array(UTF8String.fromString("str1"), 0, 2L, false, ArrayData.toArrayData(Array.range(0, 5))) // scalastyle:ignore magic.number val fieldNames = Array("strField", "intField", "longField", "boolField", "arrField") @@ -25,41 +25,41 @@ class TestSparkStruct { @Test def testCreateSparkStruct(): Unit = { - val stdStruct = SparkWrapper.createStdData(structData, structType).asInstanceOf[StdStruct] + val stdStruct = SparkWrapper.createStdData(structData, structType).asInstanceOf[RowData] assertSame(stdStruct.asInstanceOf[PlatformData].getUnderlyingData, structData) } @Test def testSparkStructGetField(): Unit = { - val stdStruct = SparkWrapper.createStdData(structData, structType).asInstanceOf[StdStruct] + val stdStruct = SparkWrapper.createStdData(structData, structType).asInstanceOf[RowData] dataArray.indices.foreach(idx => { - assertEquals(stdStruct.getField(idx).asInstanceOf[PlatformData].getUnderlyingData, dataArray(idx)) - assertEquals(stdStruct.getField(fieldNames(idx)).asInstanceOf[PlatformData].getUnderlyingData, dataArray(idx)) + assertEquals(SparkWrapper.getPlatformData(stdStruct.getField(idx)), dataArray(idx)) + assertEquals(SparkWrapper.getPlatformData(stdStruct.getField(fieldNames(idx))), dataArray(idx)) }) } @Test def testSparkStructFields(): Unit = { - val stdStruct = SparkWrapper.createStdData(structData, structType).asInstanceOf[StdStruct] + val stdStruct = SparkWrapper.createStdData(structData, structType).asInstanceOf[RowData] assertEquals(stdStruct.fields().size(), structData.numFields) - assertEquals(stdStruct.fields().toArray.map(f => f.asInstanceOf[PlatformData].getUnderlyingData), dataArray) + assertEquals(stdStruct.fields().toArray.map(f => SparkWrapper.getPlatformData(f)), dataArray) } @Test def testSparkStructSetField(): Unit = { - val stdStruct = SparkWrapper.createStdData(structData, structType).asInstanceOf[StdStruct] - stdStruct.setField(1, stdFactory.createInteger(1)) - assertEquals(stdStruct.getField(1).asInstanceOf[PlatformData].getUnderlyingData, 1) - stdStruct.setField(fieldNames(2), stdFactory.createLong(5)) // scalastyle:ignore magic.number - assertEquals(stdStruct.getField(fieldNames(2)).asInstanceOf[PlatformData].getUnderlyingData, 5L) // scalastyle:ignore magic.number + val stdStruct = SparkWrapper.createStdData(structData, structType).asInstanceOf[RowData] + stdStruct.setField(1, 1) + assertEquals(stdStruct.getField(1), 1) + stdStruct.setField(fieldNames(2), 5L) // scalastyle:ignore magic.number + assertEquals(stdStruct.getField(fieldNames(2)), 5L) // scalastyle:ignore magic.number // Since original InternalRow is immutable, a mutable ArrayBuffer should be created and set as the underlying object assertNotSame(stdStruct.asInstanceOf[PlatformData].getUnderlyingData, structData) } @Test def testSparkStructMutabilityReset(): Unit = { - val stdStruct = SparkWrapper.createStdData(structData, structType).asInstanceOf[StdStruct] - stdStruct.setField(1, stdFactory.createInteger(1)) + val stdStruct = SparkWrapper.createStdData(structData, structType).asInstanceOf[RowData] + stdStruct.setField(1, 1) stdStruct.asInstanceOf[PlatformData].setUnderlyingData(structData) // After underlying data is explicitly set, mutable buffer should be removed assertSame(stdStruct.asInstanceOf[PlatformData].getUnderlyingData, structData) diff --git a/transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/typesystem/TestSparkBoundVariables.scala b/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/typesystem/TestSparkBoundVariables.scala similarity index 100% rename from transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/typesystem/TestSparkBoundVariables.scala rename to transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/typesystem/TestSparkBoundVariables.scala diff --git a/transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/typesystem/TestSparkTypeFactory.scala b/transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/typesystem/TestSparkTypeFactory.scala similarity index 100% rename from transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/typesystem/TestSparkTypeFactory.scala rename to transportable-udfs-spark_2.11/src/test/scala/com/linkedin/transport/spark/typesystem/TestSparkTypeFactory.scala diff --git a/transportable-udfs-spark_2.12/build.gradle b/transportable-udfs-spark_2.12/build.gradle new file mode 100644 index 00000000..d2ea86ea --- /dev/null +++ b/transportable-udfs-spark_2.12/build.gradle @@ -0,0 +1,52 @@ +apply plugin: 'scala' + +sourceSets { + main { + scala { + srcDirs = project(':transportable-udfs-spark_2.11').sourceSets.main.scala.srcDirs + } + } + test { + scala { + srcDirs = project(':transportable-udfs-spark_2.11').sourceSets.test.scala.srcDirs + } + } +} + +dependencies { + compile project(':transportable-udfs-api') + compile project(':transportable-udfs-type-system') + compile project(':transportable-udfs-utils') + // For spark-core and spark-sql dependencies, we exclude transitive dependency on 'jackson-module-paranamer', + // since this is required for the LinkedIn version of spark-core and spark-sql. + compileOnly(group: project.ext.'spark-group', name: 'spark-core_2.12', version: project.ext.'spark3-version') { + exclude module: 'jackson-module-paranamer' + } + compileOnly(group: project.ext.'spark-group', name: 'spark-sql_2.12', version: project.ext.'spark3-version') { + exclude module: 'jackson-module-paranamer' + } + compileOnly('com.fasterxml.jackson.module:jackson-module-paranamer:2.6.7') + testCompile(group: project.ext.'spark-group', name: 'spark-core_2.12', version: project.ext.'spark3-version') { + exclude module: 'jackson-module-paranamer' + } + testCompile(group: project.ext.'spark-group', name: 'spark-sql_2.12', version: project.ext.'spark3-version') { + exclude module: 'jackson-module-paranamer' + } + testCompile('com.fasterxml.jackson.module:jackson-module-paranamer:2.6.7') + testCompile project(path: ':transportable-udfs-type-system', configuration: 'tests') +} + +task jarTests(type: Jar, dependsOn: testClasses) { + classifier = 'tests' + from sourceSets.test.output +} + +configurations { + tests { + extendsFrom testRuntime + } +} + +artifacts { + tests jarTests +} diff --git a/transportable-udfs-test/transportable-udfs-test-api/src/main/java/com/linkedin/transport/test/AbstractStdUDFTest.java b/transportable-udfs-test/transportable-udfs-test-api/src/main/java/com/linkedin/transport/test/AbstractStdUDFTest.java index 7dbc3d34..4e85393e 100644 --- a/transportable-udfs-test/transportable-udfs-test-api/src/main/java/com/linkedin/transport/test/AbstractStdUDFTest.java +++ b/transportable-udfs-test/transportable-udfs-test-api/src/main/java/com/linkedin/transport/test/AbstractStdUDFTest.java @@ -6,7 +6,9 @@ package com.linkedin.transport.test; import com.google.common.base.Preconditions; -import com.linkedin.transport.api.data.StdData; +import com.linkedin.transport.api.data.ArrayData; +import com.linkedin.transport.api.data.MapData; +import com.linkedin.transport.api.data.RowData; import com.linkedin.transport.api.udf.StdUDF; import com.linkedin.transport.api.udf.TopLevelStdUDF; import com.linkedin.transport.test.spi.FunctionCall; @@ -26,15 +28,12 @@ * An abstract class to be extended by all test classes. This class contains helper methods to initialize the * {@link StdTester} and create input and output data for the test cases. * - * The mapping between a {@link StdData} to the corresponding Java type is given below: + * Primitive data is represented by primitive types when passed to the test cases. + * The mapping between container types to the corresponding Java type is given below: *
    - *
  • {@link com.linkedin.transport.api.data.StdInteger} = {@link Integer}
  • - *
  • {@link com.linkedin.transport.api.data.StdLong} = {@link Long}
  • - *
  • {@link com.linkedin.transport.api.data.StdBoolean} = {@link Boolean}
  • - *
  • {@link com.linkedin.transport.api.data.StdString} = {@link String}
  • - *
  • {@link com.linkedin.transport.api.data.StdArray} = Use {@link #array(Object...)} to create arrays
  • - *
  • {@link com.linkedin.transport.api.data.StdMap} = Use {@link #map(Object...)} to create maps
  • - *
  • {@link com.linkedin.transport.api.data.StdStruct} = Use {@link #row(Object...)} to create structs
  • + *
  • {@link ArrayData} = Use {@link #array(Object...)} to create arrays
  • + *
  • {@link MapData} = Use {@link #map(Object...)} to create maps
  • + *
  • {@link RowData} = Use {@link #row(Object...)} to create structs
  • *
* * diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericFactory.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericFactory.java index 58d6a921..d3d26338 100644 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericFactory.java +++ b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericFactory.java @@ -5,35 +5,19 @@ */ package com.linkedin.transport.test.generic; -import com.google.common.base.Preconditions; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdBoolean; -import com.linkedin.transport.api.data.StdBinary; -import com.linkedin.transport.api.data.StdDouble; -import com.linkedin.transport.api.data.StdFloat; -import com.linkedin.transport.api.data.StdInteger; -import com.linkedin.transport.api.data.StdLong; -import com.linkedin.transport.api.data.StdMap; -import com.linkedin.transport.api.data.StdString; -import com.linkedin.transport.api.data.StdStruct; +import com.linkedin.transport.api.data.ArrayData; +import com.linkedin.transport.api.data.MapData; +import com.linkedin.transport.api.data.RowData; import com.linkedin.transport.api.types.StdType; -import com.linkedin.transport.test.generic.data.GenericArray; -import com.linkedin.transport.test.generic.data.GenericBoolean; -import com.linkedin.transport.test.generic.data.GenericBinary; -import com.linkedin.transport.test.generic.data.GenericDouble; -import com.linkedin.transport.test.generic.data.GenericFloat; -import com.linkedin.transport.test.generic.data.GenericInteger; -import com.linkedin.transport.test.generic.data.GenericLong; -import com.linkedin.transport.test.generic.data.GenericMap; -import com.linkedin.transport.test.generic.data.GenericString; +import com.linkedin.transport.test.generic.data.GenericArrayData; +import com.linkedin.transport.test.generic.data.GenericMapData; import com.linkedin.transport.test.generic.data.GenericStruct; import com.linkedin.transport.test.generic.typesystem.GenericTypeFactory; import com.linkedin.transport.test.spi.types.TestType; import com.linkedin.transport.test.spi.types.TestTypeFactory; import com.linkedin.transport.typesystem.AbstractBoundVariables; import com.linkedin.transport.typesystem.TypeSignature; -import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; @@ -50,69 +34,33 @@ public GenericFactory(AbstractBoundVariables boundVariables) { } @Override - public StdInteger createInteger(int value) { - return new GenericInteger(value); + public ArrayData createArray(StdType stdType, int expectedSize) { + return new GenericArrayData(new ArrayList<>(expectedSize), (TestType) stdType.underlyingType()); } @Override - public StdLong createLong(long value) { - return new GenericLong(value); - } - - @Override - public StdBoolean createBoolean(boolean value) { - return new GenericBoolean(value); - } - - @Override - public StdString createString(String value) { - Preconditions.checkNotNull(value, "Cannot create a null StdString"); - return new GenericString(value); - } - - @Override - public StdFloat createFloat(float value) { - return new GenericFloat(value); - } - - @Override - public StdDouble createDouble(double value) { - return new GenericDouble(value); - } - - @Override - public StdBinary createBinary(ByteBuffer value) { - return new GenericBinary(value); - } - - @Override - public StdArray createArray(StdType stdType, int expectedSize) { - return new GenericArray(new ArrayList<>(expectedSize), (TestType) stdType.underlyingType()); - } - - @Override - public StdArray createArray(StdType stdType) { + public ArrayData createArray(StdType stdType) { return createArray(stdType, 0); } @Override - public StdMap createMap(StdType stdType) { - return new GenericMap((TestType) stdType.underlyingType()); + public MapData createMap(StdType stdType) { + return new GenericMapData((TestType) stdType.underlyingType()); } @Override - public StdStruct createStruct(List fieldNames, List fieldTypes) { + public RowData createStruct(List fieldNames, List fieldTypes) { return new GenericStruct(TestTypeFactory.struct(fieldNames, fieldTypes.stream().map(x -> (TestType) x.underlyingType()).collect(Collectors.toList()))); } @Override - public StdStruct createStruct(List fieldTypes) { + public RowData createStruct(List fieldTypes) { return createStruct(null, fieldTypes); } @Override - public StdStruct createStruct(StdType stdType) { + public RowData createStruct(StdType stdType) { return new GenericStruct((TestType) stdType.underlyingType()); } diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericStdUDFWrapper.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericStdUDFWrapper.java index 2a345e3c..081e92a4 100644 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericStdUDFWrapper.java +++ b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericStdUDFWrapper.java @@ -7,7 +7,6 @@ import com.linkedin.transport.api.StdFactory; import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdData; import com.linkedin.transport.api.udf.StdUDF; import com.linkedin.transport.api.udf.StdUDF0; import com.linkedin.transport.api.udf.StdUDF1; @@ -24,6 +23,7 @@ import com.linkedin.transport.utils.FileSystemUtils; import java.io.IOException; import java.lang.reflect.InvocationTargetException; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; @@ -42,7 +42,7 @@ public class GenericStdUDFWrapper { protected boolean _requiredFilesProcessed; protected StdFactory _stdFactory; private boolean[] _nullableArguments; - private StdData[] _args; + private Object[] _args; private Class _topLevelUdfClass; private List> _stdUdfImplementations; private String[] _localFiles; @@ -83,12 +83,18 @@ protected boolean containsNullValuedNonNullableArgument(Object[] arguments) { return false; } - protected StdData wrap(Object argument, StdData stdData) { - if (argument != null) { - ((PlatformData) stdData).setUnderlyingData(argument); - return stdData; - } else { + protected Object wrap(Object argument, Object stdData) { + if (argument == null) { return null; + } else { + if (argument instanceof Integer || argument instanceof Long || argument instanceof Boolean + || argument instanceof String || argument instanceof Double || argument instanceof Float + || argument instanceof ByteBuffer) { + return argument; + } else { + ((PlatformData) stdData).setUnderlyingData(argument); + return stdData; + } } } @@ -107,26 +113,26 @@ protected Class getTopLevelUdfClass() { } protected void createStdData() { - _args = new StdData[_inputTypes.length]; + _args = new Object[_inputTypes.length]; for (int i = 0; i < _inputTypes.length; i++) { _args[i] = GenericWrapper.createStdData(null, _inputTypes[i]); } } - private StdData[] wrapArguments(Object[] arguments) { - return IntStream.range(0, _args.length).mapToObj(i -> wrap(arguments[i], _args[i])).toArray(StdData[]::new); + private Object[] wrapArguments(Object[] arguments) { + return IntStream.range(0, _args.length).mapToObj(i -> wrap(arguments[i], _args[i])).toArray(Object[]::new); } public Object evaluate(Object[] arguments) { if (containsNullValuedNonNullableArgument(arguments)) { return null; } - StdData[] args = wrapArguments(arguments); + Object[] args = wrapArguments(arguments); if (!_requiredFilesProcessed) { String[] requiredFiles = getRequiredFiles(args); processRequiredFiles(requiredFiles); } - StdData result; + Object result; switch (args.length) { case 0: result = ((StdUDF0) _stdUdf).eval(); @@ -158,10 +164,10 @@ public Object evaluate(Object[] arguments) { default: throw new UnsupportedOperationException("eval not yet supported for StdUDF" + args.length); } - return result == null ? null : ((PlatformData) result).getUnderlyingData(); + return GenericWrapper.getPlatformData(result); } - public String[] getRequiredFiles(StdData[] args) { + public String[] getRequiredFiles(Object[] args) { String[] requiredFiles; switch (args.length) { case 0: diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericWrapper.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericWrapper.java index 8754f0a8..e8707568 100644 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericWrapper.java +++ b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericWrapper.java @@ -5,17 +5,10 @@ */ package com.linkedin.transport.test.generic; -import com.linkedin.transport.api.data.StdData; +import com.linkedin.transport.api.data.PlatformData; import com.linkedin.transport.api.types.StdType; -import com.linkedin.transport.test.generic.data.GenericArray; -import com.linkedin.transport.test.generic.data.GenericBoolean; -import com.linkedin.transport.test.generic.data.GenericBinary; -import com.linkedin.transport.test.generic.data.GenericDouble; -import com.linkedin.transport.test.generic.data.GenericFloat; -import com.linkedin.transport.test.generic.data.GenericInteger; -import com.linkedin.transport.test.generic.data.GenericLong; -import com.linkedin.transport.test.generic.data.GenericMap; -import com.linkedin.transport.test.generic.data.GenericString; +import com.linkedin.transport.test.generic.data.GenericArrayData; +import com.linkedin.transport.test.generic.data.GenericMapData; import com.linkedin.transport.test.generic.data.GenericStruct; import com.linkedin.transport.test.spi.Row; import com.linkedin.transport.test.spi.types.ArrayTestType; @@ -40,27 +33,17 @@ public class GenericWrapper { private GenericWrapper() { } - public static StdData createStdData(Object data, TestType dataType) { + public static Object createStdData(Object data, TestType dataType) { if (dataType instanceof UnknownTestType) { return null; - } else if (dataType instanceof IntegerTestType) { - return new GenericInteger((Integer) data); - } else if (dataType instanceof LongTestType) { - return new GenericLong((Long) data); - } else if (dataType instanceof BooleanTestType) { - return new GenericBoolean((Boolean) data); - } else if (dataType instanceof StringTestType) { - return new GenericString((String) data); - } else if (dataType instanceof FloatTestType) { - return new GenericFloat((Float) data); - } else if (dataType instanceof DoubleTestType) { - return new GenericDouble((Double) data); - } else if (dataType instanceof BinaryTestType) { - return new GenericBinary((ByteBuffer) data); + } else if (dataType instanceof IntegerTestType || dataType instanceof LongTestType + || dataType instanceof FloatTestType || dataType instanceof DoubleTestType + || dataType instanceof BooleanTestType || dataType instanceof StringTestType || dataType instanceof BinaryTestType) { + return data; } else if (dataType instanceof ArrayTestType) { - return new GenericArray((List) data, dataType); + return new GenericArrayData((List) data, dataType); } else if (dataType instanceof MapTestType) { - return new GenericMap((Map) data, dataType); + return new GenericMapData((Map) data, dataType); } else if (dataType instanceof StructTestType) { return new GenericStruct((Row) data, dataType); } else { @@ -68,6 +51,20 @@ public static StdData createStdData(Object data, TestType dataType) { } } + public static Object getPlatformData(Object transportData) { + if (transportData == null) { + return null; + } else { + if (transportData instanceof Integer || transportData instanceof Long || transportData instanceof Float + || transportData instanceof Double || transportData instanceof Boolean || transportData instanceof ByteBuffer + || transportData instanceof String) { + return transportData; + } else { + return ((PlatformData) transportData).getUnderlyingData(); + } + } + } + public static StdType createStdType(TestType dataType) { return () -> dataType; } diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericArray.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericArrayData.java similarity index 65% rename from transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericArray.java rename to transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericArrayData.java index b2152c93..2aa85cb9 100644 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericArray.java +++ b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericArrayData.java @@ -5,9 +5,8 @@ */ package com.linkedin.transport.test.generic.data; +import com.linkedin.transport.api.data.ArrayData; import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdData; import com.linkedin.transport.test.generic.GenericWrapper; import com.linkedin.transport.test.spi.types.ArrayTestType; import com.linkedin.transport.test.spi.types.TestType; @@ -15,12 +14,12 @@ import java.util.List; -public class GenericArray implements StdArray, PlatformData { +public class GenericArrayData implements ArrayData, PlatformData { private List _array; private TestType _elementType; - public GenericArray(List data, TestType type) { + public GenericArrayData(List data, TestType type) { _array = data; _elementType = ((ArrayTestType) type).getElementType(); } @@ -31,18 +30,18 @@ public int size() { } @Override - public StdData get(int idx) { - return GenericWrapper.createStdData(_array.get(idx), _elementType); + public E get(int idx) { + return (E) GenericWrapper.createStdData(_array.get(idx), _elementType); } @Override - public void add(StdData e) { - _array.add(((PlatformData) e).getUnderlyingData()); + public void add(E e) { + _array.add(GenericWrapper.getPlatformData(e)); } @Override - public Iterator iterator() { - return new Iterator() { + public Iterator iterator() { + return new Iterator() { private final Iterator _iterator = _array.iterator(); @Override @@ -51,8 +50,8 @@ public boolean hasNext() { } @Override - public StdData next() { - return GenericWrapper.createStdData(_iterator.next(), _elementType); + public E next() { + return (E) GenericWrapper.createStdData(_iterator.next(), _elementType); } }; } diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericBinary.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericBinary.java deleted file mode 100644 index 391a6752..00000000 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericBinary.java +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.test.generic.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdBinary; -import java.nio.ByteBuffer; - - -public class GenericBinary implements StdBinary, PlatformData { - - private ByteBuffer _byteBuffer; - - public GenericBinary(ByteBuffer aByteBuffer) { - _byteBuffer = aByteBuffer; - } - - @Override - public ByteBuffer get() { - return _byteBuffer; - } - - @Override - public Object getUnderlyingData() { - return _byteBuffer; - } - - @Override - public void setUnderlyingData(Object value) { - _byteBuffer = (ByteBuffer) value; - } -} diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericBoolean.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericBoolean.java deleted file mode 100644 index e731a1e3..00000000 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericBoolean.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.test.generic.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdBoolean; - - -public class GenericBoolean implements StdBoolean, PlatformData { - private Boolean _boolean; - - public GenericBoolean(Boolean aBoolean) { - _boolean = aBoolean; - } - - @Override - public boolean get() { - return _boolean; - } - - @Override - public Object getUnderlyingData() { - return _boolean; - } - - @Override - public void setUnderlyingData(Object value) { - _boolean = (Boolean) value; - } -} diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericDouble.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericDouble.java deleted file mode 100644 index 05ac39bf..00000000 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericDouble.java +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.test.generic.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdDouble; - - -public class GenericDouble implements StdDouble, PlatformData { - - private Double _double; - - public GenericDouble(Double aDouble) { - _double = aDouble; - } - - @Override - public double get() { - return _double; - } - - @Override - public Object getUnderlyingData() { - return _double; - } - - @Override - public void setUnderlyingData(Object value) { - _double = (Double) value; - } -} diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericFloat.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericFloat.java deleted file mode 100644 index 806787de..00000000 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericFloat.java +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.test.generic.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdFloat; - - -public class GenericFloat implements StdFloat, PlatformData { - - private Float _float; - - public GenericFloat(Float aFloat) { - _float = aFloat; - } - - @Override - public float get() { - return _float; - } - - @Override - public Object getUnderlyingData() { - return _float; - } - - @Override - public void setUnderlyingData(Object value) { - _float = (Float) value; - } -} diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericInteger.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericInteger.java deleted file mode 100644 index bcb1905c..00000000 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericInteger.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.test.generic.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdInteger; - - -public class GenericInteger implements StdInteger, PlatformData { - private Integer _integer; - - public GenericInteger(Integer integer) { - _integer = integer; - } - - @Override - public int get() { - return _integer; - } - - @Override - public Object getUnderlyingData() { - return _integer; - } - - @Override - public void setUnderlyingData(Object value) { - _integer = (Integer) value; - } -} diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericLong.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericLong.java deleted file mode 100644 index 85e9dac6..00000000 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericLong.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.test.generic.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdLong; - - -public class GenericLong implements StdLong, PlatformData { - private Long _long; - - public GenericLong(Long aLong) { - _long = aLong; - } - - @Override - public long get() { - return _long; - } - - @Override - public Object getUnderlyingData() { - return _long; - } - - @Override - public void setUnderlyingData(Object value) { - _long = (Long) value; - } -} diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericMap.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericMapData.java similarity index 59% rename from transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericMap.java rename to transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericMapData.java index beeeb684..343fbac1 100644 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericMap.java +++ b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericMapData.java @@ -5,9 +5,8 @@ */ package com.linkedin.transport.test.generic.data; +import com.linkedin.transport.api.data.MapData; import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdMap; import com.linkedin.transport.test.generic.GenericWrapper; import com.linkedin.transport.test.spi.types.MapTestType; import com.linkedin.transport.test.spi.types.TestType; @@ -20,19 +19,19 @@ import java.util.stream.Collectors; -public class GenericMap implements StdMap, PlatformData { +public class GenericMapData implements MapData, PlatformData { private Map _map; private final TestType _keyType; private final TestType _valueType; - public GenericMap(Map map, TestType type) { + public GenericMapData(Map map, TestType type) { _map = map; _keyType = ((MapTestType) type).getKeyType(); _valueType = ((MapTestType) type).getValueType(); } - public GenericMap(TestType type) { + public GenericMapData(TestType type) { this(new LinkedHashMap<>(), type); } @@ -52,21 +51,21 @@ public int size() { } @Override - public StdData get(StdData key) { - return GenericWrapper.createStdData(_map.get(((PlatformData) key).getUnderlyingData()), _valueType); + public V get(K key) { + return (V) GenericWrapper.createStdData(_map.get(GenericWrapper.getPlatformData(key)), _valueType); } @Override - public void put(StdData key, StdData value) { - _map.put(((PlatformData) key).getUnderlyingData(), ((PlatformData) value).getUnderlyingData()); + public void put(K key, V value) { + _map.put(GenericWrapper.getPlatformData(key), GenericWrapper.getPlatformData(value)); } @Override - public Set keySet() { - return new AbstractSet() { + public Set keySet() { + return new AbstractSet() { @Override - public Iterator iterator() { - return new Iterator() { + public Iterator iterator() { + return new Iterator() { Iterator keySet = _map.keySet().iterator(); @Override @@ -75,8 +74,8 @@ public boolean hasNext() { } @Override - public StdData next() { - return GenericWrapper.createStdData(keySet.next(), _keyType); + public K next() { + return (K) GenericWrapper.createStdData(keySet.next(), _keyType); } }; } @@ -89,12 +88,12 @@ public int size() { } @Override - public Collection values() { - return _map.values().stream().map(v -> GenericWrapper.createStdData(v, _valueType)).collect(Collectors.toList()); + public Collection values() { + return _map.values().stream().map(v -> (V) GenericWrapper.createStdData(v, _valueType)).collect(Collectors.toList()); } @Override - public boolean containsKey(StdData key) { - return _map.containsKey(((PlatformData) key).getUnderlyingData()); + public boolean containsKey(K key) { + return _map.containsKey(GenericWrapper.getPlatformData(key)); } } diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericString.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericString.java deleted file mode 100644 index 4bb1babb..00000000 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericString.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.transport.test.generic.data; - -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdString; - - -public class GenericString implements StdString, PlatformData { - private String _string; - - public GenericString(String string) { - _string = string; - } - - @Override - public String get() { - return _string; - } - - @Override - public Object getUnderlyingData() { - return _string; - } - - @Override - public void setUnderlyingData(Object value) { - _string = (String) value; - } -} diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericStruct.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericStruct.java index e92b6043..333ddb32 100644 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericStruct.java +++ b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericStruct.java @@ -6,8 +6,7 @@ package com.linkedin.transport.test.generic.data; import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdStruct; +import com.linkedin.transport.api.data.RowData; import com.linkedin.transport.test.generic.GenericWrapper; import com.linkedin.transport.test.spi.Row; import com.linkedin.transport.test.spi.types.StructTestType; @@ -19,7 +18,7 @@ import java.util.stream.IntStream; -public class GenericStruct implements StdStruct, PlatformData { +public class GenericStruct implements RowData, PlatformData { private Row _struct; private final List _fieldNames; @@ -46,27 +45,27 @@ public void setUnderlyingData(Object value) { } @Override - public StdData getField(int index) { + public Object getField(int index) { return GenericWrapper.createStdData(_struct.getFields().get(index), _fieldTypes.get(index)); } @Override - public StdData getField(String name) { + public Object getField(String name) { return getField(_fieldNames.indexOf(name)); } @Override - public void setField(int index, StdData value) { - _struct.getFields().set(index, ((PlatformData) value).getUnderlyingData()); + public void setField(int index, Object value) { + _struct.getFields().set(index, GenericWrapper.getPlatformData(value)); } @Override - public void setField(String name, StdData value) { + public void setField(String name, Object value) { setField(_fieldNames.indexOf(name), value); } @Override - public List fields() { + public List fields() { return IntStream.range(0, _struct.getFields().size()).mapToObj(this::getField).collect(Collectors.toList()); } } diff --git a/transportable-udfs-test/transportable-udfs-test-hive/build.gradle b/transportable-udfs-test/transportable-udfs-test-hive/build.gradle index 42006a0e..3c8db382 100644 --- a/transportable-udfs-test/transportable-udfs-test-hive/build.gradle +++ b/transportable-udfs-test/transportable-udfs-test-hive/build.gradle @@ -4,6 +4,13 @@ dependencies { compile project(':transportable-udfs-api') compile project(':transportable-udfs-hive') compile project(':transportable-udfs-test:transportable-udfs-test-api') - compile('org.apache.hive:hive-exec:1.2.2') - compile('org.apache.hive:hive-service:1.2.2') + compile ('org.apache.calcite:calcite-core:1.2.0-incubating') { + exclude group: 'org.pentaho', module: 'pentaho-aggdesigner-algorithm' + } + compile ('org.apache.hive:hive-exec:1.2.2') { + exclude group: 'org.apache.calcite' + } + compile ('org.apache.hive:hive-service:1.2.2') { + exclude group: 'org.apache.hive', module: 'hive-exec' + } } \ No newline at end of file diff --git a/transportable-udfs-test/transportable-udfs-test-hive/src/main/java/com/linkedin/transport/test/hive/udf/MapFromEntries.java b/transportable-udfs-test/transportable-udfs-test-hive/src/main/java/com/linkedin/transport/test/hive/udf/MapFromEntries.java index 4a415fd8..4867fcb4 100644 --- a/transportable-udfs-test/transportable-udfs-test-hive/src/main/java/com/linkedin/transport/test/hive/udf/MapFromEntries.java +++ b/transportable-udfs-test/transportable-udfs-test-hive/src/main/java/com/linkedin/transport/test/hive/udf/MapFromEntries.java @@ -7,10 +7,9 @@ import com.google.common.collect.ImmutableList; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdMap; -import com.linkedin.transport.api.data.StdStruct; +import com.linkedin.transport.api.data.ArrayData; +import com.linkedin.transport.api.data.MapData; +import com.linkedin.transport.api.data.RowData; import com.linkedin.transport.api.types.StdMapType; import com.linkedin.transport.api.udf.StdUDF1; import com.linkedin.transport.api.udf.TopLevelStdUDF; @@ -21,7 +20,7 @@ * Hive's built-in map() UDF cannot be used to create maps with complex key types. This UDF allows you to do so. * This is used inside {@link com.linkedin.transport.test.hive.HiveTester} to create arbitrary map objects */ -public class MapFromEntries extends StdUDF1 implements TopLevelStdUDF { +public class MapFromEntries extends StdUDF1 implements TopLevelStdUDF { private StdMapType _mapType; @@ -32,10 +31,10 @@ public void init(StdFactory stdFactory) { } @Override - public StdMap eval(StdArray entryArray) { - StdMap result = getStdFactory().createMap(_mapType); - for (StdData element : entryArray) { - StdStruct elementStruct = (StdStruct) element; + public MapData eval(ArrayData entryArray) { + MapData result = getStdFactory().createMap(_mapType); + for (Object element : entryArray) { + RowData elementStruct = (RowData) element; result.put(elementStruct.getField(0), elementStruct.getField(1)); } return result; diff --git a/transportable-udfs-test/transportable-udfs-test-presto/src/main/resources/META-INF/services/com.linkedin.transport.test.spi.StdTester b/transportable-udfs-test/transportable-udfs-test-presto/src/main/resources/META-INF/services/com.linkedin.transport.test.spi.StdTester deleted file mode 100644 index df711780..00000000 --- a/transportable-udfs-test/transportable-udfs-test-presto/src/main/resources/META-INF/services/com.linkedin.transport.test.spi.StdTester +++ /dev/null @@ -1 +0,0 @@ -com.linkedin.transport.test.presto.PrestoTester \ No newline at end of file diff --git a/transportable-udfs-test/transportable-udfs-test-spark/build.gradle b/transportable-udfs-test/transportable-udfs-test-spark_2.11/build.gradle similarity index 82% rename from transportable-udfs-test/transportable-udfs-test-spark/build.gradle rename to transportable-udfs-test/transportable-udfs-test-spark_2.11/build.gradle index d01ea53e..ebf44fa7 100644 --- a/transportable-udfs-test/transportable-udfs-test-spark/build.gradle +++ b/transportable-udfs-test/transportable-udfs-test-spark_2.11/build.gradle @@ -2,14 +2,14 @@ apply plugin: 'scala' dependencies { compile project(":transportable-udfs-api") - compile project(":transportable-udfs-spark") + compile project(":transportable-udfs-spark_2.11") compile project(":transportable-udfs-test:transportable-udfs-test-api") compile project(":transportable-udfs-test:transportable-udfs-test-spi") compile('com.databricks:spark-avro_2.11:4.0.0') - compile(group: project.ext.'spark-group', name: 'spark-core_2.11', version: project.ext.'spark-version') { + compile(group: project.ext.'spark-group', name: 'spark-core_2.11', version: project.ext.'spark2-version') { exclude module: 'jackson-module-paranamer' } - compile(group: project.ext.'spark-group', name: 'spark-sql_2.11', version: project.ext.'spark-version') { + compile(group: project.ext.'spark-group', name: 'spark-sql_2.11', version: project.ext.'spark2-version') { exclude module: 'jackson-module-paranamer' } compile('com.fasterxml.jackson.module:jackson-module-scala_2.11:2.7.9') diff --git a/transportable-udfs-test/transportable-udfs-test-spark/src/main/resources/META-INF/services/com.linkedin.transport.test.spi.StdTester b/transportable-udfs-test/transportable-udfs-test-spark_2.11/src/main/resources/META-INF/services/com.linkedin.transport.test.spi.StdTester similarity index 100% rename from transportable-udfs-test/transportable-udfs-test-spark/src/main/resources/META-INF/services/com.linkedin.transport.test.spi.StdTester rename to transportable-udfs-test/transportable-udfs-test-spark_2.11/src/main/resources/META-INF/services/com.linkedin.transport.test.spi.StdTester diff --git a/transportable-udfs-test/transportable-udfs-test-spark/src/main/scala/com/linkedin/transport/test/spark/SparkSqlFunctionCallGenerator.scala b/transportable-udfs-test/transportable-udfs-test-spark_2.11/src/main/scala/com/linkedin/transport/test/spark/SparkSqlFunctionCallGenerator.scala similarity index 100% rename from transportable-udfs-test/transportable-udfs-test-spark/src/main/scala/com/linkedin/transport/test/spark/SparkSqlFunctionCallGenerator.scala rename to transportable-udfs-test/transportable-udfs-test-spark_2.11/src/main/scala/com/linkedin/transport/test/spark/SparkSqlFunctionCallGenerator.scala diff --git a/transportable-udfs-test/transportable-udfs-test-spark/src/main/scala/com/linkedin/transport/test/spark/SparkTestStdUDFWrapper.scala b/transportable-udfs-test/transportable-udfs-test-spark_2.11/src/main/scala/com/linkedin/transport/test/spark/SparkTestStdUDFWrapper.scala similarity index 100% rename from transportable-udfs-test/transportable-udfs-test-spark/src/main/scala/com/linkedin/transport/test/spark/SparkTestStdUDFWrapper.scala rename to transportable-udfs-test/transportable-udfs-test-spark_2.11/src/main/scala/com/linkedin/transport/test/spark/SparkTestStdUDFWrapper.scala diff --git a/transportable-udfs-test/transportable-udfs-test-spark/src/main/scala/com/linkedin/transport/test/spark/SparkTester.scala b/transportable-udfs-test/transportable-udfs-test-spark_2.11/src/main/scala/com/linkedin/transport/test/spark/SparkTester.scala similarity index 100% rename from transportable-udfs-test/transportable-udfs-test-spark/src/main/scala/com/linkedin/transport/test/spark/SparkTester.scala rename to transportable-udfs-test/transportable-udfs-test-spark_2.11/src/main/scala/com/linkedin/transport/test/spark/SparkTester.scala diff --git a/transportable-udfs-test/transportable-udfs-test-spark/src/main/scala/com/linkedin/transport/test/spark/ToSparkTestOutputConverter.scala b/transportable-udfs-test/transportable-udfs-test-spark_2.11/src/main/scala/com/linkedin/transport/test/spark/ToSparkTestOutputConverter.scala similarity index 100% rename from transportable-udfs-test/transportable-udfs-test-spark/src/main/scala/com/linkedin/transport/test/spark/ToSparkTestOutputConverter.scala rename to transportable-udfs-test/transportable-udfs-test-spark_2.11/src/main/scala/com/linkedin/transport/test/spark/ToSparkTestOutputConverter.scala diff --git a/transportable-udfs-test/transportable-udfs-test-spark/src/main/scala/org/apache/spark/sql/StdUDFTestUtils.scala b/transportable-udfs-test/transportable-udfs-test-spark_2.11/src/main/scala/org/apache/spark/sql/StdUDFTestUtils.scala similarity index 100% rename from transportable-udfs-test/transportable-udfs-test-spark/src/main/scala/org/apache/spark/sql/StdUDFTestUtils.scala rename to transportable-udfs-test/transportable-udfs-test-spark_2.11/src/main/scala/org/apache/spark/sql/StdUDFTestUtils.scala diff --git a/transportable-udfs-test/transportable-udfs-test-spark_2.12/build.gradle b/transportable-udfs-test/transportable-udfs-test-spark_2.12/build.gradle new file mode 100644 index 00000000..c49372ad --- /dev/null +++ b/transportable-udfs-test/transportable-udfs-test-spark_2.12/build.gradle @@ -0,0 +1,31 @@ +apply plugin: 'scala' + +sourceSets { + main { + scala { + srcDirs = project(':transportable-udfs-test:transportable-udfs-test-spark_2.11').sourceSets.main.scala.srcDirs + } + resources { + srcDirs = project(':transportable-udfs-test:transportable-udfs-test-spark_2.11').sourceSets.main.resources.srcDirs + } + } +} + +dependencies { + compile project(":transportable-udfs-api") + compile project(":transportable-udfs-spark_2.12") + compile project(":transportable-udfs-test:transportable-udfs-test-api") + compile project(":transportable-udfs-test:transportable-udfs-test-spi") + compile(group: project.ext.'spark-group', name: 'spark-avro_2.12', version: project.ext.'spark3-version') { + exclude module: 'jackson-module-paranamer' + } + compile(group: project.ext.'spark-group', name: 'spark-core_2.12', version: project.ext.'spark3-version') { + exclude module: 'jackson-module-paranamer' + } + compile(group: project.ext.'spark-group', name: 'spark-sql_2.12', version: project.ext.'spark3-version') { + exclude module: 'jackson-module-paranamer' + } + compile('com.fasterxml.jackson.module:jackson-module-scala_2.12:2.7.9') + compile 'org.testng:testng:6.11' + compile 'org.slf4j:slf4j-simple:1.7.25' +} diff --git a/transportable-udfs-test/transportable-udfs-test-presto/build.gradle b/transportable-udfs-test/transportable-udfs-test-trino/build.gradle similarity index 63% rename from transportable-udfs-test/transportable-udfs-test-presto/build.gradle rename to transportable-udfs-test/transportable-udfs-test-trino/build.gradle index 0e7a6615..54a4e101 100644 --- a/transportable-udfs-test/transportable-udfs-test-presto/build.gradle +++ b/transportable-udfs-test/transportable-udfs-test-trino/build.gradle @@ -1,19 +1,23 @@ apply plugin: 'java' +java { + toolchain.languageVersion.set(JavaLanguageVersion.of(11)) +} + dependencies { compile project(":transportable-udfs-api") compile project(":transportable-udfs-test:transportable-udfs-test-api") compile project(":transportable-udfs-test:transportable-udfs-test-spi") - compile project(":transportable-udfs-presto") + compile project(":transportable-udfs-trino") compile('com.google.guava:guava:24.1-jre') - compile(group:'io.prestosql', name: 'presto-main', version: project.ext.'presto-version') { + compile(group:'io.trino', name: 'trino-main', version: project.ext.'trino-version') { exclude 'group': 'com.google.collections', 'module': 'google-collections' } - compile(group:'io.prestosql', name: 'presto-main', version: project.ext.'presto-version', classifier: 'tests') { + compile(group:'io.trino', name: 'trino-main', version: project.ext.'trino-version', classifier: 'tests') { exclude 'group': 'com.google.collections', 'module': 'google-collections' } - compile('io.airlift:testing:0.142') - // The io.airlift.slice dependency below has to match its counterpart in presto-root's pom.xml file + compile('io.airlift:testing:202') + // The io.airlift.slice dependency below has to match its counterpart in trino-root's pom.xml file // If not specified, an older version is picked up transitively from another dependency compile(group: 'io.airlift', name: 'slice', version: project.ext.'airlift-slice-version') } \ No newline at end of file diff --git a/transportable-udfs-test/transportable-udfs-test-presto/src/main/java/com/linkedin/transport/test/presto/ToPrestoTestOutputConverter.java b/transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/ToTrinoTestOutputConverter.java similarity index 91% rename from transportable-udfs-test/transportable-udfs-test-presto/src/main/java/com/linkedin/transport/test/presto/ToPrestoTestOutputConverter.java rename to transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/ToTrinoTestOutputConverter.java index 204168d6..2c6b63bd 100644 --- a/transportable-udfs-test/transportable-udfs-test-presto/src/main/java/com/linkedin/transport/test/presto/ToPrestoTestOutputConverter.java +++ b/transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/ToTrinoTestOutputConverter.java @@ -3,12 +3,12 @@ * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ -package com.linkedin.transport.test.presto; +package com.linkedin.transport.test.trino; import com.linkedin.transport.test.spi.Row; import com.linkedin.transport.test.spi.ToPlatformTestOutputConverter; import com.linkedin.transport.test.spi.types.TestType; -import io.prestosql.spi.type.SqlVarbinary; +import io.trino.spi.type.SqlVarbinary; import java.nio.ByteBuffer; import java.util.HashMap; import java.util.List; @@ -17,7 +17,7 @@ import java.util.stream.IntStream; -public class ToPrestoTestOutputConverter implements ToPlatformTestOutputConverter { +public class ToTrinoTestOutputConverter implements ToPlatformTestOutputConverter { /** * Returns a {@link List} for the given array while also converting nested elements diff --git a/transportable-udfs-test/transportable-udfs-test-presto/src/main/java/com/linkedin/transport/test/presto/PrestoSqlFunctionCallGenerator.java b/transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/TrinoSqlFunctionCallGenerator.java similarity index 94% rename from transportable-udfs-test/transportable-udfs-test-presto/src/main/java/com/linkedin/transport/test/presto/PrestoSqlFunctionCallGenerator.java rename to transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/TrinoSqlFunctionCallGenerator.java index 01b26920..f6f7b582 100644 --- a/transportable-udfs-test/transportable-udfs-test-presto/src/main/java/com/linkedin/transport/test/presto/PrestoSqlFunctionCallGenerator.java +++ b/transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/TrinoSqlFunctionCallGenerator.java @@ -3,7 +3,7 @@ * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ -package com.linkedin.transport.test.presto; +package com.linkedin.transport.test.trino; import com.linkedin.transport.test.spi.Row; import com.linkedin.transport.test.spi.SqlFunctionCallGenerator; @@ -15,7 +15,7 @@ import java.util.stream.IntStream; -public class PrestoSqlFunctionCallGenerator implements SqlFunctionCallGenerator { +public class TrinoSqlFunctionCallGenerator implements SqlFunctionCallGenerator { @Override public String getFloatArgumentString(Float value) { diff --git a/transportable-udfs-test/transportable-udfs-test-presto/src/main/java/com/linkedin/transport/test/presto/PrestoTestStdUDFWrapper.java b/transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/TrinoTestStdUDFWrapper.java similarity index 82% rename from transportable-udfs-test/transportable-udfs-test-presto/src/main/java/com/linkedin/transport/test/presto/PrestoTestStdUDFWrapper.java rename to transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/TrinoTestStdUDFWrapper.java index 7fa945b7..17f02eaf 100644 --- a/transportable-udfs-test/transportable-udfs-test-presto/src/main/java/com/linkedin/transport/test/presto/PrestoTestStdUDFWrapper.java +++ b/transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/TrinoTestStdUDFWrapper.java @@ -3,10 +3,10 @@ * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ -package com.linkedin.transport.test.presto; +package com.linkedin.transport.test.trino; import com.linkedin.transport.api.udf.StdUDF; -import com.linkedin.transport.presto.StdUdfWrapper; +import com.linkedin.transport.trino.StdUdfWrapper; import java.lang.reflect.InvocationTargetException; @@ -16,11 +16,11 @@ * The wrapper's constructor here is parameterized so that the same wrapper can be used for all UDFs throughout the * test framework rather than generating UDF specific wrappers */ -public class PrestoTestStdUDFWrapper extends StdUdfWrapper { +public class TrinoTestStdUDFWrapper extends StdUdfWrapper { private final Class _udfClass; - public PrestoTestStdUDFWrapper(Class udfClass) { + public TrinoTestStdUDFWrapper(Class udfClass) { super(createInstance(udfClass)); _udfClass = udfClass; } diff --git a/transportable-udfs-test/transportable-udfs-test-presto/src/main/java/com/linkedin/transport/test/presto/PrestoTester.java b/transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/TrinoTester.java similarity index 63% rename from transportable-udfs-test/transportable-udfs-test-presto/src/main/java/com/linkedin/transport/test/presto/PrestoTester.java rename to transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/TrinoTester.java index c03107e5..2abc0619 100644 --- a/transportable-udfs-test/transportable-udfs-test-presto/src/main/java/com/linkedin/transport/test/presto/PrestoTester.java +++ b/transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/TrinoTester.java @@ -3,43 +3,48 @@ * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ -package com.linkedin.transport.test.presto; +package com.linkedin.transport.test.trino; -import io.prestosql.metadata.BoundVariables; -import io.prestosql.operator.scalar.AbstractTestFunctions; -import io.prestosql.spi.type.Type; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.trino.metadata.BoundSignature; +import io.trino.metadata.FunctionBinding; +import io.trino.metadata.FunctionId; +import io.trino.operator.scalar.AbstractTestFunctions; +import io.trino.spi.type.Type; import com.linkedin.transport.api.StdFactory; import com.linkedin.transport.api.udf.StdUDF; import com.linkedin.transport.api.udf.TopLevelStdUDF; -import com.linkedin.transport.presto.PrestoFactory; +import com.linkedin.transport.trino.TrinoFactory; import com.linkedin.transport.test.spi.SqlFunctionCallGenerator; import com.linkedin.transport.test.spi.SqlStdTester; import com.linkedin.transport.test.spi.ToPlatformTestOutputConverter; import java.util.List; import java.util.Map; +import static io.trino.type.UnknownType.UNKNOWN; -public class PrestoTester extends AbstractTestFunctions implements SqlStdTester { + +public class TrinoTester extends AbstractTestFunctions implements SqlStdTester { private StdFactory _stdFactory; private SqlFunctionCallGenerator _sqlFunctionCallGenerator; private ToPlatformTestOutputConverter _toPlatformTestOutputConverter; - public PrestoTester() { + public TrinoTester() { _stdFactory = null; - _sqlFunctionCallGenerator = new PrestoSqlFunctionCallGenerator(); - _toPlatformTestOutputConverter = new ToPrestoTestOutputConverter(); + _sqlFunctionCallGenerator = new TrinoSqlFunctionCallGenerator(); + _toPlatformTestOutputConverter = new ToTrinoTestOutputConverter(); } @Override public void setup( Map, List>> topLevelStdUDFClassesAndImplementations) { - // Refresh Presto state during every setup call + // Refresh Trino state during every setup call initTestFunctions(); for (List> stdUDFImplementations : topLevelStdUDFClassesAndImplementations.values()) { for (Class stdUDF : stdUDFImplementations) { - registerScalarFunction(new PrestoTestStdUDFWrapper(stdUDF)); + registerScalarFunction(new TrinoTestStdUDFWrapper(stdUDF)); } } } @@ -47,7 +52,13 @@ public void setup( @Override public StdFactory getStdFactory() { if (_stdFactory == null) { - _stdFactory = new PrestoFactory(new BoundVariables(ImmutableMap.of(), ImmutableMap.of()), + FunctionBinding functionBinding = new FunctionBinding( + new FunctionId("test"), + new BoundSignature("test", UNKNOWN, ImmutableList.of()), + ImmutableMap.of(), + ImmutableMap.of()); + _stdFactory = new TrinoFactory( + functionBinding, this.functionAssertions.getMetadata()); } return _stdFactory; diff --git a/transportable-udfs-test/transportable-udfs-test-trino/src/main/resources/META-INF/services/com.linkedin.transport.test.spi.StdTester b/transportable-udfs-test/transportable-udfs-test-trino/src/main/resources/META-INF/services/com.linkedin.transport.test.spi.StdTester new file mode 100644 index 00000000..62b71d68 --- /dev/null +++ b/transportable-udfs-test/transportable-udfs-test-trino/src/main/resources/META-INF/services/com.linkedin.transport.test.spi.StdTester @@ -0,0 +1 @@ +com.linkedin.transport.test.trino.TrinoTester \ No newline at end of file diff --git a/transportable-udfs-trino/build.gradle b/transportable-udfs-trino/build.gradle new file mode 100644 index 00000000..69b6e5c2 --- /dev/null +++ b/transportable-udfs-trino/build.gradle @@ -0,0 +1,27 @@ +apply plugin: 'java' + +java { + toolchain.languageVersion.set(JavaLanguageVersion.of(11)) +} + +dependencies { + compile project(':transportable-udfs-api') + compile project(':transportable-udfs-type-system') + compile project(':transportable-udfs-utils') + compileOnly(group:'io.trino', name: 'trino-main', version: project.ext.'trino-version') { + exclude 'group': 'com.google.collections', 'module': 'google-collections' + } + testCompile(group:'io.trino', name: 'trino-main', version: project.ext.'trino-version') { + exclude 'group': 'com.google.collections', 'module': 'google-collections' + } + testCompile(group:'io.trino', name: 'trino-main', version: project.ext.'trino-version', classifier: 'tests') { + exclude 'group': 'com.google.collections', 'module': 'google-collections' + } + compileOnly(group:'io.trino', name: 'trino-spi', version: project.ext.'trino-version') + compile('org.apache.hadoop:hadoop-hdfs:2.7.4') + compile('org.apache.hadoop:hadoop-common:2.7.4') + testCompile('io.airlift:testing:0.142') + // The io.airlift.slice dependency below has to match its counterpart in trino-root's pom.xml file + // If not specified, an older version is picked up transitively from another dependency + testCompile(group: 'io.airlift', name: 'slice', version: project.ext.'airlift-slice-version') +} diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/FileSystemClient.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/FileSystemClient.java similarity index 97% rename from transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/FileSystemClient.java rename to transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/FileSystemClient.java index f964433d..b62abe35 100644 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/FileSystemClient.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/FileSystemClient.java @@ -3,7 +3,7 @@ * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ -package com.linkedin.transport.presto; +package com.linkedin.transport.trino; import com.linkedin.transport.utils.FileSystemUtils; import java.io.File; @@ -54,7 +54,7 @@ public String copyToLocalFile(String remoteFilename) { Path localPath = new Path(Paths.get(getAndCreateLocalDir(), new File(remoteFilename).getName()).toString()); FileSystem fs = remotePath.getFileSystem(conf); // It is important to pass the custom configuration object to FileSystemUtils since we load some extra - // properties from etc/**.xml in getConfiguration() for Presto + // properties from etc/**.xml in getConfiguration() for Trino String resolvedRemoteFilename = FileSystemUtils.resolveLatest(remoteFilename, conf); Path resolvedRemotePath = new Path(resolvedRemoteFilename); fs.copyToLocalFile(resolvedRemotePath, localPath); diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/StdUdfWrapper.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUdfWrapper.java similarity index 75% rename from transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/StdUdfWrapper.java rename to transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUdfWrapper.java index 14dd68b6..230edb6e 100644 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/StdUdfWrapper.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUdfWrapper.java @@ -3,15 +3,13 @@ * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ -package com.linkedin.transport.presto; +package com.linkedin.transport.trino; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.primitives.Booleans; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdData; import com.linkedin.transport.api.udf.StdUDF; import com.linkedin.transport.api.udf.StdUDF0; import com.linkedin.transport.api.udf.StdUDF1; @@ -24,18 +22,25 @@ import com.linkedin.transport.api.udf.StdUDF8; import com.linkedin.transport.api.udf.TopLevelStdUDF; import com.linkedin.transport.typesystem.GenericTypeSignatureElement; -import io.prestosql.metadata.BoundVariables; -import io.prestosql.metadata.FunctionArgumentDefinition; -import io.prestosql.metadata.FunctionKind; -import io.prestosql.metadata.FunctionMetadata; -import io.prestosql.metadata.Metadata; -import io.prestosql.metadata.Signature; -import io.prestosql.metadata.SqlScalarFunction; -import io.prestosql.metadata.TypeVariableConstraint; -import io.prestosql.operator.scalar.ScalarFunctionImplementation; -import io.prestosql.spi.classloader.ThreadContextClassLoader; -import io.prestosql.spi.type.IntegerType; -import io.prestosql.spi.type.Type; +import io.trino.metadata.FunctionArgumentDefinition; +import io.trino.metadata.FunctionBinding; +import io.trino.metadata.FunctionDependencies; +import io.trino.metadata.FunctionDependencyDeclaration; +import io.trino.metadata.FunctionKind; +import io.trino.metadata.FunctionMetadata; +import io.trino.metadata.Signature; +import io.trino.metadata.SqlScalarFunction; +import io.trino.metadata.TypeVariableConstraint; +import io.trino.operator.scalar.ChoicesScalarFunctionImplementation; +import io.trino.operator.scalar.ScalarFunctionImplementation; +import io.trino.spi.classloader.ThreadContextClassLoader; +import io.trino.spi.function.InvocationConvention; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.IntegerType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; + import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; @@ -49,10 +54,12 @@ import java.util.stream.IntStream; import org.apache.commons.lang3.ClassUtils; -import static io.prestosql.metadata.Signature.*; -import static io.prestosql.metadata.SignatureBinder.*; -import static io.prestosql.operator.TypeSignatureParser.parseTypeSignature; -import static io.prestosql.util.Reflection.*; +import static io.trino.metadata.Signature.*; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.*; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; +import static io.trino.spi.function.OperatorType.*; +import static io.trino.sql.analyzer.TypeSignatureTranslator.parseTypeSignature; +import static io.trino.util.Reflection.*; // Suppressing argument naming convention for the evalInternal methods @SuppressWarnings({"checkstyle:regexpsinglelinejava"}) @@ -97,9 +104,36 @@ protected long getRefreshIntervalMillis() { return TimeUnit.DAYS.toMillis(DEFAULT_REFRESH_INTERVAL_DAYS); } + private void registerNestedDependencies(Type nestedType, FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder builder) { + builder.addType(nestedType.getTypeSignature()); + + if (nestedType instanceof RowType) { + nestedType.getTypeParameters().forEach(type -> registerNestedDependencies(type, builder)); + } else if (nestedType instanceof ArrayType) { + registerNestedDependencies(((ArrayType) nestedType).getElementType(), builder); + } else if (nestedType instanceof MapType) { + Type keyType = ((MapType) nestedType).getKeyType(); + Type valueType = ((MapType) nestedType).getValueType(); + builder.addOperator(EQUAL, ImmutableList.of(keyType, keyType)); + registerNestedDependencies(keyType, builder); + registerNestedDependencies(valueType, builder); + } + } + + @Override + public FunctionDependencyDeclaration getFunctionDependencies(FunctionBinding functionBinding) { + FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder builder = FunctionDependencyDeclaration.builder(); + + registerNestedDependencies(functionBinding.getBoundSignature().getReturnType(), builder); + List argumentTypes = functionBinding.getBoundSignature().getArgumentTypes(); + argumentTypes.forEach(type -> registerNestedDependencies(type, builder)); + + return builder.build(); + } + @Override - public ScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, Metadata metadata) { - StdFactory stdFactory = new PrestoFactory(boundVariables, metadata); + public ScalarFunctionImplementation specialize(FunctionBinding functionBinding, FunctionDependencies functionDependencies) { + StdFactory stdFactory = new TrinoFactory(functionBinding, functionDependencies); StdUDF stdUDF = getStdUDF(); stdUDF.init(stdFactory); // Subtract a small jitter value so that refresh is triggered on first call @@ -110,14 +144,17 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in - (new Random()).nextInt(initialJitterInt)); boolean[] nullableArguments = stdUDF.getAndCheckNullableArguments(); - return new ScalarFunctionImplementation(true, getNullConventionForArguments(nullableArguments), - getMethodHandle(stdUDF, metadata, boundVariables, nullableArguments, requiredFilesNextRefreshTime)); + return new ChoicesScalarFunctionImplementation( + functionBinding, + NULLABLE_RETURN, + getNullConventionForArguments(nullableArguments), + getMethodHandle(stdUDF, functionBinding, nullableArguments, requiredFilesNextRefreshTime)); } - private MethodHandle getMethodHandle(StdUDF stdUDF, Metadata metadata, BoundVariables boundVariables, - boolean[] nullableArguments, AtomicLong requiredFilesNextRefreshTime) { - Type[] inputTypes = getPrestoTypes(stdUDF.getInputParameterSignatures(), metadata, boundVariables); - Type outputType = getPrestoType(stdUDF.getOutputParameterSignature(), metadata, boundVariables); + private MethodHandle getMethodHandle(StdUDF stdUDF, FunctionBinding functionBinding, boolean[] nullableArguments, + AtomicLong requiredFilesNextRefreshTime) { + Type[] inputTypes = functionBinding.getBoundSignature().getArgumentTypes().toArray(new Type[0]); + Type outputType = functionBinding.getBoundSignature().getReturnType(); // Generic MethodHandle for eval where all arguments are of type Object Class[] genericMethodHandleArgumentTypes = getMethodHandleArgumentTypes(inputTypes, nullableArguments, true); @@ -129,41 +166,39 @@ private MethodHandle getMethodHandle(StdUDF stdUDF, Metadata metadata, BoundVari MethodType specificMethodType = MethodType.methodType(specificMethodHandleReturnType, specificMethodHandleArgumentTypes); - // Specific MethodHandle required by presto where argument types map to the type signature + // Specific MethodHandle required by trino where argument types map to the type signature MethodHandle specificMethodHandle = MethodHandles.explicitCastArguments(genericMethodHandle, specificMethodType); return MethodHandles.insertArguments(specificMethodHandle, 0, stdUDF, inputTypes, outputType instanceof IntegerType, requiredFilesNextRefreshTime); } - private List getNullConventionForArguments( + private List getNullConventionForArguments( boolean[] nullableArguments) { return IntStream.range(0, nullableArguments.length) - .mapToObj(idx -> ScalarFunctionImplementation.ArgumentProperty.valueTypeArgumentProperty( - nullableArguments[idx] ? ScalarFunctionImplementation.NullConvention.USE_BOXED_TYPE - : ScalarFunctionImplementation.NullConvention.RETURN_NULL_ON_NULL)) + .mapToObj(idx -> nullableArguments[idx] ? BOXED_NULLABLE : NEVER_NULL) .collect(Collectors.toList()); } - private StdData[] wrapArguments(StdUDF stdUDF, Type[] types, Object[] arguments) { + private Object[] wrapArguments(StdUDF stdUDF, Type[] types, Object[] arguments) { StdFactory stdFactory = stdUDF.getStdFactory(); - StdData[] stdData = new StdData[arguments.length]; + Object[] stdData = new Object[arguments.length]; // TODO: Reuse wrapper objects by creating them once upon initialization and reuse them here // along the same lines of what we do in Hive implementation. // JIRA: https://jira01.corp.linkedin.com:8443/browse/LIHADOOP-34894 for (int i = 0; i < stdData.length; i++) { - stdData[i] = PrestoWrapper.createStdData(arguments[i], types[i], stdFactory); + stdData[i] = TrinoWrapper.createStdData(arguments[i], types[i], stdFactory); } return stdData; } protected Object eval(StdUDF stdUDF, Type[] types, boolean isIntegerReturnType, AtomicLong requiredFilesNextRefreshTime, Object... arguments) { - StdData[] args = wrapArguments(stdUDF, types, arguments); + Object[] args = wrapArguments(stdUDF, types, arguments); if (requiredFilesNextRefreshTime.get() <= System.currentTimeMillis()) { String[] requiredFiles = getRequiredFiles(stdUDF, args); processRequiredFiles(stdUDF, requiredFiles, requiredFilesNextRefreshTime); } - StdData result; + Object result; switch (args.length) { case 0: result = ((StdUDF0) stdUDF).eval(); @@ -195,16 +230,11 @@ protected Object eval(StdUDF stdUDF, Type[] types, boolean isIntegerReturnType, default: throw new RuntimeException("eval not supported yet for StdUDF" + args.length); } - if (result == null) { - return null; - } else if (isIntegerReturnType) { - return ((Number) ((PlatformData) result).getUnderlyingData()).longValue(); - } else { - return ((PlatformData) result).getUnderlyingData(); - } + + return TrinoWrapper.getPlatformData(result); } - private String[] getRequiredFiles(StdUDF stdUDF, StdData[] args) { + private String[] getRequiredFiles(StdUDF stdUDF, Object[] args) { String[] requiredFiles; switch (args.length) { case 0: @@ -261,22 +291,14 @@ private synchronized void processRequiredFiles(StdUDF stdUDF, String[] requiredF } } - private Class getJavaTypeForNullability(Type prestoType, boolean nullableArgument) { + private Class getJavaTypeForNullability(Type trinoType, boolean nullableArgument) { if (nullableArgument) { - return ClassUtils.primitiveToWrapper(prestoType.getJavaType()); + return ClassUtils.primitiveToWrapper(trinoType.getJavaType()); } else { - return prestoType.getJavaType(); + return trinoType.getJavaType(); } } - private Type[] getPrestoTypes(List parameterSignatures, Metadata metadata, BoundVariables boundVariables) { - return parameterSignatures.stream().map(p -> getPrestoType(p, metadata, boundVariables)).toArray(Type[]::new); - } - - private Type getPrestoType(String parameterSignature, Metadata metadata, BoundVariables boundVariables) { - return metadata.getType(applyBoundVariables(parseTypeSignature(parameterSignature, ImmutableSet.of()), boundVariables)); - } - private Class[] getMethodHandleArgumentTypes(Type[] argTypes, boolean[] nullableArguments, boolean useObjectForArgumentType) { Class[] methodHandleArgumentTypes = new Class[argTypes.length + 4]; diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/TrinoFactory.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/TrinoFactory.java new file mode 100644 index 00000000..46a56ba3 --- /dev/null +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/TrinoFactory.java @@ -0,0 +1,106 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.trino; + +import com.google.common.collect.ImmutableSet; +import com.linkedin.transport.api.StdFactory; + +import com.linkedin.transport.api.types.StdType; +import com.linkedin.transport.api.data.ArrayData; +import com.linkedin.transport.api.data.MapData; +import com.linkedin.transport.api.data.RowData; +import com.linkedin.transport.trino.data.TrinoArrayData; +import com.linkedin.transport.trino.data.TrinoMapData; +import com.linkedin.transport.trino.data.TrinoRowData; +import io.trino.metadata.FunctionBinding; +import io.trino.metadata.FunctionDependencies; +import io.trino.metadata.Metadata; +import io.trino.metadata.OperatorNotFoundException; +import io.trino.spi.function.InvocationConvention; +import io.trino.spi.function.OperatorType; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; +import java.lang.invoke.MethodHandle; +import java.util.List; +import java.util.stream.Collectors; + +import static io.trino.metadata.SignatureBinder.*; +import static io.trino.sql.analyzer.TypeSignatureTranslator.*; + + +public class TrinoFactory implements StdFactory { + + final FunctionBinding functionBinding; + final FunctionDependencies functionDependencies; + final Metadata metadata; + + public TrinoFactory(FunctionBinding functionBinding, FunctionDependencies functionDependencies) { + this.functionBinding = functionBinding; + this.functionDependencies = functionDependencies; + this.metadata = null; + } + + public TrinoFactory(FunctionBinding functionBinding, Metadata metadata) { + this.functionBinding = functionBinding; + this.functionDependencies = null; + this.metadata = metadata; + } + + @Override + public ArrayData createArray(StdType stdType, int expectedSize) { + return new TrinoArrayData((ArrayType) stdType.underlyingType(), expectedSize, this); + } + + @Override + public ArrayData createArray(StdType stdType) { + return createArray(stdType, 0); + } + + @Override + public MapData createMap(StdType stdType) { + return new TrinoMapData((MapType) stdType.underlyingType(), this); + } + + @Override + public TrinoRowData createStruct(List fieldNames, List fieldTypes) { + return new TrinoRowData(fieldNames, + fieldTypes.stream().map(stdType -> (Type) stdType.underlyingType()).collect(Collectors.toList()), this); + } + + @Override + public TrinoRowData createStruct(List fieldTypes) { + return new TrinoRowData( + fieldTypes.stream().map(stdType -> (Type) stdType.underlyingType()).collect(Collectors.toList()), this); + } + + @Override + public RowData createStruct(StdType stdType) { + return new TrinoRowData((RowType) stdType.underlyingType(), this); + } + + @Override + public StdType createStdType(String typeSignature) { + if (metadata != null) { + return TrinoWrapper.createStdType( + metadata.getType(applyBoundVariables(parseTypeSignature(typeSignature, ImmutableSet.of()), functionBinding))); + } + return TrinoWrapper.createStdType( + functionDependencies.getType(applyBoundVariables(parseTypeSignature(typeSignature, ImmutableSet.of()), functionBinding))); + } + + public MethodHandle getOperatorHandle( + OperatorType operatorType, + List argumentTypes, + InvocationConvention invocationConvention) throws OperatorNotFoundException { + if (metadata != null) { + return metadata.getScalarFunctionInvoker(metadata.resolveOperator(operatorType, argumentTypes), + invocationConvention).getMethodHandle(); + } + return functionDependencies.getOperatorInvoker(operatorType, argumentTypes, invocationConvention).getMethodHandle(); + } +} diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/TrinoWrapper.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/TrinoWrapper.java new file mode 100644 index 00000000..2bde058c --- /dev/null +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/TrinoWrapper.java @@ -0,0 +1,187 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.trino; + +import com.linkedin.transport.api.StdFactory; +import com.linkedin.transport.api.types.StdType; +import com.linkedin.transport.api.data.PlatformData; +import com.linkedin.transport.trino.data.TrinoData; +import com.linkedin.transport.trino.data.TrinoArrayData; +import com.linkedin.transport.trino.data.TrinoRowData; +import com.linkedin.transport.trino.data.TrinoMapData; +import com.linkedin.transport.trino.types.TrinoArrayType; +import com.linkedin.transport.trino.types.TrinoBooleanType; +import com.linkedin.transport.trino.types.TrinoBinaryType; +import com.linkedin.transport.trino.types.TrinoDoubleType; +import com.linkedin.transport.trino.types.TrinoFloatType; +import com.linkedin.transport.trino.types.TrinoIntegerType; +import com.linkedin.transport.trino.types.TrinoLongType; +import com.linkedin.transport.trino.types.TrinoMapType; +import com.linkedin.transport.trino.types.TrinoStringType; +import com.linkedin.transport.trino.types.TrinoRowType; +import com.linkedin.transport.trino.types.TrinoUnknownType; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.trino.spi.TrinoException; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.BigintType; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.IntegerType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RealType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; +import io.trino.spi.type.VarbinaryType; +import io.trino.spi.type.VarcharType; +import io.trino.type.UnknownType; + +import static io.trino.spi.type.BigintType.*; +import static io.trino.spi.type.BooleanType.*; +import static io.trino.spi.type.DoubleType.*; +import static io.trino.spi.type.IntegerType.*; +import static io.trino.spi.type.VarbinaryType.*; +import static io.trino.spi.type.VarcharType.*; +import static io.trino.spi.StandardErrorCode.*; +import static java.lang.Float.*; +import static java.lang.Math.*; +import static java.lang.String.*; +import java.nio.ByteBuffer; + +public final class TrinoWrapper { + + private TrinoWrapper() { + } + + public static Object createStdData(Object trinoData, Type trinoType, StdFactory stdFactory) { + if (trinoData == null) { + return null; + } + if (trinoType instanceof IntegerType) { + // Trino represents SQL Integers (i.e., corresponding to IntegerType above) as long or Long + // Therefore, we first cast trinoData to Long, then extract the int value. + return ((Long) trinoData).intValue(); + } else if (trinoType instanceof BigintType || trinoType.getJavaType() == boolean.class + || trinoType instanceof DoubleType) { + return trinoData; + } else if (trinoType instanceof VarcharType) { + return ((Slice) trinoData).toStringUtf8(); + } else if (trinoType instanceof RealType) { + // Trino represents SQL Reals (i.e., corresponding to RealType above) as long or Long + // Therefore, to pass it to the TrinoFloat class, we first cast it to Long, extract + // the int value and convert it the int bits to float. + long value = (long) trinoData; + int floatValue; + try { + floatValue = toIntExact(value); + } catch (ArithmeticException e) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, + format("Value (%sb) is not a valid single-precision float", Long.toBinaryString(value))); + } + return intBitsToFloat(floatValue); + } else if (trinoType instanceof VarbinaryType) { + return ((Slice) trinoData).toByteBuffer(); + } else if (trinoType instanceof ArrayType) { + return new TrinoArrayData((Block) trinoData, (ArrayType) trinoType, stdFactory); + } else if (trinoType instanceof MapType) { + return new TrinoMapData((Block) trinoData, trinoType, stdFactory); + } else if (trinoType instanceof RowType) { + return new TrinoRowData((Block) trinoData, trinoType, stdFactory); + } + assert false : "Unrecognized Trino Type: " + trinoType.getClass(); + return null; + } + + public static Object getPlatformData(Object transportData) { + if (transportData == null) { + return null; + } + if (transportData instanceof Integer) { + return ((Number) transportData).longValue(); + } else if (transportData instanceof Long) { + return ((Long) transportData).longValue(); + } else if (transportData instanceof Float) { + return (long) floatToIntBits((Float) transportData); + } else if (transportData instanceof Double) { + return ((Double) transportData).doubleValue(); + } else if (transportData instanceof Boolean) { + return ((Boolean) transportData).booleanValue(); + } else if (transportData instanceof String) { + return Slices.utf8Slice((String) transportData); + } else if (transportData instanceof ByteBuffer) { + return Slices.wrappedBuffer(((ByteBuffer) transportData).array()); + } else { + return ((PlatformData) transportData).getUnderlyingData(); + } + } + + public static void writeToBlock(Object transportData, BlockBuilder blockBuilder) { + if (transportData == null) { + blockBuilder.appendNull(); + } else { + if (transportData instanceof Integer) { + // This looks a bit strange, but the call to writeLong is correct here. INTEGER does not have a writeInt method for + // some reason. It uses BlockBuilder.writeInt internally. + INTEGER.writeLong(blockBuilder, (Integer) transportData); + } else if (transportData instanceof Long) { + BIGINT.writeLong(blockBuilder, (Long) transportData); + } else if (transportData instanceof Float) { + INTEGER.writeLong(blockBuilder, floatToIntBits((Float) transportData)); + } else if (transportData instanceof Double) { + DOUBLE.writeDouble(blockBuilder, (Double) transportData); + } else if (transportData instanceof Boolean) { + BOOLEAN.writeBoolean(blockBuilder, (Boolean) transportData); + } else if (transportData instanceof String) { + VARCHAR.writeSlice(blockBuilder, Slices.utf8Slice((String) transportData)); + } else if (transportData instanceof ByteBuffer) { + VARBINARY.writeSlice(blockBuilder, Slices.wrappedBuffer((ByteBuffer) transportData)); + } else { + ((TrinoData) transportData).writeToBlock(blockBuilder); + } + } + } + + public static StdType createStdType(Object trinoType) { + if (trinoType instanceof IntegerType) { + return new TrinoIntegerType((IntegerType) trinoType); + } else if (trinoType instanceof BigintType) { + return new TrinoLongType((BigintType) trinoType); + } else if (trinoType instanceof BooleanType) { + return new TrinoBooleanType((BooleanType) trinoType); + } else if (trinoType instanceof VarcharType) { + return new TrinoStringType((VarcharType) trinoType); + } else if (trinoType instanceof RealType) { + return new TrinoFloatType((RealType) trinoType); + } else if (trinoType instanceof DoubleType) { + return new TrinoDoubleType((DoubleType) trinoType); + } else if (trinoType instanceof VarbinaryType) { + return new TrinoBinaryType((VarbinaryType) trinoType); + } else if (trinoType instanceof ArrayType) { + return new TrinoArrayType((ArrayType) trinoType); + } else if (trinoType instanceof MapType) { + return new TrinoMapType((MapType) trinoType); + } else if (trinoType instanceof RowType) { + return new TrinoRowType(((RowType) trinoType)); + } else if (trinoType instanceof UnknownType) { + return new TrinoUnknownType(((UnknownType) trinoType)); + } + assert false : "Unrecognized Trino Type: " + trinoType.getClass(); + return null; + } + + /** + * @return index if the index is in range, -1 otherwise. + */ + public static int checkedIndexToBlockPosition(Block block, long index) { + int blockLength = block.getPositionCount(); + if (index >= 0 && index < blockLength) { + return toIntExact(index); + } + return -1; // -1 indicates that the element is out of range and the calling function should return null + } +} \ No newline at end of file diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoArray.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoArrayData.java similarity index 61% rename from transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoArray.java rename to transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoArrayData.java index 41759716..3fe21ffe 100644 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoArray.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoArrayData.java @@ -3,23 +3,22 @@ * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ -package com.linkedin.transport.presto.data; +package com.linkedin.transport.trino.data; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdArray; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.presto.PrestoWrapper; -import io.prestosql.spi.block.Block; -import io.prestosql.spi.block.BlockBuilder; -import io.prestosql.spi.block.PageBuilderStatus; -import io.prestosql.spi.type.ArrayType; -import io.prestosql.spi.type.Type; +import com.linkedin.transport.trino.TrinoWrapper; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.PageBuilderStatus; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.Type; +import com.linkedin.transport.api.data.ArrayData; import java.util.Iterator; -import static io.prestosql.spi.type.TypeUtils.*; +import static io.trino.spi.type.TypeUtils.*; -public class PrestoArray extends PrestoData implements StdArray { +public class TrinoArrayData extends TrinoData implements ArrayData { private final StdFactory _stdFactory; private final ArrayType _arrayType; @@ -28,14 +27,14 @@ public class PrestoArray extends PrestoData implements StdArray { private Block _block; private BlockBuilder _mutable; - public PrestoArray(Block block, ArrayType arrayType, StdFactory stdFactory) { + public TrinoArrayData(Block block, ArrayType arrayType, StdFactory stdFactory) { _block = block; _arrayType = arrayType; _elementType = arrayType.getElementType(); _stdFactory = stdFactory; } - public PrestoArray(ArrayType arrayType, int expectedEntries, StdFactory stdFactory) { + public TrinoArrayData(ArrayType arrayType, int expectedEntries, StdFactory stdFactory) { _block = null; _elementType = arrayType.getElementType(); _mutable = _elementType.createBlockBuilder(new PageBuilderStatus().createBlockBuilderStatus(), expectedEntries); @@ -49,19 +48,19 @@ public int size() { } @Override - public StdData get(int idx) { + public E get(int idx) { Block sourceBlock = _mutable == null ? _block : _mutable; - int position = PrestoWrapper.checkedIndexToBlockPosition(sourceBlock, idx); + int position = TrinoWrapper.checkedIndexToBlockPosition(sourceBlock, idx); Object element = readNativeValue(_elementType, sourceBlock, position); - return PrestoWrapper.createStdData(element, _elementType, _stdFactory); + return (E) TrinoWrapper.createStdData(element, _elementType, _stdFactory); } @Override - public void add(StdData e) { + public void add(E e) { if (_mutable == null) { _mutable = _elementType.createBlockBuilder(new PageBuilderStatus().createBlockBuilderStatus(), 1); } - ((PrestoData) e).writeToBlock(_mutable); + TrinoWrapper.writeToBlock(e, _mutable); } @Override @@ -75,10 +74,10 @@ public void setUnderlyingData(Object value) { } @Override - public Iterator iterator() { - return new Iterator() { + public Iterator iterator() { + return new Iterator() { Block sourceBlock = _mutable == null ? _block : _mutable; - int size = PrestoArray.this.size(); + int size = TrinoArrayData.this.size(); int position = 0; @Override @@ -87,10 +86,10 @@ public boolean hasNext() { } @Override - public StdData next() { + public E next() { Object element = readNativeValue(_elementType, sourceBlock, position); position++; - return PrestoWrapper.createStdData(element, _elementType, _stdFactory); + return (E) TrinoWrapper.createStdData(element, _elementType, _stdFactory); } }; } diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoData.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoData.java similarity index 64% rename from transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoData.java rename to transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoData.java index ecfd41d8..37c4d49d 100644 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoData.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoData.java @@ -3,16 +3,16 @@ * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ -package com.linkedin.transport.presto.data; +package com.linkedin.transport.trino.data; import com.linkedin.transport.api.data.PlatformData; -import io.prestosql.spi.block.BlockBuilder; +import io.trino.spi.block.BlockBuilder; /** - * A common super class for all Presto specific implementations of StdData types + * A common super class for all Trino specific implementations of StdData types */ -public abstract class PrestoData implements PlatformData { +public abstract class TrinoData implements PlatformData { /** * Writes this data object into the give BlockBuilder * @param blockBuilder the builder to write into diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoMap.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoMapData.java similarity index 57% rename from transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoMap.java rename to transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoMapData.java index 2cc78700..0bd38ad0 100644 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoMap.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoMapData.java @@ -3,23 +3,21 @@ * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ -package com.linkedin.transport.presto.data; +package com.linkedin.transport.trino.data; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.PlatformData; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdMap; -import com.linkedin.transport.presto.PrestoFactory; -import com.linkedin.transport.presto.PrestoWrapper; -import io.prestosql.spi.PrestoException; -import io.prestosql.spi.block.Block; -import io.prestosql.spi.block.BlockBuilder; -import io.prestosql.spi.block.PageBuilderStatus; -import io.prestosql.spi.function.OperatorType; -import io.prestosql.spi.type.MapType; -import io.prestosql.spi.type.Type; +import com.linkedin.transport.trino.TrinoFactory; +import com.linkedin.transport.trino.TrinoWrapper; +import io.trino.spi.TrinoException; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.PageBuilderStatus; +import io.trino.spi.function.OperatorType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.Type; +import com.linkedin.transport.api.data.MapData; import java.lang.invoke.MethodHandle; import java.util.AbstractCollection; import java.util.AbstractSet; @@ -27,11 +25,14 @@ import java.util.Iterator; import java.util.Set; -import static io.prestosql.spi.StandardErrorCode.*; -import static io.prestosql.spi.type.TypeUtils.*; +import static io.trino.spi.StandardErrorCode.*; +import static io.trino.spi.function.InvocationConvention.simpleConvention; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; +import static io.trino.spi.type.TypeUtils.*; -public class PrestoMap extends PrestoData implements StdMap { +public class TrinoMapData extends TrinoData implements MapData { final Type _keyType; final Type _valueType; @@ -40,7 +41,7 @@ public class PrestoMap extends PrestoData implements StdMap { final StdFactory _stdFactory; Block _block; - public PrestoMap(Type mapType, StdFactory stdFactory) { + public TrinoMapData(Type mapType, StdFactory stdFactory) { BlockBuilder mutable = mapType.createBlockBuilder(new PageBuilderStatus().createBlockBuilderStatus(), 1); mutable.beginBlockEntry(); mutable.closeEntry(); @@ -51,12 +52,11 @@ public PrestoMap(Type mapType, StdFactory stdFactory) { _mapType = mapType; _stdFactory = stdFactory; - _keyEqualsMethod = ((PrestoFactory) stdFactory).getScalarFunctionImplementation( - ((PrestoFactory) stdFactory).resolveOperator(OperatorType.EQUAL, ImmutableList.of(_keyType, _keyType))) - .getMethodHandle(); + _keyEqualsMethod = ((TrinoFactory) stdFactory).getOperatorHandle( + OperatorType.EQUAL, ImmutableList.of(_keyType, _keyType), simpleConvention(NULLABLE_RETURN, NEVER_NULL, NEVER_NULL)); } - public PrestoMap(Block block, Type mapType, StdFactory stdFactory) { + public TrinoMapData(Block block, Type mapType, StdFactory stdFactory) { this(mapType, stdFactory); _block = block; } @@ -67,13 +67,12 @@ public int size() { } @Override - public StdData get(StdData key) { - Object prestoKey = ((PlatformData) key).getUnderlyingData(); + public V get(K key) { + Object prestoKey = TrinoWrapper.getPlatformData(key); int i = seekKey(prestoKey); if (i != -1) { Object value = readNativeValue(_valueType, _block, i); - StdData stdValue = PrestoWrapper.createStdData(value, _valueType, _stdFactory); - return stdValue; + return (V) TrinoWrapper.createStdData(value, _valueType, _stdFactory); } else { return null; } @@ -82,37 +81,37 @@ public StdData get(StdData key) { // TODO: Do not copy the _mutable BlockBuilder on every update. As long as updates are append-only or for fixed-size // types, we can skip copying. @Override - public void put(StdData key, StdData value) { + public void put(K key, V value) { BlockBuilder mutable = _mapType.createBlockBuilder(new PageBuilderStatus().createBlockBuilderStatus(), 1); BlockBuilder entryBuilder = mutable.beginBlockEntry(); - Object prestoKey = ((PlatformData) key).getUnderlyingData(); - int valuePosition = seekKey(prestoKey); + Object trinoKey = TrinoWrapper.getPlatformData(key); + int valuePosition = seekKey(trinoKey); for (int i = 0; i < _block.getPositionCount(); i += 2) { // Write the current key to the map _keyType.appendTo(_block, i, entryBuilder); // Find out if we need to change the corresponding value if (i == valuePosition - 1) { // Use the user-supplied value - ((PrestoData) value).writeToBlock(entryBuilder); + TrinoWrapper.writeToBlock(value, entryBuilder); } else { // Use the existing value in original _block _valueType.appendTo(_block, i + 1, entryBuilder); } } if (valuePosition == -1) { - ((PrestoData) key).writeToBlock(entryBuilder); - ((PrestoData) value).writeToBlock(entryBuilder); + TrinoWrapper.writeToBlock(key, entryBuilder); + TrinoWrapper.writeToBlock(value, entryBuilder); } mutable.closeEntry(); _block = ((MapType) _mapType).getObject(mutable.build(), 0); } - public Set keySet() { - return new AbstractSet() { + public Set keySet() { + return new AbstractSet() { @Override - public Iterator iterator() { - return new Iterator() { + public Iterator iterator() { + return new Iterator() { int i = -2; @Override @@ -121,27 +120,27 @@ public boolean hasNext() { } @Override - public StdData next() { + public K next() { i += 2; - return PrestoWrapper.createStdData(readNativeValue(_keyType, _block, i), _keyType, _stdFactory); + return (K) TrinoWrapper.createStdData(readNativeValue(_keyType, _block, i), _keyType, _stdFactory); } }; } @Override public int size() { - return PrestoMap.this.size(); + return TrinoMapData.this.size(); } }; } @Override - public Collection values() { - return new AbstractCollection() { + public Collection values() { + return new AbstractCollection() { @Override - public Iterator iterator() { - return new Iterator() { + public Iterator iterator() { + return new Iterator() { int i = -2; @Override @@ -150,22 +149,25 @@ public boolean hasNext() { } @Override - public StdData next() { + public V next() { i += 2; - return PrestoWrapper.createStdData(readNativeValue(_valueType, _block, i + 1), _valueType, _stdFactory); + return + (V) TrinoWrapper.createStdData( + readNativeValue(_valueType, _block, i + 1), _valueType, _stdFactory + ); } }; } @Override public int size() { - return PrestoMap.this.size(); + return TrinoMapData.this.size(); } }; } @Override - public boolean containsKey(StdData key) { + public boolean containsKey(K key) { return get(key) != null; } @@ -187,8 +189,8 @@ private int seekKey(Object key) { } } catch (Throwable t) { Throwables.propagateIfInstanceOf(t, Error.class); - Throwables.propagateIfInstanceOf(t, PrestoException.class); - throw new PrestoException(GENERIC_INTERNAL_ERROR, t); + Throwables.propagateIfInstanceOf(t, TrinoException.class); + throw new TrinoException(GENERIC_INTERNAL_ERROR, t); } } return -1; diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoStruct.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoRowData.java similarity index 70% rename from transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoStruct.java rename to transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoRowData.java index e48a94c4..74d724fa 100644 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoStruct.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoRowData.java @@ -3,49 +3,48 @@ * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ -package com.linkedin.transport.presto.data; +package com.linkedin.transport.trino.data; import com.linkedin.transport.api.StdFactory; -import com.linkedin.transport.api.data.StdData; -import com.linkedin.transport.api.data.StdStruct; -import com.linkedin.transport.presto.PrestoWrapper; -import io.prestosql.spi.block.Block; -import io.prestosql.spi.block.BlockBuilder; -import io.prestosql.spi.block.BlockBuilderStatus; -import io.prestosql.spi.block.PageBuilderStatus; -import io.prestosql.spi.type.RowType; -import io.prestosql.spi.type.Type; +import com.linkedin.transport.trino.TrinoWrapper; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.PageBuilderStatus; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; +import com.linkedin.transport.api.data.RowData; import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; import java.util.stream.IntStream; -import static io.prestosql.spi.type.TypeUtils.*; +import static io.trino.spi.type.TypeUtils.*; -public class PrestoStruct extends PrestoData implements StdStruct { +public class TrinoRowData extends TrinoData implements RowData { final RowType _rowType; final StdFactory _stdFactory; Block _block; - public PrestoStruct(Type rowType, StdFactory stdFactory) { + public TrinoRowData(Type rowType, StdFactory stdFactory) { _rowType = (RowType) rowType; _stdFactory = stdFactory; } - public PrestoStruct(Block block, Type rowType, StdFactory stdFactory) { + public TrinoRowData(Block block, Type rowType, StdFactory stdFactory) { this(rowType, stdFactory); _block = block; } - public PrestoStruct(List fieldTypes, StdFactory stdFactory) { + public TrinoRowData(List fieldTypes, StdFactory stdFactory) { _stdFactory = stdFactory; _rowType = RowType.anonymous(fieldTypes); } - public PrestoStruct(List fieldNames, List fieldTypes, StdFactory stdFactory) { + public TrinoRowData(List fieldNames, List fieldTypes, StdFactory stdFactory) { _stdFactory = stdFactory; List fields = IntStream.range(0, fieldNames.size()) .mapToObj(i -> new RowType.Field(Optional.ofNullable(fieldNames.get(i)), fieldTypes.get(i))) @@ -54,18 +53,18 @@ public PrestoStruct(List fieldNames, List fieldTypes, StdFactory s } @Override - public StdData getField(int index) { - int position = PrestoWrapper.checkedIndexToBlockPosition(_block, index); + public Object getField(int index) { + int position = TrinoWrapper.checkedIndexToBlockPosition(_block, index); if (position == -1) { return null; } Type elementType = _rowType.getFields().get(position).getType(); Object element = readNativeValue(elementType, _block, position); - return PrestoWrapper.createStdData(element, elementType, _stdFactory); + return TrinoWrapper.createStdData(element, elementType, _stdFactory); } @Override - public StdData getField(String name) { + public Object getField(String name) { int index = -1; Type elementType = null; int i = 0; @@ -81,11 +80,11 @@ public StdData getField(String name) { return null; } Object element = readNativeValue(elementType, _block, index); - return PrestoWrapper.createStdData(element, elementType, _stdFactory); + return TrinoWrapper.createStdData(element, elementType, _stdFactory); } @Override - public void setField(int index, StdData value) { + public void setField(int index, Object value) { // TODO: This is not the right way to get this object. The status should be passed in from the invocation of the // function and propagated to here. See PRESTO-1359 for more details. BlockBuilderStatus blockBuilderStatus = new PageBuilderStatus().createBlockBuilderStatus(); @@ -94,7 +93,7 @@ public void setField(int index, StdData value) { int i = 0; for (RowType.Field field : _rowType.getFields()) { if (i == index) { - ((PrestoData) value).writeToBlock(rowBlockBuilder); + TrinoWrapper.writeToBlock(value, rowBlockBuilder); } else { if (_block == null) { rowBlockBuilder.appendNull(); @@ -109,13 +108,13 @@ public void setField(int index, StdData value) { } @Override - public void setField(String name, StdData value) { + public void setField(String name, Object value) { BlockBuilder mutable = _rowType.createBlockBuilder(new PageBuilderStatus().createBlockBuilderStatus(), 1); BlockBuilder rowBlockBuilder = mutable.beginBlockEntry(); int i = 0; for (RowType.Field field : _rowType.getFields()) { if (field.getName().isPresent() && name.equals(field.getName().get())) { - ((PrestoData) value).writeToBlock(rowBlockBuilder); + TrinoWrapper.writeToBlock(value, rowBlockBuilder); } else { if (_block == null) { rowBlockBuilder.appendNull(); @@ -130,12 +129,12 @@ public void setField(String name, StdData value) { } @Override - public List fields() { - ArrayList fields = new ArrayList<>(); + public List fields() { + ArrayList fields = new ArrayList<>(); for (int i = 0; i < _block.getPositionCount(); i++) { Type elementType = _rowType.getFields().get(i).getType(); Object element = readNativeValue(elementType, _block, i); - fields.add(PrestoWrapper.createStdData(element, elementType, _stdFactory)); + fields.add(TrinoWrapper.createStdData(element, elementType, _stdFactory)); } return fields; } diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoArrayType.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoArrayType.java similarity index 60% rename from transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoArrayType.java rename to transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoArrayType.java index e63f2344..9d5b8d32 100644 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoArrayType.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoArrayType.java @@ -3,25 +3,25 @@ * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ -package com.linkedin.transport.presto.types; +package com.linkedin.transport.trino.types; import com.linkedin.transport.api.types.StdArrayType; import com.linkedin.transport.api.types.StdType; -import com.linkedin.transport.presto.PrestoWrapper; -import io.prestosql.spi.type.ArrayType; +import com.linkedin.transport.trino.TrinoWrapper; +import io.trino.spi.type.ArrayType; -public class PrestoArrayType implements StdArrayType { +public class TrinoArrayType implements StdArrayType { final ArrayType arrayType; - public PrestoArrayType(ArrayType arrayType) { + public TrinoArrayType(ArrayType arrayType) { this.arrayType = arrayType; } @Override public StdType elementType() { - return PrestoWrapper.createStdType(arrayType.getElementType()); + return TrinoWrapper.createStdType(arrayType.getElementType()); } @Override diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoBinaryType.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoBinaryType.java similarity index 66% rename from transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoBinaryType.java rename to transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoBinaryType.java index 1be446f1..cf096175 100644 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoBinaryType.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoBinaryType.java @@ -3,17 +3,17 @@ * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ -package com.linkedin.transport.presto.types; +package com.linkedin.transport.trino.types; import com.linkedin.transport.api.types.StdBinaryType; -import io.prestosql.spi.type.VarbinaryType; +import io.trino.spi.type.VarbinaryType; -public class PrestoBinaryType implements StdBinaryType { +public class TrinoBinaryType implements StdBinaryType { private final VarbinaryType varbinaryType; - public PrestoBinaryType(VarbinaryType varbinaryType) { + public TrinoBinaryType(VarbinaryType varbinaryType) { this.varbinaryType = varbinaryType; } diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoBooleanType.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoBooleanType.java similarity index 65% rename from transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoBooleanType.java rename to transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoBooleanType.java index 538655fb..543ea4da 100644 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoBooleanType.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoBooleanType.java @@ -3,17 +3,17 @@ * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ -package com.linkedin.transport.presto.types; +package com.linkedin.transport.trino.types; import com.linkedin.transport.api.types.StdBooleanType; -import io.prestosql.spi.type.BooleanType; +import io.trino.spi.type.BooleanType; -public class PrestoBooleanType implements StdBooleanType { +public class TrinoBooleanType implements StdBooleanType { final BooleanType booleanType; - public PrestoBooleanType(BooleanType booleanType) { + public TrinoBooleanType(BooleanType booleanType) { this.booleanType = booleanType; } diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoDoubleType.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoDoubleType.java similarity index 66% rename from transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoDoubleType.java rename to transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoDoubleType.java index a9a6394e..db7cab6d 100644 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoDoubleType.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoDoubleType.java @@ -3,17 +3,17 @@ * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ -package com.linkedin.transport.presto.types; +package com.linkedin.transport.trino.types; import com.linkedin.transport.api.types.StdDoubleType; -import io.prestosql.spi.type.DoubleType; +import io.trino.spi.type.DoubleType; -public class PrestoDoubleType implements StdDoubleType { +public class TrinoDoubleType implements StdDoubleType { private final DoubleType doubleType; - public PrestoDoubleType(DoubleType doubleType) { + public TrinoDoubleType(DoubleType doubleType) { this.doubleType = doubleType; } diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoFloatType.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoFloatType.java similarity index 67% rename from transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoFloatType.java rename to transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoFloatType.java index 2b481c64..e12bf57e 100644 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoFloatType.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoFloatType.java @@ -3,17 +3,17 @@ * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ -package com.linkedin.transport.presto.types; +package com.linkedin.transport.trino.types; import com.linkedin.transport.api.types.StdFloatType; -import io.prestosql.spi.type.RealType; +import io.trino.spi.type.RealType; -public class PrestoFloatType implements StdFloatType { +public class TrinoFloatType implements StdFloatType { private final RealType floatType; - public PrestoFloatType(RealType floatType) { + public TrinoFloatType(RealType floatType) { this.floatType = floatType; } diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoIntegerType.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoIntegerType.java similarity index 65% rename from transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoIntegerType.java rename to transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoIntegerType.java index ed1e3002..4b79c9bd 100644 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoIntegerType.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoIntegerType.java @@ -3,17 +3,17 @@ * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ -package com.linkedin.transport.presto.types; +package com.linkedin.transport.trino.types; import com.linkedin.transport.api.types.StdIntegerType; -import io.prestosql.spi.type.IntegerType; +import io.trino.spi.type.IntegerType; -public class PrestoIntegerType implements StdIntegerType { +public class TrinoIntegerType implements StdIntegerType { final IntegerType integerType; - public PrestoIntegerType(IntegerType integerType) { + public TrinoIntegerType(IntegerType integerType) { this.integerType = integerType; } diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoLongType.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoLongType.java similarity index 66% rename from transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoLongType.java rename to transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoLongType.java index f0dbb856..f31f7871 100644 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoLongType.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoLongType.java @@ -3,17 +3,17 @@ * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ -package com.linkedin.transport.presto.types; +package com.linkedin.transport.trino.types; import com.linkedin.transport.api.types.StdLongType; -import io.prestosql.spi.type.BigintType; +import io.trino.spi.type.BigintType; -public class PrestoLongType implements StdLongType { +public class TrinoLongType implements StdLongType { final BigintType bigintType; - public PrestoLongType(BigintType bigintType) { + public TrinoLongType(BigintType bigintType) { this.bigintType = bigintType; } diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoMapType.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoMapType.java similarity index 58% rename from transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoMapType.java rename to transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoMapType.java index d11c8189..94d70602 100644 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoMapType.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoMapType.java @@ -3,30 +3,30 @@ * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ -package com.linkedin.transport.presto.types; +package com.linkedin.transport.trino.types; import com.linkedin.transport.api.types.StdMapType; import com.linkedin.transport.api.types.StdType; -import com.linkedin.transport.presto.PrestoWrapper; -import io.prestosql.spi.type.MapType; +import com.linkedin.transport.trino.TrinoWrapper; +import io.trino.spi.type.MapType; -public class PrestoMapType implements StdMapType { +public class TrinoMapType implements StdMapType { final MapType mapType; - public PrestoMapType(MapType mapType) { + public TrinoMapType(MapType mapType) { this.mapType = mapType; } @Override public StdType keyType() { - return PrestoWrapper.createStdType(mapType.getKeyType()); + return TrinoWrapper.createStdType(mapType.getKeyType()); } @Override public StdType valueType() { - return PrestoWrapper.createStdType(mapType.getKeyType()); + return TrinoWrapper.createStdType(mapType.getKeyType()); } @Override diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoStructType.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoRowType.java similarity index 51% rename from transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoStructType.java rename to transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoRowType.java index f94bd051..e4894727 100644 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoStructType.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoRowType.java @@ -3,27 +3,26 @@ * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ -package com.linkedin.transport.presto.types; +package com.linkedin.transport.trino.types; -import com.linkedin.transport.api.types.StdStructType; +import com.linkedin.transport.api.types.RowType; import com.linkedin.transport.api.types.StdType; -import com.linkedin.transport.presto.PrestoWrapper; -import io.prestosql.spi.type.RowType; +import com.linkedin.transport.trino.TrinoWrapper; import java.util.List; import java.util.stream.Collectors; -public class PrestoStructType implements StdStructType { +public class TrinoRowType implements RowType { - final RowType rowType; + final io.trino.spi.type.RowType rowType; - public PrestoStructType(RowType rowType) { + public TrinoRowType(io.trino.spi.type.RowType rowType) { this.rowType = rowType; } @Override public List fieldTypes() { - return rowType.getFields().stream().map(f -> PrestoWrapper.createStdType(f.getType())).collect(Collectors.toList()); + return rowType.getFields().stream().map(f -> TrinoWrapper.createStdType(f.getType())).collect(Collectors.toList()); } @Override diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoStringType.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoStringType.java similarity index 66% rename from transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoStringType.java rename to transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoStringType.java index 24215f29..262ee736 100644 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoStringType.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoStringType.java @@ -3,17 +3,17 @@ * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ -package com.linkedin.transport.presto.types; +package com.linkedin.transport.trino.types; import com.linkedin.transport.api.types.StdStringType; -import io.prestosql.spi.type.VarcharType; +import io.trino.spi.type.VarcharType; -public class PrestoStringType implements StdStringType { +public class TrinoStringType implements StdStringType { final VarcharType varcharType; - public PrestoStringType(VarcharType varcharType) { + public TrinoStringType(VarcharType varcharType) { this.varcharType = varcharType; } diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoUnknownType.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoUnknownType.java similarity index 66% rename from transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoUnknownType.java rename to transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoUnknownType.java index bd43692e..21d22393 100644 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoUnknownType.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/types/TrinoUnknownType.java @@ -3,17 +3,17 @@ * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ -package com.linkedin.transport.presto.types; +package com.linkedin.transport.trino.types; import com.linkedin.transport.api.types.StdUnknownType; -import io.prestosql.type.UnknownType; +import io.trino.type.UnknownType; -public class PrestoUnknownType implements StdUnknownType { +public class TrinoUnknownType implements StdUnknownType { final UnknownType unknownType; - public PrestoUnknownType(UnknownType unknownType) { + public TrinoUnknownType(UnknownType unknownType) { this.unknownType = unknownType; } diff --git a/transportable-udfs-presto/src/test/java/com/linkedin/transport/presto/TestGetTypeVariableConstraints.java b/transportable-udfs-trino/src/test/java/com/linkedin/transport/trino/TestGetTypeVariableConstraints.java similarity index 94% rename from transportable-udfs-presto/src/test/java/com/linkedin/transport/presto/TestGetTypeVariableConstraints.java rename to transportable-udfs-trino/src/test/java/com/linkedin/transport/trino/TestGetTypeVariableConstraints.java index bd4f7a0e..6f2b49ef 100644 --- a/transportable-udfs-presto/src/test/java/com/linkedin/transport/presto/TestGetTypeVariableConstraints.java +++ b/transportable-udfs-trino/src/test/java/com/linkedin/transport/trino/TestGetTypeVariableConstraints.java @@ -3,16 +3,16 @@ * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ -package com.linkedin.transport.presto; +package com.linkedin.transport.trino; import com.google.common.collect.ImmutableList; import com.linkedin.transport.api.udf.StdUDF; -import io.prestosql.metadata.TypeVariableConstraint; +import io.trino.metadata.TypeVariableConstraint; import java.util.List; import org.testng.Assert; import org.testng.annotations.Test; -import static io.prestosql.metadata.Signature.*; +import static io.trino.metadata.Signature.*; public class TestGetTypeVariableConstraints { diff --git a/travis-build.sh b/travis-build.sh deleted file mode 100755 index a7444d32..00000000 --- a/travis-build.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/usr/bin/env bash - -# TravisCI calls this script to build and test the Transport code. -# Gradle commands that are specific to the release process are -# called directly from the Travis CI configuration file. -# The rationale for placing these commands in a separate script is -# to make it easier for contributors to run these checks before -# submitting a PR. - -set -e - -cd "$(dirname "$0")" - -./gradlew clean build -s -./gradlew -p transportable-udfs-examples clean build -s diff --git a/version.properties b/version.properties index 1b2be12a..a6c2404e 100644 --- a/version.properties +++ b/version.properties @@ -1,4 +1,3 @@ -#Version of the produced binaries. This file is intended to be checked-in. -#It will be automatically bumped by release automation. -version=0.0.62 -previousVersion=0.0.61 +# Version of the produced binaries. +# The version is inferred by shipkit-auto-version Gradle plugin (https://github.com/shipkit/shipkit-auto-version) +version=0.0.*