diff --git a/client/trino-client/src/main/java/io/trino/client/ClientStandardTypes.java b/client/trino-client/src/main/java/io/trino/client/ClientStandardTypes.java index 69f2515cdf40..df399df9c7a9 100644 --- a/client/trino-client/src/main/java/io/trino/client/ClientStandardTypes.java +++ b/client/trino-client/src/main/java/io/trino/client/ClientStandardTypes.java @@ -47,6 +47,7 @@ public final class ClientStandardTypes public static final String MAP = "map"; public static final String JSON = "json"; public static final String JSON_2016 = "json2016"; + public static final String VARIANT = "variant"; public static final String IPADDRESS = "ipaddress"; public static final String UUID = "uuid"; public static final String GEOMETRY = "Geometry"; diff --git a/client/trino-client/src/main/java/io/trino/client/JsonDecodingUtils.java b/client/trino-client/src/main/java/io/trino/client/JsonDecodingUtils.java index 461d0f2fda51..72f86dc9a6b8 100644 --- a/client/trino-client/src/main/java/io/trino/client/JsonDecodingUtils.java +++ b/client/trino-client/src/main/java/io/trino/client/JsonDecodingUtils.java @@ -13,12 +13,14 @@ */ package io.trino.client; +import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.JsonToken; import com.fasterxml.jackson.databind.json.JsonMapper; import com.google.common.collect.ImmutableList; import java.io.IOException; +import java.io.StringWriter; import java.util.Base64; import java.util.HashMap; import java.util.LinkedList; @@ -64,6 +66,8 @@ import static io.trino.client.ClientStandardTypes.UUID; import static io.trino.client.ClientStandardTypes.VARBINARY; import static io.trino.client.ClientStandardTypes.VARCHAR; +import static io.trino.client.ClientStandardTypes.VARIANT; +import static io.trino.client.JsonIterators.createJsonFactory; import static java.lang.String.format; import static java.util.Collections.unmodifiableList; import static java.util.Collections.unmodifiableMap; @@ -81,6 +85,7 @@ private JsonDecodingUtils() {} private static final RealDecoder REAL_DECODER = new RealDecoder(); private static final BooleanDecoder BOOLEAN_DECODER = new BooleanDecoder(); private static final StringDecoder STRING_DECODER = new StringDecoder(); + private static final VariantDecoder VARIANT_DECODER = new VariantDecoder(); private static final Base64Decoder BASE_64_DECODER = new Base64Decoder(); private static final ObjectDecoder OBJECT_DECODER = new ObjectDecoder(); @@ -117,6 +122,8 @@ private static TypeDecoder createTypeDecoder(ClientTypeSignature signature) return REAL_DECODER; case BOOLEAN: return BOOLEAN_DECODER; + case VARIANT: + return VARIANT_DECODER; case ARRAY: return new ArrayDecoder(signature); case MAP: @@ -287,6 +294,21 @@ public Object decode(JsonParser parser) } } + private static class VariantDecoder + implements TypeDecoder + { + @Override + public Object decode(JsonParser parser) + throws IOException + { + StringWriter writer = new StringWriter(); + try (JsonGenerator generator = createJsonFactory().createGenerator(writer)) { + generator.copyCurrentStructure(parser); + } + return writer.toString(); + } + } + private static class Base64Decoder implements TypeDecoder { diff --git a/core/trino-main/src/main/java/io/trino/metadata/BlockEncodingManager.java b/core/trino-main/src/main/java/io/trino/metadata/BlockEncodingManager.java index 52aab50f2e87..0ec8ecab5c6a 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/BlockEncodingManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/BlockEncodingManager.java @@ -30,6 +30,7 @@ import io.trino.spi.block.RunLengthBlockEncoding; import io.trino.spi.block.ShortArrayBlockEncoding; import io.trino.spi.block.VariableWidthBlockEncoding; +import io.trino.spi.block.VariantBlockEncoding; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -56,6 +57,7 @@ public BlockEncodingManager(BlockEncodingSimdSupport blockEncodingSimdSupport) addBlockEncoding(new LongArrayBlockEncoding(simdSupport.vectorizeNullBitPacking(), simdSupport.compressLong(), simdSupport.expandLong())); addBlockEncoding(new Fixed12BlockEncoding(simdSupport.vectorizeNullBitPacking())); addBlockEncoding(new Int128ArrayBlockEncoding(simdSupport.vectorizeNullBitPacking())); + addBlockEncoding(new VariantBlockEncoding(simdSupport.vectorizeNullBitPacking())); addBlockEncoding(new DictionaryBlockEncoding()); addBlockEncoding(new ArrayBlockEncoding(simdSupport.vectorizeNullBitPacking())); addBlockEncoding(new MapBlockEncoding(simdSupport.vectorizeNullBitPacking())); diff --git a/core/trino-main/src/main/java/io/trino/metadata/SignatureBinder.java b/core/trino-main/src/main/java/io/trino/metadata/SignatureBinder.java index e0508a7621b1..c1b70f4deeae 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/SignatureBinder.java +++ b/core/trino-main/src/main/java/io/trino/metadata/SignatureBinder.java @@ -17,8 +17,10 @@ import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import io.trino.connector.system.GlobalSystemConnector; import io.trino.spi.TrinoException; import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.CatalogSchemaFunctionName; import io.trino.spi.function.FunctionId; import io.trino.spi.function.LongVariableConstraint; import io.trino.spi.function.Signature; @@ -30,7 +32,6 @@ import io.trino.spi.type.TypeSignature; import io.trino.sql.analyzer.TypeSignatureProvider; import io.trino.type.FunctionType; -import io.trino.type.JsonType; import io.trino.type.TypeCoercion; import io.trino.type.UnknownType; @@ -47,10 +48,13 @@ import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSortedMap.toImmutableSortedMap; +import static io.trino.metadata.GlobalFunctionCatalog.BUILTIN_SCHEMA; +import static io.trino.metadata.OperatorNameUtil.mangleOperatorName; import static io.trino.metadata.SignatureBinder.RelationshipType.EXACT; import static io.trino.metadata.SignatureBinder.RelationshipType.EXPLICIT_COERCION_FROM; import static io.trino.metadata.SignatureBinder.RelationshipType.EXPLICIT_COERCION_TO; import static io.trino.metadata.SignatureBinder.RelationshipType.IMPLICIT_COERCION; +import static io.trino.spi.function.OperatorType.CAST; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.type.TypeCalculation.calculateLiteralValue; import static io.trino.type.TypeCoercion.isCovariantTypeBase; @@ -598,7 +602,8 @@ private boolean satisfiesCoercion(RelationshipType relationshipType, Type actual private boolean canCast(Type fromType, Type toType) { - if (toType instanceof UnknownType) { + // NULL can be cast to any type; avoid re-entering coercion cache. + if (fromType instanceof UnknownType || toType instanceof UnknownType) { return true; } if (fromType instanceof RowType) { @@ -615,14 +620,14 @@ private boolean canCast(Type fromType, Type toType) } return true; } - if (toType instanceof JsonType) { + if (isRecursiveCastFromRow(toType)) { return fromType.getTypeParameters().stream() .allMatch(fromTypeParameter -> canCast(fromTypeParameter, toType)); } return false; } - if (fromType instanceof JsonType) { - if (toType instanceof RowType) { + if (toType instanceof RowType) { + if (isRecursiveCastToRow(fromType)) { return toType.getTypeParameters().stream() .allMatch(toTypeParameter -> canCast(fromType, toTypeParameter)); } @@ -636,6 +641,61 @@ private boolean canCast(Type fromType, Type toType) } } + /// Check if there is a recursive variadic CAST from ROW. + /// This needs special handling because the cast is applied to each field of ROW individually. + private boolean isRecursiveCastFromRow(Type toType) + { + return metadata.getFunctions(null, new CatalogSchemaFunctionName(GlobalSystemConnector.NAME, BUILTIN_SCHEMA, mangleOperatorName(CAST))).stream() + .map(cast -> cast.functionMetadata().getSignature()) + .anyMatch(signature -> isRecursiveCastFromRow(toType, signature)); + } + + private static boolean isRecursiveCastFromRow(Type toType, Signature signature) + { + // the return type must match toType + if (!toType.getTypeSignature().equals(signature.getReturnType())) { + return false; + } + + // there must be exactly one type variable constraint and no long variable constraints + if (signature.getTypeVariableConstraints().size() != 1 || !signature.getLongVariableConstraints().isEmpty()) { + return false; + } + TypeVariableConstraint typeVariableConstraint = signature.getTypeVariableConstraints().getFirst(); + + // The argument type must be a type variable with variadic bound of "row" + return signature.getArgumentTypes().size() == 1 && + signature.getArgumentTypes().getFirst().getBase().equals(typeVariableConstraint.getName()) && + typeVariableConstraint.isRowType(); + } + + /// Check if there is a recursive variadic CAST to ROW. + /// This needs special handling because the cast is applied to each field of ROW individually. + private boolean isRecursiveCastToRow(Type fromType) + { + return metadata.getFunctions(null, new CatalogSchemaFunctionName(GlobalSystemConnector.NAME, BUILTIN_SCHEMA, mangleOperatorName(CAST))).stream() + .map(cast -> cast.functionMetadata().getSignature()) + .anyMatch(signature -> isRecursiveCastToRow(fromType, signature)); + } + + private static boolean isRecursiveCastToRow(Type fromType, Signature signature) + { + // the argument type must match fromType + if (signature.getArgumentTypes().size() != 1 || !fromType.getTypeSignature().equals(signature.getArgumentTypes().getFirst())) { + return false; + } + + // there must be exactly one type variable constraint and no long variable constraints + if (signature.getTypeVariableConstraints().size() != 1 || !signature.getLongVariableConstraints().isEmpty()) { + return false; + } + TypeVariableConstraint typeVariableConstraint = signature.getTypeVariableConstraints().getFirst(); + + // The return type must be a type variable with variadic bound of "row" + return signature.getReturnType().getBase().equals(typeVariableConstraint.getName()) && + typeVariableConstraint.isRowType(); + } + private static List getLambdaArgumentTypeSignatures(TypeSignature lambdaTypeSignature) { List parameters = lambdaTypeSignature.getParameters(); diff --git a/core/trino-main/src/main/java/io/trino/metadata/SystemFunctionBundle.java b/core/trino-main/src/main/java/io/trino/metadata/SystemFunctionBundle.java index 9ebb7d6acae8..8c8730b7b38f 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/SystemFunctionBundle.java +++ b/core/trino-main/src/main/java/io/trino/metadata/SystemFunctionBundle.java @@ -277,6 +277,8 @@ import io.trino.type.TinyintOperators; import io.trino.type.UuidOperators; import io.trino.type.VarcharOperators; +import io.trino.type.VariantFunctions; +import io.trino.type.VariantOperators; import io.trino.type.setdigest.BuildSetDigestAggregation; import io.trino.type.setdigest.MergeSetDigestAggregation; import io.trino.type.setdigest.SetDigestFunctions; @@ -289,6 +291,7 @@ import static io.trino.operator.scalar.ArraySubscriptOperator.ARRAY_SUBSCRIPT; import static io.trino.operator.scalar.ArrayToElementConcatFunction.ARRAY_TO_ELEMENT_CONCAT_FUNCTION; import static io.trino.operator.scalar.ArrayToJsonCast.ARRAY_TO_JSON; +import static io.trino.operator.scalar.ArrayToVariantCast.ARRAY_TO_VARIANT; import static io.trino.operator.scalar.ArrayTransformFunction.ARRAY_TRANSFORM_FUNCTION; import static io.trino.operator.scalar.CastFromUnknownOperator.CAST_FROM_UNKNOWN; import static io.trino.operator.scalar.ConcatFunction.VARBINARY_CONCAT; @@ -310,6 +313,7 @@ import static io.trino.operator.scalar.MapElementAtFunction.MAP_ELEMENT_AT; import static io.trino.operator.scalar.MapFilterFunction.MAP_FILTER_FUNCTION; import static io.trino.operator.scalar.MapToJsonCast.MAP_TO_JSON; +import static io.trino.operator.scalar.MapToVariantCast.MAP_TO_VARIANT; import static io.trino.operator.scalar.MapTransformValuesFunction.MAP_TRANSFORM_VALUES_FUNCTION; import static io.trino.operator.scalar.MapZipWithFunction.MAP_ZIP_WITH_FUNCTION; import static io.trino.operator.scalar.MathFunctions.DECIMAL_MOD_FUNCTION; @@ -317,7 +321,11 @@ import static io.trino.operator.scalar.Re2JCastToRegexpFunction.castVarcharToRe2JRegexp; import static io.trino.operator.scalar.RowToJsonCast.ROW_TO_JSON; import static io.trino.operator.scalar.RowToRowCast.ROW_TO_ROW_CAST; +import static io.trino.operator.scalar.RowToVariantCast.ROW_TO_VARIANT; import static io.trino.operator.scalar.TryCastFunction.TRY_CAST; +import static io.trino.operator.scalar.VariantToArrayCast.VARIANT_TO_ARRAY; +import static io.trino.operator.scalar.VariantToMapCast.VARIANT_TO_MAP; +import static io.trino.operator.scalar.VariantToRowCast.VARIANT_TO_ROW; import static io.trino.operator.scalar.ZipFunction.ZIP_FUNCTIONS; import static io.trino.operator.scalar.ZipWithFunction.ZIP_WITH_FUNCTION; import static io.trino.operator.scalar.json.JsonArrayFunction.JSON_ARRAY_FUNCTION; @@ -334,6 +342,7 @@ import static io.trino.type.DecimalCasts.DECIMAL_TO_SMALLINT_CAST; import static io.trino.type.DecimalCasts.DECIMAL_TO_TINYINT_CAST; import static io.trino.type.DecimalCasts.DECIMAL_TO_VARCHAR_CAST; +import static io.trino.type.DecimalCasts.DECIMAL_TO_VARIANT_CAST; import static io.trino.type.DecimalCasts.DOUBLE_TO_DECIMAL_CAST; import static io.trino.type.DecimalCasts.INTEGER_TO_DECIMAL_CAST; import static io.trino.type.DecimalCasts.JSON_TO_DECIMAL_CAST; @@ -342,6 +351,7 @@ import static io.trino.type.DecimalCasts.SMALLINT_TO_DECIMAL_CAST; import static io.trino.type.DecimalCasts.TINYINT_TO_DECIMAL_CAST; import static io.trino.type.DecimalCasts.VARCHAR_TO_DECIMAL_CAST; +import static io.trino.type.DecimalCasts.VARIANT_TO_DECIMAL_CAST; import static io.trino.type.DecimalOperators.DECIMAL_ADD_OPERATOR; import static io.trino.type.DecimalOperators.DECIMAL_DIVIDE_OPERATOR; import static io.trino.type.DecimalOperators.DECIMAL_MODULUS_OPERATOR; @@ -463,6 +473,17 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators .scalars(JsonInputFunctions.class) .scalars(JsonOutputFunctions.class) .functions(JSON_OBJECT_FUNCTION, JSON_ARRAY_FUNCTION) + .scalars(VariantOperators.class) + .scalar(VariantOperators.VariantToTimeCast.class) + .scalar(VariantOperators.VariantFromTimeCast.class) + .scalar(VariantOperators.VariantToTimestampCast.class) + .scalar(VariantOperators.VariantFromTimestampCast.class) + .scalar(VariantOperators.VariantToTimestampWithTimeZoneCasts.class) + .scalar(VariantOperators.VariantFromTimestampWithTimeZoneCasts.class) + .functions(VARIANT_TO_ARRAY, ARRAY_TO_VARIANT) + .functions(VARIANT_TO_MAP, MAP_TO_VARIANT) + .functions(VARIANT_TO_ROW, ROW_TO_VARIANT) + .scalars(VariantFunctions.class) .scalars(ColorFunctions.class) .scalars(HyperLogLogFunctions.class) .scalars(QuantileDigestFunctions.class) @@ -561,6 +582,7 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators .functions(VARCHAR_TO_DECIMAL_CAST, INTEGER_TO_DECIMAL_CAST, BIGINT_TO_DECIMAL_CAST, DOUBLE_TO_DECIMAL_CAST, REAL_TO_DECIMAL_CAST, BOOLEAN_TO_DECIMAL_CAST, TINYINT_TO_DECIMAL_CAST, SMALLINT_TO_DECIMAL_CAST) .functions(NUMBER_TO_DECIMAL_CAST, DECIMAL_TO_NUMBER_CAST) .functions(JSON_TO_DECIMAL_CAST, DECIMAL_TO_JSON_CAST) + .functions(VARIANT_TO_DECIMAL_CAST, DECIMAL_TO_VARIANT_CAST) .functions(featuresConfig.isLegacyArithmeticDecimalOperators() ? LEGACY_DECIMAL_ADD_OPERATOR : DECIMAL_ADD_OPERATOR) .functions(featuresConfig.isLegacyArithmeticDecimalOperators() ? LEGACY_DECIMAL_SUBTRACT_OPERATOR : DECIMAL_SUBTRACT_OPERATOR) .functions(featuresConfig.isLegacyArithmeticDecimalOperators() ? LEGACY_DECIMAL_MULTIPLY_OPERATOR : DECIMAL_MULTIPLY_OPERATOR) diff --git a/core/trino-main/src/main/java/io/trino/metadata/TypeRegistry.java b/core/trino-main/src/main/java/io/trino/metadata/TypeRegistry.java index ec66af9968ec..df2a2a921629 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/TypeRegistry.java +++ b/core/trino-main/src/main/java/io/trino/metadata/TypeRegistry.java @@ -85,6 +85,7 @@ import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.UuidType.UUID; import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.spi.type.VariantType.VARIANT; import static io.trino.sql.analyzer.TypeSignatureTranslator.toTypeSignature; import static io.trino.type.ArrayParametricType.ARRAY; import static io.trino.type.CodePointsType.CODE_POINTS; @@ -152,6 +153,7 @@ public TypeRegistry(TypeOperators typeOperators, FeaturesConfig featuresConfig) addType(JSON_2016); addType(COLOR); addType(JSON); + addType(VARIANT); addType(CODE_POINTS); addType(IPADDRESS); addType(UUID); diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayToVariantCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayToVariantCast.java new file mode 100644 index 000000000000..5b0967bc8e56 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayToVariantCast.java @@ -0,0 +1,86 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.scalar; + +import com.google.common.collect.ImmutableList; +import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.block.Block; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.TypeSignature; +import io.trino.spi.variant.Variant; +import io.trino.util.variant.VariantWriter; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; + +import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.trino.spi.function.OperatorType.CAST; +import static io.trino.spi.type.TypeSignature.arrayType; +import static io.trino.spi.type.VariantType.VARIANT; +import static io.trino.util.Failures.checkCondition; +import static io.trino.util.variant.VariantUtil.canCastToVariant; +import static java.lang.invoke.MethodType.methodType; + +public class ArrayToVariantCast + extends SqlScalarFunction +{ + public static final ArrayToVariantCast ARRAY_TO_VARIANT = new ArrayToVariantCast(); + + private static final MethodHandle METHOD_HANDLE; + + static { + try { + METHOD_HANDLE = MethodHandles.lookup().findStatic(ArrayToVariantCast.class, "toVariant", methodType(Variant.class, VariantWriter.class, Block.class)); + } + catch (IllegalAccessException | NoSuchMethodException e) { + throw new ExceptionInInitializerError(e); + } + } + + private ArrayToVariantCast() + { + super(FunctionMetadata.operatorBuilder(CAST) + .signature(Signature.builder() + .castableToTypeParameter("T", VARIANT.getTypeSignature()) + .returnType(VARIANT) + .argumentType(arrayType(new TypeSignature("T"))) + .build()) + .build()); + } + + @Override + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) + { + ArrayType arrayType = (ArrayType) boundSignature.getArgumentTypes().getFirst(); + checkCondition(canCastToVariant(arrayType), INVALID_CAST_ARGUMENT, "Cannot cast %s to VARIANT", arrayType); + + VariantWriter writer = VariantWriter.create(arrayType); + MethodHandle methodHandle = METHOD_HANDLE.bindTo(writer); + return new ChoicesSpecializedSqlScalarFunction( + boundSignature, + FAIL_ON_NULL, + ImmutableList.of(NEVER_NULL), + methodHandle); + } + + private static Variant toVariant(VariantWriter writer, Block block) + { + return writer.write(block); + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/JsonToRowCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/JsonToRowCast.java index 65f925295d64..809b90f01283 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/JsonToRowCast.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/JsonToRowCast.java @@ -66,7 +66,6 @@ private JsonToRowCast() // this is technically a recursive constraint for cast, but TypeRegistry.canCast has explicit handling for json to row cast TypeVariableConstraint.builder("T") .rowType() - .castableFrom(JSON) .build()) .returnType(new TypeSignature("T")) .argumentType(JSON) diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/MapToVariantCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/MapToVariantCast.java new file mode 100644 index 000000000000..5742372c32cd --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/MapToVariantCast.java @@ -0,0 +1,88 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.scalar; + +import com.google.common.collect.ImmutableList; +import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.block.SqlMap; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; +import io.trino.spi.type.MapType; +import io.trino.spi.type.TypeSignature; +import io.trino.spi.variant.Variant; +import io.trino.util.variant.VariantWriter; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; + +import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.trino.spi.function.OperatorType.CAST; +import static io.trino.spi.type.TypeParameter.typeVariable; +import static io.trino.spi.type.TypeSignature.mapType; +import static io.trino.spi.type.VariantType.VARIANT; +import static io.trino.util.Failures.checkCondition; +import static io.trino.util.variant.VariantUtil.canCastToVariant; +import static java.lang.invoke.MethodType.methodType; + +public class MapToVariantCast + extends SqlScalarFunction +{ + public static final MapToVariantCast MAP_TO_VARIANT = new MapToVariantCast(); + private static final MethodHandle METHOD_HANDLE; + + static { + try { + METHOD_HANDLE = MethodHandles.lookup().findStatic(MapToVariantCast.class, "toVariant", methodType(Variant.class, VariantWriter.class, SqlMap.class)); + } + catch (IllegalAccessException | NoSuchMethodException e) { + throw new ExceptionInInitializerError(e); + } + } + + private MapToVariantCast() + { + super(FunctionMetadata.operatorBuilder(CAST) + .signature(Signature.builder() + .longVariable("N") + .castableToTypeParameter("V", VARIANT.getTypeSignature()) + .returnType(VARIANT) + .argumentType(mapType(new TypeSignature("varchar", typeVariable("N")), new TypeSignature("V"))) + .build()) + .build()); + } + + @Override + public SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) + { + MapType mapType = (MapType) boundSignature.getArgumentType(0); + checkCondition(canCastToVariant(mapType), INVALID_CAST_ARGUMENT, "Cannot cast %s to VARIANT", mapType); + + VariantWriter writer = VariantWriter.create(mapType); + MethodHandle methodHandle = METHOD_HANDLE.bindTo(writer); + + return new ChoicesSpecializedSqlScalarFunction( + boundSignature, + FAIL_ON_NULL, + ImmutableList.of(NEVER_NULL), + methodHandle); + } + + private static Variant toVariant(VariantWriter writer, SqlMap map) + { + return writer.write(map); + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/RowToJsonCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/RowToJsonCast.java index ff8e4e312255..303558aed01c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/RowToJsonCast.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/RowToJsonCast.java @@ -62,7 +62,6 @@ private RowToJsonCast() // this is technically a recursive constraint for cast, but TypeRegistry.canCast has explicit handling for row to json cast TypeVariableConstraint.builder("T") .rowType() - .castableTo(JSON) .build()) .returnType(JSON) .argumentType(new TypeSignature("T")) diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/RowToVariantCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/RowToVariantCast.java new file mode 100644 index 000000000000..c574e1b1c171 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/RowToVariantCast.java @@ -0,0 +1,86 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.scalar; + +import com.google.common.collect.ImmutableList; +import io.trino.annotation.UsedByGeneratedCode; +import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.block.SqlRow; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; +import io.trino.spi.function.TypeVariableConstraint; +import io.trino.spi.type.Type; +import io.trino.spi.type.TypeSignature; +import io.trino.spi.variant.Variant; +import io.trino.util.variant.VariantWriter; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; + +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.trino.spi.function.OperatorType.CAST; +import static io.trino.spi.type.VariantType.VARIANT; +import static java.lang.invoke.MethodType.methodType; + +public class RowToVariantCast + extends SqlScalarFunction +{ + public static final RowToVariantCast ROW_TO_VARIANT = new RowToVariantCast(); + + private static final MethodHandle METHOD_HANDLE; + + static { + try { + METHOD_HANDLE = MethodHandles.lookup().findStatic(RowToVariantCast.class, "toVariant", methodType(Variant.class, VariantWriter.class, SqlRow.class)); + } + catch (IllegalAccessException | NoSuchMethodException e) { + throw new ExceptionInInitializerError(e); + } + } + + private RowToVariantCast() + { + super(FunctionMetadata.operatorBuilder(CAST) + .signature(Signature.builder() + .typeVariableConstraint( + // this is technically a recursive constraint for cast, but TypeRegistry.canCast has explicit handling for row to variant cast + TypeVariableConstraint.builder("T") + .rowType() + .build()) + .returnType(VARIANT) + .argumentType(new TypeSignature("T")) + .build()) + .build()); + } + + @Override + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) + { + Type type = boundSignature.getArgumentType(0); + MethodHandle methodHandle = METHOD_HANDLE.bindTo(VariantWriter.create(type)); + return new ChoicesSpecializedSqlScalarFunction( + boundSignature, + FAIL_ON_NULL, + ImmutableList.of(NEVER_NULL), + methodHandle); + } + + @UsedByGeneratedCode + public static Variant toVariant(VariantWriter variantWriter, SqlRow sqlRow) + { + return variantWriter.write(sqlRow); + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/VariantToArrayCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/VariantToArrayCast.java new file mode 100644 index 000000000000..448fbfb0b826 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/VariantToArrayCast.java @@ -0,0 +1,101 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.scalar; + +import com.google.common.collect.ImmutableList; +import io.trino.annotation.UsedByGeneratedCode; +import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.Type; +import io.trino.spi.type.TypeSignature; +import io.trino.spi.variant.Variant; +import io.trino.util.variant.VariantUtil.BlockBuilderAppender; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; +import static io.trino.spi.function.OperatorType.CAST; +import static io.trino.spi.type.TypeSignature.arrayType; +import static io.trino.spi.type.VariantType.VARIANT; +import static io.trino.util.Failures.checkCondition; +import static io.trino.util.variant.VariantUtil.canCastFromVariant; +import static java.lang.invoke.MethodType.methodType; + +public class VariantToArrayCast + extends SqlScalarFunction +{ + public static final VariantToArrayCast VARIANT_TO_ARRAY = new VariantToArrayCast(); + + private static final MethodHandle METHOD_HANDLE; + + static { + try { + METHOD_HANDLE = MethodHandles.lookup().findStatic(VariantToArrayCast.class, "toArray", methodType(Block.class, Type.class, BlockBuilderAppender.class, ConnectorSession.class, Variant.class)); + } + catch (IllegalAccessException | NoSuchMethodException e) { + throw new ExceptionInInitializerError(e); + } + } + + private VariantToArrayCast() + { + super(FunctionMetadata.operatorBuilder(CAST) + .signature(Signature.builder() + .castableFromTypeParameter("T", VARIANT.getTypeSignature()) + .returnType(arrayType(new TypeSignature("T"))) + .argumentType(VARIANT) + .build()) + .nullable() + .build()); + } + + @Override + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) + { + checkArgument(boundSignature.getArity() == 1, "Expected arity to be 1"); + ArrayType arrayType = (ArrayType) boundSignature.getReturnType(); + checkCondition(canCastFromVariant(arrayType), INVALID_CAST_ARGUMENT, "Cannot cast VARIANT to %s", arrayType); + + Type elementType = arrayType.getElementType(); + BlockBuilderAppender arrayAppender = BlockBuilderAppender.createBlockBuilderAppender(elementType); + MethodHandle methodHandle = METHOD_HANDLE.bindTo(elementType).bindTo(arrayAppender); + return new ChoicesSpecializedSqlScalarFunction( + boundSignature, + NULLABLE_RETURN, + ImmutableList.of(NEVER_NULL), + methodHandle); + } + + @UsedByGeneratedCode + public static Block toArray(Type elementType, BlockBuilderAppender elementAppender, ConnectorSession connectorSession, Variant variant) + { + if (variant.isNull()) { + return null; + } + + BlockBuilder blockBuilder = elementType.createBlockBuilder(null, variant.getArrayLength()); + variant.arrayElements().forEach(element -> elementAppender.append(element, blockBuilder)); + return blockBuilder.build(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/VariantToMapCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/VariantToMapCast.java new file mode 100644 index 000000000000..019f4898492d --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/VariantToMapCast.java @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.scalar; + +import com.google.common.collect.ImmutableList; +import io.trino.annotation.UsedByGeneratedCode; +import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.SqlMap; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; +import io.trino.spi.type.MapType; +import io.trino.spi.type.TypeSignature; +import io.trino.spi.variant.Variant; +import io.trino.util.variant.VariantUtil.BlockBuilderAppender; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; +import static io.trino.spi.function.OperatorType.CAST; +import static io.trino.spi.type.TypeSignature.mapType; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.spi.type.VariantType.VARIANT; +import static io.trino.util.Failures.checkCondition; +import static io.trino.util.variant.VariantUtil.canCastFromVariant; +import static java.lang.invoke.MethodType.methodType; + +public class VariantToMapCast + extends SqlScalarFunction +{ + public static final VariantToMapCast VARIANT_TO_MAP = new VariantToMapCast(); + private static final MethodHandle METHOD_HANDLE; + + static { + try { + METHOD_HANDLE = MethodHandles.lookup().findStatic(VariantToMapCast.class, "toMap", methodType(SqlMap.class, MapType.class, BlockBuilderAppender.class, ConnectorSession.class, Variant.class)); + } + catch (IllegalAccessException | NoSuchMethodException e) { + throw new ExceptionInInitializerError(e); + } + } + + private VariantToMapCast() + { + super(FunctionMetadata.operatorBuilder(CAST) + .signature(Signature.builder() + .castableFromTypeParameter("K", VARCHAR.getTypeSignature()) + .castableFromTypeParameter("V", VARIANT.getTypeSignature()) + .returnType(mapType(new TypeSignature("K"), new TypeSignature("V"))) + .argumentType(VARIANT) + .build()) + .nullable() + .build()); + } + + @Override + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) + { + checkArgument(boundSignature.getArity() == 1, "Expected arity to be 1"); + MapType mapType = (MapType) boundSignature.getReturnType(); + checkCondition(canCastFromVariant(mapType), INVALID_CAST_ARGUMENT, "Cannot cast VARIANT to %s", mapType); + + BlockBuilderAppender mapAppender = BlockBuilderAppender.createBlockBuilderAppender(mapType); + MethodHandle methodHandle = METHOD_HANDLE.bindTo(mapType).bindTo(mapAppender); + return new ChoicesSpecializedSqlScalarFunction( + boundSignature, + NULLABLE_RETURN, + ImmutableList.of(NEVER_NULL), + methodHandle); + } + + @UsedByGeneratedCode + public static SqlMap toMap(MapType mapType, BlockBuilderAppender mapAppender, ConnectorSession connectorSession, Variant variant) + { + BlockBuilder blockBuilder = mapType.createBlockBuilder(null, 1); + mapAppender.append(variant, blockBuilder); + Block block = blockBuilder.build(); + return mapType.getObject(block, 0); + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/VariantToRowCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/VariantToRowCast.java new file mode 100644 index 000000000000..9aea9185112d --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/VariantToRowCast.java @@ -0,0 +1,99 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.scalar; + +import com.google.common.collect.ImmutableList; +import io.trino.annotation.UsedByGeneratedCode; +import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.SqlRow; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; +import io.trino.spi.function.TypeVariableConstraint; +import io.trino.spi.type.RowType; +import io.trino.spi.type.TypeSignature; +import io.trino.spi.variant.Variant; +import io.trino.util.variant.VariantUtil.BlockBuilderAppender; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; + +import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; +import static io.trino.spi.function.OperatorType.CAST; +import static io.trino.spi.type.VariantType.VARIANT; +import static io.trino.util.Failures.checkCondition; +import static io.trino.util.variant.VariantUtil.BlockBuilderAppender.createBlockBuilderAppender; +import static io.trino.util.variant.VariantUtil.canCastFromVariant; +import static java.lang.invoke.MethodType.methodType; + +public class VariantToRowCast + extends SqlScalarFunction +{ + public static final VariantToRowCast VARIANT_TO_ROW = new VariantToRowCast(); + private static final MethodHandle METHOD_HANDLE; + + static { + try { + METHOD_HANDLE = MethodHandles.lookup().findStatic(VariantToRowCast.class, "toRow", methodType(SqlRow.class, RowType.class, BlockBuilderAppender.class, ConnectorSession.class, Variant.class)); + } + catch (IllegalAccessException | NoSuchMethodException e) { + throw new ExceptionInInitializerError(e); + } + } + + private VariantToRowCast() + { + super(FunctionMetadata.operatorBuilder(CAST) + .signature(Signature.builder() + .typeVariableConstraint( + // this is technically a recursive constraint for cast, but TypeRegistry.canCast has explicit handling for variant to row cast + TypeVariableConstraint.builder("T") + .rowType() + .build()) + .returnType(new TypeSignature("T")) + .argumentType(VARIANT) + .build()) + .nullable() + .build()); + } + + @Override + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) + { + RowType rowType = (RowType) boundSignature.getReturnType(); + checkCondition(canCastFromVariant(rowType), INVALID_CAST_ARGUMENT, "Cannot cast VARIANT to %s", rowType); + + BlockBuilderAppender fieldAppender = createBlockBuilderAppender(rowType); + MethodHandle methodHandle = METHOD_HANDLE.bindTo(rowType).bindTo(fieldAppender); + return new ChoicesSpecializedSqlScalarFunction( + boundSignature, + NULLABLE_RETURN, + ImmutableList.of(NEVER_NULL), + methodHandle); + } + + @UsedByGeneratedCode + public static SqlRow toRow(RowType rowType, BlockBuilderAppender rowAppender, ConnectorSession connectorSession, Variant variant) + { + BlockBuilder blockBuilder = rowType.createBlockBuilder(null, 1); + rowAppender.append(variant, blockBuilder); + Block block = blockBuilder.build(); + return rowType.getObject(block, 0); + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/OperatorValidator.java b/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/OperatorValidator.java index 105b7ba1aaa9..f7852f8e02a5 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/OperatorValidator.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/OperatorValidator.java @@ -53,15 +53,18 @@ public static void validateOperator(OperatorType operatorType, TypeSignature ret break; case SUBSCRIPT: validateOperatorSignature(operatorType, returnType, argumentTypes, 2); - checkArgument(argumentTypes.get(0).getBase().equals(StandardTypes.ARRAY) || argumentTypes.get(0).getBase().equals(StandardTypes.MAP), "First argument must be an ARRAY or MAP"); - if (argumentTypes.get(0).getBase().equals(StandardTypes.ARRAY)) { - checkArgument(argumentTypes.get(1).getBase().equals(StandardTypes.BIGINT), "Second argument must be a BIGINT"); - TypeSignature elementType = ((TypeParameter.Type) argumentTypes.get(0).getParameters().get(0)).type(); - checkArgument(returnType.equals(elementType), "[] return type does not match ARRAY element type"); - } - else { - TypeSignature valueType = ((TypeParameter.Type) argumentTypes.get(0).getParameters().get(1)).type(); - checkArgument(returnType.equals(valueType), "[] return type does not match MAP value type"); + checkArgument(argumentTypes.get(0).getBase().equals(StandardTypes.ARRAY) || argumentTypes.get(0).getBase().equals(StandardTypes.MAP) || argumentTypes.get(0).getBase().equals(StandardTypes.VARIANT), "First argument must be an ARRAY, MAP, or VARIANT"); + switch (argumentTypes.get(0).getBase()) { + case StandardTypes.ARRAY -> { + checkArgument(argumentTypes.get(1).getBase().equals(StandardTypes.BIGINT), "Second argument must be a BIGINT"); + TypeSignature elementType = ((TypeParameter.Type) argumentTypes.get(0).getParameters().get(0)).type(); + checkArgument(returnType.equals(elementType), "[] return type does not match ARRAY element type"); + } + case StandardTypes.MAP -> { + TypeSignature valueType = ((TypeParameter.Type) argumentTypes.get(0).getParameters().get(1)).type(); + checkArgument(returnType.equals(valueType), "[] return type does not match MAP value type"); + } + default -> {} } break; case HASH_CODE: diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/timestamp/VarcharToTimestampCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/timestamp/VarcharToTimestampCast.java index 505e4a250d6c..489f3754e815 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/timestamp/VarcharToTimestampCast.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/timestamp/VarcharToTimestampCast.java @@ -13,7 +13,6 @@ */ package io.trino.operator.scalar.timestamp; -import com.google.common.annotations.VisibleForTesting; import io.airlift.slice.Slice; import io.trino.spi.TrinoException; import io.trino.spi.function.LiteralParameter; @@ -69,7 +68,6 @@ public static LongTimestamp castToLong(@LiteralParameter("p") long precision, @S } } - @VisibleForTesting public static long castToShortTimestamp(int precision, String value) { checkArgument(precision <= MAX_SHORT_PRECISION, "precision must be less than max short timestamp precision"); @@ -125,7 +123,6 @@ public static long castToShortTimestamp(int precision, String value) return epochSecond * MICROSECONDS_PER_SECOND + rescale(fractionValue, actualPrecision, 6); } - @VisibleForTesting public static LongTimestamp castToLongTimestamp(int precision, String value) { checkArgument(precision > MAX_SHORT_PRECISION && precision <= MAX_PRECISION, "precision out of range"); diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/timestamptz/VarcharToTimestampWithTimeZoneCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/timestamptz/VarcharToTimestampWithTimeZoneCast.java index f9ececb433b5..237127ccfec6 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/timestamptz/VarcharToTimestampWithTimeZoneCast.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/timestamptz/VarcharToTimestampWithTimeZoneCast.java @@ -81,7 +81,7 @@ public static LongTimestampWithTimeZone castToLong(@LiteralParameter("p") long p } } - private static long toShort(int precision, String value, Function zoneId) + public static long toShort(int precision, String value, Function zoneId) { checkArgument(precision <= MAX_SHORT_PRECISION, "precision must be less than max short timestamp precision"); @@ -134,7 +134,7 @@ private static long toShort(int precision, String value, Function zoneId) + public static LongTimestampWithTimeZone toLong(int precision, String value, Function zoneId) { checkArgument(precision > MAX_SHORT_PRECISION && precision <= MAX_PRECISION, "precision out of range"); diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/JsonEncodingUtils.java b/core/trino-main/src/main/java/io/trino/server/protocol/JsonEncodingUtils.java index 800efd3f28ae..64b23b88bd4f 100644 --- a/core/trino-main/src/main/java/io/trino/server/protocol/JsonEncodingUtils.java +++ b/core/trino-main/src/main/java/io/trino/server/protocol/JsonEncodingUtils.java @@ -44,8 +44,11 @@ import io.trino.spi.type.Type; import io.trino.spi.type.VarbinaryType; import io.trino.spi.type.VarcharType; +import io.trino.spi.type.VariantType; +import io.trino.spi.variant.Variant; import io.trino.type.SqlIntervalDayTime; import io.trino.type.SqlIntervalYearMonth; +import io.trino.util.variant.VariantUtil; import java.io.IOException; import java.math.BigDecimal; @@ -64,6 +67,7 @@ import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.spi.type.VariantType.VARIANT; import static java.util.Objects.requireNonNull; public final class JsonEncodingUtils @@ -79,6 +83,7 @@ private JsonEncodingUtils() {} private static final TinyintEncoder TINYINT_ENCODER = new TinyintEncoder(); private static final VarcharEncoder VARCHAR_ENCODER = new VarcharEncoder(); private static final VarbinaryEncoder VARBINARY_ENCODER = new VarbinaryEncoder(); + private static final VariantEncoder VARIANT_ENCODER = new VariantEncoder(); public static TypeEncoder[] createTypeEncoders(Session session, List types) { @@ -106,6 +111,7 @@ public static TypeEncoder createTypeEncoder(Type type, boolean supportsParametri case VarcharType _ -> VARCHAR_ENCODER; case VarbinaryType _ -> VARBINARY_ENCODER; case CharType charType -> new CharEncoder(charType.getLength()); + case VariantType _ -> VARIANT_ENCODER; // TODO: add specialized Short/Long decimal encoders case ArrayType arrayType -> new ArrayEncoder(arrayType, createTypeEncoder(arrayType.getElementType(), supportsParametricDateTime)); case MapType mapType -> new MapEncoder(mapType, createTypeEncoder(mapType.getValueType(), supportsParametricDateTime)); @@ -312,6 +318,23 @@ public void encode(JsonGenerator generator, Block block, int position) } } + private static final class VariantEncoder + implements TypeEncoder + { + @Override + public void encode(JsonGenerator generator, Block block, int position) + throws IOException + { + if (block.isNull(position)) { + generator.writeNull(); + return; + } + + Variant variant = VARIANT.getObject(block, position); + generator.writeRawValue(VariantUtil.asJson(variant).toStringUtf8()); + } + } + private static final class ArrayEncoder implements TypeEncoder { diff --git a/core/trino-main/src/main/java/io/trino/type/DecimalCasts.java b/core/trino-main/src/main/java/io/trino/type/DecimalCasts.java index 8cfec7fb3518..22cd18d135e9 100644 --- a/core/trino-main/src/main/java/io/trino/type/DecimalCasts.java +++ b/core/trino-main/src/main/java/io/trino/type/DecimalCasts.java @@ -35,7 +35,9 @@ import io.trino.spi.type.TrinoNumber; import io.trino.spi.type.TypeSignature; import io.trino.spi.type.VarcharType; +import io.trino.spi.variant.Variant; import io.trino.util.JsonCastException; +import io.trino.util.variant.VariantUtil; import java.io.IOException; import java.math.BigDecimal; @@ -62,6 +64,7 @@ import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.TypeParameter.typeVariable; import static io.trino.spi.type.VarcharType.UNBOUNDED_LENGTH; +import static io.trino.spi.type.VariantType.VARIANT; import static io.trino.type.JsonType.JSON; import static io.trino.util.Failures.checkCondition; import static io.trino.util.JsonUtil.createJsonFactory; @@ -102,6 +105,8 @@ public final class DecimalCasts // that never happens for SliceOutput. public static final SqlScalarFunction DECIMAL_TO_JSON_CAST = castFunctionFromDecimalTo(JSON.getTypeSignature(), true, "shortDecimalToJson", "longDecimalToJson"); public static final SqlScalarFunction JSON_TO_DECIMAL_CAST = castFunctionToDecimalFromBuilder(JSON.getTypeSignature(), true, false, "jsonToShortDecimal", "jsonToLongDecimal"); + public static final SqlScalarFunction DECIMAL_TO_VARIANT_CAST = castFunctionFromDecimalTo(VARIANT.getTypeSignature(), true, "shortDecimalToVariant", "longDecimalToVariant"); + public static final SqlScalarFunction VARIANT_TO_DECIMAL_CAST = castFunctionToDecimalFromBuilder(VARIANT.getTypeSignature(), true, false, "variantToShortDecimal", "variantToLongDecimal"); private static final JsonMapper JSON_MAPPER = new JsonMapper(createJsonFactory()); @@ -661,6 +666,30 @@ public static Long jsonToShortDecimal(Slice json, long precision, long scale, lo } } + @UsedByGeneratedCode + public static Variant shortDecimalToVariant(long decimal, long precision, long scale, long tenToScale) + { + return Variant.ofDecimal(BigDecimal.valueOf(decimal, DecimalConversions.intScale(scale))); + } + + @UsedByGeneratedCode + public static Variant longDecimalToVariant(Int128 decimal, long precision, long scale, Int128 tenToScale) + { + return Variant.ofDecimal(new BigDecimal(decimal.toBigInteger(), DecimalConversions.intScale(scale))); + } + + @UsedByGeneratedCode + public static Int128 variantToLongDecimal(Variant variant, long precision, long scale, Int128 tenToScale) + { + return VariantUtil.asLongDecimal(variant, intPrecision(precision), DecimalConversions.intScale(scale)); + } + + @UsedByGeneratedCode + public static Long variantToShortDecimal(Variant variant, long precision, long scale, long tenToScale) + { + return VariantUtil.asShortDecimal(variant, intPrecision(precision), DecimalConversions.intScale(scale)); + } + @SuppressWarnings("NumericCastThatLosesPrecision") private static int intPrecision(long precision) { diff --git a/core/trino-main/src/main/java/io/trino/type/VariantFunctions.java b/core/trino-main/src/main/java/io/trino/type/VariantFunctions.java new file mode 100644 index 000000000000..94f199e6b9b3 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/type/VariantFunctions.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.type; + +import io.airlift.slice.Slice; +import io.trino.spi.function.ScalarFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.type.StandardTypes; +import io.trino.spi.variant.Metadata; +import io.trino.spi.variant.Variant; + +public final class VariantFunctions +{ + private VariantFunctions() {} + + /** + * Decodes/deserializes binary variant metadata and value into a {@link Variant}. + * This is intended as a bridge when importing raw variant data, but generally should not be used directly. + */ + @ScalarFunction(hidden = true) + @SqlType(StandardTypes.VARIANT) + public static Variant decodeVariant(@SqlType(StandardTypes.VARBINARY) Slice metadata, @SqlType(StandardTypes.VARBINARY) Slice value) + { + return Variant.from(Metadata.from(metadata), value); + } + + @ScalarFunction + @SqlType(StandardTypes.BOOLEAN) + public static boolean variantIsNull(@SqlType(StandardTypes.VARIANT) Variant value) + { + return value.isNull(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/type/VariantOperators.java b/core/trino-main/src/main/java/io/trino/type/VariantOperators.java new file mode 100644 index 000000000000..47660bcefaf3 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/type/VariantOperators.java @@ -0,0 +1,400 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.type; + +import io.airlift.slice.Slice; +import io.trino.spi.TrinoException; +import io.trino.spi.function.LiteralParameter; +import io.trino.spi.function.LiteralParameters; +import io.trino.spi.function.ScalarOperator; +import io.trino.spi.function.SqlNullable; +import io.trino.spi.function.SqlType; +import io.trino.spi.type.LongTimestamp; +import io.trino.spi.type.LongTimestampWithTimeZone; +import io.trino.spi.type.StandardTypes; +import io.trino.spi.variant.Header; +import io.trino.spi.variant.Variant; +import io.trino.util.variant.VariantUtil; +import io.trino.util.variant.VariantWriter; + +import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static io.trino.spi.function.OperatorType.CAST; +import static io.trino.spi.function.OperatorType.SUBSCRIPT; +import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc; +import static java.lang.Float.intBitsToFloat; +import static java.lang.Math.multiplyExact; +import static java.lang.Math.toIntExact; +import static java.util.Locale.ENGLISH; + +public final class VariantOperators +{ + private VariantOperators() {} + + @ScalarOperator(SUBSCRIPT) + @SqlType(StandardTypes.VARIANT) + public static Variant dereference(@SqlType(StandardTypes.VARIANT) Variant value, @SqlType(StandardTypes.BIGINT) long index) + { + if (value.basicType() != Header.BasicType.ARRAY) { + throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "VARIANT value is %s, not an array".formatted(variantTypeName(value))); + } + checkArrayIndex(index, value.getArrayLength()); + return value.getArrayElement(toIntExact(index) - 1); + } + + private static void checkArrayIndex(long index, int arrayLength) + { + if (index == 0) { + throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "VARIANT array indices start at 1"); + } + if (index < 0) { + throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "VARIANT array subscript is negative: " + index); + } + if (index > arrayLength) { + throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "VARIANT array subscript must be less than or equal to array length: %d > %d".formatted(index, arrayLength)); + } + } + + @SqlNullable + @ScalarOperator(SUBSCRIPT) + @SqlType(StandardTypes.VARIANT) + public static Variant dereference(@SqlType(StandardTypes.VARIANT) Variant value, @SqlType(StandardTypes.VARCHAR) Slice fieldName) + { + if (value.basicType() != Header.BasicType.OBJECT) { + throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "VARIANT value is %s, not an object".formatted(variantTypeName(value))); + } + return value.getObjectField(fieldName).orElse(null); + } + + private static String variantTypeName(Variant value) + { + if (value == null) { + return "null"; + } + return switch (value.basicType()) { + case PRIMITIVE -> value.primitiveType().name().toLowerCase(ENGLISH); + case SHORT_STRING -> "string"; + case OBJECT, ARRAY -> value.basicType().name().toLowerCase(ENGLISH); + }; + } + + @SqlNullable + @ScalarOperator(CAST) + @SqlType(StandardTypes.BOOLEAN) + public static Boolean castToBoolean(@SqlType(StandardTypes.VARIANT) Variant value) + { + return VariantUtil.asBoolean(value); + } + + @ScalarOperator(CAST) + @SqlType(StandardTypes.VARIANT) + public static Variant castFromBoolean(@SqlType(StandardTypes.BOOLEAN) boolean value) + { + return Variant.ofBoolean(value); + } + + @SqlNullable + @ScalarOperator(CAST) + @SqlType(StandardTypes.TINYINT) + public static Long castToTinyint(@SqlType(StandardTypes.VARIANT) Variant value) + { + return VariantUtil.asTinyint(value); + } + + @ScalarOperator(CAST) + @SqlType(StandardTypes.VARIANT) + public static Variant castFromTinyint(@SqlType(StandardTypes.TINYINT) long value) + { + return Variant.ofByte((byte) value); + } + + @SqlNullable + @ScalarOperator(CAST) + @SqlType(StandardTypes.SMALLINT) + public static Long castToSmallint(@SqlType(StandardTypes.VARIANT) Variant value) + { + return VariantUtil.asSmallint(value); + } + + @ScalarOperator(CAST) + @SqlType(StandardTypes.VARIANT) + public static Variant castFromSmallint(@SqlType(StandardTypes.SMALLINT) long value) + { + return Variant.ofShort((short) value); + } + + @SqlNullable + @ScalarOperator(CAST) + @SqlType(StandardTypes.INTEGER) + public static Long castToInteger(@SqlType(StandardTypes.VARIANT) Variant value) + { + return VariantUtil.asInteger(value); + } + + @ScalarOperator(CAST) + @SqlType(StandardTypes.VARIANT) + public static Variant castFromInteger(@SqlType(StandardTypes.INTEGER) long value) + { + return Variant.ofInt((int) value); + } + + @SqlNullable + @ScalarOperator(CAST) + @SqlType(StandardTypes.BIGINT) + public static Long castToBigint(@SqlType(StandardTypes.VARIANT) Variant value) + { + return VariantUtil.asBigint(value); + } + + @ScalarOperator(CAST) + @SqlType(StandardTypes.VARIANT) + public static Variant castFromBigint(@SqlType(StandardTypes.BIGINT) long value) + { + return Variant.ofLong(value); + } + + @SqlNullable + @ScalarOperator(CAST) + @SqlType(StandardTypes.REAL) + public static Long castToReal(@SqlType(StandardTypes.VARIANT) Variant value) + { + return VariantUtil.asReal(value); + } + + @ScalarOperator(CAST) + @SqlType(StandardTypes.VARIANT) + public static Variant castFromReal(@SqlType(StandardTypes.REAL) long value) + { + float floatValue = intBitsToFloat((int) value); + return Variant.ofFloat(floatValue); + } + + @SqlNullable + @ScalarOperator(CAST) + @SqlType(StandardTypes.DOUBLE) + public static Double castToDouble(@SqlType(StandardTypes.VARIANT) Variant value) + { + return VariantUtil.asDouble(value); + } + + @ScalarOperator(CAST) + @SqlType(StandardTypes.VARIANT) + public static Variant castFromDouble(@SqlType(StandardTypes.DOUBLE) double value) + { + return Variant.ofDouble(value); + } + + @SqlNullable + @ScalarOperator(CAST) + @SqlType(StandardTypes.VARCHAR) + public static Slice castToVarchar(@SqlType(StandardTypes.VARIANT) Variant value) + { + return VariantUtil.asVarchar(value); + } + + @ScalarOperator(CAST) + @LiteralParameters("x") + @SqlType(StandardTypes.VARIANT) + public static Variant castFromVarchar(@SqlType("varchar(x)") Slice value) + { + return Variant.ofString(value); + } + + @SqlNullable + @ScalarOperator(CAST) + @SqlType(StandardTypes.VARBINARY) + public static Slice castToVarbinary(@SqlType(StandardTypes.VARIANT) Variant value) + { + return VariantUtil.asVarbinary(value); + } + + @ScalarOperator(CAST) + @SqlType(StandardTypes.VARIANT) + public static Variant castFromVarbinary(@SqlType(StandardTypes.VARBINARY) Slice value) + { + return Variant.ofBinary(value); + } + + @SqlNullable + @ScalarOperator(CAST) + @SqlType(StandardTypes.DATE) + public static Long castToDate(@SqlType(StandardTypes.VARIANT) Variant value) + { + return VariantUtil.asDate(value); + } + + @ScalarOperator(CAST) + @SqlType(StandardTypes.VARIANT) + public static Variant castFromDate(@SqlType(StandardTypes.DATE) long value) + { + return Variant.ofDate(toIntExact(value)); + } + + @SqlNullable + @ScalarOperator(CAST) + @SqlType(StandardTypes.UUID) + public static Slice castToUuid(@SqlType(StandardTypes.VARIANT) Variant value) + { + return VariantUtil.asUuid(value); + } + + @ScalarOperator(CAST) + @SqlType(StandardTypes.VARIANT) + public static Variant castFromUuid(@SqlType(StandardTypes.UUID) Slice value) + { + return Variant.ofUuid(value); + } + + @SqlNullable + @ScalarOperator(CAST) + @SqlType(StandardTypes.JSON) + public static Slice castToJson(@SqlType(StandardTypes.VARIANT) Variant value) + { + return VariantUtil.asJson(value); + } + + private static final VariantWriter JSON_VARIANT_WRITER = VariantWriter.create(JsonType.JSON); + + @ScalarOperator(CAST) + @SqlType(StandardTypes.VARIANT) + public static Variant castFromJson(@SqlType(StandardTypes.JSON) Slice value) + { + return JSON_VARIANT_WRITER.write(value); + } + + @ScalarOperator(CAST) + public static final class VariantToTimeCast + { + private VariantToTimeCast() {} + + @LiteralParameters("p") + @SqlNullable + @SqlType("time(p)") + public static Long castToTime(@LiteralParameter("p") long precision, @SqlType(StandardTypes.VARIANT) Variant value) + { + return VariantUtil.asTime(value, toIntExact(precision)); + } + } + + @ScalarOperator(CAST) + public static final class VariantFromTimeCast + { + private VariantFromTimeCast() {} + + @LiteralParameters("p") + @SqlType(StandardTypes.VARIANT) + public static Variant castFromTime(@LiteralParameter("p") long precision, @SqlType("time(p)") long epochPicos) + { + return Variant.ofTimeMicrosNtz(epochPicos / 1_000_000L); + } + } + + @ScalarOperator(CAST) + public static final class VariantToTimestampCast + { + private VariantToTimestampCast() {} + + @LiteralParameters("p") + @SqlNullable + @SqlType("timestamp(p)") + public static Long castToShortTimestamp(@LiteralParameter("p") long precision, @SqlType(StandardTypes.VARIANT) Variant value) + { + return VariantUtil.asShortTimestamp(value, toIntExact(precision)); + } + + @LiteralParameters("p") + @SqlNullable + @SqlType("timestamp(p)") + public static LongTimestamp castToLongTimestamp(@LiteralParameter("p") long precision, @SqlType(StandardTypes.VARIANT) Variant value) + { + return VariantUtil.asLongTimestamp(value, toIntExact(precision)); + } + } + + @ScalarOperator(CAST) + public static final class VariantFromTimestampCast + { + private VariantFromTimestampCast() {} + + @LiteralParameters("p") + @SqlType(StandardTypes.VARIANT) + public static Variant castFromTimestamp(@LiteralParameter("p") long precision, @SqlType("timestamp(p)") long epochMicros) + { + return Variant.ofTimestampMicrosNtz(epochMicros); + } + + @LiteralParameters("p") + @SqlType(StandardTypes.VARIANT) + public static Variant castFromTimestamp(@LiteralParameter("p") long precision, @SqlType("timestamp(p)") LongTimestamp timestamp) + { + long nanosFromMicros = multiplyExact(timestamp.getEpochMicros(), 1_000L); + long extraNanos = timestamp.getPicosOfMicro() / 1_000; // 1000 ps = 1 ns + long nanos = Math.addExact(nanosFromMicros, extraNanos); + + return Variant.ofTimestampNanosNtz(nanos); + } + } + + @ScalarOperator(CAST) + public static final class VariantToTimestampWithTimeZoneCasts + { + private VariantToTimestampWithTimeZoneCasts() {} + + @LiteralParameters("p") + @SqlNullable + @SqlType("timestamp(p) with time zone") + public static Long castToShortTimestampWithTimeZone(@LiteralParameter("p") long precision, @SqlType(StandardTypes.VARIANT) Variant variant) + { + return VariantUtil.asShortTimestampWithTimeZone(variant, toIntExact(precision)); + } + + @LiteralParameters("p") + @SqlNullable + @SqlType("timestamp(p) with time zone") + public static LongTimestampWithTimeZone castToLongTimestampWithTimeZone(@LiteralParameter("p") long precision, @SqlType(StandardTypes.VARIANT) Variant variant) + { + return VariantUtil.asLongTimestampWithTimeZone(variant, toIntExact(precision)); + } + } + + @ScalarOperator(CAST) + public static final class VariantFromTimestampWithTimeZoneCasts + { + private VariantFromTimestampWithTimeZoneCasts() {} + + @LiteralParameters("p") + @SqlType(StandardTypes.VARIANT) + public static Variant castFromTimestampWithTimeZone(@LiteralParameter("p") long precision, @SqlType("timestamp(p) with time zone") long packedEpochMillis) + { + long epochMillis = unpackMillisUtc(packedEpochMillis); + return Variant.ofTimestampMicrosUtc(multiplyExact(epochMillis, 1_000L)); + } + + @LiteralParameters("p") + @SqlType(StandardTypes.VARIANT) + public static Variant castFromTimestampWithTimeZone(@LiteralParameter("p") long precision, @SqlType("timestamp(p) with time zone") LongTimestampWithTimeZone timestamp) + { + if (precision <= 6) { + long millisFromMillis = multiplyExact(timestamp.getEpochMillis(), 1000L); + int extraMillis = timestamp.getPicosOfMilli() / 1_000_000; + long epochMicros = Math.addExact(millisFromMillis, extraMillis); + return Variant.ofTimestampMicrosUtc(epochMicros); + } + + long nanosFromMillis = multiplyExact(timestamp.getEpochMillis(), 1_000_000L); + int extraNanos = timestamp.getPicosOfMilli() / 1_000; + long nanos = Math.addExact(nanosFromMillis, extraNanos); + return Variant.ofTimestampNanosUtc(nanos); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/util/variant/ArrayVariantWriter.java b/core/trino-main/src/main/java/io/trino/util/variant/ArrayVariantWriter.java new file mode 100644 index 000000000000..54b4419f54df --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/util/variant/ArrayVariantWriter.java @@ -0,0 +1,39 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.util.variant; + +import io.trino.spi.block.Block; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.Type; +import io.trino.spi.variant.Metadata; + +record ArrayVariantWriter(ArrayType type, VariantWriter elementWriter) + implements VariantWriter +{ + @Override + public PlannedValue plan(Metadata.Builder metadataBuilder, Object value) + { + if (value == null) { + return NullPlannedValue.NULL_PLANNED_VALUE; + } + + Type elementType = type.getElementType(); + Block array = (Block) value; + PlannedValue[] planned = new PlannedValue[array.getPositionCount()]; + for (int position = 0; position < array.getPositionCount(); position++) { + planned[position] = elementWriter.plan(metadataBuilder, elementType.getObject(array, position)); + } + return new PlannedArrayValue(planned); + } +} diff --git a/core/trino-main/src/main/java/io/trino/util/variant/JsonVariantWriter.java b/core/trino-main/src/main/java/io/trino/util/variant/JsonVariantWriter.java new file mode 100644 index 000000000000..782ac1b776b6 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/util/variant/JsonVariantWriter.java @@ -0,0 +1,324 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.util.variant; + +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonToken; +import io.airlift.slice.Slice; +import io.trino.plugin.base.util.JsonUtils; +import io.trino.spi.TrinoException; +import io.trino.spi.variant.Metadata; +import it.unimi.dsi.fastutil.ints.IntArrayList; +import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import it.unimi.dsi.fastutil.ints.IntSet; + +import java.io.IOException; +import java.io.InputStream; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.util.ArrayList; +import java.util.List; +import java.util.function.IntUnaryOperator; + +import static com.fasterxml.jackson.core.JsonTokenId.ID_FALSE; +import static com.fasterxml.jackson.core.JsonTokenId.ID_NULL; +import static com.fasterxml.jackson.core.JsonTokenId.ID_NUMBER_FLOAT; +import static com.fasterxml.jackson.core.JsonTokenId.ID_NUMBER_INT; +import static com.fasterxml.jackson.core.JsonTokenId.ID_START_ARRAY; +import static com.fasterxml.jackson.core.JsonTokenId.ID_START_OBJECT; +import static com.fasterxml.jackson.core.JsonTokenId.ID_STRING; +import static com.fasterxml.jackson.core.JsonTokenId.ID_TRUE; +import static io.airlift.slice.Slices.utf8Slice; +import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static io.trino.spi.variant.VariantEncoder.ENCODED_BOOLEAN_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_DECIMAL16_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_DOUBLE_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_FLOAT_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_INT_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_LONG_SIZE; +import static io.trino.spi.variant.VariantEncoder.encodeBoolean; +import static io.trino.spi.variant.VariantEncoder.encodeDecimal16; +import static io.trino.spi.variant.VariantEncoder.encodeDouble; +import static io.trino.spi.variant.VariantEncoder.encodeFloat; +import static io.trino.spi.variant.VariantEncoder.encodeInt; +import static io.trino.spi.variant.VariantEncoder.encodeLong; +import static io.trino.spi.variant.VariantEncoder.encodeString; +import static io.trino.spi.variant.VariantEncoder.encodedStringSize; +import static io.trino.util.variant.NullPlannedValue.NULL_PLANNED_VALUE; +import static java.util.Objects.requireNonNull; + +final class JsonVariantWriter + implements VariantWriter +{ + public static final JsonVariantWriter JSON_VARIANT_WRITER = new JsonVariantWriter(); + + private static final JsonFactory JSON_FACTORY = JsonUtils.jsonFactory(); + + private JsonVariantWriter() {} + + @Override + public PlannedValue plan(Metadata.Builder metadataBuilder, Object value) + { + if (value == null) { + return NULL_PLANNED_VALUE; + } + + Slice json = (Slice) value; + + try (InputStream input = json.getInput(); JsonParser parser = JSON_FACTORY.createParser(input)) { + JsonToken token = parser.nextToken(); + if (token == null) { + throw new IllegalArgumentException("Invalid JSON input for VARIANT: empty input"); + } + + PlannedValue planned = planValue(parser, token, metadataBuilder); + + // Ensure we consumed exactly one top-level value + JsonToken trailing = parser.nextToken(); + if (trailing != null) { + throw new IllegalArgumentException("Invalid JSON input for VARIANT: trailing data after top-level value"); + } + + return planned; + } + catch (IOException e) { + throw new IllegalArgumentException("Invalid JSON input for VARIANT", e); + } + } + + private static PlannedValue planValue(JsonParser parser, JsonToken token, Metadata.Builder metadataBuilder) + throws IOException + { + return switch (token.id()) { + case ID_NULL -> NULL_PLANNED_VALUE; + case ID_TRUE -> new PlannedBooleanValue(true); + case ID_FALSE -> new PlannedBooleanValue(false); + case ID_STRING -> new PlannedStringValue(utf8Slice(parser.getText())); + case ID_NUMBER_INT, ID_NUMBER_FLOAT -> switch (parser.getNumberType()) { + case INT -> new PlannedIntValue(parser.getIntValue()); + case LONG -> new PlannedLongValue(parser.getLongValue()); + case BIG_INTEGER -> new PlannedDecimal16Value(parser.getBigIntegerValue(), 0); + case FLOAT -> new PlannedFloatValue(parser.getFloatValue()); + case DOUBLE -> new PlannedDoubleValue(parser.getDoubleValue()); + case BIG_DECIMAL -> { + BigDecimal decimal = parser.getDecimalValue(); + yield new PlannedDecimal16Value(decimal.unscaledValue(), decimal.scale()); + } + }; + case ID_START_OBJECT -> planObject(parser, metadataBuilder); + case ID_START_ARRAY -> planArray(parser, metadataBuilder); + + default -> throw new IllegalArgumentException("Unsupported JSON token for VARIANT: " + token); + }; + } + + private static PlannedValue planObject(JsonParser parser, Metadata.Builder metadataBuilder) + throws IOException + { + // JSON object field order in VARIANT is lexicographical by field name bytes (via fieldId ordering), + // so we record provisional fieldIds and values now, and sort by final fieldIds at write time. + // + // Duplicate keys: fail (matches JSON -> MAP cast behavior). + IntArrayList fieldIds = new IntArrayList(); + List values = new ArrayList<>(); + + IntSet seeFields = new IntOpenHashSet(); + + for (JsonToken token = parser.nextToken(); token != JsonToken.END_OBJECT; token = parser.nextToken()) { + if (token == null) { + throw new IllegalArgumentException("Invalid JSON input for VARIANT: unexpected EOF in object"); + } + if (token != JsonToken.FIELD_NAME) { + throw new IllegalArgumentException("Invalid JSON input for VARIANT: expected FIELD_NAME but got " + token); + } + + String fieldName = parser.currentName(); + int provisionalFieldId = metadataBuilder.addFieldName(utf8Slice(fieldName)); + + if (!seeFields.add(provisionalFieldId)) { + throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "Duplicate map keys are not allowed"); + } + + JsonToken valueToken = parser.nextToken(); + if (valueToken == null) { + throw new IllegalArgumentException("Invalid JSON input for VARIANT: unexpected EOF after field name"); + } + + PlannedValue plannedValue = planValue(parser, valueToken, metadataBuilder); + + fieldIds.add(provisionalFieldId); + values.add(plannedValue); + } + + return new PlannedObjectValue(fieldIds.toIntArray(), values.toArray(PlannedValue[]::new)); + } + + private static PlannedValue planArray(JsonParser parser, Metadata.Builder metadataBuilder) + throws IOException + { + List elements = new ArrayList<>(); + for (JsonToken token = parser.nextToken(); token != JsonToken.END_ARRAY; token = parser.nextToken()) { + if (token == null) { + throw new IllegalArgumentException("Invalid JSON input for VARIANT: unexpected EOF in array"); + } + elements.add(planValue(parser, token, metadataBuilder)); + } + return new PlannedArrayValue(elements.toArray(PlannedValue[]::new)); + } + + private record PlannedBooleanValue(boolean value) + implements PlannedValue + { + @Override + public void finalize(IntUnaryOperator sortedFieldIdMapping) {} + + @Override + public int size() + { + return ENCODED_BOOLEAN_SIZE; + } + + @Override + public int write(Slice out, int offset) + { + return encodeBoolean(value, out, offset); + } + } + + private record PlannedIntValue(int value) + implements PlannedValue + { + @Override + public void finalize(IntUnaryOperator sortedFieldIdMapping) {} + + @Override + public int size() + { + return ENCODED_INT_SIZE; + } + + @Override + public int write(Slice out, int offset) + { + return encodeInt(value, out, offset); + } + } + + private record PlannedLongValue(long value) + implements PlannedValue + { + @Override + public void finalize(IntUnaryOperator sortedFieldIdMapping) {} + + @Override + public int size() + { + return ENCODED_LONG_SIZE; + } + + @Override + public int write(Slice out, int offset) + { + return encodeLong(value, out, offset); + } + } + + private record PlannedFloatValue(float value) + implements PlannedValue + { + @Override + public void finalize(IntUnaryOperator sortedFieldIdMapping) {} + + @Override + public int size() + { + return ENCODED_FLOAT_SIZE; + } + + @Override + public int write(Slice out, int offset) + { + return encodeFloat(value, out, offset); + } + } + + private record PlannedDoubleValue(double value) + implements PlannedValue + { + @Override + public void finalize(IntUnaryOperator sortedFieldIdMapping) {} + + @Override + public int size() + { + return ENCODED_DOUBLE_SIZE; + } + + @Override + public int write(Slice out, int offset) + { + return encodeDouble(value, out, offset); + } + } + + private record PlannedDecimal16Value(BigInteger unscaledValue, int scale) + implements PlannedValue + { + private PlannedDecimal16Value(BigInteger unscaledValue, int scale) + { + this.unscaledValue = requireNonNull(unscaledValue, "unscaledValue is null"); + this.scale = scale; + } + + @Override + public void finalize(IntUnaryOperator sortedFieldIdMapping) {} + + @Override + public int size() + { + return ENCODED_DECIMAL16_SIZE; + } + + @Override + public int write(Slice out, int offset) + { + return encodeDecimal16(unscaledValue, scale, out, offset); + } + } + + private record PlannedStringValue(Slice value) + implements PlannedValue + { + private PlannedStringValue(Slice value) + { + this.value = requireNonNull(value, "value is null"); + } + + @Override + public void finalize(IntUnaryOperator sortedFieldIdMapping) {} + + @Override + public int size() + { + return encodedStringSize(value.length()); + } + + @Override + public int write(Slice out, int offset) + { + return encodeString(value, out, offset); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/util/variant/MapVariantWriter.java b/core/trino-main/src/main/java/io/trino/util/variant/MapVariantWriter.java new file mode 100644 index 000000000000..19d216eaaa4b --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/util/variant/MapVariantWriter.java @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.util.variant; + +import io.trino.spi.block.Block; +import io.trino.spi.block.SqlMap; +import io.trino.spi.type.MapType; +import io.trino.spi.type.Type; +import io.trino.spi.variant.Metadata; + +import static io.trino.util.variant.PrimitiveMapVariantWriter.getMapKeys; + +public record MapVariantWriter(MapType type, VariantWriter valueWriter) + implements VariantWriter +{ + @Override + public PlannedValue plan(Metadata.Builder metadataBuilder, Object value) + { + if (value == null) { + return NullPlannedValue.NULL_PLANNED_VALUE; + } + + SqlMap sqlMap = (SqlMap) value; + int[] fieldIds = metadataBuilder.addFieldNames(getMapKeys(sqlMap)); + + PlannedValue[] plannedValues = new PlannedValue[sqlMap.getSize()]; + + Type valueType = type.getValueType(); + Block valueBlock = sqlMap.getRawValueBlock(); + int valueBlockOffset = sqlMap.getRawOffset(); + for (int entry = 0; entry < sqlMap.getSize(); entry++) { + plannedValues[entry] = valueWriter.plan(metadataBuilder, valueType.getObject(valueBlock, valueBlockOffset + entry)); + } + + return new PlannedObjectValue(fieldIds, plannedValues); + } +} diff --git a/core/trino-main/src/main/java/io/trino/util/variant/NullPlannedValue.java b/core/trino-main/src/main/java/io/trino/util/variant/NullPlannedValue.java new file mode 100644 index 000000000000..9e3580e40d05 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/util/variant/NullPlannedValue.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.util.variant; + +import io.airlift.slice.Slice; +import io.trino.util.variant.VariantWriter.PlannedValue; + +import java.util.function.IntUnaryOperator; + +import static io.trino.spi.variant.VariantEncoder.ENCODED_NULL_SIZE; +import static io.trino.spi.variant.VariantEncoder.encodeNull; + +final class NullPlannedValue + implements PlannedValue +{ + static final PlannedValue NULL_PLANNED_VALUE = new NullPlannedValue(); + + private NullPlannedValue() {} + + @Override + public void finalize(IntUnaryOperator sortedFieldIdMapping) {} + + @Override + public int size() + { + return ENCODED_NULL_SIZE; + } + + @Override + public int write(Slice out, int offset) + { + return encodeNull(out, offset); + } +} diff --git a/core/trino-main/src/main/java/io/trino/util/variant/PlannedArrayValue.java b/core/trino-main/src/main/java/io/trino/util/variant/PlannedArrayValue.java new file mode 100644 index 000000000000..c2ec3f1cbe46 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/util/variant/PlannedArrayValue.java @@ -0,0 +1,70 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.util.variant; + +import io.airlift.slice.Slice; + +import java.util.function.IntUnaryOperator; + +import static io.trino.spi.variant.VariantEncoder.encodeArrayHeading; +import static io.trino.spi.variant.VariantEncoder.encodedArraySize; + +final class PlannedArrayValue + implements VariantWriter.PlannedValue +{ + private final VariantWriter.PlannedValue[] elements; + private int size = -1; + + PlannedArrayValue(VariantWriter.PlannedValue[] elements) + { + this.elements = elements; + } + + @Override + public void finalize(IntUnaryOperator sortedFieldIdMapping) + { + if (size >= 0) { + throw new IllegalStateException("finalize() already called"); + } + + int totalElementLength = 0; + for (VariantWriter.PlannedValue element : elements) { + element.finalize(sortedFieldIdMapping); + totalElementLength += element.size(); + } + size = encodedArraySize(elements.length, totalElementLength); + } + + @Override + public int size() + { + if (size < 0) { + throw new IllegalStateException("finalize() must be called before size()"); + } + return size; + } + + @Override + public int write(Slice out, int offset) + { + int count = elements.length; + int headerSize = encodeArrayHeading(count, i -> elements[i].size(), out, offset); + + int written = headerSize; + for (int i = 0; i < count; i++) { + written += elements[i].write(out, offset + written); + } + return written; + } +} diff --git a/core/trino-main/src/main/java/io/trino/util/variant/PlannedObjectValue.java b/core/trino-main/src/main/java/io/trino/util/variant/PlannedObjectValue.java new file mode 100644 index 000000000000..d946c266800c --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/util/variant/PlannedObjectValue.java @@ -0,0 +1,97 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.util.variant; + +import io.airlift.slice.Slice; + +import java.util.function.IntUnaryOperator; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.spi.variant.VariantEncoder.encodeObjectHeading; +import static io.trino.spi.variant.VariantEncoder.encodedObjectSize; +import static io.trino.util.variant.PrimitiveMapVariantWriter.determineWriteOrder; +import static java.util.Objects.requireNonNull; + +final class PlannedObjectValue + implements VariantWriter.PlannedValue +{ + // Initially, field IDs are provisional and must be remapped in finalize() + // This array is mutated in finalize() + private final int[] fieldIds; + private final VariantWriter.PlannedValue[] values; + + private int size = -1; + + PlannedObjectValue(int[] fieldIds, VariantWriter.PlannedValue[] values) + { + this.fieldIds = requireNonNull(fieldIds, "fieldIds is null"); + this.values = requireNonNull(values, "values is null"); + checkArgument(fieldIds.length == values.length, "fieldIds length %s does not match values length %s", fieldIds.length, values.length); + } + + @Override + public void finalize(IntUnaryOperator sortedFieldIdMapping) + { + if (size >= 0) { + throw new IllegalStateException("finalize() already called"); + } + + int maxFieldId = -1; + for (int i = 0; i < fieldIds.length; i++) { + int finalFieldId = sortedFieldIdMapping.applyAsInt(fieldIds[i]); + fieldIds[i] = finalFieldId; + maxFieldId = Math.max(maxFieldId, finalFieldId); + } + + int totalElementLength = 0; + for (VariantWriter.PlannedValue plannedValue : values) { + plannedValue.finalize(sortedFieldIdMapping); + totalElementLength += plannedValue.size(); + } + + size = encodedObjectSize(maxFieldId, fieldIds.length, totalElementLength); + } + + @Override + public int size() + { + if (size < 0) { + throw new IllegalStateException("finalize() must be called before size()"); + } + return size; + } + + @Override + public int write(Slice out, int offset) + { + if (size < 0) { + throw new IllegalStateException("finalize() must be called before write()"); + } + + int[] writeOrder = determineWriteOrder(fieldIds); + + int written = encodeObjectHeading( + fieldIds.length, + i -> fieldIds[writeOrder[i]], + i -> values[writeOrder[i]].size(), + out, + offset); + + for (int entry : writeOrder) { + written += values[entry].write(out, offset + written); + } + + return written; + } +} diff --git a/core/trino-main/src/main/java/io/trino/util/variant/PrimitiveArrayVariantWriter.java b/core/trino-main/src/main/java/io/trino/util/variant/PrimitiveArrayVariantWriter.java new file mode 100644 index 000000000000..4c45c3502d3c --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/util/variant/PrimitiveArrayVariantWriter.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.util.variant; + +import io.airlift.slice.Slice; +import io.trino.spi.block.Block; +import io.trino.spi.type.ArrayType; +import io.trino.spi.variant.Metadata; + +import java.util.function.IntUnaryOperator; + +import static io.trino.spi.variant.VariantEncoder.encodeArrayHeading; +import static io.trino.spi.variant.VariantEncoder.encodeNull; +import static io.trino.spi.variant.VariantEncoder.encodedArraySize; + +record PrimitiveArrayVariantWriter(ArrayType type, PrimitiveVariantEncoder elementEncoder) + implements VariantWriter +{ + @Override + public PlannedValue plan(Metadata.Builder metadataBuilder, Object value) + { + if (value == null) { + return NullPlannedValue.NULL_PLANNED_VALUE; + } + + Block array = (Block) value; + int totalElementsLength = 0; + for (int i = 0; i < array.getPositionCount(); i++) { + totalElementsLength += elementEncoder.size(array, i); + } + int totalSize = encodedArraySize(array.getPositionCount(), totalElementsLength); + return new PrimitiveArrayPlannedValue(array, totalSize, elementEncoder); + } + + private record PrimitiveArrayPlannedValue(Block elements, int size, PrimitiveVariantEncoder primitiveEncoder) + implements PlannedValue + { + @Override + public void finalize(IntUnaryOperator sortedFieldIdMapping) {} + + @Override + public int write(Slice out, int offset) + { + int written = encodeArrayHeading(elements.getPositionCount(), position -> primitiveEncoder.size(elements, position), out, offset); + for (int i = 0; i < elements.getPositionCount(); i++) { + if (elements.isNull(i)) { + written += encodeNull(out, offset + written); + } + else { + written += primitiveEncoder.write(elements, i, out, offset + written); + } + } + return written; + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/util/variant/PrimitiveMapVariantWriter.java b/core/trino-main/src/main/java/io/trino/util/variant/PrimitiveMapVariantWriter.java new file mode 100644 index 000000000000..c5ae8aba81a7 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/util/variant/PrimitiveMapVariantWriter.java @@ -0,0 +1,162 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.util.variant; + +import com.google.common.collect.ImmutableSet; +import io.airlift.slice.Slice; +import io.trino.spi.block.Block; +import io.trino.spi.block.SqlMap; +import io.trino.spi.block.VariableWidthBlock; +import io.trino.spi.type.MapType; +import io.trino.spi.variant.Metadata; +import it.unimi.dsi.fastutil.ints.IntArrays; + +import java.util.List; +import java.util.function.IntUnaryOperator; + +import static io.trino.spi.variant.VariantEncoder.encodeObjectHeading; +import static io.trino.spi.variant.VariantEncoder.encodedObjectSize; +import static java.util.Objects.requireNonNull; + +public record PrimitiveMapVariantWriter(MapType type, PrimitiveVariantEncoder valueEncoder) + implements VariantWriter +{ + @Override + public PlannedValue plan(Metadata.Builder metadataBuilder, Object value) + { + if (value == null) { + return NullPlannedValue.NULL_PLANNED_VALUE; + } + + SqlMap sqlMap = (SqlMap) value; + int[] fieldIds = metadataBuilder.addFieldNames(getMapKeys(sqlMap)); + + return new PlannedPrimitiveMapValue(fieldIds, sqlMap.getRawValueBlock(), sqlMap.getRawOffset(), valueEncoder); + } + + private static final class PlannedPrimitiveMapValue + implements PlannedValue + { + // Initially, field IDs are provisional and must be remapped in finalize() + // This array is mutated in finalize() + private final int[] fieldIds; + private final Block valueBlock; + private final int valueBlockOffset; + private final PrimitiveVariantEncoder valueEncoder; + + private int size = -1; + + private PlannedPrimitiveMapValue(int[] fieldIds, Block valueBlock, int valueBlockOffset, PrimitiveVariantEncoder valueEncoder) + { + this.fieldIds = requireNonNull(fieldIds, "fieldIds is null"); + this.valueBlock = requireNonNull(valueBlock, "valueBlock is null"); + this.valueBlockOffset = valueBlockOffset; + this.valueEncoder = requireNonNull(valueEncoder, "valueEncoder is null"); + } + + @Override + public void finalize(IntUnaryOperator sortedFieldIdMapping) + { + if (size >= 0) { + throw new IllegalStateException("finalize() already called"); + } + + int maxFieldId = -1; + for (int i = 0; i < fieldIds.length; i++) { + int finalFieldId = sortedFieldIdMapping.applyAsInt(fieldIds[i]); + fieldIds[i] = finalFieldId; + maxFieldId = Math.max(maxFieldId, finalFieldId); + } + + int totalElementLength = 0; + for (int i = 0; i < fieldIds.length; i++) { + totalElementLength += valueEncoder.size(valueBlock, valueBlockOffset + i); + } + size = encodedObjectSize(maxFieldId, fieldIds.length, totalElementLength); + } + + @Override + public int size() + { + if (size < 0) { + throw new IllegalStateException("finalize() must be called before size()"); + } + return size; + } + + @Override + public int write(Slice out, int offset) + { + if (size < 0) { + throw new IllegalStateException("finalize() must be called before write()"); + } + + int[] writeOrder = determineWriteOrder(fieldIds); + + int written = encodeObjectHeading( + fieldIds.length, + i -> fieldIds[writeOrder[i]], + i -> { + int entry = writeOrder[i]; + return valueEncoder.size(valueBlock, valueBlockOffset + entry); + }, + out, + offset); + + for (int entry : writeOrder) { + written += valueEncoder.write(valueBlock, valueBlockOffset + entry, out, offset + written); + } + return written; + } + } + + /// Extract map keys from the SqlMap, ensuring no null or duplicate keys + /// @throws IllegalArgumentException if a map key is null or if there are duplicate keys + static List getMapKeys(SqlMap sqlMap) + { + Block keyBlock = sqlMap.getRawKeyBlock(); + VariableWidthBlock underlyingBlock = (VariableWidthBlock) keyBlock.getUnderlyingValueBlock(); + int offset = sqlMap.getRawOffset(); + int size = sqlMap.getSize(); + ImmutableSet.Builder keySet = ImmutableSet.builderWithExpectedSize(size); + for (int i = 0; i < size; i++) { + int underlyingPosition = keyBlock.getUnderlyingValuePosition(offset + i); + if (underlyingBlock.isNull(underlyingPosition)) { + throw new IllegalArgumentException("Map key is null"); + } + keySet.add(underlyingBlock.getSlice(underlyingPosition)); + } + // ImmutableSet as list preserves insertion order + List keys = keySet.build().asList(); + if (keys.size() != size) { + throw new IllegalArgumentException("Map contains duplicate keys"); + } + return keys; + } + + /// Determine the order to write entries. Object fields must be written in lexicographical + /// order of the field names. Since the metadata dictionary is sorted, the fieldIds are also + /// in lexicographical order of the field names. + /// @param fieldIds from a globally sorted metadata dictionary + /// @return the order to write entries so that field names are in lexicographical order + static int[] determineWriteOrder(int[] fieldIds) + { + int[] writeOrder = new int[fieldIds.length]; + for (int i = 0; i < writeOrder.length; i++) { + writeOrder[i] = i; + } + IntArrays.unstableSort(writeOrder, 0, writeOrder.length, (left, right) -> Integer.compare(fieldIds[left], fieldIds[right])); + return writeOrder; + } +} diff --git a/core/trino-main/src/main/java/io/trino/util/variant/PrimitiveVariantEncoder.java b/core/trino-main/src/main/java/io/trino/util/variant/PrimitiveVariantEncoder.java new file mode 100644 index 000000000000..447865a80b89 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/util/variant/PrimitiveVariantEncoder.java @@ -0,0 +1,487 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.util.variant; + +import io.airlift.slice.Slice; +import io.trino.spi.block.Block; +import io.trino.spi.block.VariableWidthBlock; +import io.trino.spi.type.BigintType; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.DateType; +import io.trino.spi.type.DecimalType; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.Int128; +import io.trino.spi.type.IntegerType; +import io.trino.spi.type.LongTimestamp; +import io.trino.spi.type.LongTimestampWithTimeZone; +import io.trino.spi.type.RealType; +import io.trino.spi.type.SmallintType; +import io.trino.spi.type.TimeType; +import io.trino.spi.type.TimestampType; +import io.trino.spi.type.TimestampWithTimeZoneType; +import io.trino.spi.type.TinyintType; +import io.trino.spi.type.Type; +import io.trino.spi.type.UuidType; +import io.trino.spi.type.VarbinaryType; +import io.trino.spi.type.VarcharType; +import io.trino.spi.variant.VariantEncoder; +import io.trino.type.UnknownType; + +import java.util.Optional; + +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.RealType.REAL; +import static io.trino.spi.type.SmallintType.SMALLINT; +import static io.trino.spi.type.TinyintType.TINYINT; +import static io.trino.spi.type.UuidType.UUID; +import static io.trino.spi.variant.VariantEncoder.ENCODED_BOOLEAN_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_BYTE_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_DATE_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_DECIMAL16_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_DECIMAL8_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_DOUBLE_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_FLOAT_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_INT_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_LONG_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_NULL_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_SHORT_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_TIMESTAMP_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_TIME_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_UUID_SIZE; +import static io.trino.spi.variant.VariantEncoder.encodeBinary; +import static io.trino.spi.variant.VariantEncoder.encodeBoolean; +import static io.trino.spi.variant.VariantEncoder.encodeByte; +import static io.trino.spi.variant.VariantEncoder.encodeDate; +import static io.trino.spi.variant.VariantEncoder.encodeDecimal16; +import static io.trino.spi.variant.VariantEncoder.encodeDecimal8; +import static io.trino.spi.variant.VariantEncoder.encodeDouble; +import static io.trino.spi.variant.VariantEncoder.encodeFloat; +import static io.trino.spi.variant.VariantEncoder.encodeInt; +import static io.trino.spi.variant.VariantEncoder.encodeLong; +import static io.trino.spi.variant.VariantEncoder.encodeNull; +import static io.trino.spi.variant.VariantEncoder.encodeShort; +import static io.trino.spi.variant.VariantEncoder.encodeString; +import static io.trino.spi.variant.VariantEncoder.encodeUuid; +import static io.trino.spi.variant.VariantEncoder.encodedBinarySize; +import static io.trino.spi.variant.VariantEncoder.encodedStringSize; +import static java.lang.Math.multiplyExact; + +/// Encodes a primitive value directly from a block into a variant Slice. +/// This does not extend VariantWriter to avoid the required creation of +/// PlannedValue objects that would be required to retain the Block and position. +/// Because of this efficiency advantage, this class should be used in preference +/// to VariantWriter wherever possible. +abstract class PrimitiveVariantEncoder +{ + static Optional create(Type type) + { + return Optional.ofNullable(switch (type) { + case UnknownType _ -> new UnknownVariantEncoder(); + case BooleanType _ -> new BooleanVariantEncoder(); + case TinyintType _ -> new TinyintVariantEncoder(); + case SmallintType _ -> new SmallintVariantEncoder(); + case IntegerType _ -> new IntegerVariantEncoder(); + case BigintType _ -> new BigintVariantEncoder(); + case DecimalType t -> t.isShort() ? new ShortDecimalVariantEncoder(t) : new LongDecimalVariantEncoder(t); + case RealType _ -> new FloatVariantEncoder(); + case DoubleType _ -> new DoubleVariantEncoder(); + case DateType _ -> new DateVariantEncoder(); + case TimeType _ -> new TimeVariantEncoder(); + case TimestampType t -> t.isShort() ? new ShortTimestampVariantEncoder() : new LongTimestampVariantEncoder(t); + case TimestampWithTimeZoneType t -> t.isShort() ? new ShortTimestampWithTimezoneVariantEncoder() : new LongTimestampWithTimezoneVariantEncoder(t); + case UuidType _ -> new UuidVariantEncoder(); + case VarcharType _ -> new VarcharVariantEncoder(); + case VarbinaryType _ -> new VarbinaryVariantEncoder(); + default -> null; + }); + } + + abstract int size(Block block, int position); + + public final int write(Block block, int position, Slice out, int offset) + { + if (block.isNull(position)) { + return encodeNull(out, offset); + } + return writeNonNull(block, position, out, offset); + } + + abstract int writeNonNull(Block block, int position, Slice out, int offset); + + private static final class UnknownVariantEncoder + extends PrimitiveVariantEncoder + { + @Override + public int size(Block block, int position) + { + return ENCODED_NULL_SIZE; + } + + @Override + public int writeNonNull(Block block, int position, Slice out, int offset) + { + return encodeNull(out, offset); + } + } + + private abstract static class FixedPrimitiveVariantEncoder + extends PrimitiveVariantEncoder + { + private final int encodedSize; + + FixedPrimitiveVariantEncoder(int encodedSize) + { + this.encodedSize = encodedSize; + } + + @Override + public final int size(Block block, int position) + { + if (block.isNull(position)) { + return ENCODED_NULL_SIZE; + } + return encodedSize; + } + } + + private static class BooleanVariantEncoder + extends FixedPrimitiveVariantEncoder + { + private BooleanVariantEncoder() + { + super(ENCODED_BOOLEAN_SIZE); + } + + @Override + public int writeNonNull(Block block, int position, Slice out, int offset) + { + return encodeBoolean(BOOLEAN.getBoolean(block, position), out, offset); + } + } + + private static class TinyintVariantEncoder + extends FixedPrimitiveVariantEncoder + { + private TinyintVariantEncoder() + { + super(ENCODED_BYTE_SIZE); + } + + @Override + public int writeNonNull(Block block, int position, Slice out, int offset) + { + return encodeByte(TINYINT.getByte(block, position), out, offset); + } + } + + private static class SmallintVariantEncoder + extends FixedPrimitiveVariantEncoder + { + private SmallintVariantEncoder() + { + super(ENCODED_SHORT_SIZE); + } + + @Override + public int writeNonNull(Block block, int position, Slice out, int offset) + { + return encodeShort(SMALLINT.getShort(block, position), out, offset); + } + } + + private static class IntegerVariantEncoder + extends FixedPrimitiveVariantEncoder + { + private IntegerVariantEncoder() + { + super(ENCODED_INT_SIZE); + } + + @Override + public int writeNonNull(Block block, int position, Slice out, int offset) + { + return encodeInt(IntegerType.INTEGER.getInt(block, position), out, offset); + } + } + + private static class BigintVariantEncoder + extends FixedPrimitiveVariantEncoder + { + private BigintVariantEncoder() + { + super(ENCODED_LONG_SIZE); + } + + @Override + public int writeNonNull(Block block, int position, Slice out, int offset) + { + return encodeLong(BIGINT.getLong(block, position), out, offset); + } + } + + private static class ShortDecimalVariantEncoder + extends FixedPrimitiveVariantEncoder + { + private final DecimalType type; + + private ShortDecimalVariantEncoder(DecimalType type) + { + super(ENCODED_DECIMAL8_SIZE); + this.type = type; + } + + @Override + public int writeNonNull(Block block, int position, Slice out, int offset) + { + return encodeDecimal8(type.getLong(block, position), type.getScale(), out, offset); + } + } + + private static class LongDecimalVariantEncoder + extends FixedPrimitiveVariantEncoder + { + private final DecimalType type; + + private LongDecimalVariantEncoder(DecimalType type) + { + super(ENCODED_DECIMAL16_SIZE); + this.type = type; + } + + @Override + public int writeNonNull(Block block, int position, Slice out, int offset) + { + return encodeDecimal16((Int128) type.getObject(block, position), type.getScale(), out, offset); + } + } + + private static class FloatVariantEncoder + extends FixedPrimitiveVariantEncoder + { + private FloatVariantEncoder() + { + super(ENCODED_FLOAT_SIZE); + } + + @Override + public int writeNonNull(Block block, int position, Slice out, int offset) + { + return encodeFloat(REAL.getFloat(block, position), out, offset); + } + } + + private static class DoubleVariantEncoder + extends FixedPrimitiveVariantEncoder + { + private DoubleVariantEncoder() + { + super(ENCODED_DOUBLE_SIZE); + } + + @Override + public int writeNonNull(Block block, int position, Slice out, int offset) + { + return encodeDouble(DOUBLE.getDouble(block, position), out, offset); + } + } + + private static class DateVariantEncoder + extends FixedPrimitiveVariantEncoder + { + private DateVariantEncoder() + { + super(ENCODED_DATE_SIZE); + } + + @Override + public int writeNonNull(Block block, int position, Slice out, int offset) + { + return encodeDate(DateType.DATE.getInt(block, position), out, offset); + } + } + + private static class TimeVariantEncoder + extends FixedPrimitiveVariantEncoder + { + private TimeVariantEncoder() + { + super(ENCODED_TIME_SIZE); + } + + @Override + public int writeNonNull(Block block, int position, Slice out, int offset) + { + long epochPicos = BIGINT.getLong(block, position); + long micros = epochPicos / 1_000_000L; + return VariantEncoder.encodeTimeMicrosNtz(micros, out, offset); + } + } + + private static class ShortTimestampVariantEncoder + extends FixedPrimitiveVariantEncoder + { + private ShortTimestampVariantEncoder() + { + super(ENCODED_TIMESTAMP_SIZE); + } + + @Override + public int writeNonNull(Block block, int position, Slice out, int offset) + { + long epochMicros = BIGINT.getLong(block, position); + long nanos = multiplyExact(epochMicros, 1_000L); + return VariantEncoder.encodeTimestampNanosNtz(nanos, out, offset); + } + } + + private static class LongTimestampVariantEncoder + extends FixedPrimitiveVariantEncoder + { + private final TimestampType type; + + private LongTimestampVariantEncoder(TimestampType type) + { + super(ENCODED_TIMESTAMP_SIZE); + this.type = type; + } + + @Override + public int writeNonNull(Block block, int position, Slice out, int offset) + { + LongTimestamp timestamp = (LongTimestamp) type.getObject(block, position); + long nanosFromMicros = multiplyExact(timestamp.getEpochMicros(), 1_000L); + long extraNanos = timestamp.getPicosOfMicro() / 1_000; // 1000 ps = 1 ns + long nanos = Math.addExact(nanosFromMicros, extraNanos); + return VariantEncoder.encodeTimestampNanosNtz(nanos, out, offset); + } + } + + private static class ShortTimestampWithTimezoneVariantEncoder + extends FixedPrimitiveVariantEncoder + { + private ShortTimestampWithTimezoneVariantEncoder() + { + super(ENCODED_TIMESTAMP_SIZE); + } + + @Override + public int writeNonNull(Block block, int position, Slice out, int offset) + { + long packedEpochMillis = BIGINT.getLong(block, position); + long epochMillis = unpackMillisUtc(packedEpochMillis); + long epochMicros = multiplyExact(epochMillis, 1_000L); + return VariantEncoder.encodeTimestampMicrosUtc(epochMicros, out, offset); + } + } + + private static class LongTimestampWithTimezoneVariantEncoder + extends FixedPrimitiveVariantEncoder + { + private final TimestampWithTimeZoneType type; + + private LongTimestampWithTimezoneVariantEncoder(TimestampWithTimeZoneType type) + { + super(ENCODED_TIMESTAMP_SIZE); + this.type = type; + } + + @Override + public int writeNonNull(Block block, int position, Slice out, int offset) + { + LongTimestampWithTimeZone timestamp = (LongTimestampWithTimeZone) type.getObject(block, position); + + if (type.getPrecision() <= 6) { + long millisFromMillis = multiplyExact(timestamp.getEpochMillis(), 1000L); + int extraMillis = timestamp.getPicosOfMilli() / 1_000_000; + long epochMicros = Math.addExact(millisFromMillis, extraMillis); + return VariantEncoder.encodeTimestampMicrosUtc(epochMicros, out, offset); + } + + long nanosFromMillis = multiplyExact(timestamp.getEpochMillis(), 1_000_000L); + int extraNanos = timestamp.getPicosOfMilli() / 1_000; + long nanos = Math.addExact(nanosFromMillis, extraNanos); + return VariantEncoder.encodeTimestampNanosUtc(nanos, out, offset); + } + } + + private static class UuidVariantEncoder + extends FixedPrimitiveVariantEncoder + { + private UuidVariantEncoder() + { + super(ENCODED_UUID_SIZE); + } + + @Override + public int writeNonNull(Block block, int position, Slice out, int offset) + { + return encodeUuid(UUID.getSlice(block, position), out, offset); + } + } + + private static final class VarcharVariantEncoder + extends PrimitiveVariantEncoder + { + @Override + public int size(Block block, int position) + { + if (block.isNull(position)) { + return ENCODED_NULL_SIZE; + } + return encodedStringSize(getSliceLength(block, position)); + } + + @Override + public int writeNonNull(Block block, int position, Slice out, int offset) + { + return encodeString(getSlice(block, position), out, offset); + } + } + + private static final class VarbinaryVariantEncoder + extends PrimitiveVariantEncoder + { + @Override + public int size(Block block, int position) + { + if (block.isNull(position)) { + return ENCODED_NULL_SIZE; + } + return encodedBinarySize(getSliceLength(block, position)); + } + + @Override + public int writeNonNull(Block block, int position, Slice out, int offset) + { + return encodeBinary(getSlice(block, position), out, offset); + } + } + + private static Slice getSlice(Block block, int position) + { + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getSlice(valuePosition); + } + + private static int getSliceLength(Block block, int position) + { + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getSliceLength(valuePosition); + } +} diff --git a/core/trino-main/src/main/java/io/trino/util/variant/RowVariantWriter.java b/core/trino-main/src/main/java/io/trino/util/variant/RowVariantWriter.java new file mode 100644 index 000000000000..5386cdb574ad --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/util/variant/RowVariantWriter.java @@ -0,0 +1,219 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.util.variant; + +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.trino.spi.TrinoException; +import io.trino.spi.block.Block; +import io.trino.spi.block.SqlRow; +import io.trino.spi.type.RowType; +import io.trino.spi.type.RowType.Field; +import io.trino.spi.type.Type; +import io.trino.spi.variant.Metadata; +import it.unimi.dsi.fastutil.ints.IntArrays; +import it.unimi.dsi.fastutil.ints.IntComparator; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.function.IntUnaryOperator; + +import static com.google.common.base.Verify.verify; +import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; +import static io.trino.spi.variant.VariantEncoder.encodeObjectHeading; +import static io.trino.spi.variant.VariantEncoder.encodedObjectSize; +import static java.util.Objects.requireNonNull; + +final class RowVariantWriter + implements VariantWriter +{ + // all fields are in the order they will be written + private final List fieldNames; + private final List fieldTypes; + private final List> fieldPrimitiveEncoders; + private final List> fieldWriters; + private final IntUnaryOperator outputIndexToRowIndex; + + RowVariantWriter(RowType type) + { + List fieldNames = type.getFields().stream() + .map(Field::getName) + .map(name -> name.orElseThrow(() -> new TrinoException(INVALID_CAST_ARGUMENT, "Cannot cast ROW with anonymous fields to VARIANT"))) + .map(Slices::utf8Slice) + .toList(); + int[] writeOrder = determineWriteOrder(fieldNames); + + // Build "in write order" views + int fieldCount = type.getFields().size(); + ImmutableList.Builder fieldNamesInWriteOrder = ImmutableList.builderWithExpectedSize(fieldCount); + ImmutableList.Builder fieldTypesInWriteOrder = ImmutableList.builderWithExpectedSize(fieldCount); + ImmutableList.Builder> primitiveEncodersInWriteOrder = ImmutableList.builderWithExpectedSize(fieldCount); + ImmutableList.Builder> writersInWriteOrder = ImmutableList.builderWithExpectedSize(fieldCount); + + for (int fieldIndex : writeOrder) { + Type fieldType = type.getTypeParameters().get(fieldIndex); + fieldNamesInWriteOrder.add(fieldNames.get(fieldIndex)); + fieldTypesInWriteOrder.add(fieldType); + + Optional primitiveEncoder = PrimitiveVariantEncoder.create(fieldType); + primitiveEncodersInWriteOrder.add(primitiveEncoder); + + if (primitiveEncoder.isPresent()) { + writersInWriteOrder.add(Optional.empty()); + } + else { + writersInWriteOrder.add(Optional.of(VariantWriter.create(fieldType))); + } + } + + this.fieldNames = fieldNamesInWriteOrder.build(); + this.fieldTypes = fieldTypesInWriteOrder.build(); + this.fieldPrimitiveEncoders = primitiveEncodersInWriteOrder.build(); + this.fieldWriters = writersInWriteOrder.build(); + this.outputIndexToRowIndex = outputIndex -> writeOrder[outputIndex]; + } + + private static int[] determineWriteOrder(List fieldNames) + { + int[] writeOrder = new int[fieldNames.size()]; + for (int i = 0; i < fieldNames.size(); i++) { + writeOrder[i] = i; + } + IntArrays.unstableSort(writeOrder, 0, writeOrder.length, IntComparator.comparing(fieldNames::get)); + return writeOrder; + } + + @Override + public PlannedValue plan(Metadata.Builder metadataBuilder, Object value) + { + if (value == null) { + return NullPlannedValue.NULL_PLANNED_VALUE; + } + + int[] provisionalFieldIds = metadataBuilder.addFieldNames(fieldNames); + + FieldBlocks blocks = new FieldBlocks((SqlRow) value, outputIndexToRowIndex); + List> plannedValues = new ArrayList<>(fieldWriters.size()); + for (int i = 0; i < fieldWriters.size(); i++) { + Optional fieldWriter = fieldWriters.get(i); + if (fieldWriter.isPresent()) { + Object fieldValue = fieldTypes.get(i).getObject(blocks.getField(i), blocks.getPosition()); + plannedValues.add(Optional.of(fieldWriter.get().plan(metadataBuilder, fieldValue))); + } + else { + plannedValues.add(Optional.empty()); + } + } + + return new PlannedRowValue(fieldPrimitiveEncoders, plannedValues, blocks, provisionalFieldIds); + } + + private static final class PlannedRowValue + implements PlannedValue + { + private final List> fieldPrimitiveEncoders; + private final List> fieldPlannedValue; + private final FieldBlocks fieldBlocks; + + // Initially, field IDs are provisional and must be remapped in finalize() + // This array is mutated in finalize() + private final int[] fieldIds; + + private int size = -1; + + public PlannedRowValue(List> fieldPrimitiveEncoders, List> fieldPlannedValue, FieldBlocks fieldBlocks, int[] fieldIds) + { + this.fieldPrimitiveEncoders = requireNonNull(fieldPrimitiveEncoders, "fieldPrimitiveEncoders is null"); + this.fieldPlannedValue = requireNonNull(fieldPlannedValue, "fieldPlannedValue is null"); + this.fieldBlocks = requireNonNull(fieldBlocks, "fieldBlocks is null"); + this.fieldIds = requireNonNull(fieldIds, "fieldIds is null"); + } + + @Override + public void finalize(IntUnaryOperator sortedFieldIdMapping) + { + verify(size < 0, "finalize() already called"); + + int maxFieldId = -1; + for (int i = 0; i < fieldIds.length; i++) { + int finalFieldId = sortedFieldIdMapping.applyAsInt(fieldIds[i]); + fieldIds[i] = finalFieldId; + maxFieldId = Math.max(maxFieldId, finalFieldId); + } + + int totalElementLength = 0; + for (int i = 0; i < fieldPrimitiveEncoders.size(); i++) { + Optional primitiveElementEncoder = fieldPrimitiveEncoders.get(i); + if (primitiveElementEncoder.isPresent()) { + totalElementLength += primitiveElementEncoder.get().size(fieldBlocks.getField(i), fieldBlocks.getPosition()); + } + else { + PlannedValue plannedValue = fieldPlannedValue.get(i).orElseThrow(); + plannedValue.finalize(sortedFieldIdMapping); + totalElementLength += plannedValue.size(); + } + } + + size = encodedObjectSize(maxFieldId, fieldPrimitiveEncoders.size(), totalElementLength); + } + + @Override + public int size() + { + verify(size >= 0, "finalize() must be called before size()"); + return size; + } + + @Override + public int write(Slice out, int offset) + { + verify(size >= 0, "finalize() must be called before write()"); + + int written = encodeObjectHeading( + fieldPrimitiveEncoders.size(), + i -> fieldIds[i], + i -> fieldPrimitiveEncoders.get(i) + .map(elementEncoder -> elementEncoder.size(fieldBlocks.getField(i), fieldBlocks.getPosition())) + .orElseGet(() -> fieldPlannedValue.get(i).orElseThrow().size()), + out, + offset); + + for (int i = 0; i < fieldPrimitiveEncoders.size(); i++) { + Optional primitiveEncoder = fieldPrimitiveEncoders.get(i); + if (primitiveEncoder.isPresent()) { + written += primitiveEncoder.get().write(fieldBlocks.getField(i), fieldBlocks.getPosition(), out, offset + written); + } + else { + written += fieldPlannedValue.get(i).orElseThrow().write(out, offset + written); + } + } + return written; + } + } + + private record FieldBlocks(SqlRow row, IntUnaryOperator blockIndexMapper) + { + public int getPosition() + { + return row.getRawIndex(); + } + + public Block getField(int outputIndex) + { + return row.getRawFieldBlock(blockIndexMapper.applyAsInt(outputIndex)); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/util/variant/VariantCastException.java b/core/trino-main/src/main/java/io/trino/util/variant/VariantCastException.java new file mode 100644 index 000000000000..8c87afa7d1e7 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/util/variant/VariantCastException.java @@ -0,0 +1,27 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.util.variant; + +import io.trino.spi.TrinoException; + +import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; + +public class VariantCastException + extends TrinoException +{ + public VariantCastException(String message) + { + super(INVALID_CAST_ARGUMENT, message); + } +} diff --git a/core/trino-main/src/main/java/io/trino/util/variant/VariantUtil.java b/core/trino-main/src/main/java/io/trino/util/variant/VariantUtil.java new file mode 100644 index 000000000000..63e636e521f3 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/util/variant/VariantUtil.java @@ -0,0 +1,1302 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.util.variant; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.json.JsonMapper; +import com.google.common.collect.ImmutableMap; +import io.airlift.slice.DynamicSliceOutput; +import io.airlift.slice.Slice; +import io.airlift.slice.SliceOutput; +import io.trino.operator.scalar.time.TimeOperators; +import io.trino.operator.scalar.timestamp.VarcharToTimestampCast; +import io.trino.operator.scalar.timestamptz.VarcharToTimestampWithTimeZoneCast; +import io.trino.spi.TrinoException; +import io.trino.spi.block.ArrayBlockBuilder; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.MapBlockBuilder; +import io.trino.spi.block.RowBlockBuilder; +import io.trino.spi.block.VariableWidthBlockBuilder; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.BigintType; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.DateType; +import io.trino.spi.type.DecimalType; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.Int128; +import io.trino.spi.type.IntegerType; +import io.trino.spi.type.LongTimestamp; +import io.trino.spi.type.LongTimestampWithTimeZone; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RealType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.RowType.Field; +import io.trino.spi.type.SmallintType; +import io.trino.spi.type.TimeType; +import io.trino.spi.type.TimestampType; +import io.trino.spi.type.TimestampWithTimeZoneType; +import io.trino.spi.type.TinyintType; +import io.trino.spi.type.Type; +import io.trino.spi.type.UuidType; +import io.trino.spi.type.VarbinaryType; +import io.trino.spi.type.VarcharType; +import io.trino.spi.type.VariantType; +import io.trino.spi.variant.Header; +import io.trino.spi.variant.Metadata; +import io.trino.spi.variant.Variant; +import io.trino.type.BigintOperators; +import io.trino.type.BooleanOperators; +import io.trino.type.DateOperators; +import io.trino.type.DateTimes; +import io.trino.type.DoubleOperators; +import io.trino.type.IntegerOperators; +import io.trino.type.JsonType; +import io.trino.type.SmallintOperators; +import io.trino.type.TinyintOperators; +import io.trino.type.UnknownType; +import io.trino.type.UuidOperators; +import io.trino.type.VarcharOperators; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.math.BigDecimal; +import java.time.Instant; +import java.time.ZoneId; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.TimeUnit; + +import static com.fasterxml.jackson.core.JsonFactory.Feature.CANONICALIZE_FIELD_NAMES; +import static com.fasterxml.jackson.core.json.JsonWriteFeature.COMBINE_UNICODE_SURROGATES_IN_UTF8; +import static com.fasterxml.jackson.core.json.JsonWriteFeature.ESCAPE_NON_ASCII; +import static com.google.common.base.Preconditions.checkArgument; +import static io.airlift.slice.Slices.utf8Slice; +import static io.trino.plugin.base.util.JsonUtils.jsonFactoryBuilder; +import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; +import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static io.trino.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.DateTimeEncoding.packDateTimeWithZone; +import static io.trino.spi.type.DateType.DATE; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.LongTimestampWithTimeZone.fromEpochMillisAndFraction; +import static io.trino.spi.type.RealType.REAL; +import static io.trino.spi.type.SmallintType.SMALLINT; +import static io.trino.spi.type.TimeType.MAX_PRECISION; +import static io.trino.spi.type.TimeZoneKey.UTC_KEY; +import static io.trino.spi.type.TinyintType.TINYINT; +import static io.trino.spi.type.UuidType.UUID; +import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.spi.type.VarcharType.UNBOUNDED_LENGTH; +import static io.trino.spi.type.VariantType.VARIANT; +import static io.trino.spi.variant.Header.BasicType.PRIMITIVE; +import static io.trino.spi.variant.Header.PrimitiveType.BINARY; +import static io.trino.type.DateTimes.MICROSECONDS_PER_DAY; +import static io.trino.type.DateTimes.NANOSECONDS_PER_DAY; +import static io.trino.type.DateTimes.PICOSECONDS_PER_DAY; +import static io.trino.type.DateTimes.round; +import static io.trino.type.JsonType.JSON; +import static io.trino.util.JsonUtil.createJsonGenerator; +import static java.lang.Float.floatToRawIntBits; +import static java.lang.Math.toIntExact; +import static java.lang.String.format; +import static java.math.RoundingMode.HALF_UP; +import static java.time.ZoneOffset.UTC; + +public final class VariantUtil +{ + private static final JsonMapper JSON_MAPPER = new JsonMapper(jsonFactoryBuilder() + .disable(CANONICALIZE_FIELD_NAMES) + // prevents characters outside BMP (e.g., emoji) from being escaped as surrogate pairs + .enable(COMBINE_UNICODE_SURROGATES_IN_UTF8) + .build()); + + private VariantUtil() {} + + public static boolean canCastToVariant(Type type) + { + if (type instanceof UnknownType || + type instanceof BooleanType || + type instanceof TinyintType || + type instanceof SmallintType || + type instanceof IntegerType || + type instanceof BigintType || + type instanceof RealType || + type instanceof DoubleType || + type instanceof DecimalType || + type instanceof VarcharType || + type instanceof VarbinaryType || + type instanceof VariantType || + type instanceof TimestampType || + type instanceof TimestampWithTimeZoneType || + type instanceof DateType || + type instanceof TimeType || + type instanceof UuidType || + type instanceof JsonType) { + return true; + } + if (type instanceof ArrayType arrayType) { + return canCastToVariant(arrayType.getElementType()); + } + if (type instanceof MapType mapType) { + return mapType.getKeyType() instanceof VarcharType && + canCastToVariant(mapType.getValueType()); + } + if (type instanceof RowType) { + return type.getTypeParameters().stream().allMatch(VariantUtil::canCastToVariant); + } + return false; + } + + public static boolean canCastFromVariant(Type type) + { + if (type instanceof UnknownType || + type instanceof BooleanType || + type instanceof TinyintType || + type instanceof SmallintType || + type instanceof IntegerType || + type instanceof BigintType || + type instanceof RealType || + type instanceof DoubleType || + type instanceof DecimalType || + type instanceof VarcharType || + type instanceof VarbinaryType || + type instanceof VariantType || + type instanceof TimestampType || + type instanceof TimestampWithTimeZoneType || + type instanceof DateType || + type instanceof TimeType || + type instanceof UuidType || + type instanceof JsonType) { + return true; + } + if (type instanceof ArrayType arrayType) { + return canCastFromVariant(arrayType.getElementType()); + } + if (type instanceof MapType mapType) { + return mapType.getKeyType() instanceof VarcharType && canCastFromVariant(mapType.getValueType()); + } + if (type instanceof RowType) { + return type.getTypeParameters().stream().allMatch(VariantUtil::canCastFromVariant); + } + return false; + } + + // utility classes and functions for cast from VARIANT + public static Slice asVarchar(Variant variant) + { + return switch (variant.basicType()) { + case PRIMITIVE -> switch (variant.primitiveType()) { + case NULL -> null; + case STRING -> variant.getString(); + case BOOLEAN_TRUE -> BooleanOperators.castToVarchar(UNBOUNDED_LENGTH, true); + case BOOLEAN_FALSE -> BooleanOperators.castToVarchar(UNBOUNDED_LENGTH, false); + case INT8 -> utf8Slice(String.valueOf(variant.getByte())); + case INT16 -> utf8Slice(String.valueOf(variant.getShort())); + case INT32 -> utf8Slice(String.valueOf(variant.getInt())); + case INT64 -> utf8Slice(String.valueOf(variant.getLong())); + case DECIMAL4, DECIMAL8, DECIMAL16 -> utf8Slice(variant.getDecimal().toString()); + case FLOAT -> DoubleOperators.castToVarchar(UNBOUNDED_LENGTH, variant.getFloat()); + case DOUBLE -> DoubleOperators.castToVarchar(UNBOUNDED_LENGTH, variant.getDouble()); + case DATE -> DateOperators.castToVarchar(UNBOUNDED_LENGTH, variant.getDate()); + case TIME_NTZ_MICROS -> TimeOperators.castToVarchar(UNBOUNDED_LENGTH, 6, variant.getTimeMicros() * 1_000_000L); + case TIMESTAMP_UTC_MICROS -> { + long micros = variant.getTimestampMicros(); + long epochMillis = Math.floorDiv(micros, 1_000L); + int picosOfMilli = toIntExact(Math.floorMod(micros, 1_000L) * 1_000_000L); + yield utf8Slice(DateTimes.formatTimestampWithTimeZone(6, epochMillis, picosOfMilli, UTC_KEY.getZoneId())); + } + case TIMESTAMP_NTZ_MICROS -> utf8Slice(DateTimes.formatTimestamp(6, variant.getTimestampMicros(), 0, UTC)); + case TIMESTAMP_UTC_NANOS -> { + long nanos = variant.getTimestampNanos(); + long epochMillis = Math.floorDiv(nanos, 1_000_000L); + int picosOfMilli = toIntExact(Math.floorMod(nanos, 1_000_000L) * 1_000L); + yield utf8Slice(DateTimes.formatTimestampWithTimeZone(9, epochMillis, picosOfMilli, UTC_KEY.getZoneId())); + } + case TIMESTAMP_NTZ_NANOS -> { + long nanos = variant.getTimestampNanos(); + long epochMicros = Math.floorDiv(nanos, 1_000L); + int picosOfMicros = toIntExact(Math.floorMod(nanos, 1_000L) * 1_000L); + yield utf8Slice(DateTimes.formatTimestamp(9, epochMicros, picosOfMicros, UTC)); + } + case UUID -> utf8Slice(variant.getUuid().toString()); + default -> throw new VariantCastException("Unsupported VARIANT primitive type for cast to VARCHAR: " + variant.primitiveType()); + }; + case SHORT_STRING -> variant.getString(); + default -> throw new VariantCastException("Unsupported VARIANT type for cast to VARCHAR: " + variant.basicType()); + }; + } + + public static Boolean asBoolean(Variant variant) + { + return switch (variant.basicType()) { + case PRIMITIVE -> switch (variant.primitiveType()) { + case NULL -> null; + case BOOLEAN_TRUE -> true; + case BOOLEAN_FALSE -> false; + case STRING -> VarcharOperators.castToBoolean(variant.getString()); + case INT8 -> TinyintOperators.castToBoolean(variant.getByte()); + case INT16 -> SmallintOperators.castToBoolean(variant.getShort()); + case INT32 -> IntegerOperators.castToBoolean(variant.getInt()); + case INT64 -> BigintOperators.castToBoolean(variant.getLong()); + case DECIMAL4, DECIMAL8, DECIMAL16 -> variant.getDecimal().compareTo(BigDecimal.ZERO) != 0; + case FLOAT -> DoubleOperators.castToBoolean(variant.getFloat()); + case DOUBLE -> DoubleOperators.castToBoolean(variant.getDouble()); + default -> throw new VariantCastException("Unsupported VARIANT primitive type for cast to BOOLEAN: " + variant.primitiveType()); + }; + case SHORT_STRING -> VarcharOperators.castToBoolean(variant.getString()); + default -> throw new VariantCastException("Unsupported VARIANT type for cast to BOOLEAN: " + variant.basicType()); + }; + } + + public static Long asTinyint(Variant variant) + { + return switch (variant.basicType()) { + case PRIMITIVE -> switch (variant.primitiveType()) { + case NULL -> null; + case BOOLEAN_TRUE -> BooleanOperators.castToTinyint(true); + case BOOLEAN_FALSE -> BooleanOperators.castToTinyint(false); + case STRING -> VarcharOperators.castToTinyint(variant.getString()); + case INT8 -> (long) variant.getByte(); + case INT16 -> SmallintOperators.castToTinyint(variant.getShort()); + case INT32 -> IntegerOperators.castToTinyint(variant.getInt()); + case INT64 -> BigintOperators.castToTinyint(variant.getLong()); + case DECIMAL4, DECIMAL8, DECIMAL16 -> { + BigDecimal decimalValue = variant.getDecimal(); + try { + yield (long) decimalValue.byteValueExact(); + } + catch (ArithmeticException e) { + throw new TrinoException(NUMERIC_VALUE_OUT_OF_RANGE, "Out of range for tinyint: " + decimalValue, e); + } + } + case FLOAT -> DoubleOperators.castToTinyint(variant.getFloat()); + case DOUBLE -> DoubleOperators.castToTinyint(variant.getDouble()); + default -> throw new VariantCastException("Unsupported VARIANT primitive type for cast to TINYINT: " + variant.primitiveType()); + }; + case SHORT_STRING -> VarcharOperators.castToTinyint(variant.getString()); + default -> throw new VariantCastException("Unsupported VARIANT type for cast to TINYINT: " + variant.basicType()); + }; + } + + public static Long asSmallint(Variant variant) + { + return switch (variant.basicType()) { + case PRIMITIVE -> switch (variant.primitiveType()) { + case NULL -> null; + case BOOLEAN_TRUE -> BooleanOperators.castToSmallint(true); + case BOOLEAN_FALSE -> BooleanOperators.castToSmallint(false); + case STRING -> VarcharOperators.castToSmallint(variant.getString()); + case INT8 -> (long) variant.getByte(); + case INT16 -> (long) variant.getShort(); + case INT32 -> IntegerOperators.castToSmallint(variant.getInt()); + case INT64 -> BigintOperators.castToSmallint(variant.getLong()); + case DECIMAL4, DECIMAL8, DECIMAL16 -> { + BigDecimal decimalValue = variant.getDecimal(); + try { + yield (long) decimalValue.shortValueExact(); + } + catch (ArithmeticException e) { + throw new TrinoException(NUMERIC_VALUE_OUT_OF_RANGE, "Out of range for smallint: " + decimalValue, e); + } + } + case FLOAT -> DoubleOperators.castToSmallint(variant.getFloat()); + case DOUBLE -> DoubleOperators.castToSmallint(variant.getDouble()); + default -> throw new VariantCastException("Unsupported VARIANT primitive type for cast to SMALLINT: " + variant.primitiveType()); + }; + case SHORT_STRING -> VarcharOperators.castToSmallint(variant.getString()); + default -> throw new VariantCastException("Unsupported VARIANT type for cast to SMALLINT: " + variant.basicType()); + }; + } + + public static Long asInteger(Variant variant) + { + return switch (variant.basicType()) { + case PRIMITIVE -> switch (variant.primitiveType()) { + case NULL -> null; + case BOOLEAN_TRUE -> BooleanOperators.castToInteger(true); + case BOOLEAN_FALSE -> BooleanOperators.castToInteger(false); + case STRING -> VarcharOperators.castToInteger(variant.getString()); + case INT8 -> (long) variant.getByte(); + case INT16 -> (long) variant.getShort(); + case INT32 -> (long) variant.getInt(); + case INT64 -> BigintOperators.castToInteger(variant.getLong()); + case DECIMAL4, DECIMAL8, DECIMAL16 -> { + BigDecimal decimalValue = variant.getDecimal(); + try { + yield (long) decimalValue.intValueExact(); + } + catch (ArithmeticException e) { + throw new TrinoException(NUMERIC_VALUE_OUT_OF_RANGE, "Out of range for integer: " + decimalValue, e); + } + } + case FLOAT -> DoubleOperators.castToInteger(variant.getFloat()); + case DOUBLE -> DoubleOperators.castToInteger(variant.getDouble()); + default -> throw new VariantCastException("Unsupported VARIANT primitive type for cast to INTEGER: " + variant.primitiveType()); + }; + case SHORT_STRING -> VarcharOperators.castToInteger(variant.getString()); + default -> throw new VariantCastException("Unsupported VARIANT type for cast to INTEGER: " + variant.basicType()); + }; + } + + public static Long asBigint(Variant variant) + { + return switch (variant.basicType()) { + case PRIMITIVE -> switch (variant.primitiveType()) { + case NULL -> null; + case BOOLEAN_TRUE -> BooleanOperators.castToBigint(true); + case BOOLEAN_FALSE -> BooleanOperators.castToBigint(false); + case STRING -> VarcharOperators.castToBigint(variant.getString()); + case INT8 -> (long) variant.getByte(); + case INT16 -> (long) variant.getShort(); + case INT32 -> (long) variant.getInt(); + case INT64 -> variant.getLong(); + case DECIMAL4, DECIMAL8, DECIMAL16 -> { + BigDecimal decimalValue = variant.getDecimal(); + try { + yield decimalValue.longValueExact(); + } + catch (ArithmeticException e) { + throw new TrinoException(NUMERIC_VALUE_OUT_OF_RANGE, "Out of range for bigint: " + decimalValue, e); + } + } + case FLOAT -> DoubleOperators.castToLong(variant.getFloat()); + case DOUBLE -> DoubleOperators.castToLong(variant.getDouble()); + default -> throw new VariantCastException("Unsupported VARIANT primitive type for cast to BIGINT: " + variant.primitiveType()); + }; + case SHORT_STRING -> VarcharOperators.castToBigint(variant.getString()); + default -> throw new VariantCastException("Unsupported VARIANT type for cast to BIGINT: " + variant.basicType()); + }; + } + + public static Long asReal(Variant variant) + { + return switch (variant.basicType()) { + case PRIMITIVE -> switch (variant.primitiveType()) { + case NULL -> null; + case BOOLEAN_TRUE -> BooleanOperators.castToReal(true); + case BOOLEAN_FALSE -> BooleanOperators.castToReal(false); + case STRING -> VarcharOperators.castToFloat(variant.getString()); + case INT8 -> TinyintOperators.castToReal(variant.getByte()); + case INT16 -> SmallintOperators.castToReal(variant.getShort()); + case INT32 -> IntegerOperators.castToReal(variant.getInt()); + case INT64 -> BigintOperators.castToReal(variant.getLong()); + case DECIMAL4, DECIMAL8, DECIMAL16 -> (long) floatToRawIntBits(variant.getDecimal().floatValue()); + case FLOAT -> (long) floatToRawIntBits(variant.getFloat()); + case DOUBLE -> DoubleOperators.castToReal(variant.getDouble()); + default -> throw new VariantCastException("Unsupported VARIANT primitive type for cast to REAL: " + variant.primitiveType()); + }; + case SHORT_STRING -> VarcharOperators.castToFloat(variant.getString()); + default -> throw new VariantCastException("Unsupported VARIANT type for cast to REAL: " + variant.basicType()); + }; + } + + public static Double asDouble(Variant variant) + { + return switch (variant.basicType()) { + case PRIMITIVE -> switch (variant.primitiveType()) { + case NULL -> null; + case BOOLEAN_TRUE -> BooleanOperators.castToDouble(true); + case BOOLEAN_FALSE -> BooleanOperators.castToDouble(false); + case STRING -> VarcharOperators.castToDouble(variant.getString()); + case INT8 -> TinyintOperators.castToDouble(variant.getByte()); + case INT16 -> SmallintOperators.castToDouble(variant.getShort()); + case INT32 -> IntegerOperators.castToDouble(variant.getInt()); + case INT64 -> BigintOperators.castToDouble(variant.getLong()); + case DECIMAL4, DECIMAL8, DECIMAL16 -> variant.getDecimal().doubleValue(); + case FLOAT -> (double) variant.getFloat(); + case DOUBLE -> variant.getDouble(); + default -> throw new VariantCastException("Unsupported VARIANT primitive type for cast to DOUBLE: " + variant.primitiveType()); + }; + case SHORT_STRING -> VarcharOperators.castToDouble(variant.getString()); + default -> throw new VariantCastException("Unsupported VARIANT type for cast to DOUBLE: " + variant.basicType()); + }; + } + + public static Long asShortDecimal(Variant variant, int precision, int scale) + { + BigDecimal bigDecimal = asJavaDecimal(variant, precision, scale); + if (bigDecimal == null) { + return null; + } + return bigDecimal.unscaledValue().longValue(); + } + + public static Int128 asLongDecimal(Variant variant, int precision, int scale) + { + BigDecimal bigDecimal = asJavaDecimal(variant, precision, scale); + if (bigDecimal == null) { + return null; + } + return Int128.valueOf(bigDecimal.unscaledValue()); + } + + private static BigDecimal asJavaDecimal(Variant variant, int precision, int scale) + { + BigDecimal bigDecimal; + try { + bigDecimal = switch (variant.basicType()) { + case PRIMITIVE -> switch (variant.primitiveType()) { + case NULL -> null; + case BOOLEAN_TRUE -> BigDecimal.ONE; + case BOOLEAN_FALSE -> BigDecimal.ZERO; + case STRING -> new BigDecimal(variant.getString().toStringUtf8()); + case INT8 -> BigDecimal.valueOf(variant.getByte()); + case INT16 -> BigDecimal.valueOf(variant.getShort()); + case INT32 -> BigDecimal.valueOf(variant.getInt()); + case INT64 -> BigDecimal.valueOf(variant.getLong()); + case DECIMAL4, DECIMAL8, DECIMAL16 -> variant.getDecimal(); + case FLOAT -> BigDecimal.valueOf(variant.getFloat()); + case DOUBLE -> BigDecimal.valueOf(variant.getDouble()); + default -> throw new VariantCastException("Unsupported VARIANT primitive type for cast to DECIMAL(%s,%s): %s".formatted(precision, scale, variant.primitiveType())); + }; + case SHORT_STRING -> new BigDecimal(variant.getString().toStringUtf8()); + default -> throw new VariantCastException("Unsupported VARIANT type for cast to DECIMAL(%s,%s): %s".formatted(precision, scale, variant.basicType())); + }; + } + catch (NumberFormatException e) { + throw new TrinoException(INVALID_CAST_ARGUMENT, "Cannot cast input variant to DECIMAL(" + precision + "," + scale + ")", e); + } + if (bigDecimal == null) { + return null; + } + bigDecimal = bigDecimal.setScale(scale, HALF_UP); + if (bigDecimal.precision() > precision) { + throw new TrinoException(INVALID_CAST_ARGUMENT, "Cannot cast input variant to DECIMAL(" + precision + "," + scale + ")"); + } + return bigDecimal; + } + + public static Long asDate(Variant variant) + { + return switch (variant.basicType()) { + case PRIMITIVE -> switch (variant.primitiveType()) { + case NULL -> null; + case DATE -> (long) variant.getDate(); + case TIMESTAMP_UTC_MICROS, TIMESTAMP_NTZ_MICROS -> { + long micros = variant.getTimestampMicros(); + long epochSeconds = Math.floorDiv(micros, 1_000_000L); + int nanoAdjustment = (int) Math.floorMod(micros, 1_000_000L) * 1_000; + yield Instant.ofEpochSecond(epochSeconds, nanoAdjustment) + .atZone(UTC) + .toLocalDate() + .toEpochDay(); + } + case TIMESTAMP_UTC_NANOS, TIMESTAMP_NTZ_NANOS -> { + long nanos = variant.getTimestampNanos(); + long epochSeconds = Math.floorDiv(nanos, 1_000_000_000L); + int nanoAdjustment = (int) Math.floorMod(nanos, 1_000_000_000L); + yield Instant.ofEpochSecond(epochSeconds, nanoAdjustment) + .atZone(UTC) + .toLocalDate() + .toEpochDay(); + } + case STRING -> DateOperators.castFromVarchar(variant.getString()); + default -> throw new VariantCastException("Unsupported VARIANT primitive type for cast to DATE: " + variant.primitiveType()); + }; + case SHORT_STRING -> DateOperators.castFromVarchar(variant.getString()); + default -> throw new VariantCastException("Unsupported VARIANT type for cast to DATE: " + variant.basicType()); + }; + } + + public static Long asTime(Variant variant, int precision) + { + return switch (variant.basicType()) { + case PRIMITIVE -> switch (variant.primitiveType()) { + case NULL -> null; + case TIME_NTZ_MICROS -> { + long timePicos = variant.getTimeMicros() * 1_000_000L; + // round can round up to a value equal to 24h, so we need to compute module 24h + yield round(timePicos, MAX_PRECISION - precision) % PICOSECONDS_PER_DAY; + } + case TIMESTAMP_UTC_MICROS, TIMESTAMP_NTZ_MICROS -> { + long micros = Math.floorMod(variant.getTimestampMicros(), MICROSECONDS_PER_DAY); + long timePicos = micros * 1_000_000L; + // round can round up to a value equal to 24h, so we need to compute module 24h + yield round(timePicos, MAX_PRECISION - precision) % PICOSECONDS_PER_DAY; + } + case TIMESTAMP_UTC_NANOS, TIMESTAMP_NTZ_NANOS -> { + long nanos = Math.floorMod(variant.getTimestampNanos(), NANOSECONDS_PER_DAY); + long timePicos = nanos * 1_000L; + // round can round up to a value equal to 24h, so we need to compute module 24h + yield round(timePicos, MAX_PRECISION - precision) % PICOSECONDS_PER_DAY; + } + case STRING -> TimeOperators.castFromVarchar(precision, variant.getString()); + default -> throw new VariantCastException("Unsupported VARIANT primitive type for cast to TIME: " + variant.primitiveType()); + }; + case SHORT_STRING -> TimeOperators.castFromVarchar(precision, variant.getString()); + default -> throw new VariantCastException("Unsupported VARIANT type for cast to TIME: " + variant.basicType()); + }; + } + + public static Long asShortTimestamp(Variant variant, int precision) + { + if (precision < 0 || precision > TimestampType.MAX_SHORT_PRECISION) { + throw new IllegalArgumentException("precision must be between 0 and " + TimestampType.MAX_SHORT_PRECISION); + } + + return switch (variant.basicType()) { + case PRIMITIVE -> switch (variant.primitiveType()) { + case NULL -> null; + case DATE -> TimeUnit.DAYS.toMicros(variant.getDate()); + case TIMESTAMP_UTC_MICROS, TIMESTAMP_NTZ_MICROS -> { + long micros = variant.getTimestampMicros(); + if (precision == 6) { + yield micros; + } + yield round(micros, 6 - precision); + } + case TIMESTAMP_UTC_NANOS, TIMESTAMP_NTZ_NANOS -> { + long nanos = variant.getTimestampNanos(); + // round is always required since the max precision is 6 (microseconds) + long roundedNanos = round(nanos, 9 - precision); + yield roundedNanos / 1_000; + } + case STRING -> VarcharToTimestampCast.castToShortTimestamp(precision, variant.getString().toStringUtf8()); + default -> throw new VariantCastException("Unsupported VARIANT primitive type for cast to TIMESTAMP(%d): %s".formatted(precision, variant.primitiveType())); + }; + case SHORT_STRING -> VarcharToTimestampCast.castToShortTimestamp(precision, variant.getString().toStringUtf8()); + default -> throw new VariantCastException("Unsupported VARIANT type for cast to TIMESTAMP(%d): %s".formatted(precision, variant.basicType())); + }; + } + + public static LongTimestamp asLongTimestamp(Variant variant, int precision) + { + if (precision <= TimestampType.MAX_SHORT_PRECISION || precision > TimestampType.MAX_PRECISION) { + throw new IllegalArgumentException("precision must be between %d and %d".formatted(TimestampType.MAX_SHORT_PRECISION, TimestampType.MAX_PRECISION)); + } + + return switch (variant.basicType()) { + case PRIMITIVE -> switch (variant.primitiveType()) { + case NULL -> null; + case DATE -> new LongTimestamp(TimeUnit.DAYS.toMicros(variant.getDate()), 0); + case TIMESTAMP_UTC_MICROS, TIMESTAMP_NTZ_MICROS -> new LongTimestamp(variant.getTimestampMicros(), 0); + case TIMESTAMP_UTC_NANOS, TIMESTAMP_NTZ_NANOS -> { + long nanos = variant.getTimestampNanos(); + if (precision < 9) { + nanos = round(nanos, 9 - precision); + } + long micros = Math.floorDiv(nanos, 1_000L); + int picosOfMicro = toIntExact(Math.floorMod(nanos, 1_000L) * 1_000L); + yield new LongTimestamp(micros, picosOfMicro); + } + case STRING -> VarcharToTimestampCast.castToLongTimestamp(precision, variant.getString().toStringUtf8()); + default -> throw new VariantCastException("Unsupported VARIANT primitive type for cast to TIMESTAMP(%d): %s".formatted(precision, variant.primitiveType())); + }; + case SHORT_STRING -> VarcharToTimestampCast.castToLongTimestamp(precision, variant.getString().toStringUtf8()); + default -> throw new VariantCastException("Unsupported VARIANT type for cast to TIMESTAMP(%d): %s".formatted(precision, variant.basicType())); + }; + } + + public static Long asShortTimestampWithTimeZone(Variant variant, int precision) + { + if (precision < 0 || precision > TimestampWithTimeZoneType.MAX_SHORT_PRECISION) { + throw new IllegalArgumentException("precision must be between 0 and " + TimestampWithTimeZoneType.MAX_SHORT_PRECISION); + } + + return switch (variant.basicType()) { + case PRIMITIVE -> switch (variant.primitiveType()) { + case NULL -> null; + case DATE -> packDateTimeWithZone(TimeUnit.DAYS.toMillis(variant.getDate()), UTC_KEY); + // round is always required as the max precision is 3 (milliseconds), and both micros and nanos have higher precision + case TIMESTAMP_UTC_MICROS, TIMESTAMP_NTZ_MICROS -> packDateTimeWithZone(round(variant.getTimestampMicros(), 6 - precision) / 1_000, UTC_KEY); + case TIMESTAMP_UTC_NANOS, TIMESTAMP_NTZ_NANOS -> packDateTimeWithZone(round(variant.getTimestampNanos(), 9 - precision) / 1_000_000, UTC_KEY); + case STRING -> asShortTimestampWithTimeZone(variant.getString(), precision); + default -> throw new VariantCastException("Unsupported VARIANT primitive type for cast to TIMESTAMP(%d) WITH TIME ZONE: %s".formatted(precision, variant.primitiveType())); + }; + case SHORT_STRING -> asShortTimestampWithTimeZone(variant.getString(), precision); + default -> throw new VariantCastException("Unsupported VARIANT type for cast to TIMESTAMP(%d) WITH TIME ZONE: %s".formatted(precision, variant.basicType())); + }; + } + + private static long asShortTimestampWithTimeZone(Slice varchar, int precision) + { + return VarcharToTimestampWithTimeZoneCast.toShort(precision, varchar.toStringUtf8(), timezone -> timezone == null ? UTC : ZoneId.of(timezone)); + } + + public static LongTimestampWithTimeZone asLongTimestampWithTimeZone(Variant variant, int precision) + { + if (precision <= TimestampWithTimeZoneType.MAX_SHORT_PRECISION || precision > TimestampWithTimeZoneType.MAX_PRECISION) { + throw new IllegalArgumentException("precision must be between %d and %d".formatted(TimestampWithTimeZoneType.MAX_SHORT_PRECISION, TimestampWithTimeZoneType.MAX_PRECISION)); + } + + return switch (variant.basicType()) { + case PRIMITIVE -> switch (variant.primitiveType()) { + case NULL -> null; + case DATE -> fromEpochMillisAndFraction(TimeUnit.DAYS.toMillis(variant.getDate()), 0, UTC_KEY); + case TIMESTAMP_UTC_MICROS, TIMESTAMP_NTZ_MICROS -> { + long micros = variant.getTimestampMicros(); + if (precision < 6) { + micros = round(micros, 6 - precision); + } + long millis = Math.floorDiv(micros, 1_000L); + int picosOfMillis = toIntExact(Math.floorMod(micros, 1_000L) * 1_000_000L); + yield fromEpochMillisAndFraction(millis, picosOfMillis, UTC_KEY); + } + case TIMESTAMP_UTC_NANOS, TIMESTAMP_NTZ_NANOS -> { + long nanos = variant.getTimestampNanos(); + if (precision < 9) { + nanos = round(nanos, 9 - precision); + } + long millis = Math.floorDiv(nanos, 1_000_000L); + int picosOfMillis = toIntExact(Math.floorMod(nanos, 1_000_000L) * 1_000L); + yield fromEpochMillisAndFraction(millis, picosOfMillis, UTC_KEY); + } + case STRING -> asLongTimestampWithTimeZone(variant.getString(), precision); + default -> throw new VariantCastException("Unsupported VARIANT primitive type for cast to TIMESTAMP(%d) WITH TIME ZONE: %s".formatted(precision, variant.primitiveType())); + }; + case SHORT_STRING -> asLongTimestampWithTimeZone(variant.getString(), precision); + default -> throw new VariantCastException("Unsupported VARIANT type for cast to TIMESTAMP(%d) WITH TIME ZONE: %s".formatted(precision, variant.basicType())); + }; + } + + private static LongTimestampWithTimeZone asLongTimestampWithTimeZone(Slice varchar, int precision) + { + return VarcharToTimestampWithTimeZoneCast.toLong(precision, varchar.toStringUtf8(), timezone -> timezone == null ? UTC : ZoneId.of(timezone)); + } + + public static Slice asUuid(Variant variant) + { + return switch (variant.basicType()) { + case PRIMITIVE -> switch (variant.primitiveType()) { + case NULL -> null; + case UUID -> variant.getUuidSlice(); + case STRING -> UuidOperators.castFromVarcharToUuid(variant.getString()); + default -> throw new VariantCastException("Unsupported VARIANT primitive type for cast to UUID: " + variant.primitiveType()); + }; + case SHORT_STRING -> UuidOperators.castFromVarcharToUuid(variant.getString()); + default -> throw new VariantCastException("Unsupported VARIANT type for cast to UUID: " + variant.basicType()); + }; + } + + public static Slice asVarbinary(Variant variant) + { + if (variant.isNull()) { + return null; + } + if (variant.basicType() != PRIMITIVE || variant.primitiveType() != BINARY) { + throw new VariantCastException("Unsupported VARIANT type for cast to VARBINARY: " + variant.basicType() + "/" + variant.primitiveType()); + } + return variant.getBinary(); + } + + // given a VARIANT parser, write to the BlockBuilder + public interface BlockBuilderAppender + { + void append(Variant variant, BlockBuilder blockBuilder); + + static BlockBuilderAppender createBlockBuilderAppender(Type type) + { + if (type instanceof BooleanType) { + return new BooleanBlockBuilderAppender(); + } + if (type instanceof TinyintType) { + return new TinyintBlockBuilderAppender(); + } + if (type instanceof SmallintType) { + return new SmallintBlockBuilderAppender(); + } + if (type instanceof IntegerType) { + return new IntegerBlockBuilderAppender(); + } + if (type instanceof BigintType) { + return new BigintBlockBuilderAppender(); + } + if (type instanceof RealType) { + return new RealBlockBuilderAppender(); + } + if (type instanceof DoubleType) { + return new DoubleBlockBuilderAppender(); + } + if (type instanceof DecimalType decimalType) { + if (decimalType.isShort()) { + return new ShortDecimalBlockBuilderAppender(decimalType); + } + + return new LongDecimalBlockBuilderAppender(decimalType); + } + if (type instanceof VarcharType) { + return new VarcharBlockBuilderAppender(type); + } + if (type instanceof VarbinaryType) { + return new VarbinaryBlockBuilderAppender(type); + } + if (type instanceof DateType) { + return new DateBlockBuilderAppender(); + } + if (type instanceof TimeType timeType) { + return new TimeBlockBuilderAppender(timeType); + } + if (type instanceof TimestampType timestampType) { + if (timestampType.isShort()) { + return new ShortTimestampBlockBuilderAppender(timestampType); + } + return new LongTimestampBlockBuilderAppender(timestampType); + } + if (type instanceof TimestampWithTimeZoneType timestampWithTimeZoneType) { + if (timestampWithTimeZoneType.isShort()) { + return new ShortTimestampWithTimeZoneBlockBuilderAppender(timestampWithTimeZoneType); + } + return new LongTimestampWithTimeZoneBlockBuilderAppender(timestampWithTimeZoneType); + } + if (type instanceof UuidType) { + return new UuidBlockBuilderAppender(); + } + if (type instanceof VariantType) { + return new VariantBlockBuilderAppender(); + } + if (type instanceof JsonType) { + return new JsonBlockBuilderAppender(); + } + if (type instanceof ArrayType arrayType) { + return new ArrayBlockBuilderAppender(createBlockBuilderAppender(arrayType.getElementType())); + } + if (type instanceof MapType mapType) { + checkArgument( + mapType.getKeyType() instanceof VarcharType, + "Only maps with VARCHAR keys are supported for cast from VARIANT, but got: %s", + mapType); + return new MapBlockBuilderAppender(createBlockBuilderAppender(mapType.getValueType())); + } + if (type instanceof RowType rowType) { + List rowFields = rowType.getFields(); + BlockBuilderAppender[] fieldAppenders = new BlockBuilderAppender[rowFields.size()]; + for (int i = 0; i < fieldAppenders.length; i++) { + fieldAppenders[i] = createBlockBuilderAppender(rowFields.get(i).getType()); + } + return new RowBlockBuilderAppender(fieldAppenders, getFieldNameToIndex(rowFields)); + } + + throw new TrinoException(INVALID_FUNCTION_ARGUMENT, format("Unsupported type: %s", type)); + } + } + + private static class BooleanBlockBuilderAppender + implements BlockBuilderAppender + { + @Override + public void append(Variant variant, BlockBuilder blockBuilder) + { + Boolean result = asBoolean(variant); + if (result == null) { + blockBuilder.appendNull(); + } + else { + BOOLEAN.writeBoolean(blockBuilder, result); + } + } + } + + private static class TinyintBlockBuilderAppender + implements BlockBuilderAppender + { + @Override + public void append(Variant variant, BlockBuilder blockBuilder) + { + Long result = asTinyint(variant); + if (result == null) { + blockBuilder.appendNull(); + } + else { + TINYINT.writeLong(blockBuilder, result); + } + } + } + + private static class SmallintBlockBuilderAppender + implements BlockBuilderAppender + { + @Override + public void append(Variant variant, BlockBuilder blockBuilder) + { + Long result = asSmallint(variant); + if (result == null) { + blockBuilder.appendNull(); + } + else { + SMALLINT.writeLong(blockBuilder, result); + } + } + } + + private static class IntegerBlockBuilderAppender + implements BlockBuilderAppender + { + @Override + public void append(Variant variant, BlockBuilder blockBuilder) + { + Long result = asInteger(variant); + if (result == null) { + blockBuilder.appendNull(); + } + else { + INTEGER.writeLong(blockBuilder, result); + } + } + } + + private static class BigintBlockBuilderAppender + implements BlockBuilderAppender + { + @Override + public void append(Variant variant, BlockBuilder blockBuilder) + { + Long result = asBigint(variant); + if (result == null) { + blockBuilder.appendNull(); + } + else { + BIGINT.writeLong(blockBuilder, result); + } + } + } + + private static class RealBlockBuilderAppender + implements BlockBuilderAppender + { + @Override + public void append(Variant variant, BlockBuilder blockBuilder) + { + Long result = asReal(variant); + if (result == null) { + blockBuilder.appendNull(); + } + else { + REAL.writeLong(blockBuilder, result); + } + } + } + + private static class DoubleBlockBuilderAppender + implements BlockBuilderAppender + { + @Override + public void append(Variant variant, BlockBuilder blockBuilder) + { + Double result = asDouble(variant); + if (result == null) { + blockBuilder.appendNull(); + } + else { + DOUBLE.writeDouble(blockBuilder, result); + } + } + } + + private record ShortDecimalBlockBuilderAppender(DecimalType type) + implements BlockBuilderAppender + { + @Override + public void append(Variant variant, BlockBuilder blockBuilder) + { + Long result = asShortDecimal(variant, type.getPrecision(), type.getScale()); + + if (result == null) { + blockBuilder.appendNull(); + } + else { + type.writeLong(blockBuilder, result); + } + } + } + + private record LongDecimalBlockBuilderAppender(DecimalType type) + implements BlockBuilderAppender + { + @Override + public void append(Variant variant, BlockBuilder blockBuilder) + { + Int128 result = asLongDecimal(variant, type.getPrecision(), type.getScale()); + + if (result == null) { + blockBuilder.appendNull(); + } + else { + type.writeObject(blockBuilder, result); + } + } + } + + private record VarcharBlockBuilderAppender(Type type) + implements BlockBuilderAppender + { + @Override + public void append(Variant variant, BlockBuilder blockBuilder) + { + Slice result = asVarchar(variant); + if (result == null) { + blockBuilder.appendNull(); + } + else { + type.writeSlice(blockBuilder, result); + } + } + } + + private record VarbinaryBlockBuilderAppender(Type type) + implements BlockBuilderAppender + { + @Override + public void append(Variant variant, BlockBuilder blockBuilder) + { + Slice result = asVarbinary(variant); + if (result == null) { + blockBuilder.appendNull(); + } + else { + VARBINARY.writeSlice(blockBuilder, result); + } + } + } + + private record DateBlockBuilderAppender() + implements BlockBuilderAppender + { + @Override + public void append(Variant variant, BlockBuilder blockBuilder) + { + Long result = asDate(variant); + if (result == null) { + blockBuilder.appendNull(); + } + else { + DATE.writeLong(blockBuilder, result); + } + } + } + + private record TimeBlockBuilderAppender(TimeType type) + implements BlockBuilderAppender + { + @Override + public void append(Variant variant, BlockBuilder blockBuilder) + { + Long result = asTime(variant, type.getPrecision()); + if (result == null) { + blockBuilder.appendNull(); + } + else { + type.writeLong(blockBuilder, result); + } + } + } + + private record ShortTimestampBlockBuilderAppender(TimestampType type) + implements BlockBuilderAppender + { + @Override + public void append(Variant variant, BlockBuilder blockBuilder) + { + Long result = asShortTimestamp(variant, type.getPrecision()); + if (result == null) { + blockBuilder.appendNull(); + } + else { + type.writeLong(blockBuilder, result); + } + } + } + + private record LongTimestampBlockBuilderAppender(TimestampType type) + implements BlockBuilderAppender + { + @Override + public void append(Variant variant, BlockBuilder blockBuilder) + { + LongTimestamp result = asLongTimestamp(variant, type.getPrecision()); + if (result == null) { + blockBuilder.appendNull(); + } + else { + type.writeObject(blockBuilder, result); + } + } + } + + private record ShortTimestampWithTimeZoneBlockBuilderAppender(TimestampWithTimeZoneType type) + implements BlockBuilderAppender + { + @Override + public void append(Variant variant, BlockBuilder blockBuilder) + { + Long result = asShortTimestampWithTimeZone(variant, type.getPrecision()); + if (result == null) { + blockBuilder.appendNull(); + } + else { + type.writeLong(blockBuilder, result); + } + } + } + + private record LongTimestampWithTimeZoneBlockBuilderAppender(TimestampWithTimeZoneType type) + implements BlockBuilderAppender + { + @Override + public void append(Variant variant, BlockBuilder blockBuilder) + { + LongTimestampWithTimeZone result = asLongTimestampWithTimeZone(variant, type.getPrecision()); + if (result == null) { + blockBuilder.appendNull(); + } + else { + type.writeObject(blockBuilder, result); + } + } + } + + private record UuidBlockBuilderAppender() + implements BlockBuilderAppender + { + @Override + public void append(Variant variant, BlockBuilder blockBuilder) + { + Slice result = asUuid(variant); + if (result == null) { + blockBuilder.appendNull(); + } + else { + UUID.writeSlice(blockBuilder, result); + } + } + } + + private record VariantBlockBuilderAppender() + implements BlockBuilderAppender + { + @Override + public void append(Variant variant, BlockBuilder blockBuilder) + { + VARIANT.writeObject(blockBuilder, variant); + } + } + + private record JsonBlockBuilderAppender() + implements BlockBuilderAppender + { + @Override + public void append(Variant variant, BlockBuilder blockBuilder) + { + Slice result = asJson(variant); + if (result == null) { + blockBuilder.appendNull(); + } + else { + JSON.writeSlice(blockBuilder, result); + } + } + } + + private record ArrayBlockBuilderAppender(BlockBuilderAppender elementAppender) + implements BlockBuilderAppender + { + @Override + public void append(Variant variant, BlockBuilder blockBuilder) + { + if (variant.isNull()) { + blockBuilder.appendNull(); + return; + } + + if (variant.basicType() != Header.BasicType.ARRAY) { + throw new VariantCastException("Expected a variant array, but got " + variant.basicType()); + } + ((ArrayBlockBuilder) blockBuilder).buildEntry(elementBuilder -> + variant.arrayElements().forEach(element -> elementAppender.append(element, elementBuilder))); + } + } + + private record MapBlockBuilderAppender(BlockBuilderAppender valueAppender) + implements BlockBuilderAppender + { + @Override + public void append(Variant variant, BlockBuilder blockBuilder) + { + if (variant.isNull()) { + blockBuilder.appendNull(); + return; + } + + if (variant.basicType() != Header.BasicType.OBJECT) { + throw new VariantCastException(format("Expected a variant object, but got %s", variant.basicType())); + } + + MapBlockBuilder mapBlockBuilder = (MapBlockBuilder) blockBuilder; + Metadata metadata = variant.metadata(); + mapBlockBuilder.buildEntry((keyBuilder, valueBuilder) -> + variant.objectFields().forEach(fieldIdValue -> { + ((VariableWidthBlockBuilder) keyBuilder).writeEntry(metadata.get(fieldIdValue.fieldId())); + valueAppender.append(fieldIdValue.value(), valueBuilder); + })); + } + } + + private record RowBlockBuilderAppender(BlockBuilderAppender[] fieldAppenders, Optional> fieldNameToIndex) + implements BlockBuilderAppender + { + @Override + public void append(Variant variant, BlockBuilder blockBuilder) + { + if (variant.isNull()) { + blockBuilder.appendNull(); + return; + } + + if (variant.basicType() != Header.BasicType.OBJECT) { + throw new VariantCastException("Expected an object, but got " + variant.basicType()); + } + + ((RowBlockBuilder) blockBuilder).buildEntry(fieldBuilders -> parseVariantToSingleRowBlock(variant, fieldBuilders, fieldAppenders, fieldNameToIndex)); + } + } + + private static Optional> getFieldNameToIndex(List rowFields) + { + if (rowFields.getFirst().getName().isEmpty()) { + return Optional.empty(); + } + + ImmutableMap.Builder fieldNameToIndex = ImmutableMap.builderWithExpectedSize(rowFields.size()); + for (int i = 0; i < rowFields.size(); i++) { + fieldNameToIndex.put(rowFields.get(i).getName().orElseThrow(), i); + } + return Optional.of(fieldNameToIndex.buildOrThrow()); + } + + private static void parseVariantToSingleRowBlock( + Variant variant, + List fieldBuilders, + BlockBuilderAppender[] fieldAppenders, + Optional> fieldNameToIndex) + { + if (fieldNameToIndex.isEmpty()) { + throw new VariantCastException("Cannot cast VARIANT object to anonymous row type. Row fields must have names."); + } + boolean[] fieldWritten = new boolean[fieldAppenders.length]; + + Metadata metadata = variant.metadata(); + variant.objectFields().forEach(field -> { + String fieldName = metadata.get(field.fieldId()).toStringUtf8().toLowerCase(Locale.ENGLISH); + Integer fieldIndex = fieldNameToIndex.get().get(fieldName); + if (fieldIndex != null) { + if (fieldWritten[fieldIndex]) { + throw new VariantCastException("Duplicate field: " + fieldName); + } + fieldWritten[fieldIndex] = true; + fieldAppenders[fieldIndex].append(field.value(), fieldBuilders.get(fieldIndex)); + } + }); + + for (int i = 0; i < fieldWritten.length; i++) { + if (!fieldWritten[i]) { + fieldBuilders.get(i).appendNull(); + } + } + } + + public static Slice asJson(Variant variant) + { + try { + SliceOutput output = new DynamicSliceOutput(40); + try (JsonGenerator jsonGenerator = createJsonGenerator(JSON_MAPPER, output)) { + jsonGenerator.configure(ESCAPE_NON_ASCII.mappedFeature(), false); + toJsonValue(jsonGenerator, variant); + } + return output.slice(); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static void toJsonValue(JsonGenerator jsonGenerator, Variant variant) + throws IOException + { + switch (variant.basicType()) { + case PRIMITIVE -> { + switch (variant.primitiveType()) { + case NULL -> jsonGenerator.writeNull(); + case BINARY -> { + Slice binary = variant.getBinary(); + jsonGenerator.writeBinary(binary.byteArray(), binary.byteArrayOffset(), binary.length()); + } + case STRING -> jsonGenerator.writeString(variant.getString().toStringUtf8()); + case BOOLEAN_TRUE -> jsonGenerator.writeBoolean(true); + case BOOLEAN_FALSE -> jsonGenerator.writeBoolean(false); + case INT8 -> jsonGenerator.writeNumber(variant.getByte()); + case INT16 -> jsonGenerator.writeNumber(variant.getShort()); + case INT32 -> jsonGenerator.writeNumber(variant.getInt()); + case INT64 -> jsonGenerator.writeNumber(variant.getLong()); + case DECIMAL4, DECIMAL8, DECIMAL16 -> jsonGenerator.writeNumber(variant.getDecimal()); + case FLOAT -> jsonGenerator.writeNumber(variant.getFloat()); + case DOUBLE -> jsonGenerator.writeNumber(variant.getDouble()); + case DATE -> jsonGenerator.writeString(DateOperators.castToVarchar(UNBOUNDED_LENGTH, variant.getDate()).toStringUtf8()); + case TIME_NTZ_MICROS -> jsonGenerator.writeString(TimeOperators.castToVarchar(UNBOUNDED_LENGTH, 6, variant.getTimeMicros() * 1_000_000L).toStringUtf8()); + case TIMESTAMP_UTC_MICROS -> { + long micros = variant.getTimestampMicros(); + long epochMillis = Math.floorDiv(micros, 1_000L); + int picosOfMilli = toIntExact(Math.floorMod(micros, 1_000L) * 1_000_000L); + jsonGenerator.writeString(DateTimes.formatTimestampWithTimeZone(6, epochMillis, picosOfMilli, UTC_KEY.getZoneId())); + } + case TIMESTAMP_NTZ_MICROS -> jsonGenerator.writeString(DateTimes.formatTimestamp(6, variant.getTimestampMicros(), 0, UTC)); + case TIMESTAMP_UTC_NANOS -> { + long nanos = variant.getTimestampNanos(); + long epochMillis = Math.floorDiv(nanos, 1_000_000L); + int picosOfMilli = toIntExact(Math.floorMod(nanos, 1_000_000L) * 1_000L); + jsonGenerator.writeString(DateTimes.formatTimestampWithTimeZone(9, epochMillis, picosOfMilli, UTC_KEY.getZoneId())); + } + case TIMESTAMP_NTZ_NANOS -> { + long nanos = variant.getTimestampNanos(); + long epochMicros = Math.floorDiv(nanos, 1_000L); + int picosOfMicros = toIntExact(Math.floorMod(nanos, 1_000L) * 1_000L); + jsonGenerator.writeString(DateTimes.formatTimestamp(9, epochMicros, picosOfMicros, UTC)); + } + case UUID -> jsonGenerator.writeString(variant.getUuid().toString()); + } + } + case SHORT_STRING -> jsonGenerator.writeString(variant.getString().toStringUtf8()); + case ARRAY -> { + jsonGenerator.writeStartArray(); + variant.arrayElements().forEach(element -> { + try { + toJsonValue(jsonGenerator, element); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + }); + jsonGenerator.writeEndArray(); + } + case OBJECT -> { + Metadata metadata = variant.metadata(); + jsonGenerator.writeStartObject(); + variant.objectFields().forEach(fieldIdValue -> { + try { + String fieldName = metadata.get(fieldIdValue.fieldId()).toStringUtf8(); + jsonGenerator.writeFieldName(fieldName); + toJsonValue(jsonGenerator, fieldIdValue.value()); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + }); + jsonGenerator.writeEndObject(); + } + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/util/variant/VariantVariantWriter.java b/core/trino-main/src/main/java/io/trino/util/variant/VariantVariantWriter.java new file mode 100644 index 000000000000..a87874114573 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/util/variant/VariantVariantWriter.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.util.variant; + +import io.airlift.slice.Slice; +import io.trino.spi.variant.Metadata; +import io.trino.spi.variant.Variant; +import io.trino.spi.variant.VariantFieldRemapper; + +import java.util.function.IntUnaryOperator; + +import static java.util.Objects.requireNonNull; + +final class VariantVariantWriter + implements VariantWriter +{ + public static final VariantVariantWriter VARIANT_VARIANT_WRITER = new VariantVariantWriter(); + + private VariantVariantWriter() {} + + @Override + public PlannedValue plan(Metadata.Builder metadataBuilder, Object value) + { + if (value == null) { + return NullPlannedValue.NULL_PLANNED_VALUE; + } + return new PlannedVariantValue(VariantFieldRemapper.create((Variant) value, metadataBuilder)); + } + + private record PlannedVariantValue(VariantFieldRemapper variantFieldRemapper) + implements PlannedValue + { + private PlannedVariantValue(VariantFieldRemapper variantFieldRemapper) + { + this.variantFieldRemapper = requireNonNull(variantFieldRemapper, "variantFieldRemapper is null"); + } + + @Override + public void finalize(IntUnaryOperator sortedFieldIdMapping) + { + variantFieldRemapper.finalize(sortedFieldIdMapping); + } + + @Override + public int size() + { + return variantFieldRemapper.size(); + } + + @Override + public int write(Slice out, int offset) + { + variantFieldRemapper.write(out, offset); + return size(); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/util/variant/VariantWriter.java b/core/trino-main/src/main/java/io/trino/util/variant/VariantWriter.java new file mode 100644 index 000000000000..4f6f6fed94c8 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/util/variant/VariantWriter.java @@ -0,0 +1,104 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.util.variant; + +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; +import io.trino.spi.type.VarcharType; +import io.trino.spi.type.VariantType; +import io.trino.spi.variant.Metadata; +import io.trino.spi.variant.Variant; +import io.trino.type.JsonType; + +import java.util.Optional; +import java.util.function.IntUnaryOperator; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.util.variant.JsonVariantWriter.JSON_VARIANT_WRITER; +import static io.trino.util.variant.VariantVariantWriter.VARIANT_VARIANT_WRITER; + +public interface VariantWriter +{ + static VariantWriter create(Type type) + { + return switch (type) { + case ArrayType arrayType -> { + Type elementType = arrayType.getElementType(); + Optional primitiveElementEncoder = PrimitiveVariantEncoder.create(elementType); + if (primitiveElementEncoder.isPresent()) { + yield new PrimitiveArrayVariantWriter(arrayType, primitiveElementEncoder.get()); + } + yield new ArrayVariantWriter(arrayType, create(elementType)); + } + case MapType mapType -> { + checkArgument(mapType.getKeyType() instanceof VarcharType, "Map key type must be VARCHAR: %s", mapType.getKeyType()); + Type valueType = mapType.getValueType(); + Optional primitiveValueEncoder = PrimitiveVariantEncoder.create(valueType); + if (primitiveValueEncoder.isPresent()) { + yield new PrimitiveMapVariantWriter(mapType, primitiveValueEncoder.get()); + } + yield new MapVariantWriter(mapType, create(valueType)); + } + case RowType rowType -> new RowVariantWriter(rowType); + case VariantType _ -> VARIANT_VARIANT_WRITER; + case JsonType _ -> JSON_VARIANT_WRITER; + // VariantWriter cannot be used for primitive types. Instead, use VariantEncoder directly, which is significantly more efficient. + default -> throw new IllegalArgumentException("Unsupported type for VariantWriter: " + type); + }; + } + + default Variant write(Object value) + { + Metadata.Builder metadataBuilder = Metadata.builder(); + PlannedValue plannedValue = plan(metadataBuilder, value); + Metadata.Builder.SortedMetadata build = metadataBuilder.buildSorted(); + IntUnaryOperator sortedFieldIdMapping = build.sortedFieldIdMapping(); + plannedValue.finalize(sortedFieldIdMapping); + Slice out = Slices.allocate(plannedValue.size()); + plannedValue.write(out, 0); + return Variant.from(build.metadata(), out); + } + + /// Plans the writing of the given value. + /// Fields required for writing the value will be registered in the provided Builder. + /// The returned PlannedValue can then be used to write the value after finalizing with remapped + /// field IDs from the `Metadata.Builder`. + /// @param metadataBuilder the metadata builder to register required fields + /// @param value the stack value to plan writing for + PlannedValue plan(Metadata.Builder metadataBuilder, Object value); + + /// A planned value that can be finalized and written. + interface PlannedValue + { + /// Finalizes the planned value by remapping provisional field IDs to final sorted field IDs. + /// The system creates a globally sorted metadata dictionary after all values have been planned, + /// which may change the field IDs assigned during planning. This method can rely on the field + /// IDs being assigned in ascending order for determining the write order of object fields. + void finalize(IntUnaryOperator sortedFieldIdMapping); + + /// Returns the size in bytes required to write the value. + /// This must be called after finalize(). + int size(); + + /// Writes the value to the given output slice at the specified offset. + /// This must be called after finalize(). + /// This method can be called multiple times to write the same value to different output slices. + /// @return the number of bytes written + int write(Slice out, int offset); + } +} diff --git a/core/trino-main/src/test/java/io/trino/block/TestVariantBlock.java b/core/trino-main/src/test/java/io/trino/block/TestVariantBlock.java new file mode 100644 index 000000000000..9b7c68425aec --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/block/TestVariantBlock.java @@ -0,0 +1,105 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.block; + +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariantBlock; +import io.trino.spi.block.VariantBlockBuilder; +import io.trino.spi.variant.Variant; +import org.junit.jupiter.api.Test; + +import java.util.Random; + +import static java.util.Arrays.copyOfRange; +import static org.assertj.core.api.Assertions.assertThat; + +class TestVariantBlock + extends AbstractTestBlock +{ + @Test + void test() + { + Variant[] expectedValues = createTestValue(17); + assertFixedWithValues(expectedValues); + assertFixedWithValues(alternatingNullValues(expectedValues)); + } + + @Test + void testCopyRegion() + { + Variant[] expectedValues = createTestValue(100); + Block block = createBlockBuilderWithValues(expectedValues).build(); + Block actual = block.copyRegion(10, 10); + Block expected = createBlockBuilderWithValues(copyOfRange(expectedValues, 10, 20)).build(); + assertThat(actual.getPositionCount()).isEqualTo(expected.getPositionCount()); + assertThat(actual.getSizeInBytes()).isEqualTo(expected.getSizeInBytes()); + } + + @Test + void testCopyPositions() + { + Variant[] expectedValues = alternatingNullValues(createTestValue(17)); + BlockBuilder blockBuilder = createBlockBuilderWithValues(expectedValues); + assertBlockFilteredPositions(expectedValues, blockBuilder.build(), 0, 2, 4, 6, 7, 9, 10, 16); + } + + private void assertFixedWithValues(Variant[] expectedValues) + { + Block block = createBlockBuilderWithValues(expectedValues).build(); + assertBlock(block, expectedValues); + } + + private static BlockBuilder createBlockBuilderWithValues(Variant[] expectedValues) + { + VariantBlockBuilder blockBuilder = new VariantBlockBuilder(null, expectedValues.length); + writeValues(expectedValues, blockBuilder); + return blockBuilder; + } + + private static void writeValues(Variant[] expectedValues, VariantBlockBuilder blockBuilder) + { + for (Variant expectedValue : expectedValues) { + if (expectedValue == null) { + blockBuilder.appendNull(); + } + else { + blockBuilder.writeEntry(expectedValue); + } + } + } + + private static Variant[] createTestValue(int positionCount) + { + Variant[] expectedValues = new Variant[positionCount]; + Random random = new Random(0); + for (int position = 0; position < positionCount; position++) { + expectedValues[position] = Variant.ofInt(random.nextInt()); + } + return expectedValues; + } + + @Override + protected void assertPositionValue(Block block, int position, T expectedValue) + { + if (expectedValue == null) { + assertThat(block.isNull(position)).isTrue(); + return; + } + + VariantBlock variantBlock = (VariantBlock) block; + Variant variant = variantBlock.getVariant(position); + assertThat(variant).isEqualTo(expectedValue); + } +} diff --git a/core/trino-main/src/test/java/io/trino/metadata/TestSignatureBinder.java b/core/trino-main/src/test/java/io/trino/metadata/TestSignatureBinder.java index 5365154d6495..8d0ae43a0c30 100644 --- a/core/trino-main/src/test/java/io/trino/metadata/TestSignatureBinder.java +++ b/core/trino-main/src/test/java/io/trino/metadata/TestSignatureBinder.java @@ -29,6 +29,7 @@ import java.util.List; import java.util.Optional; +import java.util.Set; import static io.trino.metadata.SignatureBinder.applyBoundVariables; import static io.trino.spi.type.BigintType.BIGINT; @@ -513,6 +514,24 @@ public void testBindToUnparametrizedVarcharIsImpossible() .succeeds(); } + @Test + public void testUnknownToVariantIsCastableToWithoutRecursion() + { + // This forces SignatureBinder to evaluate EXPLICIT_COERCION_TO via canCast(actualType, variant). + Signature function = functionSignature() + .returnType(BOOLEAN) + .argumentType(new TypeSignature("T")) + .typeVariableConstraint(TypeVariableConstraint.builder("T") + .castableTo(new TypeSignature("variant")) + .build()) + .build(); + + assertThat(function) + .boundTo(UNKNOWN) + .withCoercion() + .succeeds(); + } + @Test public void testBasic() { @@ -1139,6 +1158,60 @@ public void testCanCoerceFrom() .fails(); } + @Test + public void testRowIsCastableToVariantWhenFieldsAreCastable() + { + Signature function = functionSignature() + .returnType(BOOLEAN) + .argumentType(new TypeSignature("T")) + .typeVariableConstraint(TypeVariableConstraint.builder("T") + .rowType() + .castableTo(parseTypeSignature("variant", Set.of())) + .build()) + .build(); + + assertThat(function) + .boundTo(RowType.anonymous(ImmutableList.of(BIGINT, DOUBLE))) + .withCoercion() + .succeeds(); + } + + @Test + public void testRowIsNotCastableToArbitraryTypeWithoutRecursiveOperator() + { + Signature function = functionSignature() + .returnType(BOOLEAN) + .argumentType(new TypeSignature("T")) + .typeVariableConstraint(TypeVariableConstraint.builder("T") + .rowType() + .castableTo(TIMESTAMP_MILLIS.getTypeSignature()) + .build()) + .build(); + + assertThat(function) + .boundTo(RowType.anonymous(ImmutableList.of(BIGINT, DOUBLE))) + .withCoercion() + .fails(); + } + + @Test + public void testVariantIsCastableToRowWhenVariantIsCastableToEachField() + { + Signature function = functionSignature() + .returnType(BOOLEAN) + .argumentType(new TypeSignature("T")) + .typeVariableConstraint(TypeVariableConstraint.builder("T") + .rowType() + .castableFrom(parseTypeSignature("json", Set.of())) + .build()) + .build(); + + assertThat(function) + .boundTo(RowType.anonymous(ImmutableList.of(BIGINT, DOUBLE))) + .withCoercion() + .succeeds(); + } + @Test public void testBindParameters() { diff --git a/core/trino-main/src/test/java/io/trino/server/protocol/TestJsonEncodingUtils.java b/core/trino-main/src/test/java/io/trino/server/protocol/TestJsonEncodingUtils.java index c411d54e3366..e39d3b1548fb 100644 --- a/core/trino-main/src/test/java/io/trino/server/protocol/TestJsonEncodingUtils.java +++ b/core/trino-main/src/test/java/io/trino/server/protocol/TestJsonEncodingUtils.java @@ -33,6 +33,7 @@ import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; +import io.trino.spi.variant.Variant; import org.junit.jupiter.api.Test; import java.io.ByteArrayInputStream; @@ -67,6 +68,7 @@ import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.spi.type.VariantType.VARIANT; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; @@ -406,6 +408,24 @@ public void testBooleanArraySerialization() .containsExactly(array(true, false, true, false, true, null)); } + @Test + public void testVariantSerialization() + throws IOException + { + List columns = ImmutableList.of(typed("col0", VARIANT)); + var blockBuilder = VARIANT.createBlockBuilder(null, 3); + blockBuilder.appendNull(); + VARIANT.writeObject(blockBuilder, Variant.ofObject(Map.of( + utf8Slice("a"), Variant.ofInt(1), + utf8Slice("b"), Variant.ofArray(List.of(Variant.ofBoolean(true), Variant.NULL_VALUE))))); + VARIANT.writeObject(blockBuilder, Variant.NULL_VALUE); + Block block = blockBuilder.build(); + + Page page = page(block); + assertThat(roundTrip(columns, page, "[[null],[{\"a\":1,\"b\":[true,null]}],[null]]")) + .isEqualTo(column(null, "{\"a\":1,\"b\":[true,null]}", null)); + } + @Test public void testMapSerialization() throws IOException diff --git a/core/trino-main/src/test/java/io/trino/type/AbstractTestType.java b/core/trino-main/src/test/java/io/trino/type/AbstractTestType.java index 384ae9d12ad1..fe739a406fdc 100644 --- a/core/trino-main/src/test/java/io/trino/type/AbstractTestType.java +++ b/core/trino-main/src/test/java/io/trino/type/AbstractTestType.java @@ -29,6 +29,7 @@ import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; +import io.trino.spi.variant.Variant; import io.trino.type.BlockTypeOperators.BlockPositionEqual; import io.trino.type.BlockTypeOperators.BlockPositionHashCode; import io.trino.type.BlockTypeOperators.BlockPositionIsIdentical; @@ -645,6 +646,9 @@ private static Object getNonNullValueForType(Type type) if (type.getJavaType() == LongTimestampWithTimeZone.class) { return LongTimestampWithTimeZone.fromEpochSecondsAndFraction(1, 0, UTC_KEY); } + if (type.getJavaType() == Variant.class) { + return Variant.ofString(Slices.utf8Slice("_")); + } switch (type) { case ArrayType arrayType -> { Type elementType = arrayType.getElementType(); diff --git a/core/trino-main/src/test/java/io/trino/type/TestRowOperators.java b/core/trino-main/src/test/java/io/trino/type/TestRowOperators.java index 53b6bc2d7dd5..31a267a30163 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestRowOperators.java +++ b/core/trino-main/src/test/java/io/trino/type/TestRowOperators.java @@ -553,6 +553,10 @@ public void testFieldAccessor() @Test public void testRowCast() { + assertThat(assertions.expression("cast(a AS row(aa bigint, bb double))[2]") + .binding("a", "row(2, CAST(null as double))")) + .isNull(DOUBLE); + assertThat(assertions.expression("cast(a AS row(aa bigint, bb bigint))[1]") .binding("a", "row(2, 3)")) .isEqualTo(2L); @@ -565,10 +569,6 @@ public void testRowCast() .binding("a", "row(2, 3)")) .isEqualTo(true); - assertThat(assertions.expression("cast(a AS row(aa bigint, bb double))[2]") - .binding("a", "row(2, CAST(null as double))")) - .isNull(DOUBLE); - assertThat(assertions.expression("cast(a AS row(aa bigint, bb varchar))[2]") .binding("a", "row(2, 'test_str')")) .hasType(VARCHAR) diff --git a/core/trino-main/src/test/java/io/trino/type/TestVariantFunctions.java b/core/trino-main/src/test/java/io/trino/type/TestVariantFunctions.java new file mode 100644 index 000000000000..0562b9d5d19b --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/type/TestVariantFunctions.java @@ -0,0 +1,65 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.type; + +import io.trino.operator.scalar.VarbinaryFunctions; +import io.trino.spi.variant.Variant; +import io.trino.sql.query.QueryAssertions; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; + +@TestInstance(PER_CLASS) +@Execution(CONCURRENT) +class TestVariantFunctions +{ + private QueryAssertions assertions; + + @BeforeAll + void init() + { + assertions = new QueryAssertions(); + } + + @AfterAll + void teardown() + { + assertions.close(); + assertions = null; + } + + @Test + void testVariantIsNull() + { + assertThat(assertions.expression("variant_is_null(a)") + .binding("a", toVariantLiteral(Variant.NULL_VALUE))) + .isEqualTo(true); + assertThat(assertions.expression("variant_is_null(a)") + .binding("a", toVariantLiteral(Variant.ofLong(42)))) + .isEqualTo(false); + } + + private static String toVariantLiteral(Variant variant) + { + String hexMetadata = VarbinaryFunctions.toHex(variant.metadata().toSlice()).toStringUtf8(); + String hexValue = VarbinaryFunctions.toHex(variant.data()).toStringUtf8(); + return String.format("decode_variant(X'%s', X'%s')", hexMetadata, hexValue); + } +} diff --git a/core/trino-main/src/test/java/io/trino/type/TestVariantOperators.java b/core/trino-main/src/test/java/io/trino/type/TestVariantOperators.java new file mode 100644 index 000000000000..3fa1bdb600b9 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/type/TestVariantOperators.java @@ -0,0 +1,1529 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.type; + +import com.google.common.collect.ImmutableSortedMap; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.trino.operator.scalar.VarbinaryFunctions; +import io.trino.spi.type.SqlDate; +import io.trino.spi.type.SqlDecimal; +import io.trino.spi.type.SqlTime; +import io.trino.spi.type.SqlTimestamp; +import io.trino.spi.type.SqlTimestampWithTimeZone; +import io.trino.spi.type.SqlVarbinary; +import io.trino.spi.variant.Metadata; +import io.trino.spi.variant.ObjectFieldIdValue; +import io.trino.spi.variant.Variant; +import io.trino.sql.query.QueryAssertions; +import io.trino.testing.assertions.TrinoExceptionAssert; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static com.google.common.base.Verify.verify; +import static io.airlift.slice.Slices.utf8Slice; +import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; +import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static io.trino.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE; +import static io.trino.spi.variant.VariantEncoder.encodeObject; +import static io.trino.spi.variant.VariantEncoder.encodedObjectSize; +import static io.trino.testing.assertions.TrinoExceptionAssert.assertTrinoExceptionThrownBy; +import static java.math.BigDecimal.ONE; +import static java.math.RoundingMode.HALF_UP; +import static java.time.ZoneOffset.UTC; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.entry; +import static org.assertj.core.api.InstanceOfAssertFactories.list; +import static org.assertj.core.api.InstanceOfAssertFactories.map; +import static org.assertj.core.api.InstanceOfAssertFactories.type; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; + +@TestInstance(PER_CLASS) +@Execution(CONCURRENT) +class TestVariantOperators +{ + private QueryAssertions assertions; + + @BeforeAll + void init() + { + assertions = new QueryAssertions(); + } + + @AfterAll + void teardown() + { + assertions.close(); + assertions = null; + } + + @Test + void testCastWithBoolean() + { + for (boolean value : new boolean[] {true, false}) { + assertCastToVariant("BOOLEAN '%s'".formatted(value), value); + assertCastFromVariant(Variant.ofBoolean(value), "BOOLEAN", value); + assertCastFromVariant(Variant.ofString(value ? "true" : "false"), "BOOLEAN", value); + } + + assertCastFromVariant(Variant.ofByte((byte) 0), "BOOLEAN", false); + assertCastFromVariant(Variant.ofByte((byte) 7), "BOOLEAN", true); + + assertCastFromVariant(Variant.ofShort((short) 0), "BOOLEAN", false); + assertCastFromVariant(Variant.ofShort((short) -3), "BOOLEAN", true); + + assertCastFromVariant(Variant.ofInt(0), "BOOLEAN", false); + assertCastFromVariant(Variant.ofInt(42), "BOOLEAN", true); + + assertCastFromVariant(Variant.ofLong(0L), "BOOLEAN", false); + assertCastFromVariant(Variant.ofLong(Long.MIN_VALUE), "BOOLEAN", true); + + assertCastFromVariant(Variant.ofFloat(0.0f), "BOOLEAN", false); + assertCastFromVariant(Variant.ofFloat(-0.1f), "BOOLEAN", true); + + assertCastFromVariant(Variant.ofDouble(0.0), "BOOLEAN", false); + assertCastFromVariant(Variant.ofDouble(123.456), "BOOLEAN", true); + + assertCastFromVariant(Variant.ofDecimal(new BigDecimal("0.0000")), "BOOLEAN", false); + assertCastFromVariant(Variant.ofDecimal(new BigDecimal("0.0001")), "BOOLEAN", true); + + assertCastFromVariant(Variant.NULL_VALUE, "BOOLEAN", null); + } + + @Test + void testCastWithTinyint() + { + for (byte value : new byte[] {0, 1, -1, 127, -128, 42, -42}) { + assertCastToVariant("TINYINT '%d'".formatted(value), value); + assertCastFromVariant(Variant.ofByte(value), "TINYINT", value); + assertCastFromVariant(Variant.ofString(String.valueOf(value)), "TINYINT", value); + } + + assertCastFromVariant(Variant.ofBoolean(true), "TINYINT", (byte) 1); + assertCastFromVariant(Variant.ofBoolean(false), "TINYINT", (byte) 0); + + assertCastFromVariant(Variant.ofShort((short) 42), "TINYINT", (byte) 42); + assertCastFromVariantThrows(Variant.ofShort((short) (Byte.MAX_VALUE + 1)), "TINYINT") + .hasErrorCode(NUMERIC_VALUE_OUT_OF_RANGE) + .hasMessageContaining("Out of range for tinyint: " + (Byte.MAX_VALUE + 1)); + + assertCastFromVariant(Variant.ofInt(42), "TINYINT", (byte) 42); + assertCastFromVariantThrows(Variant.ofInt(Byte.MAX_VALUE + 1), "TINYINT") + .hasErrorCode(NUMERIC_VALUE_OUT_OF_RANGE) + .hasMessageContaining("Out of range for tinyint: " + (Byte.MAX_VALUE + 1)); + + assertCastFromVariant(Variant.ofLong(42L), "TINYINT", (byte) 42); + assertCastFromVariantThrows(Variant.ofLong(Byte.MAX_VALUE + 1L), "TINYINT") + .hasErrorCode(NUMERIC_VALUE_OUT_OF_RANGE) + .hasMessageContaining("Out of range for tinyint: " + (Byte.MAX_VALUE + 1L)); + + assertCastFromVariant(Variant.ofDecimal(new BigDecimal("42")), "TINYINT", (byte) 42); + assertCastFromVariantThrows(Variant.ofDecimal(new BigDecimal("42.5")), "TINYINT") + .hasErrorCode(NUMERIC_VALUE_OUT_OF_RANGE) + .hasMessageContaining("Out of range for tinyint: 42.5"); + assertCastFromVariantThrows(Variant.ofDecimal(new BigDecimal(Byte.MAX_VALUE + 1L)), "TINYINT") + .hasErrorCode(NUMERIC_VALUE_OUT_OF_RANGE) + .hasMessageContaining("Out of range for tinyint: " + (Byte.MAX_VALUE + 1L)); + + assertCastFromVariant(Variant.ofFloat(42.0f), "TINYINT", (byte) 42); + assertCastFromVariant(Variant.ofDouble(42.0d), "TINYINT", (byte) 42); + + assertCastFromVariant(Variant.NULL_VALUE, "TINYINT", null); + } + + @Test + void testCastWithSmallint() + { + for (short value : new short[] {0, 1, -1, Short.MAX_VALUE, Short.MIN_VALUE, 42, -42}) { + assertCastToVariant("SMALLINT '%d'".formatted(value), value); + assertCastFromVariant(Variant.ofShort(value), "SMALLINT", value); + assertCastFromVariant(Variant.ofString(String.valueOf(value)), "SMALLINT", value); + } + + assertCastFromVariant(Variant.ofBoolean(true), "SMALLINT", (short) 1); + assertCastFromVariant(Variant.ofBoolean(false), "SMALLINT", (short) 0); + + assertCastFromVariant(Variant.ofByte((byte) 123), "SMALLINT", (short) 123); + + assertCastFromVariant(Variant.ofInt(12345), "SMALLINT", (short) 12345); + assertCastFromVariantThrows(Variant.ofInt(Short.MAX_VALUE + 1), "SMALLINT") + .hasErrorCode(NUMERIC_VALUE_OUT_OF_RANGE) + .hasMessageContaining("Out of range for smallint: " + (Short.MAX_VALUE + 1)); + + assertCastFromVariant(Variant.ofLong(12345L), "SMALLINT", (short) 12345); + assertCastFromVariantThrows(Variant.ofLong(Short.MAX_VALUE + 1), "SMALLINT") + .hasErrorCode(NUMERIC_VALUE_OUT_OF_RANGE) + .hasMessageContaining("Out of range for smallint: " + (Short.MAX_VALUE + 1)); + + assertCastFromVariant(Variant.ofDecimal(new BigDecimal("123")), "SMALLINT", (short) 123); + assertCastFromVariantThrows(Variant.ofDecimal(new BigDecimal("123.5")), "SMALLINT") + .hasErrorCode(NUMERIC_VALUE_OUT_OF_RANGE) + .hasMessageContaining("Out of range for smallint: 123.5"); + assertCastFromVariantThrows(Variant.ofDecimal(new BigDecimal(Short.MAX_VALUE + 1L)), "SMALLINT") + .hasErrorCode(NUMERIC_VALUE_OUT_OF_RANGE) + .hasMessageContaining("Out of range for smallint: " + (Short.MAX_VALUE + 1L)); + + assertCastFromVariant(Variant.ofFloat(123.9f), "SMALLINT", (short) 124); + assertCastFromVariant(Variant.ofDouble(123.9d), "SMALLINT", (short) 124); + + assertCastFromVariant(Variant.NULL_VALUE, "SMALLINT", null); + } + + @Test + void testCastWithInteger() + { + for (int value : new int[] {0, 1, -1, Integer.MAX_VALUE, Integer.MIN_VALUE, 42, -42}) { + assertCastToVariant("INTEGER '%d'".formatted(value), value); + assertCastFromVariant(Variant.ofInt(value), "INTEGER", value); + assertCastFromVariant(Variant.ofString(String.valueOf(value)), "INTEGER", value); + } + + assertCastFromVariant(Variant.ofBoolean(true), "INTEGER", 1); + assertCastFromVariant(Variant.ofBoolean(false), "INTEGER", 0); + + assertCastFromVariant(Variant.ofByte((byte) 123), "INTEGER", 123); + assertCastFromVariant(Variant.ofShort((short) 12345), "INTEGER", 12345); + + assertCastFromVariant(Variant.ofLong(123456L), "INTEGER", 123456); + assertCastFromVariantThrows(Variant.ofLong(Integer.MAX_VALUE + 1L), "INTEGER") + .hasErrorCode(NUMERIC_VALUE_OUT_OF_RANGE) + .hasMessageContaining("Out of range for integer: " + (Integer.MAX_VALUE + 1L)); + + assertCastFromVariant(Variant.ofDecimal(new BigDecimal("1234")), "INTEGER", 1234); + assertCastFromVariantThrows(Variant.ofDecimal(new BigDecimal("1234.5")), "INTEGER") + .hasErrorCode(NUMERIC_VALUE_OUT_OF_RANGE) + .hasMessageContaining("Out of range for integer: 1234.5"); + assertCastFromVariantThrows(Variant.ofDecimal(new BigDecimal(Integer.MAX_VALUE + 1L)), "INTEGER") + .hasErrorCode(NUMERIC_VALUE_OUT_OF_RANGE) + .hasMessageContaining("Out of range for integer: " + (Integer.MAX_VALUE + 1L)); + + assertCastFromVariant(Variant.ofFloat(1234.6f), "INTEGER", 1235); + assertCastFromVariant(Variant.ofDouble(1234.6d), "INTEGER", 1235); + + assertCastFromVariant(Variant.NULL_VALUE, "INTEGER", null); + } + + @Test + void testCastWithBigint() + { + for (long value : new long[] {0L, 1L, -1L, Long.MAX_VALUE, Long.MIN_VALUE, 42L, -42L}) { + assertCastToVariant("BIGINT '%d'".formatted(value), value); + assertCastFromVariant(Variant.ofLong(value), "BIGINT", value); + assertCastFromVariant(Variant.ofString(String.valueOf(value)), "BIGINT", value); + } + + assertCastFromVariant(Variant.ofBoolean(true), "BIGINT", 1L); + assertCastFromVariant(Variant.ofBoolean(false), "BIGINT", 0L); + + assertCastFromVariant(Variant.ofByte((byte) 0x12), "BIGINT", 0x12L); + assertCastFromVariant(Variant.ofShort((short) 0x1234), "BIGINT", 0x1234L); + assertCastFromVariant(Variant.ofInt(0x1234_5678), "BIGINT", 0x1234_5678L); + + assertCastFromVariant(Variant.ofFloat(1234.5678f), "BIGINT", 1235L); + assertCastFromVariant(Variant.ofDouble(1234.5678d), "BIGINT", 1235L); + assertCastFromVariant(Variant.ofDecimal(new BigDecimal("1234")), "BIGINT", 1234L); + assertCastFromVariantThrows(Variant.ofDecimal(new BigDecimal("1234.5")), "BIGINT") + .hasErrorCode(NUMERIC_VALUE_OUT_OF_RANGE) + .hasMessageContaining("Out of range for bigint: 1234.5"); + assertCastFromVariantThrows(Variant.ofDecimal(new BigDecimal(Long.MAX_VALUE).add(ONE)), "BIGINT") + .hasErrorCode(NUMERIC_VALUE_OUT_OF_RANGE) + .hasMessageContaining("Out of range for bigint: " + new BigDecimal(Long.MAX_VALUE).add(ONE)); + + assertCastFromVariant(Variant.NULL_VALUE, "BIGINT", null); + } + + @Test + void testCastWithReal() + { + for (float value : new float[] {0.0f, 1.0f, -1.0f, 1234.5f, -1234.5f}) { + assertCastToVariant("REAL '%s'".formatted(value), value); + assertCastFromVariant(Variant.ofFloat(value), "REAL", value); + } + + assertCastFromVariant(Variant.ofBoolean(true), "REAL", 1.0f); + assertCastFromVariant(Variant.ofBoolean(false), "REAL", 0.0f); + + assertCastFromVariant(Variant.ofByte((byte) 5), "REAL", 5.0f); + assertCastFromVariant(Variant.ofShort((short) -7), "REAL", -7.0f); + assertCastFromVariant(Variant.ofInt(123_456), "REAL", 123_456.0f); + assertCastFromVariant(Variant.ofLong(1234L), "REAL", 1234.0f); + + assertCastFromVariant(Variant.ofDecimal(new BigDecimal("1234.5")), "REAL", 1234.5f); + assertCastFromVariant(Variant.ofDouble(1.5d), "REAL", 1.5f); + + assertCastFromVariant(Variant.ofString("1.5"), "REAL", 1.5f); + + assertCastFromVariant(Variant.NULL_VALUE, "REAL", null); + } + + @Test + void testCastWithDouble() + { + for (double value : new double[] {0.0, 1.0, -1.0, 1234.5, -1234.5}) { + assertCastToVariant("DOUBLE '%s'".formatted(value), value); + assertCastFromVariant(Variant.ofDouble(value), "DOUBLE", value); + assertCastFromVariant(Variant.ofString(String.valueOf(value)), "DOUBLE", value); + } + + assertCastFromVariant(Variant.ofBoolean(true), "DOUBLE", 1.0); + assertCastFromVariant(Variant.ofBoolean(false), "DOUBLE", 0.0); + + assertCastFromVariant(Variant.ofByte((byte) 5), "DOUBLE", 5.0); + assertCastFromVariant(Variant.ofShort((short) -7), "DOUBLE", -7.0); + assertCastFromVariant(Variant.ofInt(123_456), "DOUBLE", 123_456.0); + assertCastFromVariant(Variant.ofLong(1234L), "DOUBLE", 1234.0); + + assertCastFromVariant(Variant.ofDecimal(new BigDecimal("1234.5")), "DOUBLE", 1234.5); + assertCastFromVariant(Variant.ofFloat(1.5f), "DOUBLE", 1.5); + + assertCastFromVariant(Variant.NULL_VALUE, "DOUBLE", null); + } + + @Test + void testCastWithShortDecimal() + { + BigDecimal[] values = { + new BigDecimal("0"), + new BigDecimal("1.23"), + new BigDecimal("-1.23"), + new BigDecimal("12345678.90"), + new BigDecimal("12345678.912345"), + }; + + for (BigDecimal value : values) { + assertCastToVariant("DECIMAL '%s'".formatted(value), value); + + BigDecimal scaled = value.setScale(2, HALF_UP); + assertCastFromVariant(Variant.ofDecimal(value), "DECIMAL(10,2)", new SqlDecimal(scaled.unscaledValue(), 10, 2)); + } + + assertCastFromVariant(Variant.ofBoolean(true), "DECIMAL(10,2)", new SqlDecimal(BigInteger.valueOf(100), 10, 2)); + assertCastFromVariant(Variant.ofBoolean(false), "DECIMAL(10,2)", new SqlDecimal(BigInteger.ZERO, 10, 2)); + + assertCastFromVariant(Variant.ofByte((byte) 5), "DECIMAL(10,2)", new SqlDecimal(BigInteger.valueOf(500), 10, 2)); + assertCastFromVariant(Variant.ofShort((short) -7), "DECIMAL(10,2)", new SqlDecimal(BigInteger.valueOf(-700), 10, 2)); + assertCastFromVariant(Variant.ofInt(123_456), "DECIMAL(10,2)", new SqlDecimal(BigInteger.valueOf(12_345_600), 10, 2)); + assertCastFromVariant(Variant.ofLong(12_345_678L), "DECIMAL(10,2)", new SqlDecimal(BigInteger.valueOf(1_234_567_800L), 10, 2)); + + assertCastFromVariant(Variant.ofFloat(1.5f), "DECIMAL(10,2)", new SqlDecimal(BigInteger.valueOf(150), 10, 2)); + assertCastFromVariant(Variant.ofDouble(1.5d), "DECIMAL(10,2)", new SqlDecimal(BigInteger.valueOf(150), 10, 2)); + + assertCastFromVariant(Variant.ofString("1234.56"), "DECIMAL(10,2)", new SqlDecimal(BigInteger.valueOf(123_456), 10, 2)); + assertCastFromVariant(Variant.ofString("1234.5678"), "DECIMAL(10,2)", new SqlDecimal(BigInteger.valueOf(123_457), 10, 2)); + assertCastFromVariantThrows(Variant.ofString("hello"), "DECIMAL(10,2)") + .hasErrorCode(INVALID_CAST_ARGUMENT) + .hasMessage("Cannot cast input variant to DECIMAL(10,2)"); + + assertCastFromVariant(Variant.NULL_VALUE, "DECIMAL(10,2)", null); + } + + @Test + void testCastWithLongDecimal() + { + int precision = 20; + int scale = 4; + + BigDecimal[] values = { + new BigDecimal("0"), + new BigDecimal("1.2345"), + new BigDecimal("-1.2345"), + new BigDecimal("1234567890123.4567"), + new BigDecimal("123456789012345.6789"), + }; + + for (BigDecimal value : values) { + assertCastToVariant("DECIMAL '%s'".formatted(value), value); + + BigDecimal scaled = value.setScale(scale, HALF_UP); + assertCastFromVariant(Variant.ofDecimal(value), "DECIMAL(20,4)", new SqlDecimal(scaled.unscaledValue(), precision, scale)); + } + + assertCastFromVariant(Variant.ofBoolean(true), "DECIMAL(20,4)", new SqlDecimal(BigInteger.valueOf(10_000), precision, scale)); + assertCastFromVariant(Variant.ofBoolean(false), "DECIMAL(20,4)", new SqlDecimal(BigInteger.ZERO, precision, scale)); + + assertCastFromVariant(Variant.ofLong(1234L), "DECIMAL(20,4)", new SqlDecimal(BigInteger.valueOf(12_340_000L), precision, scale)); + + assertCastFromVariant(Variant.ofString("1234.5678"), "DECIMAL(20,4)", new SqlDecimal(BigInteger.valueOf(12_345_678L), precision, scale)); + assertCastFromVariant(Variant.ofString("1234.56789"), "DECIMAL(20,4)", new SqlDecimal(BigInteger.valueOf(12_345_679L), precision, scale)); + assertCastFromVariantThrows(Variant.ofString("this is not a decimal value"), "DECIMAL(20,4)") + .hasErrorCode(INVALID_CAST_ARGUMENT) + .hasMessage("Cannot cast input variant to DECIMAL(20,4)"); + + assertCastFromVariant(Variant.ofDouble(1.5d), "DECIMAL(20,4)", new SqlDecimal(BigInteger.valueOf(15_000), precision, scale)); + assertCastFromVariant(Variant.ofFloat(1.5f), "DECIMAL(20,4)", new SqlDecimal(BigInteger.valueOf(15_000), precision, scale)); + + assertCastFromVariant(Variant.NULL_VALUE, "DECIMAL(20,4)", null); + } + + @Test + void testCastWithVarchar() + { + for (String value : new String[] {"", "hello", "a somewhat longer string 123", "特殊字符", "emoji 😊", "x".repeat(1000)}) { + assertCastToVariant("VARCHAR '%s'".formatted(value), value); + assertCastFromVariant(Variant.ofString(value), "VARCHAR", value); + } + + assertCastFromVariant(Variant.ofBoolean(true), "VARCHAR", "true"); + assertCastFromVariant(Variant.ofBoolean(false), "VARCHAR", "false"); + + assertCastFromVariant(Variant.ofByte((byte) 5), "VARCHAR", "5"); + assertCastFromVariant(Variant.ofShort((short) -7), "VARCHAR", "-7"); + assertCastFromVariant(Variant.ofInt(123_456), "VARCHAR", "123456"); + assertCastFromVariant(Variant.ofLong(1234L), "VARCHAR", "1234"); + + BigDecimal decimal = new BigDecimal("1234.50"); + assertCastFromVariant(Variant.ofDecimal(decimal), "VARCHAR", "1234.50"); + + assertCastFromVariant(Variant.ofFloat(1.5f), "VARCHAR", "1.5E0"); + assertCastFromVariant(Variant.ofDouble(1.5d), "VARCHAR", "1.5E0"); + + LocalDate date = LocalDate.of(2024, 10, 24); + assertCastFromVariant(Variant.ofDate(date), "VARCHAR", "2024-10-24"); + + assertCastFromVariant(Variant.ofTimestampMicrosNtz(LocalDateTime.parse("2024-10-24T12:34:56.123456")), "VARCHAR", "2024-10-24 12:34:56.123456"); + + assertCastFromVariant(Variant.ofTimestampMicrosUtc(Instant.parse("2024-10-24T12:34:56.123456Z")), "VARCHAR", "2024-10-24 12:34:56.123456 UTC"); + + assertCastFromVariant(Variant.ofTimestampNanosNtz(LocalDateTime.parse("2024-10-24T12:34:56.123456789")), "VARCHAR", "2024-10-24 12:34:56.123456789"); + + assertCastFromVariant(Variant.ofTimestampNanosUtc(Instant.parse("2024-10-24T12:34:56.123456789Z")), "VARCHAR", "2024-10-24 12:34:56.123456789 UTC"); + + UUID uuid = UUID.fromString("123e4567-e89b-12d3-a456-426655440000"); + assertCastFromVariant(Variant.ofUuid(uuid), "VARCHAR", "123e4567-e89b-12d3-a456-426655440000"); + + assertCastFromVariant(Variant.ofString("short"), "VARCHAR", "short"); + + assertCastFromVariant(Variant.NULL_VALUE, "VARCHAR", null); + } + + @Test + void testCastWithVarbinary() + { + Slice data = Slices.wrappedBuffer(new byte[] {0x01, 0x02, 0x03}); + + assertCastToVariant("CAST(X'010203' as VARBINARY)", data); + + assertCastFromVariant(Variant.ofBinary(data), "VARBINARY", new SqlVarbinary(data.getBytes())); + + assertCastFromVariant(Variant.NULL_VALUE, "VARBINARY", null); + } + + @Test + void testCastWithDate() + { + LocalDate date = LocalDate.of(2024, 10, 24); + int days = (int) date.toEpochDay(); + + assertCastToVariant("DATE '2024-10-24'", date); + + assertCastFromVariant(Variant.ofDate(date), "DATE", new SqlDate(days)); + + long epochMicros = TimeUnit.DAYS.toMicros(days) + TimeUnit.HOURS.toMicros(10); + assertCastFromVariant(Variant.ofTimestampMicrosNtz(epochMicros), "DATE", new SqlDate(days)); + assertCastFromVariant(Variant.ofTimestampMicrosUtc(epochMicros), "DATE", new SqlDate(days)); + + long epochNanos = TimeUnit.DAYS.toNanos(days) + TimeUnit.HOURS.toNanos(10); + assertCastFromVariant(Variant.ofTimestampNanosNtz(epochNanos), "DATE", new SqlDate(days)); + assertCastFromVariant(Variant.ofTimestampNanosUtc(epochNanos), "DATE", new SqlDate(days)); + + assertCastFromVariant(Variant.ofString("2024-10-24"), "DATE", new SqlDate(days)); + + assertCastFromVariant(Variant.NULL_VALUE, "DATE", null); + } + + @Test + void testCastWithTime() + { + LocalTime localTime = LocalTime.of(22, 23, 24); + assertCastToVariant("TIME '22:23:24.123456'", localTime.withNano(123_456_000)); + assertCastToVariant("TIME '22:23:24.12345'", localTime.withNano(123_450_000)); + assertCastToVariant("TIME '22:23:24.1234'", localTime.withNano(123_400_000)); + assertCastToVariant("TIME '22:23:24.123'", localTime.withNano(123_000_000)); + assertCastToVariant("TIME '22:23:24.12'", localTime.withNano(120_000_000)); + assertCastToVariant("TIME '22:23:24.1'", localTime.withNano(100_000_000)); + assertCastToVariant("TIME '22:23:24'", localTime); + + long epochNanos = localTime.withNano(123_456_123).toNanoOfDay(); + long epochMicros = epochNanos / 1000; + assertCastFromVariant(Variant.ofTimeMicrosNtz(epochMicros), "TIME(6)", SqlTime.newInstance(6, epochMicros * 1_000_000L)); + assertCastFromVariant(Variant.ofTimeMicrosNtz(epochMicros), "VARCHAR", "22:23:24.123456"); + + assertThat(assertions.expression("cast(a as JSON)") + .binding("a", toVariantLiteral(Variant.ofTimeMicrosNtz(epochMicros)))) + .isEqualTo("\"22:23:24.123456\""); + + assertCastFromVariant(Variant.ofTimeMicrosNtz(epochMicros), "TIME(3)", SqlTime.newInstance(3, (epochMicros / 1_000L) * 1_000_000_000L)); + + assertCastFromVariant(Variant.ofString("22:23:24.123456"), "TIME(6)", SqlTime.newInstance(6, epochMicros * 1_000_000L)); + + LocalDateTime localDateTime = LocalDateTime.parse("2024-10-24T22:23:24.123456123"); + assertCastFromVariant(Variant.ofTimestampMicrosNtz(localDateTime), "TIME(6)", SqlTime.newInstance(6, epochMicros * 1_000_000L)); + assertCastFromVariant(Variant.ofTimestampNanosNtz(localDateTime), "TIME(6)", SqlTime.newInstance(6, epochMicros * 1_000_000L)); + assertCastFromVariant(Variant.ofTimestampNanosNtz(localDateTime), "TIME(9)", SqlTime.newInstance(9, epochNanos * 1_000L)); + assertCastFromVariant(Variant.ofTimestampNanosNtz(localDateTime), "TIME(12)", SqlTime.newInstance(12, epochNanos * 1_000L)); + + Instant instant = Instant.parse("2024-10-24T22:23:24.123456123Z"); + assertCastFromVariant(Variant.ofTimestampMicrosUtc(instant), "TIME(6)", SqlTime.newInstance(6, epochMicros * 1_000_000L)); + assertCastFromVariant(Variant.ofTimestampNanosUtc(instant), "TIME(6)", SqlTime.newInstance(6, epochMicros * 1_000_000L)); + assertCastFromVariant(Variant.ofTimestampNanosUtc(instant), "TIME(9)", SqlTime.newInstance(9, epochNanos * 1_000L)); + assertCastFromVariant(Variant.ofTimestampNanosUtc(instant), "TIME(12)", SqlTime.newInstance(12, epochNanos * 1_000L)); + + LocalDateTime negativeLocalDateTimeMicros = LocalDateTime.parse("1969-12-31T23:59:59.999999"); + LocalDateTime negativeLocalDateTimeNanos = LocalDateTime.parse("1969-12-31T23:59:59.999999999"); + Instant negativeInstantMicros = Instant.parse("1969-12-31T23:59:59.999999Z"); + Instant negativeInstantNanos = Instant.parse("1969-12-31T23:59:59.999999999Z"); + assertCastFromVariant(Variant.ofTimestampMicrosNtz(negativeLocalDateTimeMicros), "TIME(6)", SqlTime.newInstance(6, 86_399_999_999_000_000L)); + assertCastFromVariant(Variant.ofTimestampMicrosUtc(negativeInstantMicros), "TIME(6)", SqlTime.newInstance(6, 86_399_999_999_000_000L)); + assertCastFromVariant(Variant.ofTimestampNanosNtz(negativeLocalDateTimeNanos), "TIME(9)", SqlTime.newInstance(9, 86_399_999_999_999_000L)); + assertCastFromVariant(Variant.ofTimestampNanosUtc(negativeInstantNanos), "TIME(9)", SqlTime.newInstance(9, 86_399_999_999_999_000L)); + + assertCastFromVariant(Variant.NULL_VALUE, "TIME(3)", null); + assertCastFromVariant(Variant.NULL_VALUE, "TIME(6)", null); + } + + @Test + void testCastWithTimestampShort() + { + LocalDateTime localDateTime = LocalDateTime.of(2024, 10, 24, 12, 34, 56); + assertCastToVariant("TIMESTAMP '2024-10-24 12:34:56.123456'", localDateTime.withNano(123_456_000)); + assertCastToVariant("TIMESTAMP '2024-10-24 12:34:56.12345'", localDateTime.withNano(123_450_000)); + assertCastToVariant("TIMESTAMP '2024-10-24 12:34:56.1234'", localDateTime.withNano(123_400_000)); + assertCastToVariant("TIMESTAMP '2024-10-24 12:34:56.123'", localDateTime.withNano(123_000_000)); + assertCastToVariant("TIMESTAMP '2024-10-24 12:34:56.12'", localDateTime.withNano(120_000_000)); + assertCastToVariant("TIMESTAMP '2024-10-24 12:34:56.1'", localDateTime.withNano(100_000_000)); + assertCastToVariant("TIMESTAMP '2024-10-24 12:34:56'", localDateTime); + + long epochMicros = Instant.parse("2024-10-24T12:34:56.123Z").toEpochMilli() * 1000 + 456; + + assertCastFromVariant(Variant.ofTimestampMicrosNtz(epochMicros), "TIMESTAMP(6)", SqlTimestamp.newInstance(6, epochMicros, 0)); + assertCastFromVariant(Variant.ofTimestampMicrosUtc(epochMicros), "TIMESTAMP(6)", SqlTimestamp.newInstance(6, epochMicros, 0)); + + assertCastFromVariant(Variant.ofTimestampMicrosNtz(epochMicros), "TIMESTAMP(2)", SqlTimestamp.newInstance(2, (epochMicros / 10_000) * 10_000, 0)); + assertCastFromVariant(Variant.ofTimestampMicrosUtc(epochMicros), "TIMESTAMP(2)", SqlTimestamp.newInstance(2, (epochMicros / 10_000) * 10_000, 0)); + + long epochNanos = Instant.parse("2024-10-24T12:34:56.123Z").toEpochMilli() * 1_000_000 + 456_431; + + assertCastFromVariant(Variant.ofTimestampNanosNtz(epochNanos), "TIMESTAMP(6)", SqlTimestamp.newInstance(6, epochMicros, 0)); + assertCastFromVariant(Variant.ofTimestampNanosUtc(epochNanos), "TIMESTAMP(6)", SqlTimestamp.newInstance(6, epochMicros, 0)); + + assertCastFromVariant(Variant.ofTimestampNanosNtz(epochNanos), "TIMESTAMP(2)", SqlTimestamp.newInstance(2, (epochNanos / 10_000_000) * 10_000, 0)); + assertCastFromVariant(Variant.ofTimestampNanosUtc(epochNanos), "TIMESTAMP(2)", SqlTimestamp.newInstance(2, (epochNanos / 10_000_000) * 10_000, 0)); + + assertCastFromVariant(Variant.ofString("2024-10-24 12:34:56.123456"), "TIMESTAMP(6)", SqlTimestamp.newInstance(6, epochMicros, 0)); + + LocalDate date = LocalDate.of(2024, 10, 24); + assertCastFromVariant(Variant.ofDate(date), "TIMESTAMP(6)", SqlTimestamp.newInstance(6, TimeUnit.DAYS.toMicros(date.toEpochDay()), 0)); + + assertCastFromVariant(Variant.NULL_VALUE, "TIMESTAMP(3)", null); + assertCastFromVariant(Variant.NULL_VALUE, "TIMESTAMP(6)", null); + } + + @Test + void testCastWithTimestampLong() + { + LocalDateTime localDateTime = LocalDateTime.of(2024, 10, 24, 12, 34, 56); + assertCastToVariant("TIMESTAMP '2024-10-24 12:34:56.123456789'", localDateTime.withNano(123_456_789)); + assertCastToVariant("TIMESTAMP '2024-10-24 12:34:56.12345678'", localDateTime.withNano(123_456_780)); + assertCastToVariant("TIMESTAMP '2024-10-24 12:34:56.1234567'", localDateTime.withNano(123_456_700)); + + long epochSeconds = Instant.parse("2024-10-24T12:34:56.123Z").getEpochSecond(); + long nanosOfSecond = 123_456_789L; + long nanos = epochSeconds * 1_000_000_000L + nanosOfSecond; + + assertCastFromVariant(Variant.ofTimestampNanosNtz(nanos), "TIMESTAMP(9)", SqlTimestamp.fromSeconds(9, epochSeconds, nanosOfSecond)); + assertCastFromVariant(Variant.ofTimestampNanosUtc(nanos), "TIMESTAMP(9)", SqlTimestamp.fromSeconds(9, epochSeconds, nanosOfSecond)); + + long roundedNanosOfSecondP7 = ((nanosOfSecond + 50L) / 100L) * 100L; + assertCastFromVariant(Variant.ofTimestampNanosNtz(nanos), "TIMESTAMP(7)", SqlTimestamp.fromSeconds(7, epochSeconds, roundedNanosOfSecondP7)); + assertCastFromVariant(Variant.ofTimestampNanosUtc(nanos), "TIMESTAMP(7)", SqlTimestamp.fromSeconds(7, epochSeconds, roundedNanosOfSecondP7)); + + long microsOfSecond = 123_456L; + long epochMicros = epochSeconds * 1_000_000L + microsOfSecond; + + assertCastFromVariant(Variant.ofTimestampMicrosNtz(epochMicros), "TIMESTAMP(9)", SqlTimestamp.fromSeconds(9, epochSeconds, microsOfSecond * 1_000L)); + assertCastFromVariant(Variant.ofTimestampMicrosUtc(epochMicros), "TIMESTAMP(9)", SqlTimestamp.fromSeconds(9, epochSeconds, microsOfSecond * 1_000L)); + + assertCastFromVariant(Variant.ofString("2024-10-24 12:34:56.123456789"), "TIMESTAMP(9)", SqlTimestamp.fromSeconds(9, epochSeconds, nanosOfSecond)); + + LocalDate date = LocalDate.of(2024, 10, 24); + long dateMicros = TimeUnit.DAYS.toMicros(date.toEpochDay()); + assertCastFromVariant(Variant.ofDate(date), "TIMESTAMP(9)", SqlTimestamp.newInstance(9, dateMicros, 0)); + + assertCastFromVariant(Variant.NULL_VALUE, "TIMESTAMP(9)", null); + assertCastFromVariant(Variant.NULL_VALUE, "TIMESTAMP(12)", null); + } + + @Test + void testCastWithTimestampWithTimeZoneShort() + { + long epochSecond = Instant.parse("2024-10-24T12:34:56Z").getEpochSecond(); + assertCastToVariant("TIMESTAMP '2024-10-24 12:34:56.123Z'", Instant.ofEpochSecond(epochSecond, 123_000_000)); + assertCastToVariant("TIMESTAMP '2024-10-24 12:34:56.12Z'", Instant.ofEpochSecond(epochSecond, 120_000_000)); + assertCastToVariant("TIMESTAMP '2024-10-24 12:34:56.1Z'", Instant.ofEpochSecond(epochSecond, 100_000_000)); + assertCastToVariant("TIMESTAMP '2024-10-24 12:34:56Z'", Instant.ofEpochSecond(epochSecond, 0)); + + Instant instant = Instant.parse("2024-10-24T12:34:56.789Z"); + long epochMicros = instant.toEpochMilli() * 1_000L; + + assertCastToVariant("TIMESTAMP '2024-10-24 12:34:56.789 UTC'", instant); + + assertCastFromVariant(Variant.ofTimestampMicrosUtc(epochMicros), "TIMESTAMP(3) WITH TIME ZONE", SqlTimestampWithTimeZone.fromInstant(3, instant, UTC)); + + assertCastFromVariant(Variant.ofTimestampMicrosNtz(epochMicros), "TIMESTAMP(3) WITH TIME ZONE", SqlTimestampWithTimeZone.fromInstant(3, instant, UTC)); + + assertCastFromVariant(Variant.ofTimestampMicrosUtc(epochMicros), "TIMESTAMP(4) WITH TIME ZONE", SqlTimestampWithTimeZone.fromInstant(4, instant, UTC)); + + LocalDate date = LocalDate.of(2024, 10, 24); + Instant dateInstant = date.atStartOfDay(UTC).toInstant(); + assertCastFromVariant(Variant.ofDate(date), "TIMESTAMP(3) WITH TIME ZONE", SqlTimestampWithTimeZone.fromInstant(3, dateInstant, UTC)); + + assertCastFromVariant(Variant.ofString("2024-10-24 12:34:56.789 UTC"), "TIMESTAMP(3) WITH TIME ZONE", SqlTimestampWithTimeZone.fromInstant(3, instant, UTC)); + + assertThat(assertions.expression("cast(a as VARIANT)") + .binding("a", "TIMESTAMP '2024-10-24 12:34:56.789 UTC'")) + .asInstanceOf(type(Variant.class)) + .extracting(Variant::getTimestampMicros) + .isEqualTo(epochMicros); + + assertCastFromVariant(Variant.NULL_VALUE, "TIMESTAMP(3) WITH TIME ZONE", null); + } + + @Test + void testCastWithTimestampWithTimeZoneLong() + { + long epochSecond = Instant.parse("2024-10-24T12:34:56Z").getEpochSecond(); + assertCastToVariant("TIMESTAMP '2024-10-24 12:34:56.123456789Z'", Instant.ofEpochSecond(epochSecond, 123_456_789)); + assertCastToVariant("TIMESTAMP '2024-10-24 12:34:56.12345678Z'", Instant.ofEpochSecond(epochSecond, 123_456_780)); + assertCastToVariant("TIMESTAMP '2024-10-24 12:34:56.1234567Z'", Instant.ofEpochSecond(epochSecond, 123_456_700)); + assertCastToVariant("TIMESTAMP '2024-10-24 12:34:56.123456Z'", Instant.ofEpochSecond(epochSecond, 123_456_000)); + assertCastToVariant("TIMESTAMP '2024-10-24 12:34:56.12345Z'", Instant.ofEpochSecond(epochSecond, 123_450_000)); + assertCastToVariant("TIMESTAMP '2024-10-24 12:34:56.1234Z'", Instant.ofEpochSecond(epochSecond, 123_400_000)); + + Instant instant = Instant.parse("2024-10-24T12:34:56.123456789Z"); + long epochNanos = instant.toEpochMilli() * 1_000_000L + 456_789L; + + assertCastFromVariant(Variant.ofTimestampNanosUtc(epochNanos), "TIMESTAMP(9) WITH TIME ZONE", SqlTimestampWithTimeZone.fromInstant(9, instant, UTC)); + + assertCastFromVariant(Variant.ofTimestampNanosNtz(epochNanos), "TIMESTAMP(9) WITH TIME ZONE", SqlTimestampWithTimeZone.fromInstant(9, instant, UTC)); + + assertCastFromVariant(Variant.ofTimestampNanosUtc(epochNanos), "TIMESTAMP(7) WITH TIME ZONE", SqlTimestampWithTimeZone.fromInstant(7, instant, UTC)); + + LocalDate date = LocalDate.of(2024, 10, 24); + Instant dateInstant = date.atStartOfDay(UTC).toInstant(); + assertCastFromVariant(Variant.ofDate(date), "TIMESTAMP(9) WITH TIME ZONE", SqlTimestampWithTimeZone.fromInstant(9, dateInstant, UTC)); + + assertCastFromVariant(Variant.ofString("2024-10-24 12:34:56.123456789 UTC"), "TIMESTAMP(9) WITH TIME ZONE", SqlTimestampWithTimeZone.fromInstant(9, instant, UTC)); + + assertThat(assertions.expression("cast(a as VARIANT)") + .binding("a", "TIMESTAMP '2024-10-24 12:34:56.123456789 UTC'")) + .asInstanceOf(type(Variant.class)) + .extracting(Variant::getTimestampNanos) + .isEqualTo(epochNanos); + + assertCastFromVariant(Variant.NULL_VALUE, "TIMESTAMP(9) WITH TIME ZONE", null); + } + + @Test + void testCastWithUuid() + { + UUID uuid = UUID.fromString("123e4567-e89b-12d3-a456-426655440000"); + + assertCastToVariant("UUID '123e4567-e89b-12d3-a456-426655440000'", uuid); + + assertCastFromVariant(Variant.ofUuid(uuid), "UUID", uuid.toString()); + assertCastFromVariant(Variant.ofString(uuid.toString()), "UUID", uuid.toString()); + + assertCastFromVariant(Variant.NULL_VALUE, "UUID", null); + } + + @Test + void testCastWithJson() + { + // STRING → JSON + assertThat(assertions.expression("cast(a as JSON)") + .binding("a", toVariantLiteral(Variant.ofString("hello")))) + .isEqualTo("\"hello\""); + assertThat(assertions.expression("cast(a as JSON)") + .binding("a", toVariantLiteral(Variant.ofString("emoji 😊")))) + .isEqualTo("\"emoji 😊\""); + assertThat(assertions.expression("cast(a as JSON)") + .binding("a", toVariantLiteral(Variant.ofString("中文字符")))) + .isEqualTo("\"中文字符\""); + + // BOOLEAN → JSON + assertThat(assertions.expression("cast(a as JSON)") + .binding("a", toVariantLiteral(Variant.ofBoolean(true)))) + .isEqualTo("true"); + + assertThat(assertions.expression("cast(a as JSON)") + .binding("a", toVariantLiteral(Variant.ofBoolean(false)))) + .isEqualTo("false"); + + // TINYINT → JSON + assertThat(assertions.expression("cast(a as JSON)") + .binding("a", toVariantLiteral(Variant.ofByte((byte) 5)))) + .isEqualTo("5"); + + // SMALLINT → JSON + assertThat(assertions.expression("cast(a as JSON)") + .binding("a", toVariantLiteral(Variant.ofShort((short) -7)))) + .isEqualTo("-7"); + + // INTEGER → JSON + assertThat(assertions.expression("cast(a as JSON)") + .binding("a", toVariantLiteral(Variant.ofInt(123_456)))) + .isEqualTo("123456"); + + // BIGINT → JSON + assertThat(assertions.expression("cast(a as JSON)") + .binding("a", toVariantLiteral(Variant.ofLong(1234L)))) + .isEqualTo("1234"); + + // DECIMAL → JSON + BigDecimal decimal = new BigDecimal("1234.50"); + assertThat(assertions.expression("cast(a as JSON)") + .binding("a", toVariantLiteral(Variant.ofDecimal(decimal)))) + .isEqualTo("1234.50"); + + // REAL → JSON + assertThat(assertions.expression("cast(a as JSON)") + .binding("a", toVariantLiteral(Variant.ofFloat(1.5f)))) + .isEqualTo("1.5"); + + // DOUBLE → JSON + assertThat(assertions.expression("cast(a as JSON)") + .binding("a", toVariantLiteral(Variant.ofDouble(1.5d)))) + .isEqualTo("1.5"); + + // DATE → JSON (string) + LocalDate date = LocalDate.of(2024, 10, 24); + assertThat(assertions.expression("cast(a as JSON)") + .binding("a", toVariantLiteral(Variant.ofDate(date)))) + .isEqualTo("\"2024-10-24\""); + + // TIMESTAMP_MICROS_NTZ → JSON (string) + assertThat(assertions.expression("cast(a as JSON)") + .binding("a", toVariantLiteral( + Variant.ofTimestampMicrosNtz(LocalDateTime.parse("2024-10-24T12:34:56.123456"))))) + .isEqualTo("\"2024-10-24 12:34:56.123456\""); + + // TIMESTAMP_MICROS_UTC → JSON (string with zone) + assertThat(assertions.expression("cast(a as JSON)") + .binding("a", toVariantLiteral( + Variant.ofTimestampMicrosUtc(Instant.parse("2024-10-24T12:34:56.123456Z"))))) + .isEqualTo("\"2024-10-24 12:34:56.123456 UTC\""); + + // TIMESTAMP_NANOS_NTZ → JSON (string) + assertThat(assertions.expression("cast(a as JSON)") + .binding("a", toVariantLiteral( + Variant.ofTimestampNanosNtz(LocalDateTime.parse("2024-10-24T12:34:56.123456789"))))) + .isEqualTo("\"2024-10-24 12:34:56.123456789\""); + + // TIMESTAMP_NANOS_UTC → JSON (string with zone) + assertThat(assertions.expression("cast(a as JSON)") + .binding("a", toVariantLiteral( + Variant.ofTimestampNanosUtc(Instant.parse("2024-10-24T12:34:56.123456789Z"))))) + .isEqualTo("\"2024-10-24 12:34:56.123456789 UTC\""); + + // UUID → JSON (string) + UUID uuid = UUID.fromString("123e4567-e89b-12d3-a456-426655440000"); + assertThat(assertions.expression("cast(a as JSON)") + .binding("a", toVariantLiteral(Variant.ofUuid(uuid)))) + .isEqualTo("\"123e4567-e89b-12d3-a456-426655440000\""); + + // BINARY → JSON (base64 string of "abc" → "YWJj") + assertThat(assertions.expression("cast(a as JSON)") + .binding("a", toVariantLiteral(Variant.ofBinary(utf8Slice("abc"))))) + .isEqualTo("\"YWJj\""); + + // ARRAY → JSON + // Adjust to your actual array-construction helper if different + Variant arrayVariant = Variant.ofArray(List.of(Variant.ofInt(1), Variant.ofString("two"))); + assertThat(assertions.expression("cast(a as JSON)") + .binding("a", toVariantLiteral(arrayVariant))) + .isEqualTo("[1,\"two\"]"); + + // OBJECT → JSON + // Adjust to your actual object-construction helper if different + Variant objectVariant = Variant.ofObject(Map.of( + utf8Slice("a"), Variant.ofInt(1), + utf8Slice("b"), Variant.ofString("two"))); + assertThat(assertions.expression("cast(a as JSON)") + .binding("a", toVariantLiteral(objectVariant))) + .isEqualTo("{\"a\":1,\"b\":\"two\"}"); + + // VARIANT primitive NULL → JSON 'null' (as a JSON value, not SQL NULL) + assertThat(assertions.expression("cast(a as JSON)") + .binding("a", toVariantLiteral(Variant.NULL_VALUE))) + .isEqualTo("null"); + } + + @Test + void testCastJsonPrimitivesToVariant() + { + assertThat(assertions.expression("CAST(a AS VARIANT)") + .binding("a", "JSON 'null'")) + .asInstanceOf(type(Variant.class)) + .extracting(Variant::isNull) + .isEqualTo(true); + + assertCastToVariant("JSON 'true'", true); + assertCastToVariant("JSON 'false'", false); + + assertCastToVariant("JSON '123'", 123); + assertCastToVariant("JSON '1234567890123'", 1234567890123L); + + BigDecimal bigDecimal = new BigDecimal("12345678901234567890123456789012345678"); + assertCastToVariant("JSON '%s'".formatted(bigDecimal), bigDecimal); + + assertCastToVariant("JSON '1234.50'", 1234.5); + + assertCastToVariant("JSON '\"hello\"'", "hello"); + assertCastToVariant("JSON '\"emoji 😊\"'", "emoji 😊"); + + assertCastToVariant("JSON '[]'", List.of()); + assertCastToVariant("JSON '[1, 2, 3]'", List.of(1, 2, 3)); + assertCastToVariant("JSON '{}'", Map.of()); + assertCastToVariant("JSON '{\"a\": 1, \"b\": \"two\"}'", Map.of("a", 1, "b", "two")); + } + + @Test + void testCastJsonArraysAndObjectsRoundTrip() + { + // Simple array + assertThat(assertions.expression("CAST(CAST(a AS VARIANT) AS JSON)") + .binding("a", "JSON '[1, 2, 3]'")) + .isEqualTo("[1,2,3]"); + + // Nested arrays and objects + assertThat(assertions.expression("CAST(CAST(a AS VARIANT) AS JSON)") + .binding("a", "JSON '{\"a\": [1, {\"b\": true}], \"c\": null}'")) + .isEqualTo("{\"a\":[1,{\"b\":true}],\"c\":null}"); + + // Empty array + assertThat(assertions.expression("CAST(CAST(a AS VARIANT) AS JSON)") + .binding("a", "JSON '[]'")) + .isEqualTo("[]"); + + // Empty object + assertThat(assertions.expression("CAST(CAST(a AS VARIANT) AS JSON)") + .binding("a", "JSON '{}'")) + .isEqualTo("{}"); + } + + @Test + void testCastJsonToVariantMetadataAndFieldOrdering() + { + // Top-level array of objects with different field sets and order + // This exercises: + // * global metadata over the whole JSON tree + // * object-field encoding sorted by UTF-8 name (via Slice sort) + // * correct field-id assignment per object + String json = """ + [ + {"b": 1, "a": 2}, + {"c": 3, "a": 4} + ] + """; + + assertThat(assertions.expression("CAST(a AS VARIANT)") + .binding("a", "JSON '%s'".formatted(json))) + .asInstanceOf(type(Variant.class)) + .satisfies(variant -> { + assertThat(variant.metadata().isSorted()).isTrue(); + // Root is an array + List elements = variant.arrayElements().toList(); + assertThat(elements).hasSize(2); + + Metadata metadata = variant.metadata(); + + int idA = metadata.id(utf8Slice("a")); + int idB = metadata.id(utf8Slice("b")); + int idC = metadata.id(utf8Slice("c")); + + // First object: {"b":1,"a":2} + Variant object1 = elements.getFirst(); + // verify fields are written in sorted order by field name UTF-8 bytes + assertThat(object1.objectFields().toList()) + .extracting(ObjectFieldIdValue::fieldId) + .containsExactly(idA, idB); + + assertThat(object1.getObjectField(idA).orElseThrow().getInt()).isEqualTo(2); + assertThat(object1.getObjectField(idB).orElseThrow().getInt()).isEqualTo(1); + + // Second object: {"c":3,"a":4} + Variant object2 = elements.get(1); + // verify fields are written in sorted order by field name UTF-8 bytes + assertThat(object2.objectFields().toList()) + .extracting(ObjectFieldIdValue::fieldId) + .containsExactly(idA, idC); + + assertThat(object2.getObjectField(idA).orElseThrow().getInt()).isEqualTo(4); + assertThat(object2.getObjectField(idC).orElseThrow().getInt()).isEqualTo(3); + }); + } + + @Test + void testCastJsonToVariantUtf8FieldOrdering() + { + // Use some non-ASCII field names to exercise the Slice/UTF-8 sort. + // These are chosen just to make sure we're not assuming ASCII-only. + String json = """ + { + "é": 1, + "e": 2, + "Ω": 3 + } + """; + + assertThat(assertions.expression("CAST(a AS VARIANT)") + .binding("a", "JSON '%s'".formatted(json))) + .asInstanceOf(type(Variant.class)) + .satisfies(variant -> { + Metadata metadata = variant.metadata(); + assertThat(metadata.isSorted()).isTrue(); + + // Binary UTF-8 ordering: "e" < "é" < "Ω" + assertThat(metadata.get(0)).isEqualTo(utf8Slice("e")); + assertThat(metadata.get(1)).isEqualTo(utf8Slice("é")); + assertThat(metadata.get(2)).isEqualTo(utf8Slice("Ω")); + + // Object fields should reference those metadata entries + assertThat(variant.objectFields() + .map(ObjectFieldIdValue::fieldId) + .map(metadata::get) + .map(Slice::toStringUtf8) + .toList()) + .containsExactly("e", "é", "Ω"); + }); + + // Also round-trip JSON -> VARIANT -> JSON structurally + assertThat(assertions.expression("CAST(CAST(a AS VARIANT) AS JSON)") + .binding("a", "JSON '%s'".formatted(json.replace("'", "''")))) + .isEqualTo("{\"e\":2,\"é\":1,\"Ω\":3}"); + } + + @Test + void testVariantToArrayCast() + { + Variant arrayVariant = Variant.ofArray(Arrays.asList( + Variant.ofBoolean(true), + Variant.ofByte((byte) 10), + Variant.ofShort((short) 20), + Variant.ofInt(30), + Variant.ofLong(40), + Variant.ofFloat(50), + Variant.ofDouble(60), + Variant.ofString("70"))); + // VARIANT -> ARRAY + assertThat(assertions.expression("cast(a as ARRAY)") + .binding("a", toVariantLiteral(arrayVariant))) + .asInstanceOf(list(Integer.class)) + .containsExactly(1, 10, 20, 30, 40, 50, 60, 70); + + // VARIANT -> ARRAY + assertThat(assertions.expression("cast(a as ARRAY)") + .binding("a", toVariantLiteral(arrayVariant))) + .asInstanceOf(list(Byte.class)) + .containsExactly((byte) 1, (byte) 10, (byte) 20, (byte) 30, (byte) 40, (byte) 50, (byte) 60, (byte) 70); + + // VARIANT -> ARRAY + assertThat(assertions.expression("cast(a as ARRAY)") + .binding("a", toVariantLiteral(arrayVariant))) + .asInstanceOf(list(Variant.class)) + .satisfies(list -> { + assertThat(list).hasSize(8); + assertThat((Variant) list.get(0)).extracting(Variant::getBoolean).isEqualTo(true); + assertThat((Variant) list.get(3)).extracting(Variant::getInt).isEqualTo(30); + assertThat((Variant) list.get(7)).extracting(variant -> variant.getString().toStringUtf8()).isEqualTo("70"); + }); + + Variant jsonArrayVariant = Variant.ofArray(List.of(Variant.ofInt(1), Variant.ofString("two"))); + assertThat(assertions.expression("cast(cast(a as ARRAY) as JSON)") + .binding("a", toVariantLiteral(jsonArrayVariant))) + .isEqualTo("[1,\"two\"]"); + } + + @Test + void testArrayToVariantCast() + { + List intElements = Arrays.asList(1, 10, 20, 30, 40, 50, 60, 70); + String intArrayLiteral = intElements.stream() + .map(String::valueOf) + .collect(Collectors.joining(", ", "ARRAY[", "]")); + assertCastToVariant(intArrayLiteral, intElements); + + List stringElements = Arrays.asList("one", "two", "three"); + String stringArrayLiteral = stringElements.stream() + .map(value -> "'" + value + "'") + .collect(Collectors.joining(", ", "ARRAY[", "]")); + assertCastToVariant(stringArrayLiteral, stringElements); + } + + @Test + void testMapToVariantCast() + { + // Basic map cast (string keys, primitive values) + // Variant preserves key case + assertCastToVariant( + "MAP(ARRAY['banAna', 'appLE', 'chERry'], ARRAY[2, 1, 3])", + Map.of( + "appLE", 1, + "banAna", 2, + "chERry", 3)); + + // Nested map values (generic map writer path) + assertCastToVariant( + "MAP(ARRAY['b', 'a'], ARRAY[MAP(ARRAY['y'], ARRAY[10]), MAP(ARRAY['x'], ARRAY[20])])", + Map.of( + "a", Map.of("x", 20), + "b", Map.of("y", 10))); + + // Map values that are arrays (generic map writer path) + assertCastToVariant( + "MAP(ARRAY['b', 'a'], ARRAY[ARRAY[1, 2, 3], ARRAY[4, 5]])", + Map.of( + "a", List.of(4, 5), + "b", List.of(1, 2, 3))); + + // Note, duplicate keys and null keys cannot be tested easily because Trino MAP constructor + // itself enforces non-null unique keys. This could be done with a custom function, + // but is not worth the effort. + } + + @Test + void testRowToVariantCast() + { + // Trino fields are always in upper-case and this is preserved in Variant + assertCastToVariant( + "ROW(1 AS a, 'two' AS b, TRUE AS c)", + Map.of( + "A", 1, + "B", "two", + "C", true)); + + assertCastToVariant( + "ROW('two' AS b, 1 AS a, TRUE AS c)", + Map.of( + "A", 1, + "B", "two", + "C", true)); + + String oneOfEverythingRowLiteral = """ + ROW( + NULL AS a, + TRUE AS b, + 123 AS c, + 1234567890123 AS d, + 1234.56 AS e, + REAL '1.5' AS f, + 'hello' AS g, + DATE '2024-10-24' AS h, + TIMESTAMP '2024-10-24 12:34:56.123456' AS i, + UUID '123e4567-e89b-12d3-a456-426655440000' AS j, + ARRAY[1, 2, 3] AS k, + ROW('y' as x, 10 as z) AS l + ) + """; + Map expectedMap = new LinkedHashMap<>(); + expectedMap.put("A", null); + expectedMap.put("B", true); + expectedMap.put("C", 123); + expectedMap.put("D", 1234567890123L); + expectedMap.put("E", new BigDecimal("1234.56")); + expectedMap.put("F", 1.5f); + expectedMap.put("G", "hello"); + expectedMap.put("H", LocalDate.of(2024, 10, 24)); + expectedMap.put("I", LocalDateTime.of(2024, 10, 24, 12, 34, 56, 123456000)); + expectedMap.put("J", UUID.fromString("123e4567-e89b-12d3-a456-426655440000")); + expectedMap.put("K", List.of(1, 2, 3)); + expectedMap.put("L", Map.of( + "X", "y", + "Z", 10)); + + assertCastToVariant(oneOfEverythingRowLiteral, expectedMap); + } + + @Test + void testVariantLeafNoMetadataIsCopied() + { + Variant leaf = Variant.ofInt(123); + + assertThat(assertions.expression("cast(a as VARIANT)") + .binding("a", "ARRAY[%s, %s]".formatted( + toVariantLiteral(leaf), + toVariantLiteral(leaf)))) + .asInstanceOf(type(Variant.class)) + .satisfies(variant -> { + assertThat(variant.metadata().dictionarySize()).isEqualTo(0); + assertThat(variant.arrayElements() + .map(Variant::getInt) + .toList()) + .isEqualTo(List.of(123, 123)); + }); + } + + @Test + void testVariantLeafSameSizeRemap() + { + Variant leaf1 = Variant.ofObject(Map.of( + utf8Slice("b"), Variant.ofInt(1), + utf8Slice("a"), Variant.ofString("x"))); + + Variant leaf2 = Variant.ofObject(Map.of( + utf8Slice("c"), Variant.ofInt(2))); + + assertThat(assertions.expression("cast(a as VARIANT)") + .binding("a", "ARRAY[%s, %s]".formatted( + toVariantLiteral(leaf1), + toVariantLiteral(leaf2)))) + .asInstanceOf(type(Variant.class)) + .satisfies(variant -> { + assertThat(variant.metadata().dictionarySize()).isEqualTo(3); + assertThat(variant.metadata().isSorted()).isTrue(); + + assertThat(variant.arrayElements() + .map(Variant::toObject) + .toList()) + .isEqualTo(List.of( + Map.of("a", "x", "b", 1), + Map.of("c", 2))); + }); + } + + @Test + void testVariantLeafResizeRemap() + { + Variant big = objectVariantWithManyFields(260, "a"); + + Variant small = Variant.ofObject(Map.of( + utf8Slice("z"), Variant.ofInt(7))); + + assertThat(assertions.expression("cast(a as VARIANT)") + .binding("a", "ARRAY[%s, %s]".formatted( + toVariantLiteral(big), + toVariantLiteral(small)))) + .asInstanceOf(type(Variant.class)) + .satisfies(variant -> { + assertThat(variant.metadata().dictionarySize()).isGreaterThan(256); + assertThat(variant.metadata().isSorted()).isTrue(); + + List elements = variant.arrayElements().toList(); + assertThat(elements).hasSize(2); + assertThat(elements.get(1).toObject()).isEqualTo(Map.of("z", 7)); + }); + } + + @Test + void testVariantLeafDeepRecursionRemap() + { + Variant leaf = Variant.ofObject(ImmutableSortedMap.of( + utf8Slice("outer"), + Variant.ofObject(Map.of( + utf8Slice("inner"), + Variant.ofArray(List.of( + Variant.ofObject(Map.of(utf8Slice("x"), Variant.ofInt(10))), + Variant.ofObject(Map.of(utf8Slice("y"), Variant.ofInt(20))))))))); + + Variant big = objectVariantWithManyFields(300, "a"); + + assertThat(assertions.expression("cast(a as VARIANT)") + .binding("a", "ARRAY[%s, %s]".formatted( + toVariantLiteral(big), + toVariantLiteral(leaf)))) + .asInstanceOf(type(Variant.class)) + .satisfies(variant -> { + assertThat(variant.metadata().dictionarySize()).isGreaterThan(256); + assertThat(variant.metadata().isSorted()).isTrue(); + + List elements = variant.arrayElements().toList(); + assertThat(elements).hasSize(2); + assertThat(elements.get(1).toObject()) + .isEqualTo(Map.of( + "outer", Map.of( + "inner", List.of( + Map.of("x", 10), + Map.of("y", 20))))); + }); + } + + @Test + void testVariantLeafArraySameSizeRemapCopiesArrayHeader() + { + // Big object contributes lots of metadata entries, but stays < 256 so field-id width stays 1 byte. + Variant big = objectVariantWithManyFields(200, "a"); + + // Leaf is an ARRAY whose elements are OBJECTs. Variant.ofArray() will normalize the leaf's + // metadata before the outer cast sees it, but the outer cast still merges in big's names and + // remaps the nested field ids while staying in SAME_SIZE territory (1-byte ids). + Metadata leafMetadata = Metadata.of(List.of(utf8Slice("b"), utf8Slice("a"))); + Variant leafArray = Variant.ofArray(List.of( + createObjectWithSortedFields(leafMetadata, List.of( + new ObjectField(0, Variant.ofInt(1)), + new ObjectField(1, Variant.ofString("x")))), + createObjectWithSortedFields(leafMetadata, List.of( + new ObjectField(0, Variant.ofInt(2)), + new ObjectField(1, Variant.ofString("y")))))); + + assertThat(assertions.expression("cast(a as VARIANT)") + .binding("a", "ARRAY[%s, %s]".formatted( + toVariantLiteral(big), + toVariantLiteral(leafArray)))) + .asInstanceOf(type(Variant.class)) + .satisfies(variant -> { + // Ensure we stayed in SAME_SIZE territory (unsigned byte ids) + assertThat(variant.metadata().dictionarySize()).isLessThanOrEqualTo(256); + assertThat(variant.metadata().isSorted()).isTrue(); + + List elements = variant.arrayElements().toList(); + assertThat(elements).hasSize(2); + + // Second element is the leaf ARRAY; verify semantics survived remap. + assertThat(elements.get(1).toObject()) + .isEqualTo(List.of( + Map.of("a", "x", "b", 1), + Map.of("a", "y", "b", 2))); + }); + } + private static Variant objectVariantWithManyFields(int fieldCount, String prefix) + { + List names = new ArrayList<>(fieldCount); + List fields = new ArrayList<>(fieldCount); + for (int i = 0; i < fieldCount; i++) { + // fixed width keeps lexicographic ordering predictable + Slice name = utf8Slice("%s%03d".formatted(prefix, i)); + names.add(name); + fields.add(new ObjectField(i, Variant.ofInt(i))); + } + return createObjectWithSortedFields(Metadata.of(names), fields); + } + + @Test + void testArrayToVariantCastWithVariantElements() + { + Variant leaf1 = createObjectWithSortedFields( + Metadata.of(List.of(utf8Slice("b"), utf8Slice("a"))), + List.of( + new ObjectField(0, Variant.ofInt(1)), + new ObjectField(1, Variant.ofString("x")))); + + Variant leaf2 = createObjectWithSortedFields( + Metadata.of(List.of(utf8Slice("c"), utf8Slice("a"))), + List.of( + new ObjectField(0, Variant.ofInt(2)), + new ObjectField(1, Variant.ofInt(3)))); + + assertCastToVariant( + "ARRAY[%s, %s]".formatted(toVariantLiteral(leaf1), toVariantLiteral(leaf2)), + List.of( + Map.of("a", "x", "b", 1), + Map.of("a", 3, "c", 2)) + ); + } + + @Test + void testMapToVariantCastWithVariantValues() + { + Variant leaf1 = createObjectWithSortedFields( + Metadata.of(List.of(utf8Slice("b"), utf8Slice("a"))), + List.of( + new ObjectField(0, Variant.ofInt(1)), + new ObjectField(1, Variant.ofString("x")))); + + Variant leaf2 = createObjectWithSortedFields( + Metadata.of(List.of(utf8Slice("c"), utf8Slice("a"))), + List.of( + new ObjectField(0, Variant.ofInt(2)), + new ObjectField(1, Variant.ofInt(3)))); + + assertCastToVariant( + "MAP(ARRAY['k2', 'k1'], ARRAY[%s, %s])".formatted(toVariantLiteral(leaf2), toVariantLiteral(leaf1)), + Map.of( + "k1", Map.of("a", "x", "b", 1), + "k2", Map.of("a", 3, "c", 2))); + } + + @Test + void testRowToVariantCastWithVariantFields() + { + Variant leaf1 = createObjectWithSortedFields( + Metadata.of(List.of(utf8Slice("b"), utf8Slice("a"))), + List.of( + new ObjectField(0, Variant.ofInt(1)), + new ObjectField(1, Variant.ofString("x")))); + + Variant leaf2 = createObjectWithSortedFields( + Metadata.of(List.of(utf8Slice("c"), utf8Slice("a"))), + List.of( + new ObjectField(0, Variant.ofInt(2)), + new ObjectField(1, Variant.ofInt(3)))); + + assertCastToVariant( + "ROW(%s AS v1, 10 AS x, %s AS v2)".formatted(toVariantLiteral(leaf1), toVariantLiteral(leaf2)), + Map.of( + "V1", Map.of("a", "x", "b", 1), + "X", 10, + "V2", Map.of("a", 3, "c", 2))); + } + + @Test + void testArrayDereference() + { + Variant arrayVariant = Variant.fromObject(List.of(10, true, "😊")); + + assertThat(assertions.expression("a[1]") + .binding("a", toVariantLiteral(arrayVariant))) + .asInstanceOf(type(Variant.class)) + .extracting(Variant::getInt) + .isEqualTo(10); + assertThat(assertions.expression("a[3]") + .binding("a", toVariantLiteral(arrayVariant))) + .asInstanceOf(type(Variant.class)) + .extracting(variant -> variant.getString().toStringUtf8()) + .isEqualTo("😊"); + + // out of bounds is an error + assertTrinoExceptionThrownBy(() -> assertions.expression("a[0]") + .binding("a", toVariantLiteral(arrayVariant)) + .evaluate()) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT) + .hasMessage("VARIANT array indices start at 1"); + assertTrinoExceptionThrownBy(() -> assertions.expression("a[4]") + .binding("a", toVariantLiteral(arrayVariant)) + .evaluate()) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT) + .hasMessage("VARIANT array subscript must be less than or equal to array length: 4 > 3"); + assertTrinoExceptionThrownBy(() -> assertions.expression("a[-4]") + .binding("a", toVariantLiteral(arrayVariant)) + .evaluate()) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT) + .hasMessage("VARIANT array subscript is negative: -4"); + + assertTrinoExceptionThrownBy(() -> assertions.expression("a[1]") + .binding("a", toVariantLiteral(Variant.ofInt(1))) + .evaluate()) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT) + .hasMessage("VARIANT value is int32, not an array"); + } + + @Test + void testObjectDereference() + { + Variant arrayVariant = createObjectWithSortedFields( + Metadata.of(List.of(utf8Slice("apple"), utf8Slice("BANANA"), utf8Slice("c"))), + List.of( + new ObjectField(1, Variant.ofInt(10)), + new ObjectField(2, Variant.ofBoolean(true)), + new ObjectField(0, Variant.ofString("😊")))); + + assertThat(assertions.expression("a['apple']") + .binding("a", toVariantLiteral(arrayVariant))) + .asInstanceOf(type(Variant.class)) + .extracting(variant -> variant.getString().toStringUtf8()) + .isEqualTo("😊"); + assertThat(assertions.expression("a['BANANA']") + .binding("a", toVariantLiteral(arrayVariant))) + .asInstanceOf(type(Variant.class)) + .extracting(Variant::getInt) + .isEqualTo(10); + assertThat(assertions.expression("a['c']") + .binding("a", toVariantLiteral(arrayVariant))) + .asInstanceOf(type(Variant.class)) + .extracting(Variant::getBoolean) + .isEqualTo(true); + + // unknown key returns null + assertThat(assertions.expression("a['unknown']") + .binding("a", toVariantLiteral(arrayVariant))) + .isNull(); + + // keys are case sensitive + assertThat(assertions.expression("a['Apple']") + .binding("a", toVariantLiteral(arrayVariant))) + .isNull(); + assertThat(assertions.expression("a['banana']") + .binding("a", toVariantLiteral(arrayVariant))) + .isNull(); + + assertTrinoExceptionThrownBy(() -> assertions.expression("a['apple']") + .binding("a", toVariantLiteral(Variant.ofInt(1))) + .evaluate()) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT) + .hasMessage("VARIANT value is int32, not an object"); + } + + private void assertCastToVariant(String sqlLiteral, Object expected) + { + assertThat(assertions.expression("CAST(a as VARIANT)") + .binding("a", sqlLiteral)) + .asInstanceOf(type(Variant.class)) + .extracting(Variant::toObject) + .isEqualTo(expected); + + assertThat(assertions.expression("CAST(a AS VARIANT)") + .binding("a", "ARRAY[%s, %s, %s]".formatted(sqlLiteral, sqlLiteral, sqlLiteral))) + .asInstanceOf(type(Variant.class)) + .extracting(Variant::toObject) + .isEqualTo(Arrays.asList(expected, expected, expected)); + + assertThat(assertions.expression("CAST(a AS VARIANT)") + .binding("a", "MAP(ARRAY['key1', 'key2', 'key3'], ARRAY[%s, %s, %s])".formatted(sqlLiteral, sqlLiteral, sqlLiteral))) + .asInstanceOf(type(Variant.class)) + .extracting(Variant::toObject) + .isEqualTo(Map.of( + "key1", expected, + "key2", expected, + "key3", expected)); + + assertThat(assertions.expression("CAST(a AS VARIANT)") + .binding("a", "ROW(%s AS col1, %s AS col2, %s AS col3)".formatted(sqlLiteral, sqlLiteral, sqlLiteral))) + .asInstanceOf(type(Variant.class)) + .extracting(Variant::toObject) + .isEqualTo(Map.of( + "COL1", expected, + "COL2", expected, + "COL3", expected)); + + assertThat(assertions.expression("CAST(a as VARIANT) = CAST(b as VARIANT)") + .binding("a", sqlLiteral) + .binding("b", sqlLiteral)) + .isEqualTo(true); + + assertThat(assertions.expression("CAST(a as VARIANT) = CAST(b as VARIANT)") + .binding("a", "ARRAY[%s, %s, %s]".formatted(sqlLiteral, sqlLiteral, sqlLiteral)) + .binding("b", "ARRAY[%s, %s, %s]".formatted(sqlLiteral, sqlLiteral, sqlLiteral))) + .isEqualTo(true); + + assertThat(assertions.expression("CAST(a as VARIANT) = CAST(b as VARIANT)") + .binding("a", "MAP(ARRAY['key1', 'key2', 'key3'], ARRAY[%s, %s, %s])".formatted(sqlLiteral, sqlLiteral, sqlLiteral)) + .binding("b", "MAP(ARRAY['key1', 'key2', 'key3'], ARRAY[%s, %s, %s])".formatted(sqlLiteral, sqlLiteral, sqlLiteral))) + .isEqualTo(true); + + assertThat(assertions.expression("CAST(a as VARIANT) = CAST(b as VARIANT)") + .binding("a", "ROW(%s AS col1, %s AS col2, %s AS col3)".formatted(sqlLiteral, sqlLiteral, sqlLiteral)) + .binding("b", "ROW(%s AS col1, %s AS col2, %s AS col3)".formatted(sqlLiteral, sqlLiteral, sqlLiteral))) + .isEqualTo(true); + } + + private void assertCastFromVariant(Variant variant, String type, Object value) + { + assertThat(assertions.expression("CAST(a as %s)".formatted(type)) + .binding("a", toVariantLiteral(variant))) + .isEqualTo(value); + + assertThat(assertions.expression("CAST(a AS ARRAY<%s>)".formatted(type)) + .binding("a", toVariantLiteral(Variant.ofArray(Arrays.asList(variant, variant, variant))))) + .asInstanceOf(list(Object.class)) + .containsExactly(value, value, value); + + assertThat(assertions.expression("CAST(a AS MAP)".formatted(type)) + .binding( + "a", toVariantLiteral(createObjectWithSortedFields( + Metadata.of(List.of(utf8Slice("key1"), utf8Slice("key2"), utf8Slice("key3"))), + List.of( + new ObjectField(0, variant), + new ObjectField(1, variant), + new ObjectField(2, variant)))))) + .asInstanceOf(map(String.class, Object.class)) + .containsExactly( + entry("key1", value), + entry("key2", value), + entry("key3", value)); + + assertThat(assertions.expression("CAST(a AS ROW(col1 %s, col2 %s, col3 %s))".formatted(type, type, type)) + .binding( + "a", toVariantLiteral(createObjectWithSortedFields( + Metadata.of(List.of(utf8Slice("col1"), utf8Slice("col2"), utf8Slice("col3"))), + List.of( + new ObjectField(0, variant), + new ObjectField(1, variant), + new ObjectField(2, variant)))))) + .asInstanceOf(list(Object.class)) + .containsExactly(value, value, value); + } + + private TrinoExceptionAssert assertCastFromVariantThrows(Variant variant, String type) + { + return assertTrinoExceptionThrownBy(() -> assertions.expression("CAST(a as %s)".formatted(type)) + .binding("a", toVariantLiteral(variant)) + .evaluate()); + } + + private static String toVariantLiteral(Variant variant) + { + String hexMetadata = VarbinaryFunctions.toHex(variant.metadata().toSlice()).toStringUtf8(); + String hexValue = VarbinaryFunctions.toHex(variant.data()).toStringUtf8(); + return String.format("decode_variant(X'%s', X'%s')", hexMetadata, hexValue); + } + + private static Variant createObjectWithSortedFields(Metadata metadata, List fields) + { + return Variant.from(metadata, encodeObjectWithSortedFields(fields)); + } + + // Builds a variant with the exact specified field order. Variants by spec are required to have fields sorted ordered by field name. + // This method assumes that the caller has already sorted the fields by field name. + // This method is necessary to build test variants without global sorting in the metadata dictionary, as all convenience methods + // on Variant build a metadata dictionary with global sorting. + private static Slice encodeObjectWithSortedFields(List fields) + { + int expectedSize = encodedObjectSize( + fields.stream() + .mapToInt(ObjectField::fieldId) + .max() + .orElse(0), + fields.size(), + fields.stream() + .mapToInt(field -> field.variantValue().length()) + .sum()); + Slice output = Slices.allocate(expectedSize); + + int written = encodeObject( + fields.size(), + i -> fields.get(i).fieldId(), + i -> fields.get(i).variantValue(), + output, + 0); + verify(written == expectedSize, "written size does not match expected size"); + return output; + } + + private record ObjectField(int fieldId, Slice variantValue) + { + private ObjectField(int fieldId, Variant variant) + { + this(fieldId, variant.data()); + assertThat(variant.metadata()).isEqualTo(Metadata.EMPTY_METADATA); + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/type/TestVariantType.java b/core/trino-main/src/test/java/io/trino/type/TestVariantType.java new file mode 100644 index 000000000000..95db6a42b480 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/type/TestVariantType.java @@ -0,0 +1,78 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.type; + +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; +import io.trino.spi.variant.Variant; +import org.junit.jupiter.api.Test; + +import static io.trino.spi.type.VariantType.VARIANT; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class TestVariantType + extends AbstractTestType +{ + TestVariantType() + { + super(VARIANT, Variant.class, createTestBlock()); + } + + public static ValueBlock createTestBlock() + { + BlockBuilder blockBuilder = VARIANT.createBlockBuilder(null, 15); + VARIANT.writeObject(blockBuilder, Variant.NULL_VALUE); + VARIANT.writeObject(blockBuilder, Variant.NULL_VALUE); + VARIANT.writeObject(blockBuilder, Variant.ofBoolean(false)); + VARIANT.writeObject(blockBuilder, Variant.ofBoolean(false)); + VARIANT.writeObject(blockBuilder, Variant.ofBoolean(true)); + VARIANT.writeObject(blockBuilder, Variant.ofBoolean(true)); + VARIANT.writeObject(blockBuilder, Variant.ofInt(11)); + VARIANT.writeObject(blockBuilder, Variant.ofInt(11)); + VARIANT.writeObject(blockBuilder, Variant.ofString("hello")); + VARIANT.writeObject(blockBuilder, Variant.ofString("hello")); + VARIANT.writeObject(blockBuilder, Variant.ofDouble(44.44)); + return blockBuilder.buildValueBlock(); + } + + @Override + protected Object getGreaterValue(Object value) + { + throw new UnsupportedOperationException(); + } + + @Test + public void testRange() + { + assertThat(type.getRange()) + .isEmpty(); + } + + @Test + public void testPreviousValue() + { + assertThatThrownBy(() -> type.getPreviousValue(getSampleValue())) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Type is not orderable: " + type); + } + + @Test + public void testNextValue() + { + assertThatThrownBy(() -> type.getNextValue(getSampleValue())) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Type is not orderable: " + type); + } +} diff --git a/core/trino-spi/pom.xml b/core/trino-spi/pom.xml index 79f5bfc9c68f..8bf9a4043423 100644 --- a/core/trino-spi/pom.xml +++ b/core/trino-spi/pom.xml @@ -18,7 +18,6 @@ ${air.check.skip-basic} - com.fasterxml.jackson.core @@ -120,6 +119,18 @@ test + + org.apache.iceberg + iceberg-api + test + + + + org.apache.iceberg + iceberg-core + test + + org.assertj assertj-core @@ -271,6 +282,17 @@ + + + org.basepom.maven + duplicate-finder-maven-plugin + + + + ^iceberg-build\.properties$ + + + diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlockBuilder.java b/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlockBuilder.java index 89a8a599a90c..9e815f149d55 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlockBuilder.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlockBuilder.java @@ -33,7 +33,7 @@ public class VariableWidthBlockBuilder implements BlockBuilder { private static final int INSTANCE_SIZE = instanceSize(VariableWidthBlockBuilder.class); - private static final Block NULL_VALUE_BLOCK = new VariableWidthBlock(0, 1, EMPTY_SLICE, new int[] {0, 0}, new boolean[] {true}); + static final Block NULL_VALUE_BLOCK = new VariableWidthBlock(0, 1, EMPTY_SLICE, new int[] {0, 0}, new boolean[] {true}); private static final int SIZE_IN_BYTES_PER_POSITION = Integer.BYTES + Byte.BYTES; private final BlockBuilderStatus blockBuilderStatus; diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/VariantBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/VariantBlock.java new file mode 100644 index 000000000000..c7608d2cef22 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/block/VariantBlock.java @@ -0,0 +1,473 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.block; + +import io.airlift.slice.Slice; +import io.trino.spi.variant.Header; +import io.trino.spi.variant.Metadata; +import io.trino.spi.variant.Variant; +import jakarta.annotation.Nullable; + +import java.util.Optional; +import java.util.function.ObjLongConsumer; + +import static io.airlift.slice.SizeOf.instanceSize; +import static io.airlift.slice.SizeOf.sizeOf; +import static io.trino.spi.block.BlockUtil.checkArrayRange; +import static io.trino.spi.block.BlockUtil.checkReadablePosition; +import static io.trino.spi.block.BlockUtil.checkValidRegion; +import static io.trino.spi.block.BlockUtil.compactArray; +import static java.util.Objects.requireNonNull; + +public final class VariantBlock + implements ValueBlock +{ + private static final int INSTANCE_SIZE = instanceSize(VariantBlock.class); + + private final int startOffset; + private final int positionCount; + @Nullable + private final boolean[] isNull; + /** + * Metadata and value blocks have the same position count as this variant block. The field value of a null variant must be null. + */ + private final Block metadata; + private final Block values; + + private volatile long sizeInBytes = -1; + private volatile long retainedSizeInBytes = -1; + + /** + * Creates a variant block directly from metadata and value blocks. + */ + public static VariantBlock create(int positionCount, Block metadata, Block values, Optional isNullOptional) + { + // verify that field values for null variants are null + if (isNullOptional.isPresent()) { + boolean[] isNull = isNullOptional.get(); + checkArrayRange(isNull, 0, positionCount); + verifyPositionsAreNull(metadata, isNull, positionCount, "Metadata"); + verifyPositionsAreNull(values, isNull, positionCount, "Values"); + } + + return createInternal(0, positionCount, isNullOptional.orElse(null), metadata, values); + } + + private static void verifyPositionsAreNull(Block block, boolean[] isNull, int positionCount, String name) + { + for (int position = 0; position < positionCount; position++) { + if (isNull[position] && !block.isNull(position)) { + throw new IllegalArgumentException("%s for null variant must be null: position %d".formatted(name, position)); + } + } + } + + static VariantBlock createInternal(int startOffset, int positionCount, @Nullable boolean[] isNull, Block metadata, Block values) + { + return new VariantBlock(startOffset, positionCount, isNull, metadata, values); + } + + /** + * Use createInternal or fromMetadataValuesBlocks instead of this method. The caller of this method is assumed to have + * validated the arguments with validateConstructorArguments. + */ + private VariantBlock(int startOffset, int positionCount, @Nullable boolean[] isNull, Block metadata, Block values) + { + if (startOffset < 0) { + throw new IllegalArgumentException("startOffset is negative"); + } + + if (positionCount < 0) { + throw new IllegalArgumentException("positionCount is negative"); + } + + if (isNull != null && isNull.length - startOffset < positionCount) { + throw new IllegalArgumentException("isNull length is less than positionCount"); + } + + requireNonNull(metadata, "metadata is null"); + requireNonNull(values, "values is null"); + if (metadata.getPositionCount() != values.getPositionCount()) { + throw new IllegalArgumentException("metadata and values blocks must have the same position count"); + } + if (metadata.getPositionCount() - startOffset < positionCount) { + throw new IllegalArgumentException("fieldBlock length is less than positionCount"); + } + + this.startOffset = startOffset; + this.positionCount = positionCount; + this.isNull = positionCount == 0 ? null : isNull; + this.metadata = metadata; + this.values = values; + } + + public Block getMetadata() + { + if ((startOffset == 0) && (metadata.getPositionCount() == positionCount)) { + return metadata; + } + return metadata.getRegion(startOffset, positionCount); + } + + public Block getValues() + { + if ((startOffset == 0) && (values.getPositionCount() == positionCount)) { + return values; + } + return values.getRegion(startOffset, positionCount); + } + + public Header.BasicType getBasicType(int position) + { + VariableWidthBlock variableWidthBlock = (VariableWidthBlock) values.getUnderlyingValueBlock(); + int valuePosition = values.getUnderlyingValuePosition(position); + + Slice rawSlice = variableWidthBlock.getRawSlice(); + int rawSliceOffset = variableWidthBlock.getRawSliceOffset(valuePosition); + return Header.getBasicType(rawSlice.getByte(rawSliceOffset)); + } + + public int getRawOffset() + { + return startOffset; + } + + public Block getRawMetadata() + { + return metadata; + } + + public Block getRawValues() + { + return values; + } + + public int getOffsetBase() + { + return startOffset; + } + + @Override + public boolean mayHaveNull() + { + return isNull != null; + } + + @Override + public boolean hasNull() + { + if (isNull == null) { + return false; + } + for (int i = 0; i < positionCount; i++) { + if (isNull[startOffset + i]) { + return true; + } + } + return false; + } + + @Nullable + public boolean[] getRawIsNull() + { + return isNull; + } + + @Override + public int getPositionCount() + { + return positionCount; + } + + @Override + public long getSizeInBytes() + { + if (sizeInBytes >= 0) { + return sizeInBytes; + } + + long sizeInBytes = Byte.BYTES * (long) positionCount; + sizeInBytes += metadata.getRegionSizeInBytes(startOffset, positionCount); + sizeInBytes += values.getRegionSizeInBytes(startOffset, positionCount); + this.sizeInBytes = sizeInBytes; + return sizeInBytes; + } + + @Override + public long getRetainedSizeInBytes() + { + long retainedSizeInBytes = this.retainedSizeInBytes; + if (retainedSizeInBytes < 0) { + retainedSizeInBytes = INSTANCE_SIZE + sizeOf(isNull); + retainedSizeInBytes += metadata.getRetainedSizeInBytes(); + retainedSizeInBytes += values.getRetainedSizeInBytes(); + this.retainedSizeInBytes = retainedSizeInBytes; + } + return retainedSizeInBytes; + } + + @Override + public void retainedBytesForEachPart(ObjLongConsumer consumer) + { + consumer.accept(metadata, metadata.getRetainedSizeInBytes()); + consumer.accept(values, values.getRetainedSizeInBytes()); + if (isNull != null) { + consumer.accept(isNull, sizeOf(isNull)); + } + consumer.accept(this, INSTANCE_SIZE); + } + + @Override + public String toString() + { + return "VariantBlock{startOffset=%d, positionCount=%d}".formatted(startOffset, positionCount); + } + + @Override + public VariantBlock copyWithAppendedNull() + { + boolean[] newIsNull = new boolean[positionCount + 1]; + if (isNull != null) { + checkArrayRange(isNull, startOffset, positionCount); + System.arraycopy(isNull, startOffset, newIsNull, 0, positionCount); + } + newIsNull[positionCount] = true; + + Block newMetadata = getMetadata().copyWithAppendedNull(); + Block newValues = getValues().copyWithAppendedNull(); + return new VariantBlock(0, positionCount + 1, newIsNull, newMetadata, newValues); + } + + @Override + public VariantBlock copyPositions(int[] positions, int offset, int length) + { + checkArrayRange(positions, offset, length); + + Block newMetadata = copyBlockPositions(positions, offset, length, metadata, startOffset, positionCount); + Block newValues = copyBlockPositions(positions, offset, length, values, startOffset, positionCount); + + boolean[] newIsNull = null; + if (isNull != null) { + boolean hasNull = false; + newIsNull = new boolean[length]; + for (int i = 0; i < length; i++) { + boolean isNull = this.isNull[startOffset + positions[offset + i]]; + newIsNull[i] = isNull; + hasNull |= isNull; + } + if (!hasNull) { + newIsNull = null; + } + } + + return new VariantBlock(0, length, newIsNull, newMetadata, newValues); + } + + private static Block copyBlockPositions(int[] positions, int offset, int length, Block block, int blockOffset, int blockLength) + { + // If the variant block has a non-zero starting offset, we have to create a temporary block starting + // from the correct offset before copying positions + if (blockOffset != 0) { + block = block.getRegion(blockOffset, blockLength); + } + return block.copyPositions(positions, offset, length); + } + + @Override + public VariantBlock getRegion(int positionOffset, int length) + { + checkValidRegion(positionCount, positionOffset, length); + + return new VariantBlock(startOffset + positionOffset, length, isNull, metadata, values); + } + + @Override + public long getRegionSizeInBytes(int position, int length) + { + checkValidRegion(positionCount, position, length); + + long regionSizeInBytes = Byte.BYTES * (long) length; + regionSizeInBytes += metadata.getRegionSizeInBytes(startOffset + position, length); + regionSizeInBytes += values.getRegionSizeInBytes(startOffset + position, length); + return regionSizeInBytes; + } + + @Override + public VariantBlock copyRegion(int positionOffset, int length) + { + checkValidRegion(positionCount, positionOffset, length); + + Block newMetadata = metadata.copyRegion(startOffset + positionOffset, length); + Block newValues = values.copyRegion(startOffset + positionOffset, length); + + boolean[] newIsNull = isNull == null ? null : compactArray(isNull, startOffset + positionOffset, length); + if (startOffset == 0 && newIsNull == isNull && metadata == newMetadata && values == newValues) { + return this; + } + return new VariantBlock(0, length, newIsNull, newMetadata, newValues); + } + + public Variant getVariant(int position) + { + checkReadablePosition(this, position); + if (isNull(position)) { + throw new IllegalStateException("Position is null"); + } + return Variant.from(Metadata.from(getSlice(metadata, startOffset + position)), getSlice(values, startOffset + position)); + } + + private static Slice getSlice(Block block, int position) + { + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getSlice(valuePosition); + } + + @Override + public VariantBlock getSingleValueBlock(int position) + { + checkReadablePosition(this, position); + + Block newMetadata = metadata.getSingleValueBlock(startOffset + position); + Block newValues = values.getSingleValueBlock(startOffset + position); + boolean[] newIsNull = isNull(position) ? new boolean[] {true} : null; + return new VariantBlock(0, 1, newIsNull, newMetadata, newValues); + } + + @Override + public long getEstimatedDataSizeForStats(int position) + { + checkReadablePosition(this, position); + + if (isNull(position)) { + return 0; + } + + return metadata.getEstimatedDataSizeForStats(startOffset + position) + + values.getEstimatedDataSizeForStats(startOffset + position); + } + + @Override + public boolean isNull(int position) + { + if (!mayHaveNull()) { + return false; + } + checkReadablePosition(this, position); + return isNull[startOffset + position]; + } + + public record VariantNestedBlocks(Block metadataBlock, Block valueBlock) + { + public VariantNestedBlocks + { + requireNonNull(metadataBlock, "metadataBlock is null"); + requireNonNull(valueBlock, "valueBlock is null"); + } + } + + /** + * Returns the nested variant fields from the specified block. The block maybe a RunLengthEncodedBlock, or + * DictionaryBlock, but the underlying block must be a VariantBlock. The returned nested blocks will be the same + * length as the specified block, which means they are not null suppressed. + */ + // this code was copied from RowBlock + public static VariantNestedBlocks getNestedFields(Block block) + { + if (block instanceof RunLengthEncodedBlock runLengthEncodedBlock) { + VariantBlock variantBlock = (VariantBlock) runLengthEncodedBlock.getValue(); + return new VariantNestedBlocks( + RunLengthEncodedBlock.create(variantBlock.getMetadata(), runLengthEncodedBlock.getPositionCount()), + RunLengthEncodedBlock.create(variantBlock.getValues(), runLengthEncodedBlock.getPositionCount())); + } + if (block instanceof DictionaryBlock dictionaryBlock) { + VariantBlock variantBlock = (VariantBlock) dictionaryBlock.getDictionary(); + return new VariantNestedBlocks( + dictionaryBlock.createProjection(variantBlock.getMetadata()), + dictionaryBlock.createProjection(variantBlock.getValues())); + } + if (block instanceof VariantBlock variantBlock) { + return new VariantNestedBlocks(variantBlock.getMetadata(), variantBlock.getValues()); + } + throw new IllegalArgumentException("Unexpected block type: " + block.getClass().getSimpleName()); + } + + /** + * Returns the nested variant fields from the specified block with null variants suppressed. The block maybe a RunLengthEncodedBlock, or + * DictionaryBlock, but the underlying block must be a VariantBlock. The returned nested blocks will not be the same + * length as the specified block if it contains null variants. + */ + // this code was copied from RowBlock + public static VariantNestedBlocks getNullSuppressedNestedFields(Block block) + { + if (!block.mayHaveNull()) { + return getNestedFields(block); + } + + return switch (block) { + case RunLengthEncodedBlock runLengthEncodedBlock -> { + VariantBlock variantBlock = (VariantBlock) runLengthEncodedBlock.getValue(); + if (!variantBlock.isNull(0)) { + throw new IllegalStateException("Expected run length encoded block value to be null"); + } + // all values are null, so return a zero-length block of the correct type + yield new VariantNestedBlocks( + variantBlock.getMetadata().getRegion(0, 0), + variantBlock.getValues().getRegion(0, 0)); + } + case DictionaryBlock dictionaryBlock -> { + int[] newIds = new int[dictionaryBlock.getPositionCount()]; + int idCount = 0; + for (int position = 0; position < newIds.length; position++) { + if (!dictionaryBlock.isNull(position)) { + newIds[idCount] = dictionaryBlock.getId(position); + idCount++; + } + } + int nonNullPositionCount = idCount; + VariantBlock variantBlock = (VariantBlock) dictionaryBlock.getDictionary(); + yield new VariantNestedBlocks( + DictionaryBlock.create(nonNullPositionCount, variantBlock.getMetadata(), newIds), + DictionaryBlock.create(nonNullPositionCount, variantBlock.getValues(), newIds)); + } + case VariantBlock variantBlock -> { + int[] nonNullPositions = new int[variantBlock.getPositionCount()]; + int idCount = 0; + for (int position = 0; position < nonNullPositions.length; position++) { + if (!variantBlock.isNull(position)) { + nonNullPositions[idCount] = position; + idCount++; + } + } + int nonNullPositionCount = idCount; + yield new VariantNestedBlocks( + DictionaryBlock.create(nonNullPositionCount, variantBlock.getMetadata(), nonNullPositions), + DictionaryBlock.create(nonNullPositionCount, variantBlock.getValues(), nonNullPositions)); + } + default -> throw new IllegalArgumentException("Unexpected block type: " + block.getClass().getSimpleName()); + }; + } + + @Override + public VariantBlock getUnderlyingValueBlock() + { + return this; + } + + @Override + public Optional getNulls() + { + return BlockUtil.getNulls(isNull, startOffset, positionCount); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/VariantBlockBuilder.java b/core/trino-spi/src/main/java/io/trino/spi/block/VariantBlockBuilder.java new file mode 100644 index 000000000000..d999832fc598 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/block/VariantBlockBuilder.java @@ -0,0 +1,398 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.trino.spi.block; + +import io.airlift.slice.Slice; +import io.trino.spi.variant.Variant; +import jakarta.annotation.Nullable; + +import java.util.Arrays; + +import static io.airlift.slice.SizeOf.instanceSize; +import static io.airlift.slice.SizeOf.sizeOf; +import static java.util.Objects.checkIndex; +import static java.util.Objects.requireNonNull; + +public class VariantBlockBuilder + implements BlockBuilder +{ + private static final int INSTANCE_SIZE = instanceSize(VariantBlockBuilder.class); + private static final int VARIANT_ENTRY_SIZE = Integer.BYTES + Byte.BYTES; + private static final VariantBlock NULL_VALUE_BLOCK = VariantBlock.createInternal( + 0, + 1, + new boolean[] {true}, + VariableWidthBlockBuilder.NULL_VALUE_BLOCK, + VariableWidthBlockBuilder.NULL_VALUE_BLOCK); + + @Nullable + private final BlockBuilderStatus blockBuilderStatus; + + private int positionCount; + private boolean[] variantIsNull; + private final VariableWidthBlockBuilder metadataBlockBuilder; + private final VariableWidthBlockBuilder valuesBlockBuilder; + + private boolean hasNullVariant; + private boolean hasNonNullVariant; + + public VariantBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries) + { + this(blockBuilderStatus, expectedEntries, expectedEntries * 9); + } + + public VariantBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytes) + { + this( + blockBuilderStatus, + new VariableWidthBlockBuilder(blockBuilderStatus, expectedEntries, 64), + new VariableWidthBlockBuilder(blockBuilderStatus, expectedEntries, expectedBytes), + new boolean[expectedEntries]); + } + + private VariantBlockBuilder( + @Nullable BlockBuilderStatus blockBuilderStatus, + VariableWidthBlockBuilder metadataBlockBuilder, + VariableWidthBlockBuilder valuesBlockBuilder, + boolean[] variantIsNull) + { + this.blockBuilderStatus = blockBuilderStatus; + this.positionCount = 0; + this.variantIsNull = requireNonNull(variantIsNull, "variantIsNull is null"); + this.metadataBlockBuilder = requireNonNull(metadataBlockBuilder, "metadataBlockBuilder is null"); + this.valuesBlockBuilder = requireNonNull(valuesBlockBuilder, "valuesBlockBuilder is null"); + } + + @Override + public int getPositionCount() + { + return positionCount; + } + + @Override + public long getSizeInBytes() + { + long sizeInBytes = VARIANT_ENTRY_SIZE * (long) positionCount; + sizeInBytes += metadataBlockBuilder.getSizeInBytes(); + sizeInBytes += valuesBlockBuilder.getSizeInBytes(); + return sizeInBytes; + } + + @Override + public long getRetainedSizeInBytes() + { + long size = INSTANCE_SIZE + sizeOf(variantIsNull); + size += metadataBlockBuilder.getRetainedSizeInBytes(); + size += valuesBlockBuilder.getRetainedSizeInBytes(); + if (blockBuilderStatus != null) { + size += BlockBuilderStatus.INSTANCE_SIZE; + } + return size; + } + + public void writeEntry(Variant variant) + { + metadataBlockBuilder.writeEntry(variant.metadata().toSlice()); + valuesBlockBuilder.writeEntry(variant.data()); + entryAdded(false); + } + + public void writeEntry(Slice metadata, Slice value) + { + metadataBlockBuilder.writeEntry(metadata); + valuesBlockBuilder.writeEntry(value); + entryAdded(false); + } + + @Override + public void append(ValueBlock block, int position) + { + VariantBlock variantBlock = (VariantBlock) block; + if (block.isNull(position)) { + appendNull(); + return; + } + + Block rawMetadataBlock = variantBlock.getRawMetadata(); + Block rawValuesBlock = variantBlock.getRawValues(); + int startOffset = variantBlock.getOffsetBase(); + + appendToField(rawMetadataBlock, startOffset + position, metadataBlockBuilder); + appendToField(rawValuesBlock, startOffset + position, valuesBlockBuilder); + entryAdded(false); + } + + private static void appendToField(Block fieldBlock, int position, BlockBuilder fieldBlockBuilder) + { + switch (fieldBlock) { + case RunLengthEncodedBlock rleBlock -> fieldBlockBuilder.append(rleBlock.getValue(), 0); + case DictionaryBlock dictionaryBlock -> fieldBlockBuilder.append(dictionaryBlock.getDictionary(), dictionaryBlock.getId(position)); + case ValueBlock valueBlock -> fieldBlockBuilder.append(valueBlock, position); + } + } + + @Override + public void appendRange(ValueBlock block, int offset, int length) + { + if (length == 0) { + return; + } + + VariantBlock variantBlock = (VariantBlock) block; + ensureCapacity(positionCount + length); + + Block rawMetadataBlock = variantBlock.getRawMetadata(); + Block rawValuesBlock = variantBlock.getRawValues(); + int startOffset = variantBlock.getOffsetBase(); + + appendRangeToField(rawMetadataBlock, startOffset + offset, length, metadataBlockBuilder); + appendRangeToField(rawValuesBlock, startOffset + offset, length, valuesBlockBuilder); + + boolean[] rawVariantIsNull = variantBlock.getRawIsNull(); + if (rawVariantIsNull != null) { + for (int i = 0; i < length; i++) { + boolean isNull = rawVariantIsNull[startOffset + offset + i]; + hasNullVariant |= isNull; + hasNonNullVariant |= !isNull; + if (hasNullVariant & hasNonNullVariant) { + System.arraycopy(rawVariantIsNull, startOffset + offset + i, variantIsNull, positionCount + i, length - i); + break; + } + else { + variantIsNull[positionCount + i] = isNull; + } + } + } + else { + Arrays.fill(variantIsNull, positionCount, positionCount + length, false); + hasNonNullVariant = true; + } + positionCount += length; + + if (blockBuilderStatus != null) { + blockBuilderStatus.addBytes(VARIANT_ENTRY_SIZE * length); + } + } + + private static void appendRangeToField(Block fieldBlock, int offset, int length, BlockBuilder fieldBlockBuilder) + { + switch (fieldBlock) { + case RunLengthEncodedBlock rleBlock -> fieldBlockBuilder.appendRepeated(rleBlock.getValue(), 0, length); + case DictionaryBlock dictionaryBlock -> { + int[] rawIds = dictionaryBlock.getRawIds(); + int rawIdsOffset = dictionaryBlock.getRawIdsOffset(); + fieldBlockBuilder.appendPositions(dictionaryBlock.getDictionary(), rawIds, rawIdsOffset + offset, length); + } + case ValueBlock valueBlock -> fieldBlockBuilder.appendRange(valueBlock, offset, length); + } + } + + @Override + public void appendRepeated(ValueBlock block, int position, int count) + { + if (count == 0) { + return; + } + + VariantBlock variantBlock = (VariantBlock) block; + ensureCapacity(positionCount + count); + + Block rawMetadataBlock = variantBlock.getRawMetadata(); + Block rawValuesBlock = variantBlock.getRawValues(); + int startOffset = variantBlock.getOffsetBase(); + + appendRepeatedToField(rawMetadataBlock, startOffset + position, count, metadataBlockBuilder); + appendRepeatedToField(rawValuesBlock, startOffset + position, count, valuesBlockBuilder); + + if (variantBlock.isNull(position)) { + Arrays.fill(variantIsNull, positionCount, positionCount + count, true); + hasNullVariant = true; + } + else { + Arrays.fill(variantIsNull, positionCount, positionCount + count, false); + hasNonNullVariant = true; + } + + positionCount += count; + + if (blockBuilderStatus != null) { + blockBuilderStatus.addBytes(VARIANT_ENTRY_SIZE * count); + } + } + + private static void appendRepeatedToField(Block fieldBlock, int position, int count, BlockBuilder fieldBlockBuilder) + { + switch (fieldBlock) { + case RunLengthEncodedBlock rleBlock -> fieldBlockBuilder.appendRepeated(rleBlock.getValue(), 0, count); + case DictionaryBlock dictionaryBlock -> fieldBlockBuilder.appendRepeated(dictionaryBlock.getDictionary(), dictionaryBlock.getId(position), count); + case ValueBlock valueBlock -> fieldBlockBuilder.appendRepeated(valueBlock, position, count); + } + } + + @Override + public void appendPositions(ValueBlock block, int[] positions, int offset, int length) + { + if (length == 0) { + return; + } + + VariantBlock variantBlock = (VariantBlock) block; + ensureCapacity(positionCount + length); + + Block rawMetadataBlock = variantBlock.getRawMetadata(); + Block rawValuesBlock = variantBlock.getRawValues(); + int startOffset = variantBlock.getOffsetBase(); + + if (startOffset == 0) { + appendPositionsToField(rawMetadataBlock, positions, offset, length, metadataBlockBuilder); + appendPositionsToField(rawValuesBlock, positions, offset, length, valuesBlockBuilder); + } + else { + int[] adjustedPositions = new int[length]; + for (int i = offset; i < offset + length; i++) { + adjustedPositions[i - offset] = startOffset + positions[i]; + } + + appendPositionsToField(rawMetadataBlock, adjustedPositions, 0, length, metadataBlockBuilder); + appendPositionsToField(rawValuesBlock, adjustedPositions, 0, length, valuesBlockBuilder); + } + + boolean[] rawVariantIsNull = variantBlock.getRawIsNull(); + if (rawVariantIsNull != null) { + for (int i = 0; i < length; i++) { + if (rawVariantIsNull[startOffset + positions[offset + i]]) { + variantIsNull[positionCount + i] = true; + hasNullVariant = true; + } + else { + variantIsNull[positionCount + i] = false; + hasNonNullVariant = true; + } + } + } + else { + Arrays.fill(variantIsNull, positionCount, positionCount + length, false); + hasNonNullVariant = true; + } + positionCount += length; + + if (blockBuilderStatus != null) { + blockBuilderStatus.addBytes(VARIANT_ENTRY_SIZE * length); + } + } + + private static void appendPositionsToField(Block fieldBlock, int[] positions, int offset, int length, BlockBuilder fieldBlockBuilder) + { + switch (fieldBlock) { + case RunLengthEncodedBlock rleBlock -> fieldBlockBuilder.appendRepeated(rleBlock.getValue(), 0, length); + case DictionaryBlock dictionaryBlock -> { + int[] newPositions = new int[length]; + for (int i = 0; i < newPositions.length; i++) { + newPositions[i] = dictionaryBlock.getId(positions[offset + i]); + } + fieldBlockBuilder.appendPositions(dictionaryBlock.getDictionary(), newPositions, 0, length); + } + case ValueBlock valueBlock -> fieldBlockBuilder.appendPositions(valueBlock, positions, offset, length); + } + } + + @Override + public BlockBuilder appendNull() + { + metadataBlockBuilder.appendNull(); + valuesBlockBuilder.appendNull(); + entryAdded(true); + return this; + } + + @Override + public void resetTo(int position) + { + checkIndex(position, positionCount + 1); + positionCount = position; + metadataBlockBuilder.resetTo(position); + valuesBlockBuilder.resetTo(position); + + if (position == 0) { + hasNullVariant = false; + hasNonNullVariant = false; + return; + } + + hasNullVariant = false; + hasNonNullVariant = false; + for (int index = 0; index < position; index++) { + hasNullVariant |= variantIsNull[index]; + hasNonNullVariant |= !variantIsNull[index]; + } + } + + private void entryAdded(boolean isNull) + { + ensureCapacity(positionCount + 1); + + variantIsNull[positionCount] = isNull; + hasNullVariant |= isNull; + hasNonNullVariant |= !isNull; + positionCount++; + + if (blockBuilderStatus != null) { + blockBuilderStatus.addBytes(VARIANT_ENTRY_SIZE); + } + } + + @Override + public Block build() + { + if (!hasNonNullVariant) { + return RunLengthEncodedBlock.create(NULL_VALUE_BLOCK, positionCount); + } + return buildValueBlock(); + } + + @Override + public VariantBlock buildValueBlock() + { + Block metadataBlock = metadataBlockBuilder.buildValueBlock(); + Block valuesBlock = valuesBlockBuilder.buildValueBlock(); + return VariantBlock.createInternal(0, positionCount, hasNullVariant ? variantIsNull : null, metadataBlock, valuesBlock); + } + + private void ensureCapacity(int capacity) + { + if (variantIsNull.length >= capacity) { + return; + } + + int newSize = BlockUtil.calculateNewArraySize(variantIsNull.length, capacity); + variantIsNull = Arrays.copyOf(variantIsNull, newSize); + } + + @Override + public String toString() + { + return "VariantBlockBuilder{metadataBlockBuilder=%s, valuesBlockBuilder=%s}".formatted(metadataBlockBuilder, valuesBlockBuilder); + } + + @Override + public BlockBuilder newBlockBuilderLike(int expectedEntries, BlockBuilderStatus blockBuilderStatus) + { + return new VariantBlockBuilder( + blockBuilderStatus, + (VariableWidthBlockBuilder) metadataBlockBuilder.newBlockBuilderLike(blockBuilderStatus), + (VariableWidthBlockBuilder) valuesBlockBuilder.newBlockBuilderLike(blockBuilderStatus), + new boolean[expectedEntries]); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/VariantBlockEncoding.java b/core/trino-spi/src/main/java/io/trino/spi/block/VariantBlockEncoding.java new file mode 100644 index 000000000000..8dbbe37e055a --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/block/VariantBlockEncoding.java @@ -0,0 +1,85 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.trino.spi.block; + +import io.airlift.slice.SliceInput; +import io.airlift.slice.SliceOutput; + +import static io.trino.spi.block.EncoderUtil.decodeNullBitsScalar; +import static io.trino.spi.block.EncoderUtil.decodeNullBitsVectorized; +import static io.trino.spi.block.EncoderUtil.encodeNullsAsBitsScalar; +import static io.trino.spi.block.EncoderUtil.encodeNullsAsBitsVectorized; + +public class VariantBlockEncoding + implements BlockEncoding +{ + public static final String NAME = "VARIANT"; + + private final boolean vectorizeNullBitPacking; + + public VariantBlockEncoding(boolean vectorizeNullBitPacking) + { + this.vectorizeNullBitPacking = vectorizeNullBitPacking; + } + + @Override + public String getName() + { + return NAME; + } + + @Override + public Class getBlockClass() + { + return VariantBlock.class; + } + + @Override + public void writeBlock(BlockEncodingSerde blockEncodingSerde, SliceOutput sliceOutput, Block block) + { + VariantBlock variantBlock = (VariantBlock) block; + + sliceOutput.appendInt(variantBlock.getPositionCount()); + + blockEncodingSerde.writeBlock(sliceOutput, variantBlock.getMetadata()); + blockEncodingSerde.writeBlock(sliceOutput, variantBlock.getValues()); + + if (vectorizeNullBitPacking) { + encodeNullsAsBitsVectorized(sliceOutput, variantBlock.getRawIsNull(), variantBlock.getOffsetBase(), variantBlock.getPositionCount()); + } + else { + encodeNullsAsBitsScalar(sliceOutput, variantBlock.getRawIsNull(), variantBlock.getOffsetBase(), variantBlock.getPositionCount()); + } + } + + @Override + public Block readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) + { + int positionCount = sliceInput.readInt(); + + Block metadataBlock = blockEncodingSerde.readBlock(sliceInput); + Block valuesBlock = blockEncodingSerde.readBlock(sliceInput); + + boolean[] variantIsNull; + if (vectorizeNullBitPacking) { + variantIsNull = decodeNullBitsVectorized(sliceInput, positionCount).orElse(null); + } + else { + variantIsNull = decodeNullBitsScalar(sliceInput, positionCount).orElse(null); + } + + return VariantBlock.createInternal(0, positionCount, variantIsNull, metadataBlock, valuesBlock); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/StandardTypes.java b/core/trino-spi/src/main/java/io/trino/spi/type/StandardTypes.java index 1bd67f37a158..1e68fcc034dc 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/StandardTypes.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/StandardTypes.java @@ -47,6 +47,7 @@ public final class StandardTypes public static final String JSON_2016 = "json2016"; public static final String IPADDRESS = "ipaddress"; public static final String UUID = "uuid"; + public static final String VARIANT = "variant"; public static final String GEOMETRY = "Geometry"; public static final String SPHERICAL_GEOGRAPHY = "SphericalGeography"; // SphericalGeographyType.NAME public static final String BING_TILE = "BingTile"; // BingTileType.NAME diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/VariantType.java b/core/trino-spi/src/main/java/io/trino/spi/type/VariantType.java new file mode 100644 index 000000000000..6fd4857920f8 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/type/VariantType.java @@ -0,0 +1,296 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.type; + +import io.airlift.slice.Slice; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.PageBuilderStatus; +import io.trino.spi.block.VariableWidthBlock; +import io.trino.spi.block.VariantBlock; +import io.trino.spi.block.VariantBlockBuilder; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.FlatFixed; +import io.trino.spi.function.FlatFixedOffset; +import io.trino.spi.function.FlatVariableOffset; +import io.trino.spi.function.FlatVariableWidth; +import io.trino.spi.function.ScalarOperator; +import io.trino.spi.variant.Header; +import io.trino.spi.variant.Metadata; +import io.trino.spi.variant.Variant; + +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; +import java.nio.ByteOrder; + +import static io.airlift.slice.Slices.wrappedBuffer; +import static io.trino.spi.function.OperatorType.EQUAL; +import static io.trino.spi.function.OperatorType.READ_VALUE; +import static io.trino.spi.function.OperatorType.XX_HASH_64; +import static io.trino.spi.type.TypeOperatorDeclaration.extractOperatorDeclaration; +import static io.trino.spi.variant.Metadata.EMPTY_METADATA; +import static io.trino.spi.variant.Metadata.EMPTY_METADATA_SLICE; +import static java.lang.Math.abs; +import static java.lang.Math.min; +import static java.lang.Math.toIntExact; +import static java.lang.invoke.MethodHandles.lookup; + +public class VariantType + extends AbstractType + implements VariableWidthType +{ + public static final String NAME = "variant"; + public static final VariantType VARIANT = new VariantType(); + + private static final VarHandle INT_HANDLE = MethodHandles.byteArrayViewVarHandle(int[].class, ByteOrder.LITTLE_ENDIAN); + + private static final int EXPECTED_BYTES_PER_ENTRY = 32; + private static final TypeOperatorDeclaration DEFAULT_READ_OPERATORS = extractOperatorDeclaration(VariantType.class, lookup(), Variant.class); + + private VariantType() + { + super(new TypeSignature(StandardTypes.VARIANT), Variant.class, VariantBlock.class); + } + + @Override + public String getDisplayName() + { + return NAME; + } + + @Override + public VariantBlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) + { + int maxBlockSizeInBytes; + if (blockBuilderStatus == null) { + maxBlockSizeInBytes = PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES; + } + else { + maxBlockSizeInBytes = blockBuilderStatus.getMaxPageSizeInBytes(); + } + + // it is guaranteed Math.min will not overflow; safe to cast + int expectedBytes = (int) min((long) expectedEntries * expectedBytesPerEntry, maxBlockSizeInBytes); + return new VariantBlockBuilder( + blockBuilderStatus, + expectedBytesPerEntry == 0 ? expectedEntries : min(expectedEntries, maxBlockSizeInBytes / expectedBytesPerEntry), + expectedBytes); + } + + @Override + public VariantBlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries) + { + return createBlockBuilder(blockBuilderStatus, expectedEntries, EXPECTED_BYTES_PER_ENTRY); + } + + @Override + public boolean isComparable() + { + return true; + } + + @Override + public TypeOperatorDeclaration getTypeOperatorDeclaration(TypeOperators typeOperators) + { + return DEFAULT_READ_OPERATORS; + } + + @Override + public Variant getObject(Block block, int position) + { + VariantBlock valueBlock = (VariantBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getVariant(valuePosition); + } + + @Override + public void writeObject(BlockBuilder blockBuilder, Object value) + { + ((VariantBlockBuilder) blockBuilder).writeEntry((Variant) value); + } + + @Override + public Variant getObjectValue(Block block, int position) + { + if (block.isNull(position)) { + return null; + } + VariantBlock valueBlock = (VariantBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getVariant(valuePosition); + } + + // There are two layouts, one for primitive values and one for nested value that may need metadata. + // fixed: + // total variable length (4 bytes) high bit indicates presence of metadata + // variable: + // metadata length (4 bytes) (if present) + // metadata bytes (metadata length bytes) (if present) + // value bytes (total length - metadata length bytes) + + @Override + public int getFlatFixedSize() + { + return Integer.BYTES; + } + + @Override + public boolean isFlatVariableWidth() + { + return true; + } + + @Override + public int getFlatVariableWidthSize(Block block, int position) + { + VariantBlock variantBlock = (VariantBlock) block.getUnderlyingValueBlock(); + int rawPosition = block.getUnderlyingValuePosition(position) + variantBlock.getRawOffset(); + if (variantBlock.getBasicType(rawPosition).isContainer()) { + long length = Integer.BYTES; + length += getSliceLength(variantBlock.getRawMetadata(), rawPosition); + length += getSliceLength(variantBlock.getValues(), rawPosition); + return toIntExact(length); + } + return getSliceLength(variantBlock.getValues(), rawPosition); + } + + private static int getSliceLength(Block nestedBlock, int position) + { + VariableWidthBlock variableWidthBlock = (VariableWidthBlock) nestedBlock.getUnderlyingValueBlock(); + return variableWidthBlock.getSliceLength(nestedBlock.getUnderlyingValuePosition(position)); + } + + @Override + public int getFlatVariableWidthLength(byte[] fixedSizeSlice, int fixedSizeOffset) + { + return abs((int) INT_HANDLE.get(fixedSizeSlice, fixedSizeOffset)); + } + + @ScalarOperator(READ_VALUE) + private static Variant readFlatToStack( + @FlatFixed byte[] fixedSizeSlice, + @FlatFixedOffset int fixedSizeOffset, + @FlatVariableWidth byte[] variableSizeSlice, + @FlatVariableOffset int variableSizeOffset) + { + int fixedValue = (int) INT_HANDLE.get(fixedSizeSlice, fixedSizeOffset); + if (fixedValue > 0) { + return Variant.from(EMPTY_METADATA, wrappedBuffer(variableSizeSlice, variableSizeOffset, fixedValue)); + } + + int metadataLength = (int) INT_HANDLE.get(variableSizeSlice, variableSizeOffset); + Slice metadataSlice = wrappedBuffer(variableSizeSlice, variableSizeOffset + Integer.BYTES, metadataLength); + Metadata metadata = Metadata.from(metadataSlice); + + Slice valueSlice = wrappedBuffer( + variableSizeSlice, + variableSizeOffset + Integer.BYTES + metadataLength, + abs(fixedValue) - Integer.BYTES - metadataLength); + + return Variant.from(metadata, valueSlice); + } + + @ScalarOperator(READ_VALUE) + private static void readFlatToBlock( + @FlatFixed byte[] fixedSizeSlice, + @FlatFixedOffset int fixedSizeOffset, + @FlatVariableWidth byte[] variableSizeSlice, + @FlatVariableOffset int variableSizeOffset, + BlockBuilder blockBuilder) + { + int fixedValue = (int) INT_HANDLE.get(fixedSizeSlice, fixedSizeOffset); + if (fixedValue > 0) { + ((VariantBlockBuilder) blockBuilder).writeEntry( + EMPTY_METADATA_SLICE, + wrappedBuffer(variableSizeSlice, variableSizeOffset, fixedValue)); + return; + } + + int metadataLength = (int) INT_HANDLE.get(variableSizeSlice, variableSizeOffset); + Slice metadataSlice = wrappedBuffer(variableSizeSlice, variableSizeOffset + Integer.BYTES, metadataLength); + + Slice valueSlice = wrappedBuffer( + variableSizeSlice, + variableSizeOffset + Integer.BYTES + metadataLength, + abs(fixedValue) - Integer.BYTES - metadataLength); + + ((VariantBlockBuilder) blockBuilder).writeEntry(metadataSlice, valueSlice); + } + + @ScalarOperator(READ_VALUE) + private static void writeFlatFromStack( + Variant value, + @FlatFixed byte[] fixedSizeSlice, + @FlatFixedOffset int fixedSizeOffset, + @FlatVariableWidth byte[] variableSizeSlice, + @FlatVariableOffset int variableSizeOffset) + { + Metadata metadata = value.metadata(); + Slice metadataSlice; + metadataSlice = metadata == EMPTY_METADATA ? null : metadata.toSlice(); + Slice data = value.data(); + writeFlat(metadataSlice, data, fixedSizeSlice, fixedSizeOffset, variableSizeSlice, variableSizeOffset); + } + + @ScalarOperator(READ_VALUE) + private static void writeFlatFromBlock( + @BlockPosition VariantBlock block, + @BlockIndex int position, + @FlatFixed byte[] fixedSizeSlice, + @FlatFixedOffset int fixedSizeOffset, + @FlatVariableWidth byte[] variableSizeSlice, + @FlatVariableOffset int variableSizeOffset) + { + int rawPosition = position + block.getRawOffset(); + Slice metadataSlice = ((VariableWidthBlock) block.getRawMetadata().getUnderlyingValueBlock()) + .getSlice(block.getRawMetadata().getUnderlyingValuePosition(rawPosition)); + Slice data = ((VariableWidthBlock) block.getRawValues().getUnderlyingValueBlock()) + .getSlice(block.getRawValues().getUnderlyingValuePosition(rawPosition)); + writeFlat(metadataSlice, data, fixedSizeSlice, fixedSizeOffset, variableSizeSlice, variableSizeOffset); + } + + private static void writeFlat(Slice metadataSlice, Slice data, byte[] fixedSizeSlice, int fixedSizeOffset, byte[] variableSizeSlice, int variableSizeOffset) + { + if (metadataSlice == null || !Header.getBasicType(data.getByte(0)).isContainer()) { + int length = data.length(); + INT_HANDLE.set(fixedSizeSlice, fixedSizeOffset, length); + data.getBytes(0, variableSizeSlice, variableSizeOffset, length); + return; + } + + int fixedValue = -(Integer.BYTES + metadataSlice.length() + data.length()); + + // fixed part is negative total length + INT_HANDLE.set(fixedSizeSlice, fixedSizeOffset, fixedValue); + + // variable part is: metadata length, metadata, data + INT_HANDLE.set(variableSizeSlice, variableSizeOffset, metadataSlice.length()); + metadataSlice.getBytes(0, variableSizeSlice, variableSizeOffset + Integer.BYTES, metadataSlice.length()); + data.getBytes(0, variableSizeSlice, variableSizeOffset + Integer.BYTES + metadataSlice.length(), data.length()); + } + + @ScalarOperator(EQUAL) + private static boolean equalOperator(Variant left, Variant right) + { + return left.equals(right); + } + + @ScalarOperator(XX_HASH_64) + private static long xxHash64Operator(Variant value) + { + return value.longHashCode(); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/variant/Header.java b/core/trino-spi/src/main/java/io/trino/spi/variant/Header.java new file mode 100644 index 000000000000..57ef4b03db79 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/variant/Header.java @@ -0,0 +1,209 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.variant; + +import static io.trino.spi.variant.Header.BasicType.PRIMITIVE; +import static io.trino.spi.variant.Header.BasicType.SHORT_STRING; +import static io.trino.spi.variant.VariantUtils.checkArgument; + +public final class Header +{ + public static final int SHORT_STRING_MAX_LENGTH = 63; + public static final byte VERSION = 0x01; + + private Header() {} + + public static BasicType getBasicType(byte header) + { + return BasicType.fromHeader(header); + } + + public static PrimitiveType getPrimitiveType(byte header) + { + return PrimitiveType.fromHeader(header); + } + + public static int shortStringLength(byte header) + { + // Bits 2-7 represent the length of the short string + return (header & 0b1111_1100) >>> 2; + } + + /// The number of bytes used to encode the field offsets. + /// The value is between 1 and 4. + @SuppressWarnings("JavaExistingMethodCanBeUsed") + public static int objectFieldOffsetSize(byte header) + { + // Bits 2-3 represent the size of the object field offsets + int sizeBits = (header & 0b0000_1100) >>> 2; + return sizeBits + 1; + } + + /// The number of bytes used to encode the field ids + /// The value is between 1 and 4. + public static int objectFieldIdSize(byte header) + { + // Bits 4-5 represent the size of the object field IDs + int sizeBits = (header & 0b0011_0000) >>> 4; + return sizeBits + 1; + } + + // If true, 4 bytes are used to encode the field count; otherwise, 1 byte is used. + public static boolean objectIsLarge(byte header) + { + // Bit 6 represents whether the object is large + return (header & 0b0100_0000) != 0; + } + + /// The number of bytes used to encode the array field offsets. + public static int arrayFieldOffsetSize(byte header) + { + // Bits 2-3 represent the size of the array field offsets + int sizeBits = (header & 0b0000_1100) >>> 2; + return sizeBits + 1; + } + + /// If true, 4 bytes are used to encode the element count; otherwise, 1 byte is used. + public static boolean arrayIsLarge(byte header) + { + // Bit 4 represents whether the array is large + return (header & 0b0001_0000) != 0; + } + + public static byte primitiveHeader(PrimitiveType primitiveType) + { + return (byte) (PRIMITIVE.ordinal() | primitiveType.ordinal() << 2); + } + + public static byte shortStringHeader(int length) + { + checkArgument(length >= 0 && length <= SHORT_STRING_MAX_LENGTH, () -> "Short string length must be between 0 and %s: %s".formatted(SHORT_STRING_MAX_LENGTH, length)); + return (byte) (SHORT_STRING.ordinal() | (length << 2)); + } + + public static byte objectHeader(int fieldIdSize, int fieldOffsetSize, boolean isLarge) + { + // Bits 0-1 represent the basic type + int header = BasicType.OBJECT.ordinal(); + // Bits 2-3 represent the size of the field offsets + checkArgument(fieldOffsetSize >= 1 && fieldOffsetSize <= 4, () -> "fieldOffsetSize must be between 1 and 4: %s".formatted(fieldOffsetSize)); + header |= (fieldOffsetSize - 1) << 2; + // Bits 4-5 represent the size of the field IDs + checkArgument(fieldIdSize >= 1 && fieldIdSize <= 4, () -> "fieldIdSize must be between 1 and 4: %s".formatted(fieldIdSize)); + header |= (fieldIdSize - 1) << 4; + // Bit 6 represents whether the object is large + if (isLarge) { + header |= 0b0100_0000; + } + return (byte) header; + } + + public static byte arrayHeader(int fieldOffsetSize, boolean isLarge) + { + // Bits 0-1 represent the basic type + int header = BasicType.ARRAY.ordinal(); + // Bits 2-3 represent the size of the field offsets + checkArgument(fieldOffsetSize >= 1 && fieldOffsetSize <= 4, () -> "fieldOffsetSize must be between 1 and 4: %s".formatted(fieldOffsetSize)); + header |= (fieldOffsetSize - 1) << 2; + // Bit 4 represents whether the array is large + if (isLarge) { + header |= 0b0001_0000; + } + return (byte) header; + } + + public static int metadataVersion(byte header) + { + // Bits 0-3: version + return header & 0b0000_1111; + } + + public static boolean metadataIsSorted(byte header) + { + // Bit 4: sorted + return (header & 0b0001_0000) != 0; + } + + public static int metadataOffsetSize(byte header) + { + // Bits 6-7: offset size + int sizeBits = (header & 0b1100_0000) >>> 6; + return sizeBits + 1; + } + + public static byte metadataHeader(boolean sorted, int offsetSize) + { + // Bits 0-3: version + int header = VERSION; + // Bit 4: sorted + if (sorted) { + header |= 0b0001_0000; + } + // Bit 6-7: offset size + checkArgument(offsetSize >= 1 && offsetSize <= 4, () -> "offsetSize must be between 1 and 4: %s".formatted(offsetSize)); + header |= (offsetSize - 1) << 6; + return (byte) header; + } + + public enum BasicType + { + PRIMITIVE, + SHORT_STRING, + OBJECT, + ARRAY; + + public boolean isContainer() + { + return this == OBJECT || this == ARRAY; + } + + private static BasicType fromHeader(byte header) + { + int basicTypeBits = (header & 0b0000_0011); + return values()[basicTypeBits]; + } + } + + public enum PrimitiveType + { + NULL, + BOOLEAN_TRUE, + BOOLEAN_FALSE, + INT8, + INT16, + INT32, + INT64, + DOUBLE, + DECIMAL4, + DECIMAL8, + DECIMAL16, + DATE, + TIMESTAMP_UTC_MICROS, + TIMESTAMP_NTZ_MICROS, + FLOAT, + BINARY, + STRING, + TIME_NTZ_MICROS, + TIMESTAMP_UTC_NANOS, + TIMESTAMP_NTZ_NANOS, + UUID; + + private static PrimitiveType fromHeader(byte header) + { + // Bits 2-7 represent the primitive type + int primitiveTypeBits = (header & 0b1111_1100) >>> 2; + return values()[primitiveTypeBits]; + } + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/variant/Int2IntOpenHashMap.java b/core/trino-spi/src/main/java/io/trino/spi/variant/Int2IntOpenHashMap.java new file mode 100644 index 000000000000..19593c6df500 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/variant/Int2IntOpenHashMap.java @@ -0,0 +1,286 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.variant; + +// Note: this code was forked from fastutil (http://fastutil.di.unimi.it/) Int2IntOpenHashMap +// and mimics that code style. +// Copyright (C) 2002-2019 Sebastiano Vigna +@SuppressWarnings("WhileCanBeDoWhile") +class Int2IntOpenHashMap +{ + public static final int DEFAULT_RETURN_VALUE = -1; + + /** + * 232 · φ, φ = (√5 − 1)/2. + */ + private static final int INT_PHI = 0x9E3779B9; + /** + * The default load factor of a hash table. + */ + private static final float DEFAULT_LOAD_FACTOR = 0.75f; + /** + * The array of keys. + */ + protected int[] key; + /** + * The array of values. + */ + protected int[] value; + /** + * The mask for wrapping a position counter. + */ + protected int mask; + /** + * Whether this map contains the key zero. + */ + private boolean containsNullKey; + /** + * The current table size. + */ + protected int n; + /** + * Threshold after which we rehash. It must be the table size times {@link #f}. + */ + protected int maxFill; + /** + * Number of entries in the set (including the key zero, if present). + */ + protected int size; + /** + * The acceptable load factor. + */ + protected final float f; + + public Int2IntOpenHashMap(final int expected) + { + this(expected, DEFAULT_LOAD_FACTOR); + } + + /** + * Creates a new hash map. + * + *

+ * The actual table size will be the least power of two greater than + * {@code expected}/{@code f}. + * + * @param expected the expected number of elements in the hash map. + * @param f the load factor. + */ + + private Int2IntOpenHashMap(final int expected, final float f) + { + if (f <= 0 || f > 1) { + throw new IllegalArgumentException("Load factor must be greater than 0 and smaller than or equal to 1"); + } + if (expected < 0) { + throw new IllegalArgumentException("The expected number of elements must be nonnegative"); + } + this.f = f; + n = arraySize(expected, f); + mask = n - 1; + maxFill = maxFill(n, f); + key = new int[n + 1]; + value = new int[n + 1]; + } + + public int putIfAbsent(final int k, final int v) + { + final int pos = find(k); + if (pos >= 0) { + return value[pos]; + } + insert(-pos - 1, k, v); + return DEFAULT_RETURN_VALUE; + } + + public int get(final int k) + { + if (k == 0) { + return containsNullKey ? value[n] : DEFAULT_RETURN_VALUE; + } + final int[] key = this.key; + // The starting point. + int pos = mix(k) & mask; + int curr = key[pos]; + if (curr == 0) { + return DEFAULT_RETURN_VALUE; + } + if (k == curr) { + return value[pos]; + } + // There's always an unused entry. + while (true) { + pos = (pos + 1) & mask; + curr = key[pos]; + if (curr == 0) { + return DEFAULT_RETURN_VALUE; + } + if (k == curr) { + return value[pos]; + } + } + } + + public boolean containsKey(final int k) + { + if (k == 0) { + return containsNullKey; + } + final int[] key = this.key; + // The starting point. + int pos = mix(k) & mask; + int curr = key[pos]; + if (curr == 0) { + return false; + } + if (k == curr) { + return true; + } + // There's always an unused entry. + while (true) { + pos = (pos + 1) & mask; + curr = key[pos]; + if (curr == 0) { + return false; + } + if (k == curr) { + return true; + } + } + } + + private void insert(final int pos, final int k, final int v) + { + if (pos == n) { + containsNullKey = true; + } + key[pos] = k; + value[pos] = v; + if (size++ >= maxFill) { + rehash(arraySize(size + 1, f)); + } + } + + private int find(final int k) + { + if (k == 0) { + return containsNullKey ? n : -(n + 1); + } + final int[] key = this.key; + int pos = mix(k) & mask; + int curr = key[pos]; + // The starting point. + if (curr == 0) { + return -(pos + 1); + } + if (k == curr) { + return pos; + } + // There's always an unused entry. + while (true) { + pos = (pos + 1) & mask; + curr = key[pos]; + if (curr == 0) { + return -(pos + 1); + } + if (k == curr) { + return pos; + } + } + } + + /** + * Rehashes the map. + * + *

+ * This method implements the basic rehashing strategy, and may be overridden by + * subclasses implementing different rehashing strategies (e.g., disk-based + * rehashing). However, you should not override this method unless you + * understand the internal workings of this class. + * + * @param newN the new size + */ + private void rehash(final int newN) + { + final int[] key = this.key; + final int[] value = this.value; + final int mask = newN - 1; // Note that this is used by the hashing macro + final int[] newKey = new int[newN + 1]; + final int[] newValue = new int[newN + 1]; + int i = n; + int pos; + for (int j = realSize(); j-- != 0; ) { + --i; + while (key[i] == 0) { + --i; + } + pos = mix(key[i]) & mask; + if (!(newKey[pos] == 0)) { + pos = (pos + 1) & mask; + while (!(newKey[pos] == 0)) { + pos = (pos + 1) & mask; + } + } + newKey[pos] = key[i]; + newValue[pos] = value[i]; + } + newValue[newN] = value[n]; + n = newN; + this.mask = mask; + maxFill = maxFill(n, f); + this.key = newKey; + this.value = newValue; + } + + private int realSize() + { + return containsNullKey ? size - 1 : size; + } + + private static int mix(final int x) + { + final int h = x * INT_PHI; + return h ^ (h >>> 16); + } + + private static int maxFill(final int n, final float f) + { + /* We must guarantee that there is always at least + * one free entry (even with pathological load factors). */ + return Math.min((int) Math.ceil(n * f), n - 1); + } + + private static int arraySize(final int expected, final float f) + { + final long s = Math.max(2, nextPowerOfTwo((long) Math.ceil(expected / f))); + if (s > (1 << 30)) { + throw new IllegalArgumentException("Too large (" + expected + " expected elements with load factor " + f + ")"); + } + return (int) s; + } + + private static long nextPowerOfTwo(long x) + { + if (x == 0) { + return 1; + } + x--; + x |= x >> 1; + x |= x >> 2; + x |= x >> 4; + x |= x >> 8; + x |= x >> 16; + return (x | x >> 32) + 1; + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/variant/Metadata.java b/core/trino-spi/src/main/java/io/trino/spi/variant/Metadata.java new file mode 100644 index 000000000000..ed610c49d2c1 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/variant/Metadata.java @@ -0,0 +1,438 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.variant; + +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.IntUnaryOperator; + +import static io.trino.spi.variant.Header.metadataHeader; +import static io.trino.spi.variant.Header.metadataIsSorted; +import static io.trino.spi.variant.Header.metadataOffsetSize; +import static io.trino.spi.variant.Header.metadataVersion; +import static io.trino.spi.variant.VariantUtils.checkArgument; +import static io.trino.spi.variant.VariantUtils.checkState; +import static io.trino.spi.variant.VariantUtils.getOffsetSize; +import static io.trino.spi.variant.VariantUtils.readOffset; +import static io.trino.spi.variant.VariantUtils.writeOffset; +import static java.util.Objects.checkIndex; +import static java.util.Objects.requireNonNull; + +public final class Metadata +{ + public static final Slice EMPTY_METADATA_SLICE = Slices.wrappedBuffer(metadataHeader(false, 1), (byte) 0, (byte) 0); + public static final Metadata EMPTY_METADATA = new Metadata(EMPTY_METADATA_SLICE, false, 0, 1); + + private final Slice metadata; + private final boolean sorted; + private final int dictionarySize; + private final int offsetSize; + + private Metadata(Slice metadata, boolean sorted, int dictionarySize, int offsetSize) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + this.sorted = sorted; + checkArgument(dictionarySize >= 0, "dictionarySize is negative"); + checkArgument(offsetSize >= 1 && offsetSize <= 4, () -> "offsetSize is out of range: " + offsetSize); + this.dictionarySize = dictionarySize; + this.offsetSize = offsetSize; + } + + public static Metadata from(Slice metadata) + { + if (metadata == EMPTY_METADATA_SLICE) { + return EMPTY_METADATA; + } + + // Quick validations + checkArgument(metadata.length() >= 3, "metadata is empty"); + + // Basic header/version check + byte header = metadata.getByte(0); + int version = metadataVersion(header); + checkArgument(version == Header.VERSION, () -> "Unsupported metadata version: " + version); + + // Minimal structural check: + // - we can read dictionarySize + // - the implied dictionary region is within the slice bounds + int offsetSize = metadataOffsetSize(header); + checkArgument(metadata.length() >= 1 + offsetSize, "metadata is too short for dictionary size"); + + int dictionarySize = readOffset(metadata, 1, offsetSize); + checkArgument(dictionarySize >= 0, "Negative dictionary size: " + dictionarySize); + + // compute dictionaryStart = offsetsBase + (dictionarySize + 1) * offsetSize + long dictionaryStart = 1 + (long) offsetSize + (long) (dictionarySize + 1) * offsetSize; + checkArgument(dictionaryStart <= metadata.length(), "metadata is too short for dictionary offsets"); + + // At this point: + // - header is valid + // - version matches + // - offsets array and dictionary region fit in the slice + // + // NOT validated here (but validated when created): + // - offsets[0] == 0 + // - offsets are non-decreasing + // - last offset == dictionary length + + if (dictionarySize == 0) { + metadata = EMPTY_METADATA_SLICE; + } + return new Metadata(metadata, metadataIsSorted(header), dictionarySize, offsetSize); + } + + public static Metadata of(Collection fieldNames) + { + requireNonNull(fieldNames, "fieldNames is null"); + Set distinctFieldNames = new HashSet<>(fieldNames.size()); + for (Slice fieldName : fieldNames) { + requireNonNull(fieldName, "fieldName is null"); + checkArgument(fieldName.length() > 0, "empty field names are not allowed"); + checkArgument(distinctFieldNames.add(fieldName), () -> "duplicate field name: " + fieldName.toStringUtf8()); + } + return createMetadata(fieldNames); + } + + private static Metadata createMetadata(Collection fieldNames) + { + if (fieldNames.isEmpty()) { + return EMPTY_METADATA; + } + + int dictionarySize = fieldNames.size(); + + // Compute total dictionary length + int dictionaryLength = 0; + for (Slice fieldName : fieldNames) { + dictionaryLength += fieldName.length(); + } + + boolean sorted = VariantUtils.isSorted(fieldNames); + + int offsetSize = getOffsetSize(dictionaryLength); + + // Layout: + // [ header(1) ] + // [ dictionarySize (offsetSize bytes) ] + // [ (dictionarySize + 1) offsets (each offsetSize bytes) ] + // [ dictionary bytes (dictionaryLength bytes) ] + int headerAndSizeBytes = 1 + offsetSize; + int offsetsBytes = (dictionarySize + 1) * offsetSize; + int dictionaryStart = headerAndSizeBytes + offsetsBytes; + int totalSize = dictionaryStart + dictionaryLength; + + Slice metadata = Slices.allocate(totalSize); + int position = 0; + + // Header + metadata.setByte(position, metadataHeader(sorted, offsetSize)); + position += 1; + + // Dictionary size + writeOffset(metadata, position, dictionarySize, offsetSize); + position += offsetSize; + + // Dictionary offsets (relative to start of dictionary data) + int currentOffset = 0; + + // The first offset is always 0 + writeOffset(metadata, position, currentOffset, offsetSize); + position += offsetSize; + + // Subsequent offsets are cumulative lengths + for (Slice fieldName : fieldNames) { + currentOffset += fieldName.length(); + writeOffset(metadata, position, currentOffset, offsetSize); + position += offsetSize; + } + + // Dictionary data: write directly into the final metadata slice + int dictionaryPosition = dictionaryStart; + for (Slice fieldName : fieldNames) { + metadata.setBytes(dictionaryPosition, fieldName); + dictionaryPosition += fieldName.length(); + } + + return new Metadata(metadata, sorted, dictionarySize, offsetSize); + } + + public boolean isEmpty() + { + return metadata == EMPTY_METADATA_SLICE; + } + + public boolean isSorted() + { + return sorted; + } + + /// Returns the ID for a {@code name} in the dictionary, or -1 if not present. + public int id(Slice name) + { + requireNonNull(name, "name is null"); + if (dictionarySize == 0) { + return -1; + } + + int offsetsBase = 1 + offsetSize; + int dictionaryStart = offsetsBase + (dictionarySize + 1) * offsetSize; + + if (sorted && dictionarySize >= VariantUtils.BINARY_SEARCH_THRESHOLD) { + return binarySearchIds(offsetsBase, offsetSize, dictionaryStart, dictionarySize, name); + } + return linearSearch(offsetsBase, offsetSize, dictionaryStart, dictionarySize, name); + } + + private int linearSearch( + int offsetsBase, + int offsetSize, + int dictionaryStart, + int dictionarySize, + Slice name) + { + int position = offsetsBase; + + // First offset (must be 0, already validated when created/loaded) + int start = readOffset(metadata, position, offsetSize); + position += offsetSize; + + for (int id = 0; id < dictionarySize; id++) { + int end = readOffset(metadata, position, offsetSize); + if (metadata.equals(dictionaryStart + start, end - start, name, 0, name.length())) { + return id; + } + + start = end; + position += offsetSize; + } + + return -1; + } + + private int binarySearchIds( + int offsetsBase, + int offsetSize, + int dictionaryStart, + int dictionarySize, + Slice target) + { + int low = 0; + int high = dictionarySize - 1; + + while (low <= high) { + int mid = (low + high) >>> 1; + + int midOffsetPos = offsetsBase + (mid * offsetSize); + int start = readOffset(metadata, midOffsetPos, offsetSize); + int end = readOffset(metadata, midOffsetPos + offsetSize, offsetSize); + + int compare = metadata.compareTo(dictionaryStart + start, end - start, target, 0, target.length()); + if (compare < 0) { + low = mid + 1; + } + else if (compare > 0) { + high = mid - 1; + } + else { + return mid; + } + } + + return -1; + } + + /// Returns the field name for an ID in metadata. + /// + /// @throws IndexOutOfBoundsException if the id is out of range + public Slice get(int id) + { + checkIndex(id, dictionarySize); + + int offsetsBase = 1 + offsetSize; + + int offsetPosition = offsetsBase + (id * offsetSize); + int start = readOffset(metadata, offsetPosition, offsetSize); + int end = readOffset(metadata, offsetPosition + offsetSize, offsetSize); + int length = end - start; + + // Offsets are relative to the start of the dictionary region + int dictionaryStart = offsetsBase + (dictionarySize + 1) * offsetSize; + return metadata.slice(dictionaryStart + start, length); + } + + /// Returns the size of the metadata dictionary. + public int dictionarySize() + { + return dictionarySize; + } + + public void validateFully() + { + int position = 1 + offsetSize; + int previous = -1; + for (int i = 0; i < dictionarySize + 1; i++) { + int value = readOffset(metadata, position, offsetSize); + checkArgument(i != 0 || value == 0, "First dictionary offset must be 0"); + checkArgument(i == 0 || value > previous, "dictionary offsets must be strictly increasing"); + previous = value; + position += offsetSize; + } + + int dictionaryStart = 1 + offsetSize + (dictionarySize + 1) * offsetSize; + int dictionaryLength = metadata.length() - dictionaryStart; + checkArgument(previous == dictionaryLength, "Last dictionary offset must equal dictionary length"); + } + + public Slice toSlice() + { + return metadata; + } + + @Override + public boolean equals(Object other) + { + if (!(other instanceof Metadata that)) { + return false; + } + return metadata.equals(that.metadata); + } + + @Override + public int hashCode() + { + return metadata.hashCode(); + } + + @Override + public String toString() + { + return "Metadata[dictionarySize=%d, sorted=%s]".formatted(dictionarySize, sorted); + } + + public static Builder builder() + { + return new Builder(); + } + + public static final class Builder + { + private Map nameToId = new HashMap<>(); + private List names = new ArrayList<>(); + + private Builder() {} + + /// Add a single field name and return the provisional fieldId. + /// Final field-id must be resolved using the remap index returned by build(). + public int addFieldName(Slice fieldName) + { + checkNotBuilt(); + requireNonNull(fieldName, "fieldName is null"); + Integer existing = nameToId.get(fieldName); + if (existing != null) { + return existing; + } + + int index = names.size(); + names.add(fieldName); + nameToId.put(fieldName, index); + return index; + } + + /// Add multiple field names and return the provisional fieldIds. + /// Final field-ids must be resolved using the remap index returned by build(). + public int[] addFieldNames(List fieldNames) + { + checkNotBuilt(); + int[] provisionalFieldIds = new int[fieldNames.size()]; + for (int oldId = 0; oldId < fieldNames.size(); oldId++) { + // final field-id may differ from this temporary index, + // but the remap is in terms of final ids; we fix it in build(). + provisionalFieldIds[oldId] = addFieldName(fieldNames.get(oldId)); + } + return provisionalFieldIds; + } + + public int dictionarySize() + { + checkNotBuilt(); + return names.size(); + } + + /// Build final Metadata: + /// - Field names sorted by UTF-8 + /// - Field-ids are 0..N-1 in that sorted order + /// + /// Also returns an optional remap from "builder index" -> final field-id, + /// so you can fix any remap arrays that were filled with builder indices. + public SortedMetadata buildSorted() + { + checkNotBuilt(); + if (names.isEmpty()) { + Metadata empty = EMPTY_METADATA; + return new SortedMetadata(empty, index -> { + throw new IndexOutOfBoundsException("Metadata is empty"); + }); + } + + // Build array of indices [0..n-1] and sort by UTF-8 name + Integer[] order = new Integer[names.size()]; + for (int i = 0; i < names.size(); i++) { + order[i] = i; + } + Arrays.sort(order, Comparator.comparing(names::get, Slice::compareTo)); + + // builderIndex -> finalFieldId + int[] builderIndexToFieldId = new int[names.size()]; + + // Build final names in sorted order + List sortedNames = new ArrayList<>(names.size()); + for (int fieldId = 0; fieldId < names.size(); fieldId++) { + int builderIndex = order[fieldId]; + sortedNames.add(names.get(builderIndex)); + builderIndexToFieldId[builderIndex] = fieldId; + } + + // Mark builder as built + nameToId = null; + names = null; + + Metadata metadata = createMetadata(sortedNames); + return new SortedMetadata(metadata, index -> builderIndexToFieldId[index]); + } + + public record SortedMetadata(Metadata metadata, IntUnaryOperator sortedFieldIdMapping) + { + public SortedMetadata(Metadata metadata, IntUnaryOperator sortedFieldIdMapping) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + this.sortedFieldIdMapping = requireNonNull(sortedFieldIdMapping, "sortedFieldIdMapping is null"); + } + } + + private void checkNotBuilt() + { + checkState(nameToId != null && names != null, "Builder has already been built"); + } + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/variant/ObjectFieldIdValue.java b/core/trino-spi/src/main/java/io/trino/spi/variant/ObjectFieldIdValue.java new file mode 100644 index 000000000000..9ac535d98626 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/variant/ObjectFieldIdValue.java @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.variant; + +import static io.trino.spi.variant.VariantUtils.checkArgument; +import static java.util.Objects.requireNonNull; + +public record ObjectFieldIdValue(int fieldId, Variant value) +{ + public ObjectFieldIdValue + { + checkArgument(fieldId >= 0, "fieldId must be non-negative"); + requireNonNull(value, "value is null"); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/variant/Variant.java b/core/trino-spi/src/main/java/io/trino/spi/variant/Variant.java new file mode 100644 index 000000000000..ed72e22a0fee --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/variant/Variant.java @@ -0,0 +1,1160 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.variant; + +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.trino.spi.variant.Header.BasicType; +import io.trino.spi.variant.Header.PrimitiveType; +import io.trino.spi.variant.Metadata.Builder.SortedMetadata; +import io.trino.spi.variant.VariantDecoder.ObjectLayout; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.ZoneOffset; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.UUID; +import java.util.function.IntUnaryOperator; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +import static io.airlift.slice.Slices.utf8Slice; +import static io.trino.spi.variant.Header.BasicType.OBJECT; +import static io.trino.spi.variant.Header.BasicType.PRIMITIVE; +import static io.trino.spi.variant.Header.arrayFieldOffsetSize; +import static io.trino.spi.variant.Header.arrayIsLarge; +import static io.trino.spi.variant.Header.getBasicType; +import static io.trino.spi.variant.Header.getPrimitiveType; +import static io.trino.spi.variant.Header.objectFieldIdSize; +import static io.trino.spi.variant.Header.objectFieldOffsetSize; +import static io.trino.spi.variant.Header.objectIsLarge; +import static io.trino.spi.variant.Header.primitiveHeader; +import static io.trino.spi.variant.Header.shortStringLength; +import static io.trino.spi.variant.Metadata.EMPTY_METADATA; +import static io.trino.spi.variant.VariantEncoder.ENCODED_BOOLEAN_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_BYTE_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_DATE_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_DOUBLE_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_EMPTY_OBJECT_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_FLOAT_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_INT_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_LONG_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_NULL_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_SHORT_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_TIMESTAMP_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_TIME_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_UUID_SIZE; +import static io.trino.spi.variant.VariantEncoder.encodeArrayHeading; +import static io.trino.spi.variant.VariantEncoder.encodeBinary; +import static io.trino.spi.variant.VariantEncoder.encodeBoolean; +import static io.trino.spi.variant.VariantEncoder.encodeByte; +import static io.trino.spi.variant.VariantEncoder.encodeDate; +import static io.trino.spi.variant.VariantEncoder.encodeDecimal; +import static io.trino.spi.variant.VariantEncoder.encodeDouble; +import static io.trino.spi.variant.VariantEncoder.encodeFloat; +import static io.trino.spi.variant.VariantEncoder.encodeInt; +import static io.trino.spi.variant.VariantEncoder.encodeLong; +import static io.trino.spi.variant.VariantEncoder.encodeNull; +import static io.trino.spi.variant.VariantEncoder.encodeObjectHeading; +import static io.trino.spi.variant.VariantEncoder.encodeShort; +import static io.trino.spi.variant.VariantEncoder.encodeString; +import static io.trino.spi.variant.VariantEncoder.encodeTimeMicrosNtz; +import static io.trino.spi.variant.VariantEncoder.encodeTimestampMicrosNtz; +import static io.trino.spi.variant.VariantEncoder.encodeTimestampMicrosUtc; +import static io.trino.spi.variant.VariantEncoder.encodeTimestampNanosNtz; +import static io.trino.spi.variant.VariantEncoder.encodeTimestampNanosUtc; +import static io.trino.spi.variant.VariantEncoder.encodeUuid; +import static io.trino.spi.variant.VariantEncoder.encodedArraySize; +import static io.trino.spi.variant.VariantEncoder.encodedBinarySize; +import static io.trino.spi.variant.VariantEncoder.encodedDecimalSize; +import static io.trino.spi.variant.VariantEncoder.encodedObjectSize; +import static io.trino.spi.variant.VariantEncoder.encodedStringSize; +import static io.trino.spi.variant.VariantUtils.checkArgument; +import static io.trino.spi.variant.VariantUtils.checkState; +import static io.trino.spi.variant.VariantUtils.findFieldIndex; +import static io.trino.spi.variant.VariantUtils.readOffset; +import static io.trino.spi.variant.VariantUtils.verify; +import static java.lang.Math.multiplyExact; +import static java.lang.Math.toIntExact; +import static java.util.Collections.unmodifiableList; +import static java.util.Collections.unmodifiableMap; +import static java.util.Objects.requireNonNull; +import static java.util.Objects.requireNonNullElse; + +public final class Variant +{ + public static final Variant NULL_VALUE = from(EMPTY_METADATA, Slices.wrappedBuffer(primitiveHeader(PrimitiveType.NULL))); + public static final Variant EMPTY_ARRAY; + public static final Variant EMPTY_OBJECT; + private final Slice data; + private final Metadata metadata; + private final BasicType basicType; + private final PrimitiveType primitiveType; + + static { + IntUnaryOperator emptyIndexedOperator = index -> { + throw new IndexOutOfBoundsException(); + }; + + Slice emptyArrayValue = Slices.allocate(encodedArraySize(0, 0)); + encodeArrayHeading(0, emptyIndexedOperator, emptyArrayValue, 0); + EMPTY_ARRAY = from(EMPTY_METADATA, emptyArrayValue); + + Slice emptyObjectValue = Slices.allocate(encodedObjectSize(0, 0, 0)); + encodeObjectHeading(0, emptyIndexedOperator, emptyIndexedOperator, emptyObjectValue, 0); + EMPTY_OBJECT = from(EMPTY_METADATA, emptyObjectValue); + } + + public Variant(Slice data, Metadata metadata, BasicType basicType, PrimitiveType primitiveType) + { + requireNonNull(data, "data is null"); + requireNonNull(metadata, "metadata is null"); + requireNonNull(basicType, "basicType is null"); + checkArgument(basicType == PRIMITIVE || primitiveType == null, "primitiveType must be null for non-primitive basicType"); + checkArgument(basicType != PRIMITIVE || primitiveType != null, "primitiveType must be non-null for primitive basicType"); + + // not need to retain metadata for non-container types + if (!basicType.isContainer()) { + metadata = EMPTY_METADATA; + } + this.data = data; + this.metadata = metadata; + this.basicType = basicType; + this.primitiveType = primitiveType; + } + + public Slice data() + { + return data; + } + + public Metadata metadata() + { + return metadata; + } + + public BasicType basicType() + { + return basicType; + } + + public PrimitiveType primitiveType() + { + return primitiveType; + } + + public static Variant from(Metadata metadata, Slice data) + { + byte header = data.getByte(0); + BasicType basicType = getBasicType(header); + PrimitiveType primitiveType = basicType == PRIMITIVE ? getPrimitiveType(header) : null; + return new Variant(data, metadata, basicType, primitiveType); + } + + public static Variant ofBoolean(boolean value) + { + Slice data = Slices.allocate(ENCODED_BOOLEAN_SIZE); + encodeBoolean(value, data, 0); + return from(EMPTY_METADATA, data); + } + + public static Variant ofByte(byte value) + { + Slice data = Slices.allocate(ENCODED_BYTE_SIZE); + encodeByte(value, data, 0); + return from(EMPTY_METADATA, data); + } + + public static Variant ofShort(short value) + { + Slice data = Slices.allocate(ENCODED_SHORT_SIZE); + encodeShort(value, data, 0); + return from(EMPTY_METADATA, data); + } + + public static Variant ofInt(int value) + { + Slice data = Slices.allocate(ENCODED_INT_SIZE); + encodeInt(value, data, 0); + return from(EMPTY_METADATA, data); + } + + public static Variant ofLong(long value) + { + Slice data = Slices.allocate(ENCODED_LONG_SIZE); + encodeLong(value, data, 0); + return from(EMPTY_METADATA, data); + } + + public static Variant ofDecimal(BigDecimal value) + { + requireNonNull(value, "value is null"); + Slice data = Slices.allocate(encodedDecimalSize(value)); + encodeDecimal(value, data, 0); + return from(EMPTY_METADATA, data); + } + + public static Variant ofFloat(float value) + { + Slice data = Slices.allocate(ENCODED_FLOAT_SIZE); + encodeFloat(value, data, 0); + return from(EMPTY_METADATA, data); + } + + public static Variant ofDouble(double value) + { + Slice data = Slices.allocate(ENCODED_DOUBLE_SIZE); + encodeDouble(value, data, 0); + return from(EMPTY_METADATA, data); + } + + public static Variant ofDate(int value) + { + Slice data = Slices.allocate(ENCODED_DATE_SIZE); + encodeDate(value, data, 0); + return from(EMPTY_METADATA, data); + } + + public static Variant ofDate(LocalDate value) + { + Slice data = Slices.allocate(ENCODED_DATE_SIZE); + encodeDate(value, data, 0); + return from(EMPTY_METADATA, data); + } + + public static Variant ofTimeMicrosNtz(long value) + { + Slice data = Slices.allocate(ENCODED_TIME_SIZE); + encodeTimeMicrosNtz(value, data, 0); + return from(EMPTY_METADATA, data); + } + + public static Variant ofTimeMicrosNtz(LocalTime value) + { + Slice data = Slices.allocate(ENCODED_TIME_SIZE); + encodeTimeMicrosNtz(value, data, 0); + return from(EMPTY_METADATA, data); + } + + public static Variant ofTimestampMicrosUtc(long value) + { + Slice data = Slices.allocate(ENCODED_TIMESTAMP_SIZE); + encodeTimestampMicrosUtc(value, data, 0); + return from(EMPTY_METADATA, data); + } + + public static Variant ofTimestampMicrosUtc(Instant value) + { + Slice data = Slices.allocate(ENCODED_TIMESTAMP_SIZE); + encodeTimestampMicrosUtc(value, data, 0); + return from(EMPTY_METADATA, data); + } + + public static Variant ofTimestampNanosUtc(long value) + { + Slice data = Slices.allocate(ENCODED_TIMESTAMP_SIZE); + encodeTimestampNanosUtc(value, data, 0); + return from(EMPTY_METADATA, data); + } + + public static Variant ofTimestampNanosUtc(Instant value) + { + Slice data = Slices.allocate(ENCODED_TIMESTAMP_SIZE); + encodeTimestampNanosUtc(value, data, 0); + return from(EMPTY_METADATA, data); + } + + public static Variant ofTimestampMicrosNtz(long value) + { + Slice data = Slices.allocate(ENCODED_TIMESTAMP_SIZE); + encodeTimestampMicrosNtz(value, data, 0); + return from(EMPTY_METADATA, data); + } + + public static Variant ofTimestampMicrosNtz(LocalDateTime value) + { + Slice data = Slices.allocate(ENCODED_TIMESTAMP_SIZE); + encodeTimestampMicrosNtz(value, data, 0); + return from(EMPTY_METADATA, data); + } + + public static Variant ofTimestampNanosNtz(long value) + { + Slice data = Slices.allocate(ENCODED_TIMESTAMP_SIZE); + encodeTimestampNanosNtz(value, data, 0); + return from(EMPTY_METADATA, data); + } + + public static Variant ofTimestampNanosNtz(LocalDateTime value) + { + Slice data = Slices.allocate(ENCODED_TIMESTAMP_SIZE); + encodeTimestampNanosNtz(value, data, 0); + return from(EMPTY_METADATA, data); + } + + public static Variant ofBinary(Slice value) + { + Slice data = Slices.allocate(encodedBinarySize(value.length())); + encodeBinary(value, data, 0); + return from(EMPTY_METADATA, data); + } + + public static Variant ofString(String value) + { + return ofString(utf8Slice(value)); + } + + public static Variant ofString(Slice value) + { + Slice data = Slices.allocate(encodedStringSize(value.length())); + encodeString(value, data, 0); + return from(EMPTY_METADATA, data); + } + + public static Variant ofUuid(UUID value) + { + Slice data = Slices.allocate(ENCODED_UUID_SIZE); + encodeUuid(value, data, 0); + return from(EMPTY_METADATA, data); + } + + public static Variant ofUuid(Slice value) + { + Slice data = Slices.allocate(ENCODED_UUID_SIZE); + encodeUuid(value, data, 0); + return from(EMPTY_METADATA, data); + } + + public static Variant ofArray(List elements) + { + if (elements.isEmpty()) { + return EMPTY_ARRAY; + } + Metadata.Builder metadataBuilder = Metadata.builder(); + List remappers = elements.stream() + .map(variant -> VariantFieldRemapper.create(variant, metadataBuilder)) + .toList(); + + // finalize the metadata and remappers + SortedMetadata sortedMetadata = metadataBuilder.buildSorted(); + IntUnaryOperator sortedFieldIdMapping = sortedMetadata.sortedFieldIdMapping(); + remappers.forEach(remapper -> remapper.finalize(sortedFieldIdMapping)); + + // allocate the output slice + int totalSize = remappers.stream() + .mapToInt(VariantFieldRemapper::size) + .sum(); + Slice output = Slices.allocate(encodedArraySize(elements.size(), totalSize)); + + // write array header + int written = encodeArrayHeading(elements.size(), index -> remappers.get(index).size(), output, 0); + + // write remapped variants + for (VariantFieldRemapper remapper : remappers) { + remapper.write(output, written); + written += remapper.size(); + } + verify(written == output.length(), "Encoded size does not match expected size"); + return new Variant(output, sortedMetadata.metadata(), BasicType.ARRAY, null); + } + + public static Variant ofObject(Map fields) + { + if (fields.isEmpty()) { + return EMPTY_OBJECT; + } + + Metadata.Builder metadataBuilder = Metadata.builder(); + int[] fieldIds = fields.keySet().stream().mapToInt(metadataBuilder::addFieldName).toArray(); + List remappers = fields.values().stream().map(variant -> VariantFieldRemapper.create(variant, metadataBuilder)).toList(); + + // finalize the metadata and remappers + SortedMetadata sortedMetadata = metadataBuilder.buildSorted(); + IntUnaryOperator sortedFieldIdMapping = sortedMetadata.sortedFieldIdMapping(); + for (int i = 0; i < fieldIds.length; i++) { + fieldIds[i] = sortedFieldIdMapping.applyAsInt(fieldIds[i]); + } + remappers.forEach(remapper -> remapper.finalize(sortedFieldIdMapping)); + + // determine write order (objects must be written in lexicographical field-name order) + // the high 32 bits are the fieldId, the low 32 bits are the original index + long[] writeOrder = computeWriteOrder(fieldIds); + for (int writeIndex = 1; writeIndex < writeOrder.length; writeIndex++) { + verify(writeOrderFieldId(writeOrder, writeIndex) > writeOrderFieldId(writeOrder, writeIndex - 1), "Duplicate field IDs are not allowed in VARIANT objects"); + } + + // allocate the output slice + int totalSize = remappers.stream() + .mapToInt(VariantFieldRemapper::size) + .sum(); + int maxFieldId = writeOrderFieldId(writeOrder, writeOrder.length - 1); + Slice output = Slices.allocate(encodedObjectSize(maxFieldId, fieldIds.length, totalSize)); + + // write object header + int written = encodeObjectHeading( + fieldIds.length, + i -> writeOrderFieldId(writeOrder, i), + i -> remappers.get(writeOrderOriginalIndex(writeOrder, i)).size(), + output, + 0); + + // write remapped variants + for (int i = 0; i < writeOrder.length; i++) { + VariantFieldRemapper remapper = remappers.get(writeOrderOriginalIndex(writeOrder, i)); + remapper.write(output, written); + written += remapper.size(); + } + verify(written == output.length(), "Encoded size does not match expected size"); + return new Variant(output, sortedMetadata.metadata(), OBJECT, null); + } + + public boolean isNull() + { + return primitiveType == PrimitiveType.NULL; + } + + public boolean getBoolean() + { + if (primitiveType == PrimitiveType.BOOLEAN_TRUE) { + return true; + } + if (primitiveType == PrimitiveType.BOOLEAN_FALSE) { + return false; + } + throw new IllegalStateException("Expected primitive BOOLEAN but got " + primitiveType); + } + + public byte getByte() + { + verifyType(PrimitiveType.INT8); + return data.getByte(1); + } + + public short getShort() + { + verifyType(PrimitiveType.INT16); + return data.getShort(1); + } + + public int getInt() + { + verifyType(PrimitiveType.INT32); + return data.getInt(1); + } + + public long getLong() + { + verifyType(PrimitiveType.INT64); + return data.getLong(1); + } + + public BigDecimal getDecimal() + { + checkState(primitiveType == PrimitiveType.DECIMAL4 || + primitiveType == PrimitiveType.DECIMAL8 || + primitiveType == PrimitiveType.DECIMAL16, + () -> "Expected DECIMAL primitive but got %s".formatted(requireNonNullElse(primitiveType, basicType))); + + int scale = data.getByte(1); + checkState(scale >= 0 && scale <= 38, () -> "Corrupt DECIMAL scale: %s".formatted(scale)); + + return switch (primitiveType) { + case DECIMAL4 -> { + int unscaled = data.getInt(2); + yield BigDecimal.valueOf(unscaled, scale); + } + case DECIMAL8 -> { + long unscaled = data.getLong(2); + yield BigDecimal.valueOf(unscaled, scale); + } + case DECIMAL16 -> { + // 16-byte little-endian two's complement → BigInteger (big-endian) + byte[] bigEndian = new byte[16]; + for (int i = 0; i < 16; i++) { + bigEndian[15 - i] = data.getByte(2 + i); + } + BigInteger unscaled = new BigInteger(bigEndian); + yield new BigDecimal(unscaled, scale); + } + default -> throw new VerifyException("Expected DECIMAL primitive but got " + requireNonNullElse(primitiveType, basicType)); + }; + } + + public float getFloat() + { + verifyType(PrimitiveType.FLOAT); + return data.getFloat(1); + } + + public double getDouble() + { + verifyType(PrimitiveType.DOUBLE); + return data.getDouble(1); + } + + public int getDate() + { + verifyType(PrimitiveType.DATE); + return data.getInt(1); + } + + public LocalDate getLocalDate() + { + return LocalDate.ofEpochDay(getDate()); + } + + public long getTimeMicros() + { + checkState(primitiveType == PrimitiveType.TIME_NTZ_MICROS, + () -> "Expected primitive TIME in microseconds but got %s".formatted(primitiveType)); + return data.getLong(1); + } + + public LocalTime getLocalTime() + { + verifyType(PrimitiveType.TIME_NTZ_MICROS); + long microsOfDay = getTimeMicros(); + long nanoOfDay = multiplyExact(microsOfDay, 1_000L); + return LocalTime.ofNanoOfDay(nanoOfDay); + } + + public long getTimestampMicros() + { + checkState(primitiveType == PrimitiveType.TIMESTAMP_UTC_MICROS || primitiveType == PrimitiveType.TIMESTAMP_NTZ_MICROS, + () -> "Expected primitive TIMESTAMP in microseconds but got %s".formatted(primitiveType)); + return data.getLong(1); + } + + public long getTimestampNanos() + { + checkState(primitiveType == PrimitiveType.TIMESTAMP_UTC_NANOS || primitiveType == PrimitiveType.TIMESTAMP_NTZ_NANOS, + () -> "Expected primitive TIMESTAMP in nanoseconds but got %s".formatted(primitiveType)); + return data.getLong(1); + } + + public Instant getInstant() + { + long seconds; + int nanoOfSecond; + if (primitiveType == PrimitiveType.TIMESTAMP_UTC_MICROS) { + long micros = getTimestampMicros(); + seconds = Math.floorDiv(micros, 1_000_000); + nanoOfSecond = toIntExact(Math.floorMod(micros, 1_000_000) * 1_000L); + } + else if (primitiveType == PrimitiveType.TIMESTAMP_UTC_NANOS) { + long nanos = getTimestampNanos(); + seconds = Math.floorDiv(nanos, 1_000_000_000L); + nanoOfSecond = (int) Math.floorMod(nanos, 1_000_000_000L); + } + else { + throw new IllegalStateException("Expected primitive TIMESTAMP but got " + primitiveType); + } + return Instant.ofEpochSecond(seconds, nanoOfSecond); + } + + public LocalDateTime getLocalDateTime() + { + long seconds; + int nanoOfSecond; + if (primitiveType == PrimitiveType.TIMESTAMP_NTZ_MICROS) { + long micros = getTimestampMicros(); + seconds = Math.floorDiv(micros, 1_000_000); + nanoOfSecond = toIntExact(Math.floorMod(micros, 1_000_000) * 1_000L); + } + else if (primitiveType == PrimitiveType.TIMESTAMP_NTZ_NANOS) { + long nanos = getTimestampNanos(); + seconds = Math.floorDiv(nanos, 1_000_000_000L); + nanoOfSecond = (int) Math.floorMod(nanos, 1_000_000_000L); + } + else { + throw new IllegalStateException("Expected primitive TIMESTAMP but got " + primitiveType); + } + return LocalDateTime.ofEpochSecond(seconds, nanoOfSecond, ZoneOffset.UTC); + } + + public Slice getBinary() + { + verifyType(PrimitiveType.BINARY); + int length = data.getInt(1); + return data.slice(5, length); + } + + public Slice getString() + { + if (basicType == BasicType.SHORT_STRING) { + int length = shortStringLength(data.getByte(0)); + return data.slice(1, length); + } + + verifyType(PrimitiveType.STRING); + int length = data.getInt(1); + return data.slice(5, length); + } + + public String getStringUtf8() + { + return getString().toStringUtf8(); + } + + public Slice getUuidSlice() + { + verifyType(PrimitiveType.UUID); + return data.slice(1, 16); + } + + public UUID getUuid() + { + verifyType(PrimitiveType.UUID); + // UUID is 16-byte big-endian + long mostSigBits = Long.reverseBytes(data.getLong(1)); + long leastSigBits = Long.reverseBytes(data.getLong(9)); + return new UUID(mostSigBits, leastSigBits); + } + + public int getArrayLength() + { + verifyType(BasicType.ARRAY); + int count = arrayIsLarge(data.getByte(0)) ? data.getInt(1) : (data.getByte(1) & 0xFF); + checkState(count >= 0, () -> "Corrupt array count: " + count); + return count; + } + + public Variant getArrayElement(int index) + { + verifyType(BasicType.ARRAY); + byte header = data.getByte(0); + boolean large = arrayIsLarge(header); + int offSize = arrayFieldOffsetSize(header); + + int count = large ? data.getInt(1) : (data.getByte(1) & 0xFF); + Objects.checkIndex(index, count); + + int offsetsStart = 1 + (large ? 4 : 1); + int valuesStart = offsetsStart + (count + 1) * offSize; + + int offsetPosition = offsetsStart + (index * offSize); + int start = valuesStart + readOffset(data, offsetPosition, offSize); + int end = valuesStart + readOffset(data, offsetPosition + offSize, offSize); + return from(metadata, data.slice(start, end - start)); + } + + public Stream arrayElements() + { + verifyType(BasicType.ARRAY); + + byte header = data.getByte(0); + boolean large = arrayIsLarge(header); + int offsetSize = arrayFieldOffsetSize(header); + + int count = large ? data.getInt(1) : (data.getByte(1) & 0xFF); + + int offsetsStart = 1 + (large ? 4 : 1); + int valuesStart = offsetsStart + (count + 1) * offsetSize; + + return IntStream.range(0, count) + .mapToObj(index -> { + int offsetPosition = offsetsStart + (index * offsetSize); + int start = valuesStart + readOffset(data, offsetPosition, offsetSize); + int end = valuesStart + readOffset(data, offsetPosition + offsetSize, offsetSize); + return from(metadata, data.slice(start, end - start)); + }); + } + + public int getObjectFieldCount() + { + verifyType(OBJECT); + int count = objectIsLarge(data.getByte(0)) ? data.getInt(1) : (data.getByte(1) & 0xFF); + checkState(count >= 0, () -> "Corrupt object field count: " + count); + return count; + } + + public Optional getObjectField(int fieldId) + { + verifyType(OBJECT); + checkArgument(fieldId >= 0 && fieldId < metadata.dictionarySize(), + () -> "Invalid fieldId %d, valid range is [0, %d)".formatted(fieldId, metadata.dictionarySize())); + + // This intentionally avoids decoding the full ObjectLayout. Most production lookups come + // from dereference and read a single field from a short-lived Variant instance, so paying + // the full object traversal setup cost up front would be counterproductive here. + byte header = data.getByte(0); + boolean large = objectIsLarge(header); + int idSize = objectFieldIdSize(header); + int offsetSize = objectFieldOffsetSize(header); + int count = large ? data.getInt(1) : (data.getByte(1) & 0xFF); + checkState(count >= 0, () -> "Corrupt object field count: " + count); + int idsStart = 1 + (large ? 4 : 1); + int offsetsStart = idsStart + count * idSize; + int valuesStart = offsetsStart + (count + 1) * offsetSize; + + for (int i = 0; i < count; i++) { + int currentFieldId = readOffset(data, idsStart + i * idSize, idSize); + if (currentFieldId == fieldId) { + int startOffset = readOffset(data, offsetsStart + i * offsetSize, offsetSize); + int start = valuesStart + startOffset; + return Optional.of(from(metadata, data.slice(start, VariantDecoder.valueSize(data, start)))); + } + } + return Optional.empty(); + } + + public Optional getObjectField(Slice fieldName) + { + verifyType(OBJECT); + + byte header = data.getByte(0); + boolean large = objectIsLarge(header); + int idSize = objectFieldIdSize(header); + int offsetSize = objectFieldOffsetSize(header); + int count = large ? data.getInt(1) : (data.getByte(1) & 0xFF); + checkState(count >= 0, () -> "Corrupt object field count: " + count); + int idsStart = 1 + (large ? 4 : 1); + int offsetsStart = idsStart + count * idSize; + int valuesStart = offsetsStart + (count + 1) * offsetSize; + int fieldIndex = findFieldIndex(fieldName, metadata, data, count, idsStart, idSize); + if (fieldIndex < 0) { + return Optional.empty(); + } + + int startOffset = readOffset(data, offsetsStart + fieldIndex * offsetSize, offsetSize); + int start = valuesStart + startOffset; + return Optional.of(from(metadata, data.slice(start, VariantDecoder.valueSize(data, start)))); + } + + public Stream objectFieldNames() + { + verifyType(OBJECT); + + byte header = data.getByte(0); + boolean large = objectIsLarge(header); + int idSize = objectFieldIdSize(header); + + int count = large ? data.getInt(1) : (data.getByte(1) & 0xFF); + + int idsStart = 1 + (large ? 4 : 1); + + return IntStream.range(0, count) + .mapToObj(i -> metadata.get(readOffset(data, idsStart + i * idSize, idSize))); + } + + public Stream objectValues() + { + verifyType(OBJECT); + ObjectLayout layout = (ObjectLayout) VariantDecoder.decode(data, 0); + + return IntStream.range(0, layout.count()) + .mapToObj(i -> { + int start = layout.valueStart(i); + int end = layout.valueEnd(i); + return from(metadata, data.slice(start, end - start)); + }); + } + + public Stream objectFields() + { + verifyType(OBJECT); + ObjectLayout layout = (ObjectLayout) VariantDecoder.decode(data, 0); + + return IntStream.range(0, layout.count()) + .mapToObj(i -> { + int fieldId = layout.fieldId(i); + int start = layout.valueStart(i); + int end = layout.valueEnd(i); + Variant value = from(metadata, data.slice(start, end - start)); + + return new ObjectFieldIdValue(fieldId, value); + }); + } + + @Override + public boolean equals(Object other) + { + if (this == other) { + return true; + } + if (!(other instanceof Variant rightValue)) { + return false; + } + return VariantUtils.equals( + metadata, data, 0, + rightValue.metadata, rightValue.data, 0); + } + + @Override + public int hashCode() + { + return Long.hashCode(longHashCode()); + } + + public long longHashCode() + { + return VariantUtils.hashCode(metadata, data, 0); + } + + /// Converts this Variant into a plain Java object graph. + /// + /// This method is intended for debugging and testing. + /// The returned structure is composed of standard Java types and + /// **unmodifiable** containers. + /// + /// - Objects become {@code Map} + /// - Arrays become {@code List} + /// - Nested {@code Variant} values are recursively converted + /// + /// ## Mapping + /// + /// | Variant kind | Variant type | Java type returned | + /// |-------------|--------------|--------------------| + /// | PRIMITIVE | NULL | {@code null} | + /// | PRIMITIVE | BOOLEAN_TRUE / BOOLEAN_FALSE | {@code Boolean} | + /// | PRIMITIVE | INT8 | {@code Byte} | + /// | PRIMITIVE | INT16 | {@code Short} | + /// | PRIMITIVE | INT32 | {@code Integer} | + /// | PRIMITIVE | INT64 | {@code Long} | + /// | PRIMITIVE | FLOAT | {@code Float} | + /// | PRIMITIVE | DOUBLE | {@code Double} | + /// | PRIMITIVE | DECIMAL4 / DECIMAL8 / DECIMAL16 | {@code BigDecimal} | + /// | PRIMITIVE | DATE | {@code LocalDate} | + /// | PRIMITIVE | TIME_NTZ_MICROS | {@code LocalTime} | + /// | PRIMITIVE | TIMESTAMP_UTC_MICROS / TIMESTAMP_UTC_NANOS | {@code Instant} | + /// | PRIMITIVE | TIMESTAMP_NTZ_MICROS / TIMESTAMP_NTZ_NANOS | {@code LocalDateTime} | + /// | PRIMITIVE | BINARY | {@code Slice} | + /// | PRIMITIVE | STRING | {@code String} | + /// | PRIMITIVE | UUID | {@code UUID} | + /// | SHORT_STRING | SHORT_STRING | {@code String} | + /// | OBJECT | OBJECT | {@code Map} | + /// | ARRAY | ARRAY | {@code List} | + /// + public Object toObject() + { + return switch (basicType()) { + case PRIMITIVE -> switch (primitiveType()) { + case NULL -> null; + case BOOLEAN_TRUE -> true; + case BOOLEAN_FALSE -> false; + case INT8 -> getByte(); + case INT16 -> getShort(); + case INT32 -> getInt(); + case INT64 -> getLong(); + case DOUBLE -> getDouble(); + case DECIMAL4, DECIMAL8, DECIMAL16 -> getDecimal(); + case DATE -> getLocalDate(); + case TIMESTAMP_UTC_MICROS, TIMESTAMP_UTC_NANOS -> getInstant(); + case TIMESTAMP_NTZ_MICROS, TIMESTAMP_NTZ_NANOS -> getLocalDateTime(); + case FLOAT -> getFloat(); + case BINARY -> getBinary(); + case STRING -> getStringUtf8(); + case TIME_NTZ_MICROS -> getLocalTime(); + case UUID -> getUuid(); + }; + case SHORT_STRING -> getStringUtf8(); + case OBJECT -> { + // values can be null, so we can't use the simple toMap collectors + Map map = new HashMap<>(getObjectFieldCount()); + objectFields().forEach(field -> map.put(metadata().get(field.fieldId()).toStringUtf8(), field.value().toObject())); + yield unmodifiableMap(map); + } + case ARRAY -> { + // values can be null, so we can't use the simple toList collectors + List array = new ArrayList<>(getArrayLength()); + arrayElements().map(Variant::toObject).forEach(array::add); + yield unmodifiableList(array); + } + }; + } + + /// Creates a {@link Variant} from a plain Java object graph. + /// + /// This method is intended for debugging and testing. + /// It performs multiple full-tree passes and is **not** optimized for performance. + /// + /// Supported container types are traversed recursively, and a single shared + /// metadata dictionary is constructed for all objects in the graph. + /// + /// ## Supported input types + /// + /// - Objects become VARIANT objects + /// - Lists become VARIANT arrays + /// - Nested {@code Variant} values are rewritten into the resulting metadata + /// + /// ## Mapping + /// + /// | Java type | Variant kind | Variant type | + /// |----------|--------------|--------------| + /// | {@code null} | PRIMITIVE | NULL | + /// | {@code Boolean} | PRIMITIVE | BOOLEAN_TRUE / BOOLEAN_FALSE | + /// | {@code Byte} | PRIMITIVE | INT8 | + /// | {@code Short} | PRIMITIVE | INT16 | + /// | {@code Integer} | PRIMITIVE | INT32 | + /// | {@code Long} | PRIMITIVE | INT64 | + /// | {@code Float} | PRIMITIVE | FLOAT | + /// | {@code Double} | PRIMITIVE | DOUBLE | + /// | {@code BigDecimal} | PRIMITIVE | DECIMAL4 / DECIMAL8 / DECIMAL16 | + /// | {@code LocalDate} | PRIMITIVE | DATE | + /// | {@code LocalTime} | PRIMITIVE | TIME_NTZ_MICROS | + /// | {@code Instant} | PRIMITIVE | TIMESTAMP_UTC_NANOS | + /// | {@code LocalDateTime} | PRIMITIVE | TIMESTAMP_NTZ_NANOS | + /// | {@code UUID} | PRIMITIVE | UUID | + /// | {@code Slice} | PRIMITIVE | BINARY | + /// | {@code byte[]} | PRIMITIVE | BINARY | + /// | {@code String} | PRIMITIVE / SHORT_STRING | STRING | + /// | {@code Map} | OBJECT | OBJECT | + /// | {@code List} | ARRAY | ARRAY | + /// | {@code Variant} | *rewritten* | *preserved semantics* | + /// + /// @throws IllegalArgumentException if an unsupported Java type is encountered + /// @throws IllegalArgumentException if a map key is {@code null} or not a {@code String} + public static Variant fromObject(Object value) + { + if (value == null) { + return NULL_VALUE; + } + + // Pass 1: collect all field names across the tree + Metadata.Builder metadataBuilder = Metadata.builder(); + IdentityHashMap fieldRemappers = new IdentityHashMap<>(); + collectFieldNames(value, metadataBuilder, fieldRemappers); + SortedMetadata sortedMetadata = metadataBuilder.buildSorted(); + + fieldRemappers.values().forEach(remapper -> remapper.finalize(sortedMetadata.sortedFieldIdMapping())); + Map fieldIdByName = new HashMap<>(sortedMetadata.metadata().dictionarySize()); + for (int i = 0; i < sortedMetadata.metadata().dictionarySize(); i++) { + fieldIdByName.put(sortedMetadata.metadata().get(i).toStringUtf8(), i); + } + + // Pass 2: compute total size + IdentityHashMap containerSizeCache = new IdentityHashMap<>(); + int totalSize = computeEncodedSize(value, fieldIdByName, fieldRemappers, containerSizeCache); + + // Pass 3: write + Slice data = Slices.allocate(totalSize); + int written = writeEncoded(value, fieldIdByName, data, 0, fieldRemappers, containerSizeCache); + verify(written == totalSize, "Encoded size does not match expected size"); + return from(sortedMetadata.metadata(), data); + } + + private static void collectFieldNames(Object value, Metadata.Builder metadataBuilder, IdentityHashMap fieldRemappers) + { + switch (value) { + case null -> {} + case Variant v -> fieldRemappers.computeIfAbsent(v, _ -> VariantFieldRemapper.create(v, metadataBuilder)); + case Map map -> { + for (Object key : map.keySet()) { + metadataBuilder.addFieldName(utf8Slice(castMapKey(key))); + } + for (Object child : map.values()) { + collectFieldNames(child, metadataBuilder, fieldRemappers); + } + } + case List list -> { + for (Object child : list) { + collectFieldNames(child, metadataBuilder, fieldRemappers); + } + } + default -> { + // primitives/leaf types: nothing to collect + } + } + } + + private static int computeEncodedSize( + Object value, + Map fieldIdByName, + IdentityHashMap fieldRemappers, + IdentityHashMap containerSizeCache) + { + return switch (value) { + case null -> ENCODED_NULL_SIZE; + case Variant v -> requireNonNull(fieldRemappers.get(v), "missing remapper").size(); + case Boolean _ -> ENCODED_BOOLEAN_SIZE; + case Byte _ -> ENCODED_BYTE_SIZE; + case Short _ -> ENCODED_SHORT_SIZE; + case Integer _ -> ENCODED_INT_SIZE; + case Long _ -> ENCODED_LONG_SIZE; + case Float _ -> ENCODED_FLOAT_SIZE; + case Double _ -> ENCODED_DOUBLE_SIZE; + case BigDecimal decimal -> encodedDecimalSize(decimal); + case LocalDate _ -> ENCODED_DATE_SIZE; + case LocalTime _ -> ENCODED_TIME_SIZE; + case Instant _, LocalDateTime _ -> ENCODED_TIMESTAMP_SIZE; + case UUID _ -> ENCODED_UUID_SIZE; + case Slice slice -> encodedBinarySize(slice.length()); + case byte[] bytes -> encodedBinarySize(bytes.length); + case String s -> encodedStringSize(utf8Slice(s).length()); + case List list -> containerSizeCache.computeIfAbsent(list, _ -> { + int totalElementsLength = 0; + for (Object element : list) { + totalElementsLength += computeEncodedSize(element, fieldIdByName, fieldRemappers, containerSizeCache); + } + return encodedArraySize(list.size(), totalElementsLength); + }); + case Map map -> containerSizeCache.computeIfAbsent(map, _ -> { + if (map.isEmpty()) { + return ENCODED_EMPTY_OBJECT_SIZE; + } + + int maxFieldId = -1; + for (Object key : map.keySet()) { + maxFieldId = Math.max(maxFieldId, requireNonNull(fieldIdByName.get(castMapKey(key)))); + } + + int totalValuesLength = 0; + for (Object entry : map.values()) { + totalValuesLength += computeEncodedSize(entry, fieldIdByName, fieldRemappers, containerSizeCache); + } + + return encodedObjectSize(maxFieldId, map.size(), totalValuesLength); + }); + default -> throw new IllegalArgumentException("Unsupported object type for VARIANT: " + value.getClass().getName()); + }; + } + + private static int writeEncoded( + Object value, + Map fieldIdByName, + Slice out, + int offset, + IdentityHashMap fieldRemappers, + IdentityHashMap containerSizeCache) + { + return switch (value) { + case null -> encodeNull(out, offset); + case Variant v -> requireNonNull(fieldRemappers.get(v), "missing remapper").write(out, offset); + case Boolean v -> encodeBoolean(v, out, offset); + case Byte v -> encodeByte(v, out, offset); + case Short v -> encodeShort(v, out, offset); + case Integer v -> encodeInt(v, out, offset); + case Long v -> encodeLong(v, out, offset); + case Float v -> encodeFloat(v, out, offset); + case Double v -> encodeDouble(v, out, offset); + case BigDecimal decimal -> encodeDecimal(decimal, out, offset); + case LocalDate date -> encodeDate(date, out, offset); + case LocalTime time -> encodeTimeMicrosNtz(time, out, offset); + case Instant instant -> encodeTimestampNanosUtc(instant, out, offset); + case LocalDateTime dateTime -> encodeTimestampNanosNtz(dateTime, out, offset); + case UUID uuid -> encodeUuid(uuid, out, offset); + case Slice slice -> encodeBinary(slice, out, offset); + case byte[] bytes -> encodeBinary(Slices.wrappedBuffer(bytes), out, offset); + case String string -> encodeString(utf8Slice(string), out, offset); + case List list -> { + int written = encodeArrayHeading( + list.size(), + i -> computeEncodedSize(list.get(i), fieldIdByName, fieldRemappers, containerSizeCache), + out, + offset); + for (Object element : list) { + written += writeEncoded(element, fieldIdByName, out, offset + written, fieldRemappers, containerSizeCache); + } + yield written; + } + case Map map -> { + int count = map.size(); + if (count == 0) { + // empty object (header only) + yield encodeObjectHeading(0, _ -> 0, _ -> 0, out, offset); + } + + int[] fieldIds = new int[count]; + Object[] values = new Object[count]; + + int i = 0; + for (Map.Entry entry : map.entrySet()) { + fieldIds[i] = requireNonNull(fieldIdByName.get(castMapKey(entry.getKey()))); + values[i] = entry.getValue(); + i++; + } + + // the high 32 bits are the fieldId, the low 32 bits are the original index + long[] writeOrder = computeWriteOrder(fieldIds); + for (int writeIndex = 1; writeIndex < writeOrder.length; writeIndex++) { + verify(writeOrderFieldId(writeOrder, writeIndex) > writeOrderFieldId(writeOrder, writeIndex - 1), "Duplicate field IDs are not allowed in VARIANT objects"); + } + int written = encodeObjectHeading( + count, + index -> writeOrderFieldId(writeOrder, index), + index -> computeEncodedSize(values[writeOrderOriginalIndex(writeOrder, index)], fieldIdByName, fieldRemappers, containerSizeCache), + out, + offset); + + for (int writeOrderIndex = 0; writeOrderIndex < writeOrder.length; writeOrderIndex++) { + written += writeEncoded(values[writeOrderOriginalIndex(writeOrder, writeOrderIndex)], fieldIdByName, out, offset + written, fieldRemappers, containerSizeCache); + } + yield written; + } + default -> throw new IllegalArgumentException("Unsupported object type for VARIANT: " + value.getClass().getName()); + }; + } + + /// Returns an ordering of fields for writing. This result data is packed into a single + /// long per field, where the high 32 bits are the fieldId and the low 32 bits are the + /// original index. + private static long[] computeWriteOrder(int[] fieldIds) + { + long[] order = new long[fieldIds.length]; + for (int i = 0; i < fieldIds.length; i++) { + order[i] = (((long) fieldIds[i]) << 32) | (i & 0xffff_ffffL); + } + Arrays.sort(order); + return order; + } + + private static int writeOrderOriginalIndex(long[] writeOrder, int writeOrderIndex) + { + return (int) writeOrder[writeOrderIndex]; + } + + private static int writeOrderFieldId(long[] writeOrder, int writeOrderIndex) + { + return (int) (writeOrder[writeOrderIndex] >>> 32); + } + + private static String castMapKey(Object key) + { + return switch (key) { + case null -> throw new IllegalArgumentException("Map key is null"); + case String name -> name; + case Slice name -> name.toStringUtf8(); + default -> throw new IllegalArgumentException("Map key must be a String: " + key.getClass().getName()); + }; + } + + private void verifyType(BasicType expected) + { + checkState(basicType == expected, () -> "Expected basic %s but got %s".formatted(expected, basicType)); + } + + private void verifyType(PrimitiveType expected) + { + checkState(primitiveType == expected, () -> "Expected primitive %s but got %s".formatted(expected, requireNonNullElse(primitiveType, basicType))); + } + + @Override + public String toString() + { + return "Variant[" + + "basicType=" + basicType + ", " + + "primitiveType=" + primitiveType + ']'; + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/variant/VariantDecoder.java b/core/trino-spi/src/main/java/io/trino/spi/variant/VariantDecoder.java new file mode 100644 index 000000000000..77b96d55dc0c --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/variant/VariantDecoder.java @@ -0,0 +1,227 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.variant; + +import io.airlift.slice.Slice; + +import java.util.Arrays; + +import static io.trino.spi.variant.Header.arrayFieldOffsetSize; +import static io.trino.spi.variant.Header.arrayIsLarge; +import static io.trino.spi.variant.Header.getPrimitiveType; +import static io.trino.spi.variant.Header.objectFieldIdSize; +import static io.trino.spi.variant.Header.objectFieldOffsetSize; +import static io.trino.spi.variant.Header.objectIsLarge; +import static io.trino.spi.variant.Header.shortStringLength; +import static io.trino.spi.variant.VariantEncoder.ENCODED_BOOLEAN_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_BYTE_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_DATE_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_DECIMAL16_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_DECIMAL4_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_DECIMAL8_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_DOUBLE_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_FLOAT_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_INT_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_LONG_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_NULL_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_SHORT_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_TIMESTAMP_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_TIME_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_UUID_SIZE; +import static io.trino.spi.variant.VariantUtils.checkState; +import static io.trino.spi.variant.VariantUtils.readOffset; + +public final class VariantDecoder +{ + private VariantDecoder() {} + + public static int valueSize(Slice data, int offset) + { + byte header = data.getByte(offset); + return switch (Header.getBasicType(header)) { + case PRIMITIVE -> switch (getPrimitiveType(header)) { + case NULL -> ENCODED_NULL_SIZE; + case BOOLEAN_TRUE, BOOLEAN_FALSE -> ENCODED_BOOLEAN_SIZE; + case INT8 -> ENCODED_BYTE_SIZE; + case INT16 -> ENCODED_SHORT_SIZE; + case INT32 -> ENCODED_INT_SIZE; + case INT64 -> ENCODED_LONG_SIZE; + case DOUBLE -> ENCODED_DOUBLE_SIZE; + case DECIMAL4 -> ENCODED_DECIMAL4_SIZE; + case DECIMAL8 -> ENCODED_DECIMAL8_SIZE; + case DECIMAL16 -> ENCODED_DECIMAL16_SIZE; + case DATE -> ENCODED_DATE_SIZE; + case TIMESTAMP_UTC_MICROS, TIMESTAMP_NTZ_MICROS, TIMESTAMP_UTC_NANOS, TIMESTAMP_NTZ_NANOS -> ENCODED_TIMESTAMP_SIZE; + case FLOAT -> ENCODED_FLOAT_SIZE; + case BINARY, STRING -> 5 + data.getInt(offset + 1); + case TIME_NTZ_MICROS -> ENCODED_TIME_SIZE; + case UUID -> ENCODED_UUID_SIZE; + }; + case SHORT_STRING -> 1 + shortStringLength(header); + case ARRAY -> arraySize(data, offset, header); + case OBJECT -> objectSize(data, offset, header); + }; + } + + private static int arraySize(Slice data, int offset, byte header) + { + boolean large = arrayIsLarge(header); + int offsetSize = arrayFieldOffsetSize(header); + int count = large ? data.getInt(offset + 1) : (data.getByte(offset + 1) & 0xFF); + int offsetsStart = offset + 1 + (large ? 4 : 1); + int valuesStart = offsetsStart + (count + 1) * offsetSize; + return valuesStart - offset + readOffset(data, offsetsStart + count * offsetSize, offsetSize); + } + + private static int objectSize(Slice data, int offset, byte header) + { + boolean large = objectIsLarge(header); + int offsetSize = objectFieldOffsetSize(header); + int idSize = objectFieldIdSize(header); + int count = large ? data.getInt(offset + 1) : (data.getByte(offset + 1) & 0xFF); + int idsStart = offset + 1 + (large ? 4 : 1); + int offsetsStart = idsStart + count * idSize; + int valuesStart = offsetsStart + (count + 1) * offsetSize; + return valuesStart - offset + readOffset(data, offsetsStart + count * offsetSize, offsetSize); + } + + public static VariantLayout decode(Slice data, int offset) + { + byte header = data.getByte(offset); + return switch (Header.getBasicType(header)) { + case PRIMITIVE, SHORT_STRING -> PrimitiveLayout.PRIMITIVE; + case ARRAY -> ArrayLayout.decode(data, offset, header); + case OBJECT -> ObjectLayout.decode(data, offset, header); + }; + } + + public sealed interface VariantLayout + permits PrimitiveLayout, ArrayLayout, ObjectLayout {} + + public enum PrimitiveLayout + implements VariantLayout + { + PRIMITIVE + } + + public record ArrayLayout( + Slice data, + int offset, + int count, + int offsetSize, + int offsetsStart, + int valuesStart) + implements VariantLayout + { + static ArrayLayout decode(Slice data, int offset, byte header) + { + boolean large = arrayIsLarge(header); + int offsetSize = arrayFieldOffsetSize(header); + + int count = large ? data.getInt(offset + 1) : (data.getByte(offset + 1) & 0xFF); + int offsetsStart = offset + 1 + (large ? 4 : 1); + int valuesStart = offsetsStart + (count + 1) * offsetSize; + + return new ArrayLayout(data, offset, count, offsetSize, offsetsStart, valuesStart); + } + + int headerSize() + { + return valuesStart - offset; + } + + int elementStart(int index) + { + int offsetPosition = offsetsStart + index * offsetSize; + return valuesStart + readOffset(data, offsetPosition, offsetSize); + } + + int elementEnd(int index) + { + int offsetPosition = offsetsStart + (index + 1) * offsetSize; + return valuesStart + readOffset(data, offsetPosition, offsetSize); + } + } + + public record ObjectLayout( + Slice data, + int count, + int idSize, + int offsetSize, + int idsStart, + int offsetsStart, + int valuesStart, + int[] valueOffsets, + int[] valueLengths) + implements VariantLayout + { + static ObjectLayout decode(Slice data, int offset, byte header) + { + boolean large = objectIsLarge(header); + int idSize = objectFieldIdSize(header); + int offsetSize = objectFieldOffsetSize(header); + + int count = large ? data.getInt(offset + 1) : (data.getByte(offset + 1) & 0xFF); + checkState(count >= 0, () -> "Corrupt object field count: " + count); + int idsStart = offset + 1 + (large ? 4 : 1); + int offsetsStart = idsStart + count * idSize; + int valuesStart = offsetsStart + (count + 1) * offsetSize; + if (count == 0) { + return new ObjectLayout(data, count, idSize, offsetSize, idsStart, offsetsStart, valuesStart, new int[0], new int[0]); + } + + int[] valueOffsets = new int[count]; + int[] valueLengths = new int[count]; + for (int index = 0; index < count; index++) { + valueOffsets[index] = readOffset(data, offsetsStart + index * offsetSize, offsetSize); + } + + // Object encodings store field IDs in name order and store value start offsets, + // but they do not store value lengths or guarantee that the physical value order + // matches the field-ID order. To derive lengths for object traversal, readers must + // find the next larger offset for each field. That makes object decoding + // unavoidably more expensive than simple adjacent subtraction. + long[] sortedOffsetIndexPairs = new long[count]; + for (int index = 0; index < count; index++) { + sortedOffsetIndexPairs[index] = (((long) valueOffsets[index]) << Integer.SIZE) | index; + } + Arrays.sort(sortedOffsetIndexPairs); + + for (int index = count - 1, nextOffset = readOffset(data, offsetsStart + count * offsetSize, offsetSize); index >= 0; index--) { + long pair = sortedOffsetIndexPairs[index]; + int start = (int) (pair >>> Integer.SIZE); + int originalIndex = (int) pair; + valueLengths[originalIndex] = nextOffset - start; + nextOffset = start; + } + + return new ObjectLayout(data, count, idSize, offsetSize, idsStart, offsetsStart, valuesStart, valueOffsets, valueLengths); + } + + int fieldId(int index) + { + return readOffset(data, idsStart + index * idSize, idSize); + } + + int valueStart(int index) + { + return valuesStart + valueOffsets[index]; + } + + int valueEnd(int index) + { + return valuesStart + valueOffsets[index] + valueLengths[index]; + } + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/variant/VariantEncoder.java b/core/trino-spi/src/main/java/io/trino/spi/variant/VariantEncoder.java new file mode 100644 index 000000000000..e92fb14ccf99 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/variant/VariantEncoder.java @@ -0,0 +1,538 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.variant; + +import io.airlift.slice.Slice; +import io.trino.spi.type.Int128; +import io.trino.spi.variant.Header.PrimitiveType; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.ZoneOffset; +import java.util.List; +import java.util.UUID; +import java.util.function.IntFunction; +import java.util.function.IntUnaryOperator; + +import static io.trino.spi.variant.Header.PrimitiveType.BINARY; +import static io.trino.spi.variant.Header.PrimitiveType.BOOLEAN_FALSE; +import static io.trino.spi.variant.Header.PrimitiveType.BOOLEAN_TRUE; +import static io.trino.spi.variant.Header.PrimitiveType.DATE; +import static io.trino.spi.variant.Header.PrimitiveType.DECIMAL16; +import static io.trino.spi.variant.Header.PrimitiveType.DECIMAL4; +import static io.trino.spi.variant.Header.PrimitiveType.DECIMAL8; +import static io.trino.spi.variant.Header.PrimitiveType.DOUBLE; +import static io.trino.spi.variant.Header.PrimitiveType.FLOAT; +import static io.trino.spi.variant.Header.PrimitiveType.INT16; +import static io.trino.spi.variant.Header.PrimitiveType.INT32; +import static io.trino.spi.variant.Header.PrimitiveType.INT64; +import static io.trino.spi.variant.Header.PrimitiveType.INT8; +import static io.trino.spi.variant.Header.PrimitiveType.NULL; +import static io.trino.spi.variant.Header.PrimitiveType.STRING; +import static io.trino.spi.variant.Header.PrimitiveType.TIMESTAMP_NTZ_MICROS; +import static io.trino.spi.variant.Header.PrimitiveType.TIMESTAMP_NTZ_NANOS; +import static io.trino.spi.variant.Header.PrimitiveType.TIMESTAMP_UTC_MICROS; +import static io.trino.spi.variant.Header.PrimitiveType.TIMESTAMP_UTC_NANOS; +import static io.trino.spi.variant.Header.PrimitiveType.TIME_NTZ_MICROS; +import static io.trino.spi.variant.Header.SHORT_STRING_MAX_LENGTH; +import static io.trino.spi.variant.Header.arrayHeader; +import static io.trino.spi.variant.Header.objectHeader; +import static io.trino.spi.variant.Header.primitiveHeader; +import static io.trino.spi.variant.Header.shortStringHeader; +import static io.trino.spi.variant.VariantUtils.checkArgument; +import static io.trino.spi.variant.VariantUtils.getOffsetSize; +import static io.trino.spi.variant.VariantUtils.verify; +import static io.trino.spi.variant.VariantUtils.writeOffset; +import static java.lang.Math.max; + +public final class VariantEncoder +{ + public static final int ENCODED_NULL_SIZE = 1; + public static final int ENCODED_BOOLEAN_SIZE = 1; + public static final int ENCODED_BYTE_SIZE = 2; + public static final int ENCODED_SHORT_SIZE = 3; + public static final int ENCODED_INT_SIZE = 5; + public static final int ENCODED_LONG_SIZE = 9; + public static final int ENCODED_DECIMAL4_SIZE = 6; + public static final int ENCODED_DECIMAL8_SIZE = 10; + public static final int ENCODED_DECIMAL16_SIZE = 18; + public static final int ENCODED_FLOAT_SIZE = 5; + public static final int ENCODED_DOUBLE_SIZE = 9; + public static final int ENCODED_DATE_SIZE = 5; + public static final int ENCODED_TIME_SIZE = 9; + public static final int ENCODED_TIMESTAMP_SIZE = 9; + public static final int ENCODED_UUID_SIZE = 17; + public static final int ENCODED_EMPTY_OBJECT_SIZE = 3; + + private VariantEncoder() {} + + public static int encodeNull(Slice variant, int offset) + { + variant.setByte(offset, primitiveHeader(NULL)); + return ENCODED_NULL_SIZE; + } + + public static int encodeBoolean(boolean value, Slice variant, int offset) + { + if (value) { + variant.setByte(offset, primitiveHeader(BOOLEAN_TRUE)); + } + else { + variant.setByte(offset, primitiveHeader(BOOLEAN_FALSE)); + } + return ENCODED_BOOLEAN_SIZE; + } + + public static int encodeByte(byte value, Slice variant, int offset) + { + variant.setByte(offset, primitiveHeader(INT8)); + variant.setByte(offset + 1, value); + return ENCODED_BYTE_SIZE; + } + + public static int encodeShort(short value, Slice variant, int offset) + { + variant.setByte(offset, primitiveHeader(INT16)); + variant.setShort(offset + 1, value); + return ENCODED_SHORT_SIZE; + } + + public static int encodeInt(int value, Slice variant, int offset) + { + variant.setByte(offset, primitiveHeader(INT32)); + variant.setInt(offset + 1, value); + return ENCODED_INT_SIZE; + } + + public static int encodeLong(long value, Slice variant, int offset) + { + variant.setByte(offset, primitiveHeader(INT64)); + variant.setLong(offset + 1, value); + return ENCODED_LONG_SIZE; + } + + public static int encodedDecimalSize(BigDecimal value) + { + int precision = value.precision(); + if (precision <= 9) { + return ENCODED_DECIMAL4_SIZE; + } + else if (precision <= 18) { + return ENCODED_DECIMAL8_SIZE; + } + else if (precision <= 38) { + return ENCODED_DECIMAL16_SIZE; + } + throw new IllegalArgumentException("Decimal precision out of range: " + precision); + } + + public static int encodeDecimal(BigDecimal value, Slice data, int offset) + { + // We could improve the fit into int/long by checking the actual unscaled value, + // but this code matches the existing Iceberg logic. + int precision = value.precision(); + BigInteger unscaled = value.unscaledValue(); + int scale = value.scale(); + if (precision <= 9) { + return encodeDecimal4(unscaled.intValue(), scale, data, offset); + } + else if (precision <= 18) { + return encodeDecimal8(unscaled.longValue(), scale, data, offset); + } + else if (precision <= 38) { + return encodeDecimal16(unscaled, scale, data, offset); + } + throw new IllegalArgumentException("Decimal precision out of range: " + precision); + } + + public static int encodeDecimal4(int unscaled, int scale, Slice data, int offset) + { + validateDecimalScale(scale); + data.setByte(offset, primitiveHeader(DECIMAL4)); + data.setByte(offset + 1, (byte) scale); + data.setInt(offset + 2, unscaled); + return ENCODED_DECIMAL4_SIZE; + } + + public static int encodeDecimal8(long unscaled, int scale, Slice data, int offset) + { + validateDecimalScale(scale); + data.setByte(offset, primitiveHeader(DECIMAL8)); + data.setByte(offset + 1, (byte) scale); + data.setLong(offset + 2, unscaled); + return ENCODED_DECIMAL8_SIZE; + } + + public static int encodeDecimal16(BigInteger unscaled, int scale, Slice data, int offset) + { + long low = unscaled.longValue(); + long high; + try { + high = unscaled.shiftRight(64).longValueExact(); + } + catch (ArithmeticException e) { + throw new ArithmeticException("BigInteger out of Int128 range"); + } + + return encodeDecimal16(high, low, scale, data, offset); + } + + public static int encodeDecimal16(Int128 unscaled, int scale, Slice data, int offset) + { + return encodeDecimal16(unscaled.getHigh(), unscaled.getLow(), scale, data, offset); + } + + public static int encodeDecimal16(long high, long low, int scale, Slice data, int offset) + { + validateDecimalScale(scale); + + data.setByte(offset, primitiveHeader(DECIMAL16)); + data.setByte(offset + 1, (byte) scale); + + // int128 little-endian: low 64 bits first, then high 64 bits + data.setLong(offset + 2, low); + data.setLong(offset + 10, high); + + return ENCODED_DECIMAL16_SIZE; + } + + private static void validateDecimalScale(int scale) + { + checkArgument(scale >= 0 && scale <= 38, () -> "Invalid decimal scale: %s (expected 0..38)".formatted(scale)); + } + + public static int encodeFloat(float value, Slice variant, int offset) + { + variant.setByte(offset, primitiveHeader(FLOAT)); + variant.setFloat(offset + 1, value); + return ENCODED_FLOAT_SIZE; + } + + public static int encodeDouble(double value, Slice variant, int offset) + { + variant.setByte(offset, primitiveHeader(DOUBLE)); + variant.setDouble(offset + 1, value); + return ENCODED_DOUBLE_SIZE; + } + + public static int encodeDate(int value, Slice variant, int offset) + { + variant.setByte(offset, primitiveHeader(DATE)); + variant.setInt(offset + 1, value); + return ENCODED_DATE_SIZE; + } + + public static int encodeDate(LocalDate value, Slice variant, int offset) + { + return encodeDate((int) value.toEpochDay(), variant, offset); + } + + public static int encodeTimeMicrosNtz(long value, Slice variant, int offset) + { + variant.setByte(offset, primitiveHeader(TIME_NTZ_MICROS)); + variant.setLong(offset + 1, value); + return ENCODED_TIME_SIZE; + } + + public static int encodeTimeMicrosNtz(LocalTime value, Slice variant, int offset) + { + long nanoOfDay = value.toNanoOfDay(); + long microsOfDay = nanoOfDay / 1_000; + return encodeTimeMicrosNtz(microsOfDay, variant, offset); + } + + public static int encodeTimestampMicrosUtc(long value, Slice variant, int offset) + { + variant.setByte(offset, primitiveHeader(TIMESTAMP_UTC_MICROS)); + variant.setLong(offset + 1, value); + return ENCODED_TIMESTAMP_SIZE; + } + + public static int encodeTimestampMicrosUtc(Instant value, Slice variant, int offset) + { + long epochSecond = value.getEpochSecond(); + int nanoOfSecond = value.getNano(); + long epochMicros = epochSecond * 1_000_000 + (nanoOfSecond / 1_000); + return encodeTimestampMicrosUtc(epochMicros, variant, offset); + } + + public static int encodeTimestampNanosUtc(long value, Slice variant, int offset) + { + variant.setByte(offset, primitiveHeader(TIMESTAMP_UTC_NANOS)); + variant.setLong(offset + 1, value); + return ENCODED_TIMESTAMP_SIZE; + } + + public static int encodeTimestampNanosUtc(Instant value, Slice variant, int offset) + { + long epochSecond = value.getEpochSecond(); + int nanoOfSecond = value.getNano(); + + // For negative timestamps with a positive nano adjustment, shift one second into nanos first + // so multiplyExact can still represent the full long nanoseconds domain (including Long.MIN_VALUE). + if (epochSecond < 0 && nanoOfSecond > 0) { + epochSecond = Math.addExact(epochSecond, 1); + nanoOfSecond -= 1_000_000_000; + } + + long epochNanos = Math.addExact(Math.multiplyExact(epochSecond, 1_000_000_000L), nanoOfSecond); + return encodeTimestampNanosUtc(epochNanos, variant, offset); + } + + public static int encodeTimestampMicrosNtz(long value, Slice variant, int offset) + { + variant.setByte(offset, primitiveHeader(TIMESTAMP_NTZ_MICROS)); + variant.setLong(offset + 1, value); + return ENCODED_TIMESTAMP_SIZE; + } + + public static int encodeTimestampMicrosNtz(LocalDateTime value, Slice variant, int offset) + { + long seconds = value.toEpochSecond(ZoneOffset.UTC); + int nanoOfSecond = value.getNano(); + long micros = seconds * 1_000_000 + (nanoOfSecond / 1_000); + return encodeTimestampMicrosNtz(micros, variant, offset); + } + + public static int encodeTimestampNanosNtz(long value, Slice variant, int offset) + { + variant.setByte(offset, primitiveHeader(TIMESTAMP_NTZ_NANOS)); + variant.setLong(offset + 1, value); + return ENCODED_TIMESTAMP_SIZE; + } + + public static int encodeTimestampNanosNtz(LocalDateTime value, Slice variant, int offset) + { + long seconds = value.toEpochSecond(ZoneOffset.UTC); + int nanoOfSecond = value.getNano(); + long nanos = seconds * 1_000_000_000 + nanoOfSecond; + return encodeTimestampNanosNtz(nanos, variant, offset); + } + + public static int encodedBinarySize(int length) + { + return 5 + length; + } + + public static int encodeBinary(Slice value, Slice variant, int offset) + { + variant.setByte(offset, primitiveHeader(BINARY)); + variant.setInt(offset + 1, value.length()); + variant.setBytes(offset + 5, value); + return encodedBinarySize(value.length()); + } + + public static int encodedStringSize(int length) + { + if (length <= SHORT_STRING_MAX_LENGTH) { + return 1 + length; + } + return 5 + length; + } + + public static int encodeString(Slice value, Slice variant, int offset) + { + if (value.length() <= SHORT_STRING_MAX_LENGTH) { + variant.setByte(offset, shortStringHeader(value.length())); + variant.setBytes(offset + 1, value); + return 1 + value.length(); + } + variant.setByte(offset, primitiveHeader(STRING)); + variant.setInt(offset + 1, value.length()); + variant.setBytes(offset + 5, value); + return 5 + value.length(); + } + + public static int encodeUuid(Slice value, Slice variant, int offset) + { + checkArgument(value.length() == 16, "UUID slice must be 16 bytes long"); + variant.setByte(offset, primitiveHeader(PrimitiveType.UUID)); + variant.setBytes(offset + 1, value); + return ENCODED_UUID_SIZE; + } + + public static int encodeUuid(UUID value, Slice variant, int offset) + { + variant.setByte(offset, primitiveHeader(PrimitiveType.UUID)); + variant.setLong(offset + 1, Long.reverseBytes(value.getMostSignificantBits())); + variant.setLong(offset + 9, Long.reverseBytes(value.getLeastSignificantBits())); + return ENCODED_UUID_SIZE; + } + + public static int encodedArraySize(int elementCount, int totalElementsLength) + { + boolean large = elementCount > 255; + + int offsetSize = getOffsetSize(totalElementsLength); + int offsetsLength = (elementCount + 1) * offsetSize; + + return 1 + (large ? 4 : 1) + offsetsLength + totalElementsLength; + } + + public static int encodeArrayHeading(int elementCount, IntUnaryOperator elementLength, Slice variant, int offset) + { + boolean large = elementCount > 255; + + int totalElementsLength = 0; + for (int i = 0; i < elementCount; i++) { + totalElementsLength += elementLength.applyAsInt(i); + } + int offsetSize = getOffsetSize(totalElementsLength); + int offsetsLength = (elementCount + 1) * offsetSize; + + int headerSize = 1 + (large ? 4 : 1) + offsetsLength; + int expectedVariantSize = headerSize + totalElementsLength; + checkArgument(variant.length() >= offset + expectedVariantSize, () -> "Variant slice is too small to encode array of size " + expectedVariantSize); + + int position = offset; + // write header + variant.setByte(position, arrayHeader(offsetSize, large)); + position += 1; + + // write element count + if (large) { + variant.setInt(offset + 1, elementCount); + position += 4; + } + else { + variant.setByte(offset + 1, (byte) elementCount); + position += 1; + } + + // write offsets + int dataOffset = 0; + writeOffset(variant, position, dataOffset, offsetSize); + position += offsetSize; + for (int i = 0; i < elementCount; i++) { + dataOffset += elementLength.applyAsInt(i); + writeOffset(variant, position, dataOffset, offsetSize); + position += offsetSize; + } + verify(position == offset + headerSize, "Encoded size does not match expected size"); + return headerSize; + } + + public static int encodeArray(List elements, Slice variant, int offset) + { + int written = encodeArrayHeading(elements.size(), index -> elements.get(index).length(), variant, offset); + + // write elements + for (Slice element : elements) { + variant.setBytes(offset + written, element); + written += element.length(); + } + return written; + } + + public static int encodedObjectSize(int maxField, int elementCount, int totalElementsLength) + { + boolean large = elementCount > 255; + + int fieldIdSize = getOffsetSize(maxField); + int fieldIdsLength = elementCount * fieldIdSize; + int offsetSize = getOffsetSize(totalElementsLength); + int offsetsLength = (elementCount + 1) * offsetSize; + + return 1 + (large ? 4 : 1) + fieldIdsLength + offsetsLength + totalElementsLength; + } + + /// Encodes an object with the given field count, field IDs, and field values into the provided variant slice at the specified offset. + /// The field IDs and field values are provided as functions that take an index and return the corresponding value. The field IDs must + /// be returned in sorted order, or the resulting encoding will be invalid. + public static int encodeObject(int fieldCount, IntUnaryOperator fieldIds, IntFunction fieldValue, Slice variant, int offset) + { + int written = encodeObjectHeading( + fieldCount, + fieldIds, + index -> fieldValue.apply(index).length(), + variant, + offset); + for (int i = 0; i < fieldCount; i++) { + Slice data = fieldValue.apply(i); + variant.setBytes(offset + written, data); + written += data.length(); + } + return written; + } + + /// Encodes the heading of an object with the given field count, field IDs, and field lengths into the provided variant slice at the specified offset. + /// The field IDs and field lengths are provided as functions that take an index and return the corresponding value. The field IDs must + /// be returned in sorted order, or the resulting encoding will be invalid. + public static int encodeObjectHeading(int fieldCount, IntUnaryOperator fieldIds, IntUnaryOperator fieldLength, Slice variant, int offset) + { + if (fieldCount == 0) { + checkArgument(variant.length() >= offset + ENCODED_EMPTY_OBJECT_SIZE, "Variant slice is too small to encode empty object"); + variant.setByte(offset, objectHeader(1, 1, false)); + variant.setByte(offset + 1, 0); // zero elements + variant.setByte(offset + 2, 0); // first offest is zero (required) + return ENCODED_EMPTY_OBJECT_SIZE; + } + + boolean large = fieldCount > 255; + + int maxFieldId = -1; + int totalElementsLength = 0; + for (int i = 0; i < fieldCount; i++) { + maxFieldId = max(maxFieldId, fieldIds.applyAsInt(i)); + totalElementsLength += fieldLength.applyAsInt(i); + } + + int fieldIdSize = getOffsetSize(maxFieldId); + int fieldIdsLength = fieldCount * fieldIdSize; + int offsetSize = getOffsetSize(totalElementsLength); + int offsetsLength = (fieldCount + 1) * offsetSize; + + int headerSize = 1 + (large ? 4 : 1) + fieldIdsLength + offsetsLength; + int expectedVariantSize = headerSize + totalElementsLength; + checkArgument(variant.length() >= offset + expectedVariantSize, () -> "Variant slice is too small to encode object of size " + expectedVariantSize); + + int position = offset; + // write header + variant.setByte(position, objectHeader(fieldIdSize, offsetSize, large)); + position += 1; + + // write element count + if (large) { + variant.setInt(position, fieldCount); + position += 4; + } + else { + variant.setByte(position, (byte) fieldCount); + position += 1; + } + + // write field IDs + for (int i = 0; i < fieldCount; i++) { + writeOffset(variant, position, fieldIds.applyAsInt(i), fieldIdSize); + position += fieldIdSize; + } + + // write offsets + int dataOffset = 0; + writeOffset(variant, position, dataOffset, offsetSize); + position += offsetSize; + for (int i = 0; i < fieldCount; i++) { + dataOffset += fieldLength.applyAsInt(i); + writeOffset(variant, position, dataOffset, offsetSize); + position += offsetSize; + } + + verify(position == offset + headerSize, "Encoded size does not match expected size"); + return headerSize; + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/variant/VariantEquality.java b/core/trino-spi/src/main/java/io/trino/spi/variant/VariantEquality.java new file mode 100644 index 000000000000..cd373fe21d3a --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/variant/VariantEquality.java @@ -0,0 +1,452 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.variant; + +import io.airlift.slice.Slice; +import io.trino.spi.type.Int128; +import io.trino.spi.type.Int128Math; + +import java.lang.runtime.ExactConversionsSupport; + +import static io.trino.spi.type.Decimals.longTenToNth; +import static io.trino.spi.variant.Header.BasicType.SHORT_STRING; +import static io.trino.spi.variant.Header.PrimitiveType; +import static io.trino.spi.variant.Header.PrimitiveType.DOUBLE; +import static io.trino.spi.variant.Header.PrimitiveType.FLOAT; +import static io.trino.spi.variant.Header.PrimitiveType.INT16; +import static io.trino.spi.variant.Header.PrimitiveType.INT32; +import static io.trino.spi.variant.Header.PrimitiveType.INT64; +import static io.trino.spi.variant.Header.PrimitiveType.INT8; +import static io.trino.spi.variant.Header.arrayFieldOffsetSize; +import static io.trino.spi.variant.Header.arrayIsLarge; +import static io.trino.spi.variant.Header.getBasicType; +import static io.trino.spi.variant.Header.getPrimitiveType; +import static io.trino.spi.variant.Header.objectFieldIdSize; +import static io.trino.spi.variant.Header.objectFieldOffsetSize; +import static io.trino.spi.variant.Header.objectIsLarge; +import static io.trino.spi.variant.Header.shortStringLength; +import static io.trino.spi.variant.VariantUtils.readOffset; + +final class VariantEquality +{ + private VariantEquality() {} + + public static boolean equals( + Metadata leftMetadata, + Slice leftSlice, + int leftOffset, + Metadata rightMetadata, + Slice rightSlice, + int rightOffset) + { + if (leftMetadata == rightMetadata && leftSlice == rightSlice && leftOffset == rightOffset) { + return true; + } + + byte leftHeader = leftSlice.getByte(leftOffset); + byte rightHeader = rightSlice.getByte(rightOffset); + + ValueClass leftValueClass = ValueClass.classify(leftHeader); + ValueClass rightValueClass = ValueClass.classify(rightHeader); + + if (leftValueClass != rightValueClass) { + return false; + } + + return switch (leftValueClass) { + case NULL -> true; + case BOOLEAN -> leftHeader == rightHeader; + case NUMERIC -> equalsNumeric(leftHeader, leftSlice, leftOffset, rightHeader, rightSlice, rightOffset); + case DATE -> leftSlice.getInt(leftOffset + 1) == rightSlice.getInt(rightOffset + 1); + case TIME_NTZ -> leftSlice.getLong(leftOffset + 1) == rightSlice.getLong(rightOffset + 1); + case TIMESTAMP_UTC, TIMESTAMP_NTZ -> equalsTimestamp(leftHeader, leftSlice, leftOffset, rightHeader, rightSlice, rightOffset); + case BINARY -> leftSlice.equals(leftOffset + 5, leftSlice.getInt(leftOffset + 1), rightSlice, rightOffset + 5, rightSlice.getInt(rightOffset + 1)); + case STRING -> equalsStringLike(leftSlice, leftOffset, leftHeader, rightSlice, rightOffset, rightHeader); + case UUID -> leftSlice.equals(leftOffset + 1, 16, rightSlice, rightOffset + 1, 16); + case OBJECT -> equalsObject(leftMetadata, leftSlice, leftOffset, rightMetadata, rightSlice, rightOffset); + case ARRAY -> equalsArray(leftMetadata, leftSlice, leftOffset, rightMetadata, rightSlice, rightOffset); + }; + } + + private static boolean equalsNumeric( + byte leftHeader, + Slice leftSlice, + int leftOffset, + byte rightHeader, + Slice rightSlice, + int rightOffset) + { + PrimitiveType leftType = getPrimitiveType(leftHeader); + PrimitiveType rightType = getPrimitiveType(rightHeader); + + if (VariantUtils.isExactNumeric(leftType) && VariantUtils.isExactNumeric(rightType)) { + return equalsExactNumeric(leftHeader, leftSlice, leftOffset, rightHeader, rightSlice, rightOffset); + } + + if (VariantUtils.isFloatingNumeric(leftType) && VariantUtils.isFloatingNumeric(rightType)) { + return equalsFloatingNumeric(leftType, leftSlice, leftOffset, rightType, rightSlice, rightOffset); + } + + if (VariantUtils.isExactNumeric(leftType)) { + return equalsExactAndFloatingNumeric(leftType, leftSlice, leftOffset, rightType, rightSlice, rightOffset); + } + return equalsExactAndFloatingNumeric(rightType, rightSlice, rightOffset, leftType, leftSlice, leftOffset); + } + + private static boolean equalsExactNumeric( + byte leftHeader, + Slice leftSlice, + int leftOffset, + byte rightHeader, + Slice rightSlice, + int rightOffset) + { + PrimitiveType leftType = getPrimitiveType(leftHeader); + PrimitiveType rightType = getPrimitiveType(rightHeader); + + if ((leftType == INT8 || leftType == INT16 || leftType == INT32 || leftType == INT64) && + (rightType == INT8 || rightType == INT16 || rightType == INT32 || rightType == INT64)) { + long leftValue = switch (leftType) { + case INT8 -> leftSlice.getByte(leftOffset + 1); + case INT16 -> leftSlice.getShort(leftOffset + 1); + case INT32 -> leftSlice.getInt(leftOffset + 1); + case INT64 -> leftSlice.getLong(leftOffset + 1); + default -> throw new VerifyException("Unexpected integer type: " + leftType); + }; + long rightValue = switch (rightType) { + case INT8 -> rightSlice.getByte(rightOffset + 1); + case INT16 -> rightSlice.getShort(rightOffset + 1); + case INT32 -> rightSlice.getInt(rightOffset + 1); + case INT64 -> rightSlice.getLong(rightOffset + 1); + default -> throw new VerifyException("Unexpected integer type: " + rightType); + }; + return leftValue == rightValue; + } + return VariantUtils.decodeExactNumericCanonical(leftType, leftSlice, leftOffset) + .equals(VariantUtils.decodeExactNumericCanonical(rightType, rightSlice, rightOffset)); + } + + @SuppressWarnings("FloatingPointEquality") + private static boolean equalsFloatingNumeric( + PrimitiveType leftType, + Slice leftSlice, + int leftOffset, + PrimitiveType rightType, + Slice rightSlice, + int rightOffset) + { + double leftValue = VariantUtils.floatingAsDouble(leftType, leftSlice, leftOffset); + double rightValue = VariantUtils.floatingAsDouble(rightType, rightSlice, rightOffset); + + if (Double.isNaN(leftValue) || Double.isNaN(rightValue)) { + return false; + } + + return leftValue == rightValue; + } + + private static boolean equalsExactAndFloatingNumeric( + PrimitiveType exactType, + Slice exactSlice, + int exactOffset, + PrimitiveType floatingType, + Slice floatingSlice, + int floatingOffset) + { + double floatingValue = VariantUtils.floatingAsDouble(floatingType, floatingSlice, floatingOffset); + if (!Double.isFinite(floatingValue)) { + return false; + } + + if (floatingValue == 0.0 || ExactConversionsSupport.isDoubleToLongExact(floatingValue)) { + return equalsExactNumericWithLong(exactType, exactSlice, exactOffset, (long) floatingValue); + } + + VariantUtils.Decimal128Canonical floatingDecimal = VariantUtils.tryToDecimal128Exact(floatingValue); + if (floatingDecimal == null) { + return false; + } + + VariantUtils.Decimal128Canonical exactDecimal = VariantUtils.decodeExactNumericCanonical(exactType, exactSlice, exactOffset); + return exactDecimal.equals(floatingDecimal); + } + + private static boolean equalsExactNumericWithLong(PrimitiveType exactType, Slice exactSlice, int exactOffset, long value) + { + return switch (exactType) { + case INT8 -> exactSlice.getByte(exactOffset + 1) == value; + case INT16 -> exactSlice.getShort(exactOffset + 1) == value; + case INT32 -> exactSlice.getInt(exactOffset + 1) == value; + case INT64 -> exactSlice.getLong(exactOffset + 1) == value; + case DECIMAL4 -> decimalEqualsLong(exactSlice.getInt(exactOffset + 2), exactSlice.getByte(exactOffset + 1) & 0xFF, value); + case DECIMAL8 -> decimalEqualsLong(exactSlice.getLong(exactOffset + 2), exactSlice.getByte(exactOffset + 1) & 0xFF, value); + case DECIMAL16 -> decimalEqualsLong( + Int128.valueOf(exactSlice.getLong(exactOffset + 10), exactSlice.getLong(exactOffset + 2)), + exactSlice.getByte(exactOffset + 1) & 0xFF, + value); + default -> throw new VerifyException("Unexpected exact numeric type: " + exactType); + }; + } + + private static boolean decimalEqualsLong(long unscaled, int scale, long value) + { + if (scale == 0) { + return unscaled == value; + } + if (unscaled == 0 || value == 0) { + return unscaled == 0 && value == 0; + } + if ((unscaled < 0) != (value < 0)) { + return false; + } + if (scale > 18) { + return false; + } + try { + return unscaled == Math.multiplyExact(value, longTenToNth(scale)); + } + catch (ArithmeticException ignored) { + return false; + } + } + + private static boolean decimalEqualsLong(Int128 unscaled, int scale, long value) + { + if (scale == 0) { + return unscaled.getHigh() == (value >> 63) && unscaled.getLow() == value; + } + if (unscaled.isZero() || value == 0) { + return unscaled.isZero() && value == 0; + } + if (unscaled.isNegative() != (value < 0)) { + return false; + } + + if (scale <= 18) { + try { + long scaled = Math.multiplyExact(value, longTenToNth(scale)); + return unscaled.getHigh() == (scaled >> 63) && unscaled.getLow() == scaled; + } + catch (ArithmeticException ignored) { + return false; + } + } + + Int128 powerOfTen = Int128Math.powerOfTen(scale); + long[] scaledValue = new long[2]; + try { + Int128Math.multiply(value >> 63, value, powerOfTen.getHigh(), powerOfTen.getLow(), scaledValue, 0); + } + catch (ArithmeticException ignored) { + return false; + } + return unscaled.getHigh() == scaledValue[0] && unscaled.getLow() == scaledValue[1]; + } + + private static boolean equalsTimestamp( + byte leftHeader, + Slice leftSlice, + int leftOffset, + byte rightHeader, + Slice rightSlice, + int rightOffset) + { + long leftValue = leftSlice.getLong(leftOffset + 1); + long rightValue = rightSlice.getLong(rightOffset + 1); + + if (leftHeader == rightHeader) { + return leftValue == rightValue; + } + + PrimitiveType leftType = getPrimitiveType(leftHeader); + if (leftType == PrimitiveType.TIMESTAMP_UTC_MICROS || leftType == PrimitiveType.TIMESTAMP_NTZ_MICROS) { + return rightValue % 1_000L == 0 && leftValue == (rightValue / 1_000L); + } + return leftValue % 1_000L == 0 && (leftValue / 1_000L) == rightValue; + } + + private static boolean equalsStringLike( + Slice leftSlice, + int leftOffset, + byte leftHeader, + Slice rightSlice, + int rightOffset, + byte rightHeader) + { + int leftStringOffset; + int leftStringLength; + if (getBasicType(leftHeader) == SHORT_STRING) { + leftStringLength = shortStringLength(leftHeader); + leftStringOffset = leftOffset + 1; + } + else { + leftStringLength = leftSlice.getInt(leftOffset + 1); + leftStringOffset = leftOffset + 5; + } + + int rightStringOffset; + int rightStringLength; + if (getBasicType(rightHeader) == SHORT_STRING) { + rightStringLength = shortStringLength(rightHeader); + rightStringOffset = rightOffset + 1; + } + else { + rightStringLength = rightSlice.getInt(rightOffset + 1); + rightStringOffset = rightOffset + 5; + } + + return leftSlice.equals(leftStringOffset, leftStringLength, rightSlice, rightStringOffset, rightStringLength); + } + + private static boolean equalsObject( + Metadata leftMetadata, + Slice leftSlice, + int leftOffset, + Metadata rightMetadata, + Slice rightSlice, + int rightOffset) + { + byte leftHeader = leftSlice.getByte(leftOffset); + byte rightHeader = rightSlice.getByte(rightOffset); + + boolean leftLarge = objectIsLarge(leftHeader); + boolean rightLarge = objectIsLarge(rightHeader); + + int leftCount = leftLarge ? leftSlice.getInt(leftOffset + 1) : (leftSlice.getByte(leftOffset + 1) & 0xFF); + int rightCount = rightLarge ? rightSlice.getInt(rightOffset + 1) : (rightSlice.getByte(rightOffset + 1) & 0xFF); + if (leftCount != rightCount) { + return false; + } + + int leftIdSize = objectFieldIdSize(leftHeader); + int leftFieldOffsetSize = objectFieldOffsetSize(leftHeader); + int rightIdSize = objectFieldIdSize(rightHeader); + int rightFieldOffsetSize = objectFieldOffsetSize(rightHeader); + + int leftIdsStart = leftOffset + 1 + (leftLarge ? 4 : 1); + int rightIdsStart = rightOffset + 1 + (rightLarge ? 4 : 1); + int leftOffsetsStart = leftIdsStart + leftCount * leftIdSize; + int rightOffsetsStart = rightIdsStart + rightCount * rightIdSize; + int leftValuesStart = leftOffsetsStart + (leftCount + 1) * leftFieldOffsetSize; + int rightValuesStart = rightOffsetsStart + (rightCount + 1) * rightFieldOffsetSize; + + if (leftMetadata == rightMetadata) { + for (int index = 0; index < leftCount; index++) { + int leftFieldId = readOffset(leftSlice, leftIdsStart + index * leftIdSize, leftIdSize); + int rightFieldId = readOffset(rightSlice, rightIdsStart + index * rightIdSize, rightIdSize); + if (leftFieldId != rightFieldId) { + return false; + } + } + } + else { + for (int index = 0; index < leftCount; index++) { + int leftFieldId = readOffset(leftSlice, leftIdsStart + index * leftIdSize, leftIdSize); + int rightFieldId = readOffset(rightSlice, rightIdsStart + index * rightIdSize, rightIdSize); + + Slice leftKey = leftMetadata.get(leftFieldId); + Slice rightKey = rightMetadata.get(rightFieldId); + if (!leftKey.equals(rightKey)) { + return false; + } + } + } + + for (int index = 0; index < leftCount; index++) { + int leftValueStart = leftValuesStart + readOffset(leftSlice, leftOffsetsStart + index * leftFieldOffsetSize, leftFieldOffsetSize); + int rightValueStart = rightValuesStart + readOffset(rightSlice, rightOffsetsStart + index * rightFieldOffsetSize, rightFieldOffsetSize); + + if (!equals(leftMetadata, leftSlice, leftValueStart, rightMetadata, rightSlice, rightValueStart)) { + return false; + } + } + + return true; + } + + private static boolean equalsArray( + Metadata leftMetadata, + Slice leftSlice, + int leftOffset, + Metadata rightMetadata, + Slice rightSlice, + int rightOffset) + { + byte leftHeader = leftSlice.getByte(leftOffset); + byte rightHeader = rightSlice.getByte(rightOffset); + + boolean leftLarge = arrayIsLarge(leftHeader); + boolean rightLarge = arrayIsLarge(rightHeader); + + int leftCount = leftLarge ? leftSlice.getInt(leftOffset + 1) : (leftSlice.getByte(leftOffset + 1) & 0xFF); + int rightCount = rightLarge ? rightSlice.getInt(rightOffset + 1) : (rightSlice.getByte(rightOffset + 1) & 0xFF); + if (leftCount != rightCount) { + return false; + } + + int leftOffsetSize = arrayFieldOffsetSize(leftHeader); + int rightOffsetSize = arrayFieldOffsetSize(rightHeader); + int leftOffsetsStart = leftOffset + 1 + (leftLarge ? 4 : 1); + int rightOffsetsStart = rightOffset + 1 + (rightLarge ? 4 : 1); + int leftValuesStart = leftOffsetsStart + (leftCount + 1) * leftOffsetSize; + int rightValuesStart = rightOffsetsStart + (rightCount + 1) * rightOffsetSize; + + for (int index = 0; index < leftCount; index++) { + int leftElementStart = leftValuesStart + readOffset(leftSlice, leftOffsetsStart + index * leftOffsetSize, leftOffsetSize); + int rightElementStart = rightValuesStart + readOffset(rightSlice, rightOffsetsStart + index * rightOffsetSize, rightOffsetSize); + + if (!equals(leftMetadata, leftSlice, leftElementStart, rightMetadata, rightSlice, rightElementStart)) { + return false; + } + } + return true; + } + + private enum ValueClass + { + NULL, + BOOLEAN, + NUMERIC, + DATE, + TIME_NTZ, + TIMESTAMP_UTC, + TIMESTAMP_NTZ, + BINARY, + STRING, + UUID, + OBJECT, + ARRAY; + + private static ValueClass classify(byte header) + { + return switch (getBasicType(header)) { + case PRIMITIVE -> switch (getPrimitiveType(header)) { + case NULL -> NULL; + case BOOLEAN_TRUE, BOOLEAN_FALSE -> BOOLEAN; + case INT8, INT16, INT32, INT64, DECIMAL4, DECIMAL8, DECIMAL16, FLOAT, DOUBLE -> NUMERIC; + case DATE -> DATE; + case TIME_NTZ_MICROS -> TIME_NTZ; + case TIMESTAMP_UTC_MICROS, TIMESTAMP_UTC_NANOS -> TIMESTAMP_UTC; + case TIMESTAMP_NTZ_MICROS, TIMESTAMP_NTZ_NANOS -> TIMESTAMP_NTZ; + case BINARY -> BINARY; + case STRING -> STRING; + case UUID -> UUID; + }; + case SHORT_STRING -> STRING; + case OBJECT -> OBJECT; + case ARRAY -> ARRAY; + }; + } + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/variant/VariantFieldRemapper.java b/core/trino-spi/src/main/java/io/trino/spi/variant/VariantFieldRemapper.java new file mode 100644 index 000000000000..3b24fbb7122f --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/variant/VariantFieldRemapper.java @@ -0,0 +1,362 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.variant; + +import io.airlift.slice.Slice; + +import java.util.Arrays; +import java.util.function.IntUnaryOperator; + +import static io.trino.spi.variant.VariantDecoder.decode; +import static io.trino.spi.variant.VariantUtils.getOffsetSize; +import static io.trino.spi.variant.VariantUtils.verify; +import static java.lang.Math.max; +import static java.util.Objects.checkFromIndexSize; +import static java.util.Objects.requireNonNull; + +/// Remaps field IDs in a Variant for a new metadata dictionary. +/// Remapping is necessary when merging Variants with different metadata dictionaries into +/// a single Variant (e.g., array or object) which requires a unified metadata dictionary. +/// Remapping is a complex operation because field IDs are encoded depending on the maximum field ID +/// and object sizes are encoded depending on the total size of the object. Thus, when +/// multiple Variants are merged, the field IDs and object sizes may change, and the entire Variant +/// must be rewritten with the new encodings. +/// +/// Remapping is done in two phases: +/// 1. Initial creation of the remapper with provisional field IDs assigned to the fields +/// using the `Metadata.Builder` to track which fields are used. +/// 2. Finalization of the remapper where provisional field IDs are updated to final field IDs +/// after the globally sorted metadata dictionary is created. After finalization, the +/// final size of the remapped Variant can be fetched with the `size()` method. +/// 3. Writing the remapped Variant to an output slice with the `write()` method. +public final class VariantFieldRemapper +{ + private enum RemapMode + { + NONE, + IDENTITY, + SAME_SIZE, + RESIZE, + } + + private final int[] fieldIdMapping; + private final int originalFieldIdEncodedWidth; + + private final Slice variant; + + private RemapMode remapMode; + private int size = -1; + + private Int2IntOpenHashMap containerSizeCache; + + public static VariantFieldRemapper create(Variant variant, Metadata.Builder metadataBuilder) + { + requireNonNull(variant, "variant is null"); + requireNonNull(metadataBuilder, "metadataBuilder is null"); + + // Fast path: no fields, so there is nothing to merge/remap. + if (variant.metadata().dictionarySize() == 0) { + return new VariantFieldRemapper(variant.data()); + } + return new Builder(variant, metadataBuilder).build(); + } + + private VariantFieldRemapper(Slice variant) + { + this.variant = requireNonNull(variant, "variant is null"); + + this.fieldIdMapping = new int[0]; + this.originalFieldIdEncodedWidth = -1; + this.remapMode = RemapMode.NONE; + this.size = variant.length(); + } + + private VariantFieldRemapper(Slice variant, int[] fieldIdMapping, int originalFieldIdEncodedWidth) + { + this.fieldIdMapping = requireNonNull(fieldIdMapping, "fieldIdMapping is null"); + this.originalFieldIdEncodedWidth = originalFieldIdEncodedWidth; + this.variant = requireNonNull(variant, "variant is null"); + } + + /// Finalizes the field remapping value by updating the provisional field IDs to final field IDs. + /// The system creates a globally sorted metadata dictionary after all values have been planned, + /// which may change the field IDs assigned during initial setup. + /// This method must be called before `size()` or `remapVariant()`. + // Note: This method can rely on the field IDs being assigned in ascending order for determining the write order of object fields. + public void finalize(IntUnaryOperator remapFieldIds) + { + if (remapMode == RemapMode.NONE) { + return; + } + if (remapMode != null) { + throw new IllegalStateException("finalize() already called"); + } + + boolean identity = true; + int maxFieldId = -1; + for (int variantFieldId = 0; variantFieldId < fieldIdMapping.length; variantFieldId++) { + int provisionalFieldId = fieldIdMapping[variantFieldId]; + if (provisionalFieldId >= 0) { + int finalFieldId = remapFieldIds.applyAsInt(provisionalFieldId); + fieldIdMapping[variantFieldId] = finalFieldId; + identity &= (finalFieldId == variantFieldId); + maxFieldId = max(maxFieldId, finalFieldId); + } + } + if (identity) { + remapMode = RemapMode.IDENTITY; + } + else if (originalFieldIdEncodedWidth == 1 && getOffsetSize(maxFieldId) == 1) { + // fast path where original and final field id encoded widths are both 1 byte + // this cannot be used if size is > 1, because an object inside the + // variant could have originally been encoded with a larger field offset size, and + // with the compacted metadata dictionary, the max field id could be smaller, allowing + // the field offset size to be reduced. + remapMode = RemapMode.SAME_SIZE; + } + else { + remapMode = RemapMode.RESIZE; + } + + // compute size of remapped variant data + if (remapMode == RemapMode.IDENTITY) { + size = variant.length(); + } + else { + containerSizeCache = new Int2IntOpenHashMap(16); + size = calculateFullyRemappedSize(variant, 0, variant.length(), containerSizeCache); + } + } + + /// Returns the size, in bytes, required to write the variant. + /// This cannot be called before `finalize()`. + public int size() + { + if (remapMode == null) { + throw new IllegalStateException("size() called before finalize()"); + } + return size; + } + + /// Writes the value to the given output slice at the specified offset. + /// This must be called after `finalize()`. + /// This method can be called multiple times to write the same value to different output slices. + /// + /// @return the number of bytes written, which is equal to `size()` + public int write(Slice output, int outputOffset) + { + if (remapMode == null) { + throw new IllegalStateException("remapVariant() called before finalize()"); + } + checkFromIndexSize(outputOffset, size, output.length()); + + if (remapMode == RemapMode.IDENTITY || remapMode == RemapMode.NONE) { + output.setBytes(outputOffset, variant); + return size; + } + + int written = write(variant, 0, output, outputOffset, variant.length()); + verify(written == size, "unexpected size mismatch in remapVariant"); + return written; + } + + private int calculateFullyRemappedSize(Slice data, int offset, int length, Int2IntOpenHashMap containerSizeCache) + { + return switch (decode(data, offset)) { + case VariantDecoder.ArrayLayout array -> { + int totalChildrenLength = 0; + for (int index = 0; index < array.count(); index++) { + int start = array.elementStart(index); + int end = array.elementEnd(index); + totalChildrenLength += calculateFullyRemappedSize(data, start, end - start, containerSizeCache); + } + int arrayTotalSize = VariantEncoder.encodedArraySize(array.count(), totalChildrenLength); + containerSizeCache.putIfAbsent(offset, arrayTotalSize); + yield arrayTotalSize; + } + case VariantDecoder.ObjectLayout object -> { + int maxFieldId = -1; + int totalChildrenLength = 0; + + for (int index = 0; index < object.count(); index++) { + int remappedFieldId = fieldIdMapping[object.fieldId(index)]; + maxFieldId = max(maxFieldId, remappedFieldId); + + int start = object.valueStart(index); + int end = object.valueEnd(index); + totalChildrenLength += calculateFullyRemappedSize(data, start, end - start, containerSizeCache); + } + + int objectTotalSize = VariantEncoder.encodedObjectSize(maxFieldId, object.count(), totalChildrenLength); + containerSizeCache.putIfAbsent(offset, objectTotalSize); + yield objectTotalSize; + } + case VariantDecoder.PrimitiveLayout _ -> length; + }; + } + + private int write(Slice input, int inputOffset, Slice output, int outputOffset, int length) + { + return switch (decode(input, inputOffset)) { + case VariantDecoder.ArrayLayout array -> { + // write array header + int written = switch (remapMode) { + case SAME_SIZE -> { + // Header is unchanged byte-for-byte when SAME_SIZE. + int headerSize = array.headerSize(); + output.setBytes(outputOffset, input, inputOffset, headerSize); + yield headerSize; + } + case RESIZE -> VariantEncoder.encodeArrayHeading( + array.count(), + i -> { + int start = array.elementStart(i); + + int cachedContainerSize = containerSizeCache.get(start); + if (cachedContainerSize != Int2IntOpenHashMap.DEFAULT_RETURN_VALUE) { + return cachedContainerSize; + } + return array.elementEnd(i) - start; + }, + output, + outputOffset); + case IDENTITY, NONE -> throw new VerifyException("unexpected remap mode " + remapMode); + }; + + // write remapped elements + for (int index = 0; index < array.count(); index++) { + int start = array.elementStart(index); + int end = array.elementEnd(index); + written += write(input, start, output, outputOffset + written, end - start); + } + + int expectedSize = (remapMode == RemapMode.SAME_SIZE) ? length : containerSizeCache.get(inputOffset); + verify(written == expectedSize, "unexpected size mismatch in complete remap"); + yield written; + } + case VariantDecoder.ObjectLayout object -> { + // write the object header with remapped field IDs + int written = VariantEncoder.encodeObjectHeading( + object.count(), + i -> fieldIdMapping[object.fieldId(i)], + i -> { + int start = object.valueStart(i); + if (remapMode == RemapMode.RESIZE) { + int cachedSize = containerSizeCache.get(start); + if (cachedSize != Int2IntOpenHashMap.DEFAULT_RETURN_VALUE) { + return cachedSize; + } + } + return object.valueEnd(i) - start; + }, + output, + outputOffset); + + // write remapped elements + for (int index = 0; index < object.count(); index++) { + int start = object.valueStart(index); + int end = object.valueEnd(index); + written += write(input, start, output, outputOffset + written, end - start); + } + + int expectedSize = (remapMode == RemapMode.SAME_SIZE) ? length : containerSizeCache.get(inputOffset); + verify(written == expectedSize, "unexpected size mismatch in complete remap"); + yield written; + } + case VariantDecoder.PrimitiveLayout _ -> { + output.setBytes(outputOffset, input, inputOffset, length); + yield length; + } + }; + } + + private static final class Builder + { + private final Variant variant; + private final Metadata.Builder metadataBuilder; + private int[] variantFieldIdToProvisionalFieldId; + private int maxEnabledFieldId = -1; + private int enabledFieldCount; + + private Builder(Variant variant, Metadata.Builder metadataBuilder) + { + this.variant = requireNonNull(variant, "variant is null"); + this.metadataBuilder = requireNonNull(metadataBuilder, "metadataBuilder is null"); + + if (variant.metadata().dictionarySize() == 0) { + variantFieldIdToProvisionalFieldId = new int[0]; + return; + } + + mapVariantFields(variant.data(), 0); + } + + private boolean isFullyMapped() + { + return enabledFieldCount == variant.metadata().dictionarySize(); + } + + public void enableField(int variantFieldId) + { + if (variantFieldIdToProvisionalFieldId == null) { + variantFieldIdToProvisionalFieldId = new int[variant.metadata().dictionarySize()]; + Arrays.fill(variantFieldIdToProvisionalFieldId, -1); + } + if (variantFieldIdToProvisionalFieldId[variantFieldId] == -1) { + variantFieldIdToProvisionalFieldId[variantFieldId] = metadataBuilder.addFieldName(variant.metadata().get(variantFieldId)); + if (variantFieldId > maxEnabledFieldId) { + maxEnabledFieldId = variantFieldId; + } + enabledFieldCount++; + } + } + + private void mapVariantFields(Slice data, int offset) + { + switch (decode(data, offset)) { + case VariantDecoder.ArrayLayout array -> { + for (int index = 0; index < array.count(); index++) { + mapVariantFields(data, array.elementStart(index)); + if (isFullyMapped()) { + return; + } + } + } + case VariantDecoder.ObjectLayout object -> { + for (int index = 0; index < object.count(); index++) { + enableField(object.fieldId(index)); + if (isFullyMapped()) { + return; + } + mapVariantFields(data, object.valueStart(index)); + if (isFullyMapped()) { + return; + } + } + } + case VariantDecoder.PrimitiveLayout _ -> { + // No fields to map + } + } + } + + public VariantFieldRemapper build() + { + if (variantFieldIdToProvisionalFieldId == null) { + return new VariantFieldRemapper(variant.data()); + } + return new VariantFieldRemapper(variant.data(), variantFieldIdToProvisionalFieldId, getOffsetSize(maxEnabledFieldId)); + } + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/variant/VariantHashing.java b/core/trino-spi/src/main/java/io/trino/spi/variant/VariantHashing.java new file mode 100644 index 000000000000..059f1f52b526 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/variant/VariantHashing.java @@ -0,0 +1,313 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.variant; + +import io.airlift.slice.Slice; +import io.airlift.slice.XxHash64; +import io.trino.spi.type.Int128; + +import java.lang.runtime.ExactConversionsSupport; + +import static io.trino.spi.variant.Header.BasicType.SHORT_STRING; +import static io.trino.spi.variant.Header.PrimitiveType; +import static io.trino.spi.variant.Header.PrimitiveType.DOUBLE; +import static io.trino.spi.variant.Header.PrimitiveType.FLOAT; +import static io.trino.spi.variant.Header.PrimitiveType.INT16; +import static io.trino.spi.variant.Header.PrimitiveType.INT32; +import static io.trino.spi.variant.Header.PrimitiveType.INT64; +import static io.trino.spi.variant.Header.PrimitiveType.INT8; +import static io.trino.spi.variant.Header.arrayFieldOffsetSize; +import static io.trino.spi.variant.Header.arrayIsLarge; +import static io.trino.spi.variant.Header.getBasicType; +import static io.trino.spi.variant.Header.getPrimitiveType; +import static io.trino.spi.variant.Header.objectFieldIdSize; +import static io.trino.spi.variant.Header.objectFieldOffsetSize; +import static io.trino.spi.variant.Header.objectIsLarge; +import static io.trino.spi.variant.Header.shortStringLength; +import static io.trino.spi.variant.VariantUtils.readOffset; +import static java.lang.Double.doubleToLongBits; + +final class VariantHashing +{ + private VariantHashing() {} + + public static long hashCode(Metadata metadata, Slice slice, int offset) + { + VariantHash variantHash = new VariantHash(0); + variantHash.hashVariant(metadata, slice, offset); + return variantHash.finish(); + } + + private static final class VariantHash + { + private static final long PRIME64_1 = 0x9E3779B185EBCA87L; + private static final long PRIME64_2 = 0xC2B2AE3D27D4EB4FL; + private static final long PRIME64_3 = 0x165667B19E3779F9L; + private static final long PRIME64_4 = 0x85EBCA77C2B2AE63L; + private static final long PRIME64_5 = 0x27D4EB2F165667C5L; + + private long hash; + private long totalLength; + + VariantHash(long seed) + { + this.hash = seed + PRIME64_5; + } + + public void hashVariant(Metadata metadata, Slice slice, int offset) + { + byte header = slice.getByte(offset); + ValueClass valueClass = ValueClass.classify(header); + addInt(valueClass.hashTag()); + switch (valueClass) { + case NULL -> addLong(0); + case BOOLEAN -> addLong(getPrimitiveType(header) == PrimitiveType.BOOLEAN_TRUE ? 1 : 2); + case NUMERIC -> hashNumeric(getPrimitiveType(header), slice, offset); + case DATE -> addInt(slice.getInt(offset + 1)); + case TIME_NTZ -> addLong(slice.getLong(offset + 1)); + case TIMESTAMP_UTC, TIMESTAMP_NTZ -> hashTimestamp(header, slice, offset); + case BINARY -> addBytesHash(slice, offset + 5, slice.getInt(offset + 1)); + case STRING -> hashStringLike(slice, offset, header); + case UUID -> addBytesHash(slice, offset + 1, 16); + case OBJECT -> hashObject(metadata, slice, offset); + case ARRAY -> hashArray(metadata, slice, offset); + } + } + + private void hashNumeric(PrimitiveType primitiveType, Slice slice, int offset) + { + if (VariantUtils.isExactNumeric(primitiveType)) { + hashCanonicalDecimal(VariantUtils.decodeExactNumericCanonical(primitiveType, slice, offset)); + return; + } + + double value = VariantUtils.floatingAsDouble(primitiveType, slice, offset); + if (Double.isNaN(value)) { + addInt(4); + return; + } + if (Double.isInfinite(value)) { + addInt(value > 0 ? 2 : 3); + return; + } + if (value == 0.0 || ExactConversionsSupport.isDoubleToLongExact(value)) { + hashCanonicalLongDecimal((long) value); + return; + } + + VariantUtils.Decimal128Canonical decimal = VariantUtils.tryToDecimal128Exact(value); + if (decimal != null) { + hashCanonicalDecimal(decimal); + return; + } + + addInt(1); + addDouble(value); + } + + private void hashCanonicalDecimal(VariantUtils.Decimal128Canonical decimal) + { + addInt(0); + if (decimal.scale() != 0) { + addInt(decimal.scale()); + } + Int128 unscaled = decimal.unscaled(); + if (VariantUtils.fitsInLong(unscaled)) { + addLong(unscaled.getLow()); + return; + } + addLong(unscaled.getHigh()); + addLong(unscaled.getLow()); + } + + private void hashCanonicalLongDecimal(long value) + { + addInt(0); + addLong(value); + } + + private void addInt(int value) + { + addLong(value); + } + + private void addLong(long value) + { + totalLength += Long.BYTES; + hash ^= round(value); + hash = Long.rotateLeft(hash, 27) * PRIME64_1 + PRIME64_4; + } + + private void addDouble(double value) + { + addLong(doubleToLongBits(value)); + } + + void addBytesHash(Slice slice, int offset, int length) + { + long bytesXxHash = XxHash64.hash(slice, offset, length); + addBytesHash(bytesXxHash, length); + } + + void addBytesHash(long bytesXxHash, int length) + { + totalLength += length; + hash ^= round(bytesXxHash); + hash = Long.rotateLeft(hash, 27) * PRIME64_1 + PRIME64_4; + } + + private void hashStringLike(Slice slice, int offset, byte header) + { + int stringOffset; + int stringLength; + if (getBasicType(header) == SHORT_STRING) { + stringLength = shortStringLength(header); + stringOffset = offset + 1; + } + else { + stringLength = slice.getInt(offset + 1); + stringOffset = offset + 5; + } + + addBytesHash(slice, stringOffset, stringLength); + } + + private void hashTimestamp(byte header, Slice slice, int offset) + { + long value = slice.getLong(offset + 1); + + PrimitiveType primitiveType = getPrimitiveType(header); + if (primitiveType == PrimitiveType.TIMESTAMP_UTC_MICROS || primitiveType == PrimitiveType.TIMESTAMP_NTZ_MICROS) { + addLong(0); + addLong(value); + return; + } + + if (value % 1_000L == 0) { + addLong(0); + addLong(value / 1_000L); + return; + } + addLong(1); + addLong(value); + } + + private void hashObject(Metadata metadata, Slice slice, int offset) + { + byte header = slice.getByte(offset); + boolean large = objectIsLarge(header); + int count = large ? slice.getInt(offset + 1) : (slice.getByte(offset + 1) & 0xFF); + + int idSize = objectFieldIdSize(header); + int fieldOffsetSize = objectFieldOffsetSize(header); + + int idsStart = offset + 1 + (large ? 4 : 1); + int offsetsStart = idsStart + count * idSize; + int valuesStart = offsetsStart + (count + 1) * fieldOffsetSize; + + addInt(count); + for (int index = 0; index < count; index++) { + int fieldId = readOffset(slice, idsStart + index * idSize, idSize); + Slice key = metadata.get(fieldId); + + int valueStart = valuesStart + readOffset(slice, offsetsStart + index * fieldOffsetSize, fieldOffsetSize); + addBytesHash(key, 0, key.length()); + hashVariant(metadata, slice, valueStart); + } + } + + private void hashArray(Metadata metadata, Slice slice, int offset) + { + byte header = slice.getByte(offset); + boolean large = arrayIsLarge(header); + int count = large ? slice.getInt(offset + 1) : (slice.getByte(offset + 1) & 0xFF); + + int offsetSize = arrayFieldOffsetSize(header); + int offsetsStart = offset + 1 + (large ? 4 : 1); + int valuesStart = offsetsStart + (count + 1) * offsetSize; + + addInt(count); + for (int index = 0; index < count; index++) { + int elementStart = valuesStart + readOffset(slice, offsetsStart + index * offsetSize, offsetSize); + hashVariant(metadata, slice, elementStart); + } + } + + long finish() + { + long h = hash + totalLength; + + h ^= h >>> 33; + h *= PRIME64_2; + h ^= h >>> 29; + h *= PRIME64_3; + h ^= h >>> 32; + return h; + } + + private static long round(long value) + { + return Long.rotateLeft(value * PRIME64_2, 31) * PRIME64_1; + } + } + + private enum ValueClass + { + NULL(0), + BOOLEAN(1), + NUMERIC(2), + DATE(5), + TIME_NTZ(6), + TIMESTAMP_UTC(7), + TIMESTAMP_NTZ(8), + BINARY(9), + STRING(10), + UUID(11), + OBJECT(12), + ARRAY(13); + + private final int hashTag; + + ValueClass(int hashTag) + { + this.hashTag = hashTag; + } + + private int hashTag() + { + return hashTag; + } + + private static ValueClass classify(byte header) + { + return switch (getBasicType(header)) { + case PRIMITIVE -> switch (getPrimitiveType(header)) { + case NULL -> NULL; + case BOOLEAN_TRUE, BOOLEAN_FALSE -> BOOLEAN; + case INT8, INT16, INT32, INT64, DECIMAL4, DECIMAL8, DECIMAL16, FLOAT, DOUBLE -> NUMERIC; + case DATE -> DATE; + case TIME_NTZ_MICROS -> TIME_NTZ; + case TIMESTAMP_UTC_MICROS, TIMESTAMP_UTC_NANOS -> TIMESTAMP_UTC; + case TIMESTAMP_NTZ_MICROS, TIMESTAMP_NTZ_NANOS -> TIMESTAMP_NTZ; + case BINARY -> BINARY; + case STRING -> STRING; + case UUID -> UUID; + }; + case SHORT_STRING -> STRING; + case OBJECT -> OBJECT; + case ARRAY -> ARRAY; + }; + } + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/variant/VariantUtils.java b/core/trino-spi/src/main/java/io/trino/spi/variant/VariantUtils.java new file mode 100644 index 000000000000..d469f5848472 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/variant/VariantUtils.java @@ -0,0 +1,440 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.variant; + +import io.airlift.slice.Slice; +import io.airlift.slice.SliceInput; +import io.airlift.slice.SliceOutput; +import io.trino.spi.type.Int128; +import io.trino.spi.type.Int128Math; + +import java.util.Collection; +import java.util.Iterator; +import java.util.function.Supplier; + +import static io.trino.spi.variant.Header.PrimitiveType; +import static io.trino.spi.variant.Header.PrimitiveType.DOUBLE; +import static io.trino.spi.variant.Header.PrimitiveType.FLOAT; +import static io.trino.spi.variant.Header.PrimitiveType.INT16; +import static io.trino.spi.variant.Header.PrimitiveType.INT32; +import static io.trino.spi.variant.Header.PrimitiveType.INT64; +import static io.trino.spi.variant.Header.PrimitiveType.INT8; +import static java.lang.Double.doubleToRawLongBits; + +public final class VariantUtils +{ + private static final Int128[] INT128_POWERS_OF_FIVE = new Int128[39]; + // |value| >= 1e38 cannot be represented by variant DECIMAL(38, s) for any allowed scale s in [0, 38]. + private static final double DECIMAL128_MAGNITUDE_UPPER_BOUND = 1.0e38; + private static final Decimal128Canonical ZERO_DECIMAL = new Decimal128Canonical(Int128.ZERO, 0); + + static { + INT128_POWERS_OF_FIVE[0] = Int128.ONE; + for (int i = 1; i < INT128_POWERS_OF_FIVE.length; i++) { + INT128_POWERS_OF_FIVE[i] = Int128Math.multiply(INT128_POWERS_OF_FIVE[i - 1], 5L); + } + } + + // Threshold to switch from linear search to binary search for field indexes + static final int BINARY_SEARCH_THRESHOLD = 64; + + private VariantUtils() {} + + public static void writeOffset(Slice out, int offset, int size, int offsetSize) + { + switch (offsetSize) { + case 1 -> out.setByte(offset, (byte) size); + case 2 -> out.setShort(offset, (short) size); + // little endian + case 3 -> { + out.setByte(offset, (byte) (size & 0xFF)); + out.setByte(offset + 1, (byte) ((size >> 8) & 0xFF)); + out.setByte(offset + 2, (byte) ((size >> 16) & 0xFF)); + } + case 4 -> out.setInt(offset, size); + default -> throw new IllegalArgumentException("Unsupported offset size: " + offsetSize); + } + } + + public static void writeOffset(SliceOutput out, int size, int offsetSize) + { + switch (offsetSize) { + case 1 -> out.writeByte((byte) size); + case 2 -> out.writeShort((short) size); + // little endian + case 3 -> { + out.writeByte((byte) (size & 0xFF)); + out.writeByte((byte) ((size >> 8) & 0xFF)); + out.writeByte((byte) ((size >> 16) & 0xFF)); + } + case 4 -> out.writeInt(size); + default -> throw new IllegalArgumentException("Unsupported offset size: " + offsetSize); + } + } + + public static int readOffset(Slice data, int offset, int size) + { + return switch (size) { + case 1 -> data.getByte(offset) & 0xFF; + case 2 -> data.getShort(offset) & 0xFFFF; + // In all current usages, there is an extra byte at the end, so we can read as int directly + // This method is used for field ids and field offsets. In the case of fieldIds, they are + // always followed by the offsets, which must have at least one zero offset. In the case of + // offsets, the only way to get 3-byte offsets is to have a lot of data. So reading 4 bytes is safe. + case 3 -> data.getInt(offset) & 0xFFFFFF; + case 4 -> data.getInt(offset); + default -> throw new IllegalArgumentException("Unsupported offset size: " + size); + }; + } + + public static int readOffset(SliceInput data, int size) + { + return switch (size) { + case 1 -> data.readUnsignedByte(); + case 2 -> data.readUnsignedShort(); + case 3 -> data.readUnsignedShort() | + data.readUnsignedByte() << 16; + case 4 -> data.readInt(); + default -> throw new IllegalArgumentException("Unsupported offset size: " + size); + }; + } + + public static int getOffsetSize(int[] offsets) + { + return getOffsetSize(offsets[offsets.length - 1]); + } + + public static int getOffsetSize(int maxOffset) + { + if (maxOffset > 0xFFFFFF) { + return 4; + } + else if (maxOffset > 0xFFFF) { + return 3; + } + else if (maxOffset > 0xFF) { + return 2; + } + return 1; + } + + public static boolean isSorted(Collection fieldNames) + { + // Iceberg does not consider an empty dictionary as sorted + if (fieldNames.isEmpty()) { + return false; + } + + Iterator iterator = fieldNames.iterator(); + if (iterator.hasNext()) { + Slice previous = iterator.next(); + while (iterator.hasNext()) { + Slice next = iterator.next(); + if (previous.compareTo(next) > 0) { + return false; + } + previous = next; + } + } + return true; + } + + public static int findFieldIndex(Slice fieldName, Metadata metadata, Slice object, int fieldCount, int fieldIdsOffset, int fieldIdSize) + { + if (fieldCount > BINARY_SEARCH_THRESHOLD) { + return binarySearchFieldIndexes(fieldName, metadata, object, fieldCount, fieldIdsOffset, fieldIdSize); + } + + // linear search + for (int fieldIndex = 0; fieldIndex < fieldCount; fieldIndex++) { + if (metadata.get(readOffset(object, fieldIdsOffset + fieldIndex * fieldIdSize, fieldIdSize)).equals(fieldName)) { + return fieldIndex; + } + } + return -1; + } + + private static int binarySearchFieldIndexes(Slice fieldName, Metadata metadata, Slice object, int fieldCount, int fieldIdsOffset, int fieldIdSize) + { + // fieldIds are sorted, use binary search + int low = 0; + int high = fieldCount - 1; + while (low <= high) { + int mid = (low + high) >>> 1; + int midFieldId = readOffset(object, fieldIdsOffset + mid * fieldIdSize, fieldIdSize); + int compare = metadata.get(midFieldId).compareTo(fieldName); + if (compare < 0) { + low = mid + 1; + } + else if (compare > 0) { + high = mid - 1; + } + else { + return mid; + } + } + return -1; + } + + public static boolean equals( + Metadata leftMetadata, + Slice leftSlice, + int leftOffset, + Metadata rightMetadata, + Slice rightSlice, + int rightOffset) + { + return VariantEquality.equals(leftMetadata, leftSlice, leftOffset, rightMetadata, rightSlice, rightOffset); + } + + static boolean isExactNumeric(PrimitiveType primitiveType) + { + return switch (primitiveType) { + case INT8, INT16, INT32, INT64, DECIMAL4, DECIMAL8, DECIMAL16 -> true; + default -> false; + }; + } + + static boolean isFloatingNumeric(PrimitiveType primitiveType) + { + return primitiveType == FLOAT || primitiveType == DOUBLE; + } + + static double floatingAsDouble(PrimitiveType primitiveType, Slice slice, int offset) + { + return switch (primitiveType) { + case FLOAT -> slice.getFloat(offset + 1); + case DOUBLE -> slice.getDouble(offset + 1); + default -> throw new VerifyException("Unexpected floating type: " + primitiveType); + }; + } + + static Decimal128Canonical decodeExactNumericCanonical(PrimitiveType primitiveType, Slice slice, int offset) + { + return switch (primitiveType) { + case INT8 -> new Decimal128Canonical(Int128.valueOf(slice.getByte(offset + 1)), 0); + case INT16 -> new Decimal128Canonical(Int128.valueOf(slice.getShort(offset + 1)), 0); + case INT32 -> new Decimal128Canonical(Int128.valueOf(slice.getInt(offset + 1)), 0); + case INT64 -> new Decimal128Canonical(Int128.valueOf(slice.getLong(offset + 1)), 0); + case DECIMAL4 -> canonicalizeDecimal(slice.getInt(offset + 2), slice.getByte(offset + 1) & 0xFF); + case DECIMAL8 -> canonicalizeDecimal(slice.getLong(offset + 2), slice.getByte(offset + 1) & 0xFF); + case DECIMAL16 -> canonicalizeDecimal( + Int128.valueOf(slice.getLong(offset + 10), slice.getLong(offset + 2)), + slice.getByte(offset + 1) & 0xFF); + default -> throw new VerifyException("Unexpected exact numeric type: " + primitiveType); + }; + } + + private static Decimal128Canonical canonicalizeDecimal(long unscaled, int scale) + { + if (unscaled == 0) { + return ZERO_DECIMAL; + } + while (scale > 0 && unscaled % 10 == 0) { + unscaled /= 10; + scale--; + } + return new Decimal128Canonical(Int128.valueOf(unscaled), scale); + } + + private static Decimal128Canonical canonicalizeDecimal(Int128 unscaled, int scale) + { + if (unscaled.isZero()) { + return ZERO_DECIMAL; + } + if (fitsInLong(unscaled)) { + return canonicalizeDecimal(unscaled.getLow(), scale); + } + + long currentHigh = unscaled.getHigh(); + long currentLow = unscaled.getLow(); + long[] quotient = new long[2]; + long[] product = new long[2]; + while (scale > 0) { + if (!tryDivideByTenExactly(currentHigh, currentLow, quotient, product)) { + break; + } + currentHigh = quotient[0]; + currentLow = quotient[1]; + scale--; + } + return new Decimal128Canonical(Int128.valueOf(currentHigh, currentLow), scale); + } + + static Decimal128Canonical tryToDecimal128Exact(double value) + { + if (value == 0.0) { + // Handles both +0.0 and -0.0 and avoids zero-specific exponent/significand corner cases. + return ZERO_DECIMAL; + } + if (!Double.isFinite(value) || !canFitInVariantDecimal128(value)) { + return null; + } + + // Decode IEEE-754 double bit fields: + // sign bit (bit 63), exponent field (bits 62..52), fraction field (bits 51..0). + long bits = doubleToRawLongBits(value); + boolean negative = bits < 0; + int exponentBits = (int) ((bits >>> 52) & 0x7FFL); + long fractionBits = bits & ((1L << 52) - 1); + + // Re-express the value as: signedInteger * 2^binaryExponent, where signedInteger is an exact integer. + long significand; + int binaryExponent; + if (exponentBits == 0) { + // Subnormal numbers have no implicit leading 1 bit. + significand = fractionBits; + binaryExponent = -1074; + } + else { + // Normalized numbers have an implicit leading 1 bit before the fraction bits. + significand = (1L << 52) | fractionBits; + binaryExponent = exponentBits - 1075; + } + + // Pull out all factors of 2 from the integer part. + // This minimizes the eventual decimal scale because every removed factor of 2 lets us + // replace one "divide by 2" in the exponent with one "multiply by 5" in the unscaled value: + // (integer * 2^-k) == ((integer / 2^t) * 5^(k-t)) * 10^-(k-t), for maximal t. + int trailingZeroBits = Long.numberOfTrailingZeros(significand); + significand >>>= trailingZeroBits; + binaryExponent += trailingZeroBits; + + Int128 unscaled; + int scale; + if (binaryExponent >= 0) { + // Value is already an integer after applying the remaining power-of-two exponent: + // value = significand * 2^binaryExponent, so decimal scale is 0. + scale = 0; + unscaled = tryShiftLeftToPositiveInt128(significand, binaryExponent); + if (unscaled == null) { + return null; + } + } + else { + // We still have a negative binary exponent: value = significand / 2^(-binaryExponent). + // Convert to decimal form by multiplying by 5^scale and setting decimal scale=scale: + // significand / 2^scale == (significand * 5^scale) / 10^scale. + scale = -binaryExponent; + if (scale > 38) { + return null; + } + try { + unscaled = Int128Math.multiply(Int128.valueOf(significand), INT128_POWERS_OF_FIVE[scale]); + } + catch (ArithmeticException ignored) { + return null; + } + } + + if (negative) { + unscaled = Int128Math.negate(unscaled); + } + return canonicalizeDecimal(unscaled, scale); + } + + private static boolean canFitInVariantDecimal128(double value) + { + return Math.abs(value) < DECIMAL128_MAGNITUDE_UPPER_BOUND; + } + + private static boolean tryDivideByTenExactly(long high, long low, long[] quotient, long[] product) + { + // Compute quotient once, then verify exactness via multiply-back. + Int128Math.rescaleTruncate(high, low, -1, quotient, 0); + Int128Math.multiply(quotient[0], quotient[1], 0, 10L, product, 0); + return product[0] == high && product[1] == low; + } + + private static Int128 tryShiftLeftToPositiveInt128(long value, int shift) + { + if (shift >= 128) { + return null; + } + + long high; + long low; + if (shift == 0) { + high = 0; + low = value; + } + else if (shift < 64) { + high = value >>> (64 - shift); + low = value << shift; + } + else { + high = value << (shift - 64); + low = 0; + } + + // Positive Int128 values must keep the sign bit (bit 127) clear. + if (high < 0) { + return null; + } + return Int128.valueOf(high, low); + } + + static boolean fitsInLong(Int128 value) + { + return value.getHigh() == (value.getLow() >> 63); + } + + static record Decimal128Canonical(Int128 unscaled, int scale) + { + Decimal128Canonical + { + checkArgument(scale >= 0 && scale <= 38, () -> "Invalid decimal scale: %s".formatted(scale)); + } + } + + public static long hashCode(Metadata metadata, Slice slice, int offset) + { + return VariantHashing.hashCode(metadata, slice, offset); + } + + static void checkArgument(boolean test, String message) + { + if (!test) { + throw new IllegalArgumentException(message); + } + } + + static void checkArgument(boolean test, Supplier message) + { + if (!test) { + throw new IllegalArgumentException(message.get()); + } + } + + static void checkState(boolean test, String message) + { + if (!test) { + throw new IllegalStateException(message); + } + } + + static void checkState(boolean test, Supplier message) + { + if (!test) { + throw new IllegalStateException(message.get()); + } + } + + static void verify(boolean test, String message) + { + if (!test) { + throw new VerifyException(message); + } + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/variant/VerifyException.java b/core/trino-spi/src/main/java/io/trino/spi/variant/VerifyException.java new file mode 100644 index 000000000000..0c263f14f7a7 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/variant/VerifyException.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.variant; + +class VerifyException + extends RuntimeException +{ + VerifyException(String message) + { + super(message); + } +} diff --git a/core/trino-spi/src/main/java/module-info.java b/core/trino-spi/src/main/java/module-info.java index 10f9218e0934..270445b97474 100644 --- a/core/trino-spi/src/main/java/module-info.java +++ b/core/trino-spi/src/main/java/module-info.java @@ -39,4 +39,5 @@ exports io.trino.spi.statistics; exports io.trino.spi.transaction; exports io.trino.spi.type; + exports io.trino.spi.variant; } diff --git a/core/trino-spi/src/test/java/io/trino/spi/block/TestVariantBlockBuilder.java b/core/trino-spi/src/test/java/io/trino/spi/block/TestVariantBlockBuilder.java new file mode 100644 index 000000000000..3a08b69602ee --- /dev/null +++ b/core/trino-spi/src/test/java/io/trino/spi/block/TestVariantBlockBuilder.java @@ -0,0 +1,167 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.block; + +import io.trino.spi.variant.Header; +import io.trino.spi.variant.Variant; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static org.assertj.core.api.Assertions.assertThat; + +final class TestVariantBlockBuilder + extends AbstractTestBlockBuilder +{ + private static final int VARIANT_ENTRY_SIZE = Integer.BYTES + Byte.BYTES; + private static final int NULL_VARIANT_FIELD_ENTRY_SIZE = (Integer.BYTES + Byte.BYTES) * 2; + + @Test + public void testAppendRangeUpdatesStatus() + { + int length = 4; + VariantBlock source = createAllNullSourceBlock(length + 1); + PageBuilderStatus pageBuilderStatus = new PageBuilderStatus(Integer.MAX_VALUE); + VariantBlockBuilder blockBuilder = new VariantBlockBuilder(pageBuilderStatus.createBlockBuilderStatus(), 1); + + blockBuilder.appendRange(source, 1, length); + + assertThat(pageBuilderStatus.getSizeInBytes()) + .isEqualTo((long) length * (VARIANT_ENTRY_SIZE + NULL_VARIANT_FIELD_ENTRY_SIZE)); + } + + @Test + public void testAppendPositionsUpdatesStatusWithOffsetSource() + { + int length = 3; + VariantBlock source = createAllNullSourceBlock(length + 3).getRegion(1, length + 1); + PageBuilderStatus pageBuilderStatus = new PageBuilderStatus(Integer.MAX_VALUE); + VariantBlockBuilder blockBuilder = new VariantBlockBuilder(pageBuilderStatus.createBlockBuilderStatus(), 1); + + // appendPositions reads only the three entries starting at offset 1, so the -1 padding is never accessed. + blockBuilder.appendPositions(source, new int[] {-1, 3, 0, 2, -1}, 1, length); + + assertThat(pageBuilderStatus.getSizeInBytes()) + .isEqualTo((long) length * (VARIANT_ENTRY_SIZE + NULL_VARIANT_FIELD_ENTRY_SIZE)); + } + + @Test + public void testResetToRecomputesNullFlags() + { + VariantBlockBuilder blockBuilder = new VariantBlockBuilder(null, 2); + blockBuilder.appendNull(); + blockBuilder.resetTo(0); + blockBuilder.writeEntry(Variant.ofLong(1)); + + VariantBlock block = blockBuilder.buildValueBlock(); + + assertThat(block.mayHaveNull()).isFalse(); + assertThat(block.isNull(0)).isFalse(); + } + + @Test + public void testGetBasicTypeWithDictionaryBackedFields() + { + VariantBlockBuilder sourceBuilder = new VariantBlockBuilder(null, 2); + sourceBuilder.writeEntry(Variant.ofBoolean(true)); + sourceBuilder.writeEntry(Variant.ofString("abc")); + VariantBlock source = sourceBuilder.buildValueBlock(); + + int[] ids = {1, 0}; + Block metadata = DictionaryBlock.create(ids.length, source.getRawMetadata(), ids); + Block values = DictionaryBlock.create(ids.length, source.getRawValues(), ids); + VariantBlock dictionaryBlock = VariantBlock.create(ids.length, metadata, values, Optional.empty()); + + assertThat(dictionaryBlock.getBasicType(0)).isEqualTo(Header.BasicType.SHORT_STRING); + assertThat(dictionaryBlock.getBasicType(1)).isEqualTo(Header.BasicType.PRIMITIVE); + } + + @Test + public void testCopyWithAppendedNullOnSlicedBlock() + { + VariantBlockBuilder blockBuilder = new VariantBlockBuilder(null, 4); + blockBuilder.writeEntry(Variant.ofLong(10)); + blockBuilder.writeEntry(Variant.ofLong(20)); + blockBuilder.writeEntry(Variant.ofLong(30)); + blockBuilder.writeEntry(Variant.ofLong(40)); + + VariantBlock sliced = blockBuilder.buildValueBlock().getRegion(1, 2); + VariantBlock appended = sliced.copyWithAppendedNull(); + + assertThat(appended.getPositionCount()).isEqualTo(3); + assertThat(appended.getVariant(0)).isEqualTo(Variant.ofLong(20)); + assertThat(appended.getVariant(1)).isEqualTo(Variant.ofLong(30)); + assertThat(appended.isNull(2)).isTrue(); + } + + @Override + protected BlockBuilder createBlockBuilder() + { + return new VariantBlockBuilder(null, 1); + } + + @Override + protected List getTestValues() + { + return List.of(Variant.ofBoolean(true), Variant.ofByte((byte) 90), Variant.ofInt(91), Variant.ofDouble(92.12), Variant.ofString("ninty three")); + } + + @Override + protected Variant getUnusedTestValue() + { + return Variant.ofString("unused value"); + } + + @Override + protected ValueBlock blockFromValues(Iterable values) + { + VariantBlockBuilder blockBuilder = new VariantBlockBuilder(null, 1); + for (Variant value : values) { + if (value == null) { + blockBuilder.appendNull(); + } + else { + blockBuilder.writeEntry(value); + } + } + return blockBuilder.buildValueBlock(); + } + + @Override + protected List blockToValues(ValueBlock valueBlock) + { + VariantBlock variantBlock = (VariantBlock) valueBlock; + List actualValues = new ArrayList<>(variantBlock.getPositionCount()); + for (int i = 0; i < variantBlock.getPositionCount(); i++) { + if (variantBlock.isNull(i)) { + actualValues.add(null); + } + else { + actualValues.add(variantBlock.getVariant(i)); + } + } + return actualValues; + } + + private static VariantBlock createAllNullSourceBlock(int positions) + { + VariantBlockBuilder blockBuilder = new VariantBlockBuilder(null, 1); + for (int i = 0; i < positions; i++) { + blockBuilder.appendNull(); + } + return blockBuilder.buildValueBlock(); + } +} diff --git a/core/trino-spi/src/test/java/io/trino/spi/block/TestVariantBlockEncoding.java b/core/trino-spi/src/test/java/io/trino/spi/block/TestVariantBlockEncoding.java new file mode 100644 index 000000000000..eba53295763e --- /dev/null +++ b/core/trino-spi/src/test/java/io/trino/spi/block/TestVariantBlockEncoding.java @@ -0,0 +1,43 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.block; + +import io.trino.spi.type.Type; +import io.trino.spi.variant.Variant; + +import java.util.Random; + +import static io.trino.spi.type.VariantType.VARIANT; + +final class TestVariantBlockEncoding + extends BaseBlockEncodingTest +{ + @Override + protected Type getType() + { + return VARIANT; + } + + @Override + protected void write(BlockBuilder blockBuilder, Variant value) + { + VARIANT.writeObject(blockBuilder, value); + } + + @Override + protected Variant randomValue(Random random) + { + return Variant.ofLong(random.nextLong()); + } +} diff --git a/core/trino-spi/src/test/java/io/trino/spi/block/TestingBlockEncodingSerde.java b/core/trino-spi/src/test/java/io/trino/spi/block/TestingBlockEncodingSerde.java index bb20ddc23af5..11ea4134faf1 100644 --- a/core/trino-spi/src/test/java/io/trino/spi/block/TestingBlockEncodingSerde.java +++ b/core/trino-spi/src/test/java/io/trino/spi/block/TestingBlockEncodingSerde.java @@ -57,6 +57,7 @@ public TestingBlockEncodingSerde(Function types) addBlockEncoding(new LongArrayBlockEncoding(true, true, true)); addBlockEncoding(new Fixed12BlockEncoding(true)); addBlockEncoding(new Int128ArrayBlockEncoding(true)); + addBlockEncoding(new VariantBlockEncoding(true)); addBlockEncoding(new DictionaryBlockEncoding()); addBlockEncoding(new ArrayBlockEncoding(true)); addBlockEncoding(new MapBlockEncoding(true)); diff --git a/core/trino-spi/src/test/java/io/trino/spi/variant/TestMetadata.java b/core/trino-spi/src/test/java/io/trino/spi/variant/TestMetadata.java new file mode 100644 index 000000000000..7455470b0f3e --- /dev/null +++ b/core/trino-spi/src/test/java/io/trino/spi/variant/TestMetadata.java @@ -0,0 +1,150 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.variant; + +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import org.apache.iceberg.variants.VariantMetadata; +import org.apache.iceberg.variants.Variants; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.IntStream; + +import static io.airlift.slice.Slices.utf8Slice; +import static io.airlift.slice.Slices.wrappedBuffer; +import static io.trino.spi.variant.Header.metadataIsSorted; +import static io.trino.spi.variant.Metadata.EMPTY_METADATA; +import static io.trino.spi.variant.Metadata.EMPTY_METADATA_SLICE; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class TestMetadata +{ + // from the Iceberg implementation, an empty metadata has 3 bytes: header + offsetCount + 0th offset + private static final Slice SERIALIZED_EMPTY_METADATA = wrappedBuffer(new byte[] {0x01, 0x00, 0x00}); + + @Test + void testEmpty() + { + assertThat(EMPTY_METADATA_SLICE).isEqualTo(SERIALIZED_EMPTY_METADATA); + + assertThat(Metadata.of(List.of())).isSameAs(EMPTY_METADATA); + assertThat(EMPTY_METADATA.dictionarySize()).isEqualTo(0); + assertThat(EMPTY_METADATA.toSlice()).isSameAs(EMPTY_METADATA_SLICE); + assertThat(EMPTY_METADATA.id(Slices.utf8Slice("non_existent_field"))).isEqualTo(-1); + assertThatThrownBy(() -> EMPTY_METADATA.get(0)).isInstanceOf(IndexOutOfBoundsException.class); + + EMPTY_METADATA.validateFully(); + + // double check compatibility with Iceberg serialization + assertThat(serializeIcebergVariantMetadata(Variants.emptyMetadata())).isEqualTo(SERIALIZED_EMPTY_METADATA); + } + + @Test + void testSortedFields() + { + // one byte offsets + assertMetadata(List.of(Slices.utf8Slice("apple"), Slices.utf8Slice("banana"), Slices.utf8Slice("cherry")), true); + // two bytes offsets (total size between 256 and 65,535 bytes) + assertMetadata(generateFieldNames(100), true); + // two bytes offsets (total size between 65,536 and 16,777,215 bytes) + assertMetadata(generateFieldNames(10_000), true); + // three bytes offsets (total size over 16,777,215 bytes) + assertMetadata(generateFieldNames(5_000_000), true); + } + + private static List generateFieldNames(int endExclusive) + { + return IntStream.range(0, endExclusive).mapToObj(i -> Slices.utf8Slice("field_%7d".formatted(i))).toList(); + } + + @Test + void testUnorderedFields() + { + assertMetadata(List.of(Slices.utf8Slice("banana"), Slices.utf8Slice("cherry"), Slices.utf8Slice("apple")), false); + List fields = Arrays.asList(IntStream.range(0, 1000).mapToObj(i -> Slices.utf8Slice("field_" + i)).toArray(Slice[]::new)); + Collections.shuffle(fields); + assertMetadata(fields, false); + } + + @Test + void testDuplicateFieldsRejected() + { + assertThatThrownBy(() -> Metadata.of(List.of(utf8Slice("apple"), utf8Slice("apple")))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("duplicate field name: apple"); + } + + @Test + void testEmptyFieldRejected() + { + assertThatThrownBy(() -> Metadata.of(List.of(utf8Slice("")))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("empty field names are not allowed"); + } + + private static void assertMetadata(List fieldNames, boolean expectedSorted) + { + Metadata metadata = Metadata.of(fieldNames); + metadata.validateFully(); + assertThat(metadataIsSorted(metadata.toSlice().getByte(0))).isEqualTo(expectedSorted); + assertThat(metadata.dictionarySize()).isEqualTo(fieldNames.size()); + + // verify lookups by index + for (int i = 0; i < metadata.dictionarySize(); i++) { + assertThat(metadata.get(i)).isEqualTo(fieldNames.get(i)); + } + + // verify lookups by name + if (fieldNames.size() <= 500) { + for (int i = 0; i < fieldNames.size(); i++) { + assertThat(metadata.id(fieldNames.get(i))).isEqualTo(i); + } + } + else { + // for large dictionaries, select random lookups to verify + ThreadLocalRandom.current().ints(500, 0, fieldNames.size()) + .forEach(i -> assertThat(metadata.id(fieldNames.get(i))).isEqualTo(i)); + } + assertThat(metadata.id(Slices.utf8Slice(fieldNames.getFirst().toStringUtf8() + " non_existent_field"))).isEqualTo(-1); + + // verify direct construction creates same metadata + Metadata actual = Metadata.of(fieldNames); + assertThat(actual).isEqualTo(metadata); + + // verify serialization + assertThat(metadata.toSlice()).isEqualTo(serializeIcebergVariantMetadata(fieldNames)); + } + + private static Slice serializeIcebergVariantMetadata(List fieldNames) + { + var metadata = Variants.metadata(fieldNames.stream().map(Slice::toStringUtf8).toList()); + return serializeIcebergVariantMetadata(metadata); + } + + private static Slice serializeIcebergVariantMetadata(VariantMetadata metadata) + { + int size = metadata.sizeInBytes(); + byte[] array = new byte[size]; + ByteBuffer valueBuf = ByteBuffer.wrap(array).order(ByteOrder.LITTLE_ENDIAN); + metadata.writeTo(valueBuf, 0); + return wrappedBuffer(array); + } +} diff --git a/core/trino-spi/src/test/java/io/trino/spi/variant/TestVariant.java b/core/trino-spi/src/test/java/io/trino/spi/variant/TestVariant.java new file mode 100644 index 000000000000..d0feb8332413 --- /dev/null +++ b/core/trino-spi/src/test/java/io/trino/spi/variant/TestVariant.java @@ -0,0 +1,1488 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.variant; + +import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import org.apache.iceberg.variants.ValueArray; +import org.apache.iceberg.variants.VariantMetadata; +import org.apache.iceberg.variants.VariantObject; +import org.apache.iceberg.variants.VariantValue; +import org.apache.iceberg.variants.Variants; +import org.junit.jupiter.api.Test; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.math.RoundingMode; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.ZoneOffset; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.IntStream; + +import static io.airlift.slice.Slices.utf8Slice; +import static io.airlift.slice.Slices.wrappedBuffer; +import static io.trino.spi.variant.Metadata.EMPTY_METADATA; +import static io.trino.spi.variant.VariantEncoder.ENCODED_DECIMAL16_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_DECIMAL4_SIZE; +import static io.trino.spi.variant.VariantEncoder.ENCODED_DECIMAL8_SIZE; +import static io.trino.spi.variant.VariantEncoder.encodeDecimal16; +import static io.trino.spi.variant.VariantEncoder.encodeDecimal4; +import static io.trino.spi.variant.VariantEncoder.encodeDecimal8; +import static io.trino.spi.variant.VariantEncoder.encodeObject; +import static io.trino.spi.variant.VariantEncoder.encodedArraySize; +import static io.trino.spi.variant.VariantEncoder.encodedObjectSize; +import static io.trino.spi.variant.VariantUtils.checkArgument; +import static io.trino.spi.variant.VariantUtils.getOffsetSize; +import static io.trino.spi.variant.VariantUtils.verify; +import static io.trino.spi.variant.VariantUtils.writeOffset; +import static java.lang.Math.toIntExact; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class TestVariant +{ + @Test + void testNullVariant() + { + assertPrimitiveEncoding( + 1, + null, + _ -> Variant.NULL_VALUE, + variant -> { + assertThat(variant.isNull()).isTrue(); + return null; + }, + (_, variant, offset) -> VariantEncoder.encodeNull(variant, offset), + serializeIcebergVariant(Variants.ofNull())); + } + + @Test + void testBoolean() + { + assertPrimitiveEncoding( + 1, + true, + Variant::ofBoolean, + Variant::getBoolean, + VariantEncoder::encodeBoolean, + serializeIcebergVariant(Variants.of(true))); + assertPrimitiveEncoding( + 1, + false, + Variant::ofBoolean, + Variant::getBoolean, + VariantEncoder::encodeBoolean, + serializeIcebergVariant(Variants.of(false))); + + assertThat(Variant.ofBoolean(true)).isNotEqualTo(Variant.ofBoolean(false)); + assertThat(Variant.ofBoolean(true).longHashCode()).isNotEqualTo(Variant.ofBoolean(false).longHashCode()); + } + + @Test + void testByte() + { + for (byte value : new byte[] {Byte.MIN_VALUE, -42, -1, 0, 42, 1, Byte.MAX_VALUE}) { + assertPrimitiveEncoding( + 2, + value, + Variant::ofByte, + Variant::getByte, + VariantEncoder::encodeByte, + serializeIcebergVariant(Variants.of(value))); + + assertEqualAndSameHash(Variant.ofByte(value), Variant.ofByte(value)); + assertEqualAndSameHash(Variant.ofByte(value), Variant.ofShort(value)); + assertEqualAndSameHash(Variant.ofByte(value), Variant.ofInt(value)); + assertEqualAndSameHash(Variant.ofByte(value), Variant.ofLong(value)); + assertEqualAndSameHash(Variant.ofByte(value), BigDecimal.valueOf(value)); + } + assertThat(Variant.ofByte((byte) 0x12)).isNotEqualTo(Variant.ofByte((byte) 0x34)); + assertThat(Variant.ofByte((byte) 0x12).longHashCode()).isNotEqualTo(Variant.ofByte((byte) 0x34).longHashCode()); + } + + @Test + void testShort() + { + for (short value : new short[] {Short.MIN_VALUE, Byte.MIN_VALUE, -42, -1, 0, 42, 1, Byte.MAX_VALUE, Short.MAX_VALUE}) { + assertPrimitiveEncoding( + 3, + value, + Variant::ofShort, + Variant::getShort, + VariantEncoder::encodeShort, + serializeIcebergVariant(Variants.of(value))); + + if ((byte) value == value) { + assertEqualAndSameHash(Variant.ofShort(value), Variant.ofByte((byte) value)); + } + else { + assertNotEqualAndDifferentHash(Variant.ofShort(value), Variant.ofByte((byte) value)); + } + assertEqualAndSameHash(Variant.ofShort(value), Variant.ofShort(value)); + assertEqualAndSameHash(Variant.ofShort(value), Variant.ofInt(value)); + assertEqualAndSameHash(Variant.ofShort(value), Variant.ofLong(value)); + assertEqualAndSameHash(Variant.ofShort(value), BigDecimal.valueOf(value)); + } + + assertThat(Variant.ofShort((short) 0x1234)).isNotEqualTo(Variant.ofShort((short) 0x5678)); + assertThat(Variant.ofShort((short) 0x1234).longHashCode()).isNotEqualTo(Variant.ofShort((short) 0x5678).longHashCode()); + } + + @Test + void testInt() + { + for (int value : new int[] {Integer.MIN_VALUE, Short.MIN_VALUE, Byte.MIN_VALUE, -42, -1, 0, 42, 1, Byte.MAX_VALUE, Short.MAX_VALUE, Integer.MAX_VALUE}) { + assertPrimitiveEncoding( + 5, + value, + Variant::ofInt, + Variant::getInt, + VariantEncoder::encodeInt, + serializeIcebergVariant(Variants.of(value))); + + if ((byte) value == value) { + assertEqualAndSameHash(Variant.ofInt(value), Variant.ofByte((byte) value)); + } + else { + assertNotEqualAndDifferentHash(Variant.ofInt(value), Variant.ofByte((byte) value)); + } + if ((short) value == value) { + assertEqualAndSameHash(Variant.ofInt(value), Variant.ofShort((short) value)); + } + else { + assertNotEqualAndDifferentHash(Variant.ofInt(value), Variant.ofShort((short) value)); + } + assertEqualAndSameHash(Variant.ofInt(value), Variant.ofInt(value)); + assertEqualAndSameHash(Variant.ofInt(value), Variant.ofLong(value)); + assertEqualAndSameHash(Variant.ofInt(value), BigDecimal.valueOf(value)); + } + + assertThat(Variant.ofInt(0x12345678)).isNotEqualTo(Variant.ofInt(0x9ABCDEF0)); + assertThat(Variant.ofInt(0x12345678).longHashCode()).isNotEqualTo(Variant.ofInt(0x9ABCDEF0).longHashCode()); + } + + @Test + void testLong() + { + assertPrimitiveEncoding( + 9, + 0x11223344_55667788L, + Variant::ofLong, + Variant::getLong, + VariantEncoder::encodeLong, + serializeIcebergVariant(Variants.of(0x11223344_55667788L))); + + for (long value : new long[] { + Long.MIN_VALUE, + Integer.MIN_VALUE, + Short.MIN_VALUE, + Byte.MIN_VALUE, + 42L, + -1L, + 0L, + -42L, + 1L, + Byte.MAX_VALUE, + Short.MAX_VALUE, + Integer.MAX_VALUE, + Long.MAX_VALUE}) { + assertPrimitiveEncoding( + 9, + value, + Variant::ofLong, + Variant::getLong, + VariantEncoder::encodeLong, + serializeIcebergVariant(Variants.of(value))); + + if ((byte) value == value) { + assertEqualAndSameHash(Variant.ofLong(value), Variant.ofByte((byte) value)); + } + else { + assertNotEqualAndDifferentHash(Variant.ofLong(value), Variant.ofByte((byte) value)); + } + if ((short) value == value) { + assertEqualAndSameHash(Variant.ofLong(value), Variant.ofShort((short) value)); + } + else { + assertNotEqualAndDifferentHash(Variant.ofLong(value), Variant.ofShort((short) value)); + } + if ((int) value == value) { + assertEqualAndSameHash(Variant.ofLong(value), Variant.ofInt((int) value)); + } + else { + assertNotEqualAndDifferentHash(Variant.ofLong(value), Variant.ofInt((int) value)); + } + assertEqualAndSameHash(Variant.ofLong(value), Variant.ofLong(value)); + assertEqualAndSameHash(Variant.ofLong(value), BigDecimal.valueOf(value)); + } + + assertThat(Variant.ofLong(0x12345678_9ABCDEF0L)).isNotEqualTo(Variant.ofLong(42)); + assertThat(Variant.ofLong(0x12345678_9ABCDEF0L).longHashCode()).isNotEqualTo(Variant.ofLong(42).longHashCode()); + } + + @Test + void testFieldRemapperResizesNestedObjectWhenGlobalFieldIdWidthStaysWide() + { + List fieldNames = new ArrayList<>(); + List arrayElements = new ArrayList<>(); + for (int fieldId = 0; fieldId < 256; fieldId++) { + fieldNames.add(utf8Slice("m%03d".formatted(fieldId))); + arrayElements.add(encodeObjectWithSortedFields(List.of(new ObjectField(fieldId, Variant.ofInt(fieldId))))); + } + fieldNames.add(utf8Slice("a0")); + fieldNames.add(utf8Slice("a1")); + arrayElements.add(encodeObjectWithSortedFields(List.of( + new ObjectField(256, Variant.ofInt(10_000)), + new ObjectField(257, Variant.ofInt(20_000))))); + + Metadata metadata = Metadata.of(fieldNames); + Slice arrayData = Slices.allocate(encodedArraySize(arrayElements.size(), arrayElements.stream().mapToInt(Slice::length).sum())); + assertThat(VariantEncoder.encodeArray(arrayElements, arrayData, 0)).isEqualTo(arrayData.length()); + + Variant variant = Variant.from(metadata, arrayData); + Metadata.Builder metadataBuilder = Metadata.builder(); + VariantFieldRemapper remapper = VariantFieldRemapper.create(variant, metadataBuilder); + Metadata.Builder.SortedMetadata sortedMetadata = metadataBuilder.buildSorted(); + remapper.finalize(sortedMetadata.sortedFieldIdMapping()); + + Slice remappedData = Slices.allocate(remapper.size()); + assertThat(remapper.write(remappedData, 0)).isEqualTo(remappedData.length()); + + Variant remapped = Variant.from(sortedMetadata.metadata(), remappedData); + assertThat(remapped.getArrayLength()).isEqualTo(arrayElements.size()); + + Variant specialObject = remapped.getArrayElement(arrayElements.size() - 1); + assertThat(specialObject.getObjectField(utf8Slice("a0"))).hasValueSatisfying(value -> assertThat(value.getInt()).isEqualTo(10_000)); + assertThat(specialObject.getObjectField(utf8Slice("a1"))).hasValueSatisfying(value -> assertThat(value.getInt()).isEqualTo(20_000)); + } + + @Test + void testDecimal() + { + for (String string : List.of("0", "1", "-1", "123456789", "123456700", "-123456789", "-123456700")) { + BigInteger unscaled = new BigInteger(string); + for (int scale = 0; scale <= 9; scale++) { + BigDecimal decimal4 = new BigDecimal(unscaled, scale); + assertPrimitiveEncoding( + 6, + decimal4, + Variant::ofDecimal, + Variant::getDecimal, + (BigDecimal value, Slice variant, int offset) -> encodeDecimal4(value.unscaledValue().intValueExact(), value.scale(), variant, offset), + serializeIcebergVariant(Variants.of(decimal4))); + assertEqualAndSameHash(Variant.ofDecimal(decimal4), decimal4); + BigDecimal stripTrailingZeros = decimal4.stripTrailingZeros(); + if (stripTrailingZeros.scale() < 0) { + stripTrailingZeros = stripTrailingZeros.setScale(0, RoundingMode.UNNECESSARY); + } + if (decimal4.scale() != stripTrailingZeros.scale()) { + assertEqualAndSameHash(Variant.ofDecimal(decimal4), stripTrailingZeros); + } + } + } + + for (String string : List.of("1234567890", "123456789012345678", "123456789012345600", "-1234567890", "-123456789012345678", "-123456789012345600")) { + BigInteger unscaled = new BigInteger(string); + for (int scale = 0; scale <= 18; scale++) { + BigDecimal decimal8 = new BigDecimal(unscaled, scale); + assertPrimitiveEncoding( + 10, + decimal8, + Variant::ofDecimal, + Variant::getDecimal, + (BigDecimal decimal, Slice slice, int offset) -> encodeDecimal8(decimal.unscaledValue().longValueExact(), decimal.scale(), slice, offset), + serializeIcebergVariant(Variants.of(decimal8))); + assertEqualAndSameHash(Variant.ofDecimal(decimal8), decimal8); + BigDecimal stripTrailingZeros = decimal8.stripTrailingZeros(); + if (stripTrailingZeros.scale() < 0) { + stripTrailingZeros = stripTrailingZeros.setScale(0, RoundingMode.UNNECESSARY); + } + if (decimal8.scale() != stripTrailingZeros.scale()) { + assertEqualAndSameHash(Variant.ofDecimal(decimal8), stripTrailingZeros); + } + } + } + + for (String string : List.of( + "1234567890123456789", + "12345678901234567890123456789012345678", + "12345678901234567890123456789012345600", + "-1234567890123456789", + "-12345678901234567890123456789012345678", + "-12345678901234567890123456789012345600", + "18446744073709551616", + "9223372036854775808", + "-9223372036854775809")) { + BigInteger unscaled = new BigInteger(string); + for (int scale = 0; scale <= 38; scale++) { + BigDecimal decimal16 = new BigDecimal(unscaled, scale); + assertPrimitiveEncoding( + 18, + decimal16, + Variant::ofDecimal, + Variant::getDecimal, + (BigDecimal decimal, Slice slice, int offset) -> encodeDecimal16(decimal16.unscaledValue(), decimal.scale(), slice, offset), + serializeIcebergVariant(Variants.of(decimal16))); + assertEqualAndSameHash(Variant.ofDecimal(decimal16), decimal16); + BigDecimal stripTrailingZeros = decimal16.stripTrailingZeros(); + if (stripTrailingZeros.scale() < 0) { + stripTrailingZeros = stripTrailingZeros.setScale(0, RoundingMode.UNNECESSARY); + } + if (decimal16.scale() != stripTrailingZeros.scale()) { + assertEqualAndSameHash(Variant.ofDecimal(decimal16), stripTrailingZeros); + } + } + } + + assertNotEqualAndDifferentHash(Variant.ofDecimal(new BigDecimal("1")), new BigDecimal("1e-18")); + assertNotEqualAndDifferentHash(Variant.ofDecimal(new BigDecimal("1")), new BigDecimal("2e-19")); + + assertEqualAndSameHash(Variant.ofDecimal(new BigDecimal("1")), new BigDecimal("1.00")); + assertEqualAndSameHash(Variant.ofDecimal(new BigDecimal("0")), new BigDecimal("0.00")); + + assertNotEqualAndDifferentHash(Variant.ofDecimal(new BigDecimal("123.45")), new BigDecimal("678.9")); + assertNotEqualAndDifferentHash(Variant.ofDecimal(new BigDecimal("123.45")), new BigDecimal("678.90")); + assertNotEqualAndDifferentHash(Variant.ofDecimal(new BigDecimal("123.45")), new BigDecimal("678.900")); + + assertThat(Variant.ofDecimal(new BigDecimal("1e-19"))).isNotEqualTo(Variant.ofDecimal(new BigDecimal("0"))); + } + + @Test + void testFloat() + { + for (float value : new float[] {0.0f, -0.0f, 1.0f, -1.0f, 3.14f, -3.14f, Float.MAX_VALUE, Float.MIN_VALUE, Float.MIN_NORMAL, Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY}) { + assertPrimitiveEncoding( + 5, + value, + Variant::ofFloat, + Variant::getFloat, + VariantEncoder::encodeFloat, + serializeIcebergVariant(Variants.of(value))); + } + + assertEqualAndSameHash(Variant.ofFloat(0.0f), Variant.ofFloat(-0.0f)); + + assertNotEqualAndDifferentHash(Variant.ofFloat(1.2345f), Variant.ofFloat(6.7890f)); + + assertThat(Variant.ofFloat(Float.NaN)).isNotEqualTo(Variant.ofFloat(Float.NaN)); + assertThat(Variant.ofFloat(Float.NaN).longHashCode()).isEqualTo(Variant.ofFloat(Float.NaN).longHashCode()); + } + + @Test + void testDouble() + { + for (double value : new double[] {0.0, -0.0, 1.0, -1.0, Math.PI, -Math.PI, Double.MAX_VALUE, Double.MIN_VALUE, Double.MIN_NORMAL, Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY}) { + assertPrimitiveEncoding( + 9, + value, + Variant::ofDouble, + Variant::getDouble, + VariantEncoder::encodeDouble, + serializeIcebergVariant(Variants.of(value))); + } + + assertEqualAndSameHash(Variant.ofDouble(0.0), Variant.ofDouble(-0.0)); + + assertNotEqualAndDifferentHash(Variant.ofDouble(1.2345), Variant.ofDouble(6.7890)); + + assertThat(Variant.ofDouble(Double.NaN)).isNotEqualTo(Variant.ofDouble(Double.NaN)); + assertThat(Variant.ofDouble(Double.NaN).longHashCode()).isEqualTo(Variant.ofDouble(Double.NaN).longHashCode()); + } + + @Test + void testDate() + { + for (int date : List.of(0, 1, -1, 36525, -36525, 20000, -20000, Integer.MAX_VALUE, Integer.MIN_VALUE)) { + assertPrimitiveEncoding( + 5, + toIntExact(date), + Variant::ofDate, + Variant::getDate, + VariantEncoder::encodeDate, + serializeIcebergVariant(Variants.ofDate(date))); + assertPrimitiveEncoding( + 5, + LocalDate.ofEpochDay(date), + Variant::ofDate, + Variant::getLocalDate, + VariantEncoder::encodeDate, + serializeIcebergVariant(Variants.ofDate(date))); + } + + assertThat(Variant.ofDate(1000)).isNotEqualTo(Variant.ofDate(2000)); + assertThat(Variant.ofDate(1000).longHashCode()).isNotEqualTo(Variant.ofDate(2000).longHashCode()); + } + + @Test + void testTimeMicros() + { + for (long micros : List.of(0L, 1L, 86399999999L)) { + assertPrimitiveEncoding( + 9, + micros, + Variant::ofTimeMicrosNtz, + Variant::getTimeMicros, + VariantEncoder::encodeTimeMicrosNtz, + serializeIcebergVariant(Variants.ofTime(micros))); + assertPrimitiveEncoding( + 9, + LocalTime.ofNanoOfDay(micros * 1_000L), + Variant::ofTimeMicrosNtz, + Variant::getLocalTime, + VariantEncoder::encodeTimeMicrosNtz, + serializeIcebergVariant(Variants.ofTime(micros))); + } + + assertNotEqualAndDifferentHash(Variant.ofTimeMicrosNtz(86399999999L), Variant.ofTimeMicrosNtz(42L)); + } + + @Test + void testTimestampMicros() + { + for (long micros : List.of(0L, 1L, -1L, 1625079045123456L, -1625079045123456L, Long.MAX_VALUE, Long.MIN_VALUE)) { + long seconds = Math.floorDiv(micros, 1_000_000); + int nanoOfSecond = toIntExact(Math.floorMod(micros, 1_000_000) * 1_000L); + + assertPrimitiveEncoding( + 9, + micros, + Variant::ofTimestampMicrosUtc, + Variant::getTimestampMicros, + VariantEncoder::encodeTimestampMicrosUtc, + serializeIcebergVariant(Variants.ofTimestamptz(micros))); + assertPrimitiveEncoding( + 9, + Instant.ofEpochSecond(seconds, nanoOfSecond), + Variant::ofTimestampMicrosUtc, + Variant::getInstant, + VariantEncoder::encodeTimestampMicrosUtc, + serializeIcebergVariant(Variants.ofTimestamptz(micros))); + + assertPrimitiveEncoding( + 9, + micros, + Variant::ofTimestampMicrosNtz, + Variant::getTimestampMicros, + VariantEncoder::encodeTimestampMicrosNtz, + serializeIcebergVariant(Variants.ofTimestampntz(micros))); + assertPrimitiveEncoding( + 9, + LocalDateTime.ofEpochSecond(seconds, nanoOfSecond, ZoneOffset.UTC), + Variant::ofTimestampMicrosNtz, + Variant::getLocalDateTime, + VariantEncoder::encodeTimestampMicrosNtz, + serializeIcebergVariant(Variants.ofTimestampntz(micros))); + } + + assertThat(Variant.ofTimestampMicrosUtc(1625079045123456L)).isNotEqualTo(Variant.ofTimestampMicrosUtc(1625079045123457L)); + assertThat(Variant.ofTimestampMicrosUtc(1625079045123456L).longHashCode()).isNotEqualTo(Variant.ofTimestampMicrosUtc(1625079045123457L).longHashCode()); + + assertThat(Variant.ofTimestampMicrosNtz(1625079045123456L)).isNotEqualTo(Variant.ofTimestampMicrosNtz(1625079045123457L)); + assertThat(Variant.ofTimestampMicrosNtz(1625079045123456L).longHashCode()).isNotEqualTo(Variant.ofTimestampMicrosNtz(1625079045123457L).longHashCode()); + } + + @Test + void testTimestampNanosUtc() + { + for (long nanos : List.of(0L, 1L, -1L, 1625079045123456L, -1625079045123456L, Long.MAX_VALUE, Long.MIN_VALUE)) { + long seconds = Math.floorDiv(nanos, 1_000_000_000L); + int nanoOfSecond = (int) Math.floorMod(nanos, 1_000_000_000L); + + assertPrimitiveEncoding( + 9, + nanos, + Variant::ofTimestampNanosUtc, + Variant::getTimestampNanos, + VariantEncoder::encodeTimestampNanosUtc, + serializeIcebergVariant(Variants.ofTimestamptzNanos(nanos))); + assertPrimitiveEncoding( + 9, + Instant.ofEpochSecond(seconds, nanoOfSecond), + Variant::ofTimestampNanosUtc, + Variant::getInstant, + VariantEncoder::encodeTimestampNanosUtc, + serializeIcebergVariant(Variants.ofTimestamptzNanos(nanos))); + + assertPrimitiveEncoding( + 9, + nanos, + Variant::ofTimestampNanosNtz, + Variant::getTimestampNanos, + VariantEncoder::encodeTimestampNanosNtz, + serializeIcebergVariant(Variants.ofTimestampntzNanos(nanos))); + assertPrimitiveEncoding( + 9, + LocalDateTime.ofEpochSecond(seconds, nanoOfSecond, ZoneOffset.UTC), + Variant::ofTimestampNanosNtz, + Variant::getLocalDateTime, + VariantEncoder::encodeTimestampNanosNtz, + serializeIcebergVariant(Variants.ofTimestampntzNanos(nanos))); + } + + assertThat(Variant.ofTimestampNanosUtc(1625079045123456789L)).isNotEqualTo(Variant.ofTimestampNanosUtc(1625079045123456790L)); + assertThat(Variant.ofTimestampNanosUtc(1625079045123456789L).longHashCode()).isNotEqualTo(Variant.ofTimestampNanosUtc(1625079045123456790L).longHashCode()); + + assertThat(Variant.ofTimestampNanosNtz(1625079045123456789L)).isNotEqualTo(Variant.ofTimestampNanosNtz(1625079045123456790L)); + assertThat(Variant.ofTimestampNanosNtz(1625079045123456789L).longHashCode()).isNotEqualTo(Variant.ofTimestampNanosNtz(1625079045123456790L).longHashCode()); + + assertThatThrownBy(() -> Variant.ofTimestampNanosUtc(Instant.MAX)) + .isInstanceOf(ArithmeticException.class); + assertThatThrownBy(() -> Variant.ofTimestampNanosUtc(Instant.MIN)) + .isInstanceOf(ArithmeticException.class); + } + + @Test + void testBinary() + { + assertPrimitiveEncoding( + 5, + Slices.EMPTY_SLICE, + Variant::ofBinary, + Variant::getBinary, + VariantEncoder::encodeBinary, + serializeIcebergVariant(Variants.of(ByteBuffer.allocate(0)))); + + byte[] binaryData = new byte[] {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09}; + Slice binarySlice = wrappedBuffer(binaryData); + + assertPrimitiveEncoding( + 5 + binaryData.length, + binarySlice, + Variant::ofBinary, + Variant::getBinary, + VariantEncoder::encodeBinary, + serializeIcebergVariant(Variants.of(ByteBuffer.wrap(binaryData)))); + + assertThat(Variant.ofBinary(wrappedBuffer(new byte[] {0x01, 0x02}))).isNotEqualTo(Variant.ofBinary(wrappedBuffer(new byte[] {0x03, 0x04}))); + assertThat(Variant.ofBinary(wrappedBuffer(new byte[] {0x01, 0x02})).longHashCode()).isNotEqualTo(Variant.ofBinary(wrappedBuffer(new byte[] {0x03, 0x04})).longHashCode()); + } + + @Test + void testString() + { + for (String text : List.of("", "Hello, Iceberg Variants!", Strings.repeat("Hello, Iceberg Variants!", 100))) { + Slice textSlice = utf8Slice(text); + int expectedSize = text.length() <= 63 ? 1 + textSlice.length() : 5 + textSlice.length(); + assertPrimitiveEncoding( + expectedSize, + textSlice, + Variant::ofString, + Variant::getString, + VariantEncoder::encodeString, + serializeIcebergVariant(Variants.of(text))); + } + + assertThat(Variant.ofString("Variant A")).isNotEqualTo(Variant.ofString("Variant B")); + assertThat(Variant.ofString("Variant A").longHashCode()).isNotEqualTo(Variant.ofString("Variant B").longHashCode()); + } + + @Test + void testUuid() + { + String uuidString = "123e4567-e89b-12d3-a456-426614174000"; + UUID uuid = UUID.fromString(uuidString); + assertPrimitiveEncoding( + 17, + uuid, + Variant::ofUuid, + Variant::getUuid, + VariantEncoder::encodeUuid, + serializeIcebergVariant(Variants.ofUUID(uuidString))); + + Slice uuidSlice = Slices.allocate(16); + uuidSlice.setLong(0, Long.reverseBytes(uuid.getMostSignificantBits())); + uuidSlice.setLong(8, Long.reverseBytes(uuid.getLeastSignificantBits())); + assertPrimitiveEncoding( + 17, + uuidSlice, + Variant::ofUuid, + Variant::getUuidSlice, + VariantEncoder::encodeUuid, + serializeIcebergVariant(Variants.ofUUID(uuidString))); + + assertThat(Variant.ofUuid(UUID.fromString("123e4567-e89b-12d3-a456-426614174000"))) + .isNotEqualTo(Variant.ofUuid(UUID.fromString("223e4567-e89b-12d3-a456-426614174000"))); + } + + private static void assertPrimitiveEncoding(int dataSize, T value, VariantFactory factory, ValueGetter getter, Encoder encoder, Slice expectedEncoding) + { + Variant variant = factory.create(value); + assertThat(getter.get(variant)).isEqualTo(value); + // ensure variant equals itself + assertEqualAndSameHash(variant, variant); + + Slice buffer = Slices.allocate(dataSize); + assertThat(encoder.encode(value, buffer, 0)).isEqualTo(dataSize); + assertThat(Variant.from(EMPTY_METADATA, buffer)).isEqualTo(variant); + assertThat(buffer).isEqualTo(expectedEncoding); + + buffer = Slices.allocate(dataSize + 10); + assertThat(encoder.encode(value, buffer, 5)).isEqualTo(dataSize); + assertThat(Variant.from(EMPTY_METADATA, buffer.slice(5, dataSize))).isEqualTo(variant); + assertThat(buffer.slice(5, dataSize)).isEqualTo(expectedEncoding); + assertThat(buffer.slice(0, 5)).isEqualTo(Slices.allocate(5)); + assertThat(buffer.slice(5 + dataSize, 5)).isEqualTo(Slices.allocate(5)); + + Variant copy = Variant.from(EMPTY_METADATA, buffer.slice(5, dataSize)); + assertEqualAndSameHash(copy, variant); + } + + @Test + void testArrayOneOfEach() + { + List elements = new ArrayList<>(); + ValueArray array = Variants.array(); + + elements.add(Variant.ofByte((byte) 0x12)); + array.add(Variants.of((byte) 0x12)); + elements.add(Variant.ofShort((short) 0x3456)); + array.add(Variants.of((short) 0x3456)); + elements.add(Variant.ofInt(0x789ABCDE)); + array.add(Variants.of(0x789ABCDE)); + elements.add(Variant.ofLong(0x1122334455667788L)); + array.add(Variants.of(0x1122334455667788L)); + elements.add(Variant.ofFloat(3.14f)); + array.add(Variants.of(3.14f)); + elements.add(Variant.ofDouble(Math.E)); + array.add(Variants.of(Math.E)); + elements.add(Variant.ofDate(LocalDate.of(2021, 5, 18))); + array.add(Variants.ofDate(18765)); + elements.add(Variant.ofTimeMicrosNtz(86399999999L)); + array.add(Variants.ofTime(86399999999L)); + elements.add(Variant.ofTimestampMicrosUtc(Instant.ofEpochSecond(1625079045, 123456000))); + array.add(Variants.ofTimestamptz(1625079045123456L)); + elements.add(Variant.ofTimestampNanosUtc(Instant.ofEpochSecond(1625079045, 123456789))); + array.add(Variants.ofTimestamptzNanos(1625079045123456789L)); + elements.add(Variant.ofBinary(wrappedBuffer(new byte[] {0x0A, 0x0B, 0x0C}))); + array.add(Variants.of(ByteBuffer.wrap(new byte[] {0x0A, 0x0B, 0x0C}))); + elements.add(Variant.ofString("Iceberg Variants")); + array.add(Variants.of("Iceberg Variants")); + elements.add(Variant.ofUuid(UUID.fromString("123e4567-e89b-12d3-a456-426614174000"))); + array.add(Variants.ofUUID("123e4567-e89b-12d3-a456-426614174000")); + + assertArrayEncoding(elements, serializeIcebergVariant(array)); + } + + @Test + void testArrayOffsetSizes() + { + int expectedOffsetSize = 1; + // it is not possible to run the larger tests in CI due to memory constraints + for (int size : List.of(0xFF / 9, 0xFFFF / 9/*, 0xFFFFFF / 9, (0xFFFFFF / 9) + 1*/)) { + List elements = new ArrayList<>(size); + ValueArray array = Variants.array(); + int totalSize = 0; + for (int i = 0; i < size; i++) { + elements.add(Variant.ofLong(i)); + array.add(Variants.of((long) i)); + totalSize += 9; + } + assertThat(VariantUtils.getOffsetSize(totalSize)).isEqualTo(expectedOffsetSize); + + assertArrayEncoding(elements, serializeIcebergVariant(array)); + + expectedOffsetSize++; + } + } + + @Test + void testReadOffsetSliceInputUnsignedByte() + { + assertThat(VariantUtils.readOffset(wrappedBuffer((byte) 0x7F).getInput(), 1)).isEqualTo(0x7F); + assertThat(VariantUtils.readOffset(wrappedBuffer((byte) 0x80).getInput(), 1)).isEqualTo(0x80); + assertThat(VariantUtils.readOffset(wrappedBuffer((byte) 0xFF).getInput(), 1)).isEqualTo(0xFF); + + assertThat(VariantUtils.readOffset(wrappedBuffer((byte) 0xFF, (byte) 0x7F).getInput(), 2)).isEqualTo(0x7FFF); + assertThat(VariantUtils.readOffset(wrappedBuffer((byte) 0x00, (byte) 0x80).getInput(), 2)).isEqualTo(0x8000); + assertThat(VariantUtils.readOffset(wrappedBuffer((byte) 0xFF, (byte) 0xFF).getInput(), 2)).isEqualTo(0xFFFF); + + assertThat(VariantUtils.readOffset(wrappedBuffer((byte) 0xFF, (byte) 0xFF, (byte) 0x7F).getInput(), 3)).isEqualTo(0x7FFFFF); + assertThat(VariantUtils.readOffset(wrappedBuffer((byte) 0x00, (byte) 0x00, (byte) 0x80).getInput(), 3)).isEqualTo(0x800000); + assertThat(VariantUtils.readOffset(wrappedBuffer((byte) 0xFF, (byte) 0xFF, (byte) 0xFF).getInput(), 3)).isEqualTo(0xFFFFFF); + } + + private static void assertArrayEncoding(List elements, Slice expectedEncoding) + { + Slice buffer = Slices.allocate(encodedArraySize(elements.size(), elements.stream().mapToInt(element -> element.data().length()).sum())); + assertThat(VariantEncoder.encodeArray(elements.stream().map(Variant::data).toList(), buffer, 0)).isEqualTo(expectedEncoding.length()); + + Variant variant = Variant.from(EMPTY_METADATA, buffer); + assertThat(Variant.ofArray(elements)).isEqualTo(variant); + assertThat(Variant.ofArray(elements).data()).isEqualTo(buffer); + assertThat(Variant.fromObject(elements).toObject()).isEqualTo(variant.toObject()); + + assertThat(variant.getArrayLength()).isEqualTo(elements.size()); + assertThat(variant.arrayElements().toList()).isEqualTo(elements); + assertThat(IntStream.range(0, elements.size()).mapToObj(variant::getArrayElement).toList()) + .isEqualTo(elements); + assertThat(buffer).isEqualTo(expectedEncoding); + + buffer = Slices.allocate(buffer.length() + 10); + assertThat(VariantEncoder.encodeArray(elements.stream().map(Variant::data).toList(), buffer, 5)).isEqualTo(expectedEncoding.length()); + Variant bufferCopy = Variant.from(EMPTY_METADATA, buffer.slice(5, expectedEncoding.length())); + assertThat(Variant.ofArray(elements)).isEqualTo(bufferCopy); + assertThat(bufferCopy.getArrayLength()).isEqualTo(elements.size()); + assertThat(bufferCopy.arrayElements().toList()).isEqualTo(elements); + assertThat(buffer.slice(5, expectedEncoding.length())).isEqualTo(expectedEncoding); + + assertEqualAndSameHash(bufferCopy, variant); + + List expectedJavaObjects = elements.stream() + .map(Variant::toObject) + .toList(); + assertThat(variant.toObject()).isEqualTo(expectedJavaObjects); + assertThat(Variant.fromObject(expectedJavaObjects).toObject()).isEqualTo(expectedJavaObjects); + + Variant fromObjectCopy = Variant.fromObject(elements); + assertThat(fromObjectCopy.toObject()).isEqualTo(expectedJavaObjects); + assertEqualAndSameHash(fromObjectCopy, variant); + + assertThat(Variant.fromObject(List.of(expectedJavaObjects))).isNotEqualTo(variant); + assertThat(Variant.fromObject(List.of(expectedJavaObjects)).longHashCode()).isNotEqualTo(variant.longHashCode()); + } + + @Test + void testObjectOneOfEach() + { + List fieldNames = new ArrayList<>(); + List fields = new ArrayList<>(); + List icebergValues = new ArrayList<>(); + + int fieldId = 0; + fieldNames.add(utf8Slice("byteField")); + fields.add(new ObjectField(fieldId, Variant.ofByte((byte) 0x12))); + icebergValues.add(Variants.of((byte) 0x12)); + fieldId++; + + fieldNames.add(utf8Slice("shortField")); + fields.add(new ObjectField(fieldId, Variant.ofShort((short) 0x3456))); + icebergValues.add(Variants.of((short) 0x3456)); + fieldId++; + + fieldNames.add(utf8Slice("intField")); + fields.add(new ObjectField(fieldId, Variant.ofInt(0x789ABCDE))); + icebergValues.add(Variants.of(0x789ABCDE)); + fieldId++; + + fieldNames.add(utf8Slice("longField")); + fields.add(new ObjectField(fieldId, Variant.ofLong(0x1122334455667788L))); + icebergValues.add(Variants.of(0x1122334455667788L)); + fieldId++; + + fieldNames.add(utf8Slice("floatField")); + fields.add(new ObjectField(fieldId, Variant.ofFloat(3.14f))); + icebergValues.add(Variants.of(3.14f)); + fieldId++; + + fieldNames.add(utf8Slice("doubleField")); + fields.add(new ObjectField(fieldId, Variant.ofDouble(Math.E))); + icebergValues.add(Variants.of(Math.E)); + fieldId++; + + fieldNames.add(utf8Slice("dateField")); + fields.add(new ObjectField(fieldId, Variant.ofDate(LocalDate.of(2021, 5, 18)))); + icebergValues.add(Variants.ofDate(18765)); + fieldId++; + + fieldNames.add(utf8Slice("timeMicrosField")); + fields.add(new ObjectField(fieldId, Variant.ofTimeMicrosNtz(86399999999L))); + icebergValues.add(Variants.ofTime(86399999999L)); + fieldId++; + + fieldNames.add(utf8Slice("timestampMicrosUtcField")); + fields.add(new ObjectField(fieldId, Variant.ofTimestampMicrosUtc(Instant.ofEpochSecond(1625079045, 123456000)))); + icebergValues.add(Variants.ofTimestamptz(1625079045123456L)); + fieldId++; + + fieldNames.add(utf8Slice("timestampNanosUtcField")); + fields.add(new ObjectField(fieldId, Variant.ofTimestampNanosUtc(Instant.ofEpochSecond(1625079045, 123456789)))); + icebergValues.add(Variants.ofTimestamptzNanos(1625079045123456789L)); + fieldId++; + + fieldNames.add(utf8Slice("binaryField")); + fields.add(new ObjectField(fieldId, Variant.ofBinary(wrappedBuffer(new byte[] {0x0A, 0x0B, 0x0C})))); + icebergValues.add(Variants.of(ByteBuffer.wrap(new byte[] {0x0A, 0x0B, 0x0C}))); + fieldId++; + fieldNames.add(utf8Slice("stringField")); + fields.add(new ObjectField(fieldId, Variant.ofString("Iceberg Variants"))); + icebergValues.add(Variants.of("Iceberg Variants")); + fieldId++; + fieldNames.add(utf8Slice("uuidField")); + fields.add(new ObjectField(fieldId, Variant.ofUuid(UUID.fromString("123e4567-e89b-12d3-a456-426614174000")))); + icebergValues.add(Variants.ofUUID("123e4567-e89b-12d3-a456-426614174000")); + + TestingMetadata testingMetadata = TestingMetadata.of(fieldNames); + assertObjectEncoding(testingMetadata.metadata(), fields, testingMetadata.icebergMetadata(), icebergValues); + } + + @Test + void testObjectOffsetSizes() + { + // Build a single large dictionary once that is shared across all tests + TestingMetadata testingMetadata = TestingMetadata.of(IntStream.range(0, 0xFFFFFF + 1) + .mapToObj("%08d"::formatted) + .map(Slices::utf8Slice) + .peek(fieldName -> assertThat(fieldName.length()).isEqualTo(8)) + .toList()); + + int expectedFieldIdSize = 1; + // it is not possible to run the larger tests in CI due to memory constraints + for (int fieldCount : List.of(0xFF, 0xFFFF/*, 0xFFFFFF, 0xFFFFFF + 1*/)) { + List fields = new ArrayList<>(); + List icebergValues = new ArrayList<>(); + + for (int i = 0; i < fieldCount; i++) { + fields.add(new ObjectField(i, Variant.ofLong(i))); + icebergValues.add(Variants.of((long) i)); + } + assertThat(VariantUtils.getOffsetSize(fieldCount)).isEqualTo(expectedFieldIdSize); + + assertObjectEncoding(testingMetadata.metadata(), fields, testingMetadata.icebergMetadata(), icebergValues); + + expectedFieldIdSize++; + } + } + + private static void assertObjectEncoding(Metadata metadata, List fields, VariantMetadata variantMetadata, List icebergValues) + { + // sort the fields by field name to determine expected order + List sortedFields = fields.stream() + .sorted(Comparator.comparing(objectFieldIdValue -> metadata.get(objectFieldIdValue.fieldId()))) + .toList(); + + // Iceberg code does not have a method to encode a variant object; it only supports deserialization. + Slice buffer = encodeObjectWithSortedFields(sortedFields); + + List sortedFieldNames = sortedFields.stream() + .map(ObjectField::fieldId) + .map(metadata::get) + .map(Slice::toStringUtf8) + .toList(); + + Variant variant = Variant.from(metadata, buffer); + assertThat(variant.getObjectFieldCount()).isEqualTo(sortedFields.size()); + assertThat(variant.objectFieldNames().map(Slice::toStringUtf8).toList()).isEqualTo(sortedFieldNames); + assertThat(variant.objectValues().map(Variant::data).toList()).isEqualTo(sortedFields.stream().map(ObjectField::variantValue).toList()); + assertThat(variant.objectFields().map(field -> new ObjectField(field.fieldId(), field.value().data())).toList()).isEqualTo(sortedFields); + + // verify variant equals itself + assertEqualAndSameHash(variant, variant); + + Map expectedJavaObject = new HashMap<>(); + for (ObjectField field : sortedFields) { + expectedJavaObject.put(metadata.get(field.fieldId()).toStringUtf8(), Variant.from(metadata, field.variantValue()).toObject()); + } + assertThat(variant.toObject()).isEqualTo(expectedJavaObject); + assertThat(Variant.fromObject(expectedJavaObject).toObject()).isEqualTo(expectedJavaObject); + + Map variantFields = new HashMap<>(); + for (ObjectField field : sortedFields) { + variantFields.put(metadata.get(field.fieldId()), Variant.from(metadata, field.variantValue())); + } + Variant ofObjectCopy = Variant.ofObject(variantFields); + assertThat(ofObjectCopy.toObject()).isEqualTo(expectedJavaObject); + assertEqualAndSameHash(ofObjectCopy, variant); + + Variant fromObjectCopy = Variant.fromObject(variantFields); + assertThat(fromObjectCopy.toObject()).isEqualTo(expectedJavaObject); + assertEqualAndSameHash(fromObjectCopy, variant); + + if (sortedFields.size() < 500) { + for (ObjectField field : sortedFields) { + assertFieldLookup(field, variant, metadata); + } + } + else { + ThreadLocalRandom.current().ints(500, 0, sortedFields.size()) + .mapToObj(sortedFields::get) + .forEach(field -> assertFieldLookup(field, variant, metadata)); + } + + VariantObject icebergObject = (VariantObject) VariantValue.from(variantMetadata, buffer.toByteBuffer().order(ByteOrder.LITTLE_ENDIAN)); + assertThat(icebergObject.numFields()).isEqualTo(sortedFields.size()); + assertThat(ImmutableList.copyOf(icebergObject.fieldNames())).isEqualTo(sortedFieldNames); + if (sortedFields.size() < 500) { + for (int i = 0; i < sortedFields.size(); i++) { + assertThat(icebergObject.get(variantMetadata.get(i))).isEqualTo(icebergValues.get(i)); + } + } + else { + ThreadLocalRandom.current().ints(500, 0, sortedFields.size()) + .forEach(i -> assertThat(icebergObject.get(variantMetadata.get(i))).isEqualTo(icebergValues.get(i))); + } + + Slice offsetBuffer = Slices.allocate(buffer.length() + 10); + assertThat(encodeObject( + sortedFields.size(), + i -> sortedFields.get(i).fieldId(), + i -> sortedFields.get(i).variantValue(), + offsetBuffer, + 5)) + .isEqualTo(buffer.length()); + assertThat(offsetBuffer.slice(5, buffer.length())).isEqualTo(buffer); + assertThat(offsetBuffer.slice(0, 5)).isEqualTo(Slices.allocate(5)); + assertThat(offsetBuffer.slice(5 + buffer.length(), 5)).isEqualTo(Slices.allocate(5)); + + assertThat(Variant.fromObject(Map.of("all", variant))).isNotEqualTo(variant); + assertThat(Variant.fromObject(Map.of("all", variant)).longHashCode()).isNotEqualTo(variant.longHashCode()); + } + + private static void assertFieldLookup(ObjectField field, Variant deserializedObjectVariant, Metadata metadata) + { + assertThat(deserializedObjectVariant.getObjectField(field.fieldId()).map(Variant::data)).contains(field.variantValue()); + assertThat(metadata.get(field.fieldId()).toStringUtf8()).isEqualTo(metadata.get(field.fieldId()).toStringUtf8()); + Slice fieldName = metadata.get(field.fieldId()); + assertThat(metadata.id(fieldName)).isEqualTo(field.fieldId()); + assertThat(deserializedObjectVariant.getObjectField(fieldName).map(Variant::data)).contains(field.variantValue()); + } + + private record TestingMetadata(Metadata metadata, VariantMetadata icebergMetadata) + { + private static TestingMetadata of(List fieldNames) + { + Metadata metadata = Metadata.of(fieldNames); + // this saves some memory, but iceberg does eventually inflate accessed field names to strings + VariantMetadata icebergMetadata = Variants.metadata(metadata.toSlice().toByteBuffer().order(ByteOrder.LITTLE_ENDIAN)); + TestingMetadata result = new TestingMetadata(metadata, icebergMetadata); + return result; + } + } + + private interface VariantFactory + { + Variant create(T value); + } + + private interface ValueGetter + { + T get(Variant value); + } + + private interface Encoder + { + int encode(T value, Slice buffer, int offset); + } + + private static Slice serializeIcebergVariant(VariantValue metadata) + { + int size = metadata.sizeInBytes(); + byte[] array = new byte[size]; + ByteBuffer valueBuf = ByteBuffer.wrap(array).order(ByteOrder.LITTLE_ENDIAN); + metadata.writeTo(valueBuf, 0); + return wrappedBuffer(array); + } + + @Test + void testFromObjectPrimitiveRoundTrip() + { + assertThat(Variant.fromObject(null)).isEqualTo(Variant.NULL_VALUE); + + assertThat(Variant.fromObject(true).toObject()).isEqualTo(true); + assertThat(Variant.fromObject((byte) 12).toObject()).isEqualTo((byte) 12); + assertThat(Variant.fromObject((short) 123).toObject()).isEqualTo((short) 123); + assertThat(Variant.fromObject(123).toObject()).isEqualTo(123); + assertThat(Variant.fromObject(123L).toObject()).isEqualTo(123L); + assertThat(Variant.fromObject(1.5f).toObject()).isEqualTo(1.5f); + assertThat(Variant.fromObject(1.5d).toObject()).isEqualTo(1.5d); + + BigDecimal decimal = new BigDecimal("1234.5678"); + assertThat(Variant.fromObject(decimal).toObject()).isEqualTo(decimal); + + LocalDate date = LocalDate.of(2024, 10, 24); + assertThat(Variant.fromObject(date).toObject()).isEqualTo(date); + + Instant instant = Instant.parse("2024-10-24T12:34:56.123456789Z"); + assertThat(Variant.fromObject(instant).toObject()).isEqualTo(instant); + + LocalDateTime dateTime = LocalDateTime.parse("2024-10-24T12:34:56.123456789"); + assertThat(Variant.fromObject(dateTime).toObject()).isEqualTo(dateTime); + + UUID uuid = UUID.fromString("123e4567-e89b-12d3-a456-426614174000"); + assertThat(Variant.fromObject(uuid).toObject()).isEqualTo(uuid); + + Slice binary = wrappedBuffer(new byte[] {0x01, 0x02, 0x03}); + assertThat(Variant.fromObject(binary).toObject()).isEqualTo(binary); + assertThat(Variant.fromObject(binary.getBytes()).toObject()).isEqualTo(binary); + } + + @Test + void testOfArrayAndFromObjectArray() + { + List elements = List.of(Variant.ofInt(1), Variant.ofString("two"), Variant.NULL_VALUE, Variant.ofBoolean(true)); + + Variant array = Variant.ofArray(elements); + assertThat(array.getArrayLength()).isEqualTo(elements.size()); + assertThat(array.arrayElements().toList()).isEqualTo(elements); + + assertThat(Variant.fromObject(elements).toObject()).isEqualTo(array.toObject()); + } + + @Test + void testOfObjectAndFromObjectObject() + { + Map fields = new HashMap<>(); + fields.put(utf8Slice("b"), Variant.ofInt(2)); + fields.put(utf8Slice("a"), Variant.ofInt(1)); + fields.put(utf8Slice("c"), Variant.ofString("three")); + + Variant object = Variant.ofObject(fields); + + assertThat(object.getObjectFieldCount()).isEqualTo(3); + assertThat(object.objectFieldNames().map(Slice::toStringUtf8).toList()).isEqualTo(List.of("a", "b", "c")); + assertThat(object.toObject()).isEqualTo(Map.of("a", 1, "b", 2, "c", "three")); + + Map javaObject = Map.of("b", 2, "a", 1, "c", "three"); + assertThat(Variant.fromObject(javaObject).toObject()).isEqualTo(object.toObject()); + } + + @Test + void testObjectVariantEncodingWithOutOfOrderValueRegion() + { + Variant variant = objectVariantWithValueOrder( + List.of("flag", "id", "vc"), + List.of(Variant.ofBoolean(true), Variant.ofByte((byte) 42), Variant.ofString("x")), + List.of(1, 2, 0)); + + Slice metadataSlice = variant.metadata().toSlice(); + Slice valueSlice = variant.data(); + + VariantMetadata icebergMetadata = VariantMetadata.from(metadataSlice.toByteBuffer().order(ByteOrder.LITTLE_ENDIAN)); + VariantObject icebergObject = (VariantObject) VariantValue.from(icebergMetadata, valueSlice.toByteBuffer().order(ByteOrder.LITTLE_ENDIAN)); + assertThat(ImmutableList.copyOf(icebergObject.fieldNames())).isEqualTo(List.of("flag", "id", "vc")); + + assertThat(variant.getObjectFieldCount()).isEqualTo(3); + assertThat(variant.objectFieldNames().map(Slice::toStringUtf8).toList()).isEqualTo(List.of("flag", "id", "vc")); + assertThat(variant.objectValues().map(Variant::toObject).toList()).isEqualTo(List.of(true, (byte) 42, "x")); + assertThat(variant.getObjectField(utf8Slice("flag"))).hasValueSatisfying(value -> assertThat(value.getBoolean()).isTrue()); + assertThat(variant.getObjectField(utf8Slice("id"))).hasValueSatisfying(value -> assertThat(value.getByte()).isEqualTo((byte) 42)); + assertThat(variant.getObjectField(utf8Slice("vc"))).hasValueSatisfying(value -> assertThat(value.getStringUtf8()).isEqualTo("x")); + assertThat(variant.objectFields() + .map(field -> Map.entry(variant.metadata().get(field.fieldId()).toStringUtf8(), field.value().toObject())) + .toList()) + .isEqualTo(List.of(Map.entry("flag", true), Map.entry("id", (byte) 42), Map.entry("vc", "x"))); + assertThat(variant.toObject()).isEqualTo(Map.of("flag", true, "id", (byte) 42, "vc", "x")); + assertEqualAndSameHash(variant, Variant.fromObject(Map.of("flag", true, "id", 42, "vc", "x"))); + } + + @Test + void testObjectVariantEncodingWithOutOfOrderValueRegionAndFieldRemapper() + { + Variant variant = objectVariantWithValueOrder( + List.of("flag", "id", "vc"), + List.of(Variant.ofBoolean(true), Variant.ofByte((byte) 42), Variant.ofString("x")), + List.of(1, 2, 0)); + + // Run the variant through the remapper. This verifies the remapper code can process variants with + // fields not written in canonical order. + Metadata.Builder metadataBuilder = Metadata.builder(); + VariantFieldRemapper remapper = VariantFieldRemapper.create(variant, metadataBuilder); + Metadata.Builder.SortedMetadata sortedMetadata = metadataBuilder.buildSorted(); + remapper.finalize(sortedMetadata.sortedFieldIdMapping()); + + Slice remappedData = Slices.allocate(remapper.size()); + assertThat(remapper.write(remappedData, 0)).isEqualTo(remappedData.length()); + + Variant remapped = Variant.from(sortedMetadata.metadata(), remappedData); + assertThat(remapped.toObject()).isEqualTo(Map.of("flag", true, "id", (byte) 42, "vc", "x")); + assertEqualAndSameHash(remapped, variant); + } + + @Test + void testFromObjectMapKeyValidation() + { + Map nonStringKey = new HashMap<>(); + nonStringKey.put(123, 1); + assertThatThrownBy(() -> Variant.fromObject(nonStringKey)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Map key must be a String"); + + Map nullKey = new HashMap<>(); + nullKey.put(null, 1); + assertThatThrownBy(() -> Variant.fromObject(nullKey)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Map key is null"); + } + + @Test + void testEqualsShortStringAndLongStringAreEquivalent() + { + Slice utf8 = utf8Slice("abc"); + + // Short-string encoding (VariantEncoder will pick short string) + Slice shortEncoded = Slices.allocate(1 + utf8.length()); + assertThat(VariantEncoder.encodeString(utf8, shortEncoded, 0)).isEqualTo(1 + utf8.length()); + Variant shortVariant = Variant.from(EMPTY_METADATA, shortEncoded); + + // Long-string encoding forced: [header][int length][bytes] + Slice longEncoded = Slices.allocate(1 + Integer.BYTES + utf8.length()); + longEncoded.setByte(0, Header.primitiveHeader(Header.PrimitiveType.STRING)); + longEncoded.setInt(1, utf8.length()); + longEncoded.setBytes(1 + Integer.BYTES, utf8); + Variant longVariant = Variant.from(EMPTY_METADATA, longEncoded); + + assertEqualAndSameHash(shortVariant, longVariant); + } + + @Test + void testEqualsIntAndDecimalAreEquivalent() + { + Variant intVariant = Variant.ofByte((byte) 1); + + // decimal16 scale=2, unscaled=100 + Slice decimalEncoded = Slices.allocate(ENCODED_DECIMAL16_SIZE); + encodeDecimal16(0, 100, 2, decimalEncoded, 0); + Variant decimalVariant = Variant.from(EMPTY_METADATA, decimalEncoded); + + assertEqualAndSameHash(intVariant, decimalVariant); + } + + @Test + void testEqualsTimestampUtcMicrosAndNanosAreEquivalentWhenExactlyRepresentable() + { + long micros = 1_700_000_000_000_000L; + long nanos = micros * 1_000L; + + Variant microsVariant = Variant.ofTimestampMicrosUtc(micros); + Variant nanosVariant = Variant.ofTimestampNanosUtc(nanos); + + assertEqualAndSameHash(microsVariant, nanosVariant); + } + + @Test + void testEqualsTimestampUtcMicrosAndNanosNotEquivalentWhenNanosHasSubMicroRemainder() + { + long micros = 1_700_000_000_000_000L; + long nanos = micros * 1_000L + 1; // not representable in micros + + Variant microsVariant = Variant.ofTimestampMicrosUtc(micros); + Variant nanosVariant = Variant.ofTimestampNanosUtc(nanos); + + assertThat(microsVariant).isNotEqualTo(nanosVariant); + } + + @Test + void testFloatAndDoubleAreEquivalentWhenNumericallyEqual() + { + assertEqualAndSameHash(Variant.ofFloat(1.0f), Variant.ofDouble(1.0)); + assertEqualAndSameHash(Variant.ofFloat(-0.0f), Variant.ofDouble(0.0)); + assertEqualAndSameHash(Variant.ofFloat(Float.MIN_VALUE), Variant.ofDouble((double) Float.MIN_VALUE)); + } + + @Test + void testEqualsNumericExactAndFloatingWhenExactlyRepresentable() + { + for (long value : List.of(-1L, 0L, 1L, 42L, 1L << 40, 1L << 53)) { + assertEqualAndSameHash(Variant.ofDouble((double) value), Variant.ofLong(value)); + } + for (int value : List.of(-1, 0, 1, 42, 1 << 20)) { + assertEqualAndSameHash(Variant.ofFloat((float) value), Variant.ofInt(value)); + } + + for (String decimalString : List.of("0.5", "0.25", "0.125", "1.5", "-1.25", "10.75")) { + BigDecimal decimal = new BigDecimal(decimalString); + assertEqualAndSameHash(Variant.ofDouble(decimal.doubleValue()), decimal); + } + assertEqualAndSameHash(Variant.ofFloat(0.5f), new BigDecimal("0.5")); + + BigInteger largeExactInteger = BigInteger.ONE.shiftLeft(70); + assertEqualAndSameHash( + Variant.ofDouble(Math.scalb(1.0, 70)), + Variant.ofDecimal(new BigDecimal(largeExactInteger, 0))); + } + + @Test + void testEqualsNumericExactAndFloatingWhenNotExactlyRepresentable() + { + for (String decimalString : List.of("0.1", "0.2", "0.3", "0.7", "1.1")) { + BigDecimal decimal = new BigDecimal(decimalString); + assertNotEqualAndDifferentHash(Variant.ofDouble(decimal.doubleValue()), decimal); + assertNotEqualAndDifferentHash(Variant.ofFloat(decimal.floatValue()), decimal); + } + + assertThat(Variant.ofDouble(Double.MIN_VALUE)).isNotEqualTo(Variant.ofDecimal(BigDecimal.ZERO)); + assertThat(Variant.ofFloat(Float.MIN_VALUE)).isNotEqualTo(Variant.ofDecimal(BigDecimal.ZERO)); + } + + @Test + void testEqualsNumericInfinityAndNaNAcrossFloatingTypes() + { + assertEqualAndSameHash(Variant.ofFloat(Float.POSITIVE_INFINITY), Variant.ofDouble(Double.POSITIVE_INFINITY)); + assertEqualAndSameHash(Variant.ofFloat(Float.NEGATIVE_INFINITY), Variant.ofDouble(Double.NEGATIVE_INFINITY)); + + assertThat(Variant.ofDouble(Double.POSITIVE_INFINITY)).isNotEqualTo(Variant.ofLong(Long.MAX_VALUE)); + assertThat(Variant.ofDouble(Double.NEGATIVE_INFINITY)).isNotEqualTo(Variant.ofLong(Long.MIN_VALUE)); + + assertThat(Variant.ofFloat(Float.NaN)).isNotEqualTo(Variant.ofDouble(Double.NaN)); + } + + @Test + void testEqualsNumericDecimal16LargeScaleExactLongViaDouble() + { + Slice encoded = Slices.allocate(ENCODED_DECIMAL16_SIZE); + encodeDecimal16(BigInteger.TEN.pow(19), 19, encoded, 0); + Variant decimal = Variant.from(EMPTY_METADATA, encoded); + + assertEqualAndSameHash(decimal, Variant.ofDouble(1.0)); + assertEqualAndSameHash(decimal, Variant.ofDouble(1.0f)); + } + + @Test + void testEqualsNumericDecimal16LargeScaleLongMultiplyOverflowViaDouble() + { + Slice encoded = Slices.allocate(ENCODED_DECIMAL16_SIZE); + encodeDecimal16(BigInteger.ONE, 38, encoded, 0); + Variant decimal = Variant.from(EMPTY_METADATA, encoded); + + // In decimal-vs-floating comparison this goes through the exact-long route for 2.0, + // and value * 10^38 overflows Int128 during the DECIMAL16 exact-multiple check. + assertThat(decimal).isNotEqualTo(Variant.ofDouble(2.0)); + assertThat(decimal).isNotEqualTo(Variant.ofDouble(-2.0)); + } + + @Test + void testEqualsNumericLargeMagnitudeDoubleCannotConvertToDecimal128Exact() + { + assertThat(Variant.ofLong(0)).isNotEqualTo(Variant.ofDouble(1e38)); + assertThat(Variant.ofLong(0)).isNotEqualTo(Variant.ofDouble(-1e38)); + } + + @Test + void testEqualsNumericDoubleDecimalConversionOverflowFallsBackToNotEqual() + { + double value = Math.scalb((double) ((1L << 53) - 1), -38); + assertThat(Variant.ofLong(0)).isNotEqualTo(Variant.ofDouble(value)); + assertThat(Variant.ofLong(0)).isNotEqualTo(Variant.ofDouble(-value)); + } + + @Test + void testEqualsNumericZeroDecimal16HighScaleCanonicalizesWithIntegerZero() + { + Slice encoded = Slices.allocate(ENCODED_DECIMAL16_SIZE); + encodeDecimal16(BigInteger.ZERO, 38, encoded, 0); + Variant decimalZero = Variant.from(EMPTY_METADATA, encoded); + + assertEqualAndSameHash(decimalZero, Variant.ofLong(0)); + assertEqualAndSameHash(decimalZero, Variant.ofDouble(0.0)); + assertEqualAndSameHash(decimalZero, Variant.ofDouble(-0.0)); + } + + @Test + void testDoubleNaNPayloadsHashEqualButAreNotEqual() + { + Variant nan1 = Variant.ofDouble(Double.longBitsToDouble(0x7ff8_0000_0000_0001L)); + Variant nan2 = Variant.ofDouble(Double.longBitsToDouble(0x7ff8_0000_0000_0002L)); + + assertThat(nan1).isNotEqualTo(nan2); + assertThat(nan1.longHashCode()).isEqualTo(nan2.longHashCode()); + } + + @Test + void testEqualsObjectWhenMetadataDictionariesDiffer() + { + Variant left = Variant.from( + Metadata.of(List.of(utf8Slice("a"), utf8Slice("b"))), + encodeObjectWithSortedFields(List.of( + new ObjectField(0, Variant.ofInt(123)), + new ObjectField(1, Variant.ofString("hello"))))); + Variant right = Variant.from( + Metadata.of(List.of(utf8Slice("b"), utf8Slice("a"))), + encodeObjectWithSortedFields(List.of( + new ObjectField(1, Variant.ofInt(123)), + new ObjectField(0, Variant.ofString("hello"))))); + + // ensure we got different metadata dictionaries + assertThat(left.metadata().get(0)).isNotEqualTo(right.metadata().get(0)); + assertThat(left.metadata().get(1)).isNotEqualTo(right.metadata().get(1)); + + assertEqualAndSameHash(left, right); + } + + private static void assertEqualAndSameHash(Variant leftValue, Variant rightValue) + { + assertThat(leftValue).isEqualTo(rightValue); + assertThat(leftValue.longHashCode()).isEqualTo(rightValue.longHashCode()); + // some implementations have separate left and right side handling, so check both directions + assertThat(rightValue).isEqualTo(leftValue); + assertThat(rightValue.longHashCode()).isEqualTo(leftValue.longHashCode()); + + Variant leftArray = Variant.ofArray(List.of(leftValue)); + Variant rightArray = Variant.ofArray(List.of(rightValue)); + assertThat(leftArray).isEqualTo(rightArray); + assertThat(leftArray.longHashCode()).isEqualTo(rightArray.longHashCode()); + assertThat(rightArray).isEqualTo(leftArray); + assertThat(rightArray.longHashCode()).isEqualTo(leftArray.longHashCode()); + + Variant leftObject = Variant.ofObject(Map.of(utf8Slice("field"), leftValue)); + Variant rightObject = Variant.ofObject(Map.of(utf8Slice("field"), rightValue)); + assertThat(leftObject).isEqualTo(rightObject); + assertThat(leftObject.longHashCode()).isEqualTo(rightObject.longHashCode()); + assertThat(rightObject).isEqualTo(leftObject); + assertThat(rightObject.longHashCode()).isEqualTo(leftObject.longHashCode()); + } + + private static void assertEqualAndSameHash(Variant leftValue, BigDecimal bigDecimal) + { + for (Variant rightValue : allDecimalEncodings(bigDecimal)) { + assertEqualAndSameHash(leftValue, rightValue); + } + } + + private static void assertNotEqualAndDifferentHash(Variant leftValue, Variant rightValue) + { + assertThat(leftValue).isNotEqualTo(rightValue); + assertThat(leftValue.longHashCode()).isNotEqualTo(rightValue.longHashCode()); + // some implementations have separate left and right side handling, so check both directions + assertThat(rightValue).isNotEqualTo(leftValue); + assertThat(rightValue.longHashCode()).isNotEqualTo(leftValue.longHashCode()); + } + + private static void assertNotEqualAndDifferentHash(Variant leftValue, BigDecimal bigDecimal) + { + for (Variant rightValue : allDecimalEncodings(bigDecimal)) { + assertNotEqualAndDifferentHash(leftValue, rightValue); + } + } + + private static List allDecimalEncodings(BigDecimal bigDecimal) + { + List variants = new ArrayList<>(); + BigInteger unscaled = bigDecimal.unscaledValue(); + int scale = bigDecimal.scale(); + if (unscaled.bitLength() < 32) { + Slice data = Slices.allocate(ENCODED_DECIMAL4_SIZE); + encodeDecimal4(unscaled.intValue(), scale, data, 0); + variants.add(Variant.from(EMPTY_METADATA, data)); + } + if (unscaled.bitLength() < 64) { + Slice data = Slices.allocate(ENCODED_DECIMAL8_SIZE); + encodeDecimal8(unscaled.longValue(), scale, data, 0); + variants.add(Variant.from(EMPTY_METADATA, data)); + } + if (unscaled.bitLength() > 128) { + throw new IllegalArgumentException("Decimal precision out of range: " + unscaled.bitLength()); + } + Slice data = Slices.allocate(ENCODED_DECIMAL16_SIZE); + encodeDecimal16(unscaled, scale, data, 0); + variants.add(Variant.from(EMPTY_METADATA, data)); + return variants; + } + + // Builds a variant with the exact specified field order. Variants by spec are required to have fields sorted ordered by field name. + // This method assumes that the caller has already sorted the fields by field name. + // This method is necessary to build test variants without global sorting in the metadata dictionary, as all convenience methods + // on Variant build a metadata dictionary with global sorting. + private static Slice encodeObjectWithSortedFields(List fields) + { + int expectedSize = encodedObjectSize( + fields.stream() + .mapToInt(ObjectField::fieldId) + .max() + .orElse(0), + fields.size(), + fields.stream() + .mapToInt(field -> field.variantValue().length()) + .sum()); + Slice output = Slices.allocate(expectedSize); + + int written = encodeObject( + fields.size(), + i -> fields.get(i).fieldId(), + i -> fields.get(i).variantValue(), + output, + 0); + verify(written == expectedSize, "written size does not match expected size"); + return output; + } + + private static Variant objectVariantWithValueOrder(List fieldNames, List fieldValues, List physicalOrder) + { + assertThat(fieldNames).hasSameSizeAs(fieldValues); + assertThat(physicalOrder).hasSize(fieldValues.size()); + + int fieldCount = fieldValues.size(); + int[] seen = new int[fieldCount]; + int[] valueOffsets = new int[fieldCount]; + int totalLength = 0; + for (int logicalIndex : physicalOrder) { + checkArgument(logicalIndex >= 0 && logicalIndex < fieldCount, () -> "Invalid field index: " + logicalIndex); + checkArgument(++seen[logicalIndex] == 1, () -> "Duplicate field index: " + logicalIndex); + valueOffsets[logicalIndex] = totalLength; + totalLength += fieldValues.get(logicalIndex).data().length(); + } + for (int logicalIndex = 0; logicalIndex < fieldCount; logicalIndex++) { + checkArgument(seen[logicalIndex] == 1, "Missing field index: " + logicalIndex); + } + + Metadata metadata = Metadata.of(fieldNames.stream().map(Slices::utf8Slice).toList()); + int fieldIdSize = getOffsetSize(fieldCount - 1); + int offsetSize = getOffsetSize(totalLength); + boolean large = fieldCount > 255; + + int headerSize = encodedObjectSize(fieldCount - 1, fieldCount, totalLength) - totalLength; + Slice data = Slices.allocate(headerSize + totalLength); + + int position = 0; + data.setByte(position, Header.objectHeader(fieldIdSize, offsetSize, large)); + position++; + if (large) { + data.setInt(position, fieldCount); + position += Integer.BYTES; + } + else { + data.setByte(position, (byte) fieldCount); + position++; + } + + for (int fieldId = 0; fieldId < fieldCount; fieldId++) { + writeOffset(data, position, fieldId, fieldIdSize); + position += fieldIdSize; + } + for (int fieldId = 0; fieldId < fieldCount; fieldId++) { + writeOffset(data, position, valueOffsets[fieldId], offsetSize); + position += offsetSize; + } + writeOffset(data, position, totalLength, offsetSize); + + int writePosition = headerSize; + for (int logicalIndex : physicalOrder) { + Slice fieldValue = fieldValues.get(logicalIndex).data(); + data.setBytes(writePosition, fieldValue); + writePosition += fieldValue.length(); + } + + return Variant.from(metadata, data); + } + + private record ObjectField(int fieldId, Slice variantValue) + { + private ObjectField(int fieldId, Variant variant) + { + this(fieldId, variant.data()); + assertThat(variant.metadata()).isEqualTo(EMPTY_METADATA); + } + } +} diff --git a/docs/src/main/sphinx/connector/iceberg.md b/docs/src/main/sphinx/connector/iceberg.md index bd17f6b60aa1..def1c4b8ed43 100644 --- a/docs/src/main/sphinx/connector/iceberg.md +++ b/docs/src/main/sphinx/connector/iceberg.md @@ -4,10 +4,11 @@ ``` -Apache Iceberg is an open table format for huge analytic datasets. The Iceberg -connector allows querying data stored in files written in Iceberg format, as -defined in the [Iceberg Table Spec](https://iceberg.apache.org/spec/). The -connector supports Apache Iceberg table spec versions 1 and 2. +Apache Iceberg is an open table format for huge analytic datasets. +The Iceberg connector allows querying data stored in files written in Iceberg +format, as defined in the [Iceberg Table Spec](https://iceberg.apache.org/spec/). +The connector supports Apache Iceberg table spec versions 1 and 2. +Support for format version 3 is experimental. The table state is maintained in metadata files. All changes to table state create a new metadata file and replace the old metadata with an atomic @@ -341,6 +342,8 @@ the following table: - `VARBINARY` * - `FIXED (L)` - `VARBINARY` +* - `VARIANT` + - `VARIANT` * - `STRUCT(...)` - `ROW(...)` * - `LIST(e)` @@ -395,6 +398,8 @@ the following table: - `UUID` * - `VARBINARY` - `BINARY` +* - `VARIANT` + - `VARIANT` * - `ROW(...)` - `STRUCT(...)` * - `ARRAY(e)` @@ -403,6 +408,12 @@ the following table: - `MAP(k,v)` ::: +:::{note} +Iceberg `VARIANT` is supported only for tables using Iceberg format version `3` +or higher. To create a table with `VARIANT` columns, set +`format_version = 3` in the `WITH` clause. The default is `2`. +::: + No other types are supported. ## Security @@ -1078,6 +1089,7 @@ connector using a {doc}`WITH ` clause. for row level deletes. Version `3` support is experimental; row-level updates, deletes, and OPTIMIZE are not supported. Tables with v3 features such as column default values and encryption are not supported. + Version `3` is required for tables containing `VARIANT` columns. * - `max_commit_retry` - Number of times to retry a commit before failing. Defaults to the value of the `iceberg.max-commit-retry` catalog configuration property, which diff --git a/docs/src/main/sphinx/functions.md b/docs/src/main/sphinx/functions.md index 6acac7f6ef0e..7e812a7282a3 100644 --- a/docs/src/main/sphinx/functions.md +++ b/docs/src/main/sphinx/functions.md @@ -63,5 +63,6 @@ Teradata T-Digest URL UUID +Variant Window ``` diff --git a/docs/src/main/sphinx/functions/variant.md b/docs/src/main/sphinx/functions/variant.md new file mode 100644 index 000000000000..e0ac1f3ffb04 --- /dev/null +++ b/docs/src/main/sphinx/functions/variant.md @@ -0,0 +1,157 @@ +# VARIANT functions and operators + +The `VARIANT` type represents a semi-structured value as defined by the +[Apache Iceberg Variant specification](https://iceberg.apache.org/spec/#semi-structured-types). + +`VARIANT` values are created using casts, decoded using casts, and dereferenced +using the SQL subscript operator (`[]`). + +## Equality semantics + +Two `VARIANT` values are equal when they represent the same logical value, +regardless of internal encoding details. + +This means equality is based on value semantics, not byte-for-byte encoding. +For example: + +* Numbers compare by numeric value across numeric encodings. +* Strings compare by string bytes, regardless of short-string or regular string + encoding. +* Timestamps compare by instant/value, even when encoded at different + precisions (microseconds vs nanoseconds), when the values are exactly + representable at both precisions. +* `TIMESTAMP` and `TIMESTAMP WITH TIME ZONE` remain distinct timestamp kinds + and are not equal to each other. + +For numbers, additional edge-case rules apply: + +* Integer and decimal forms are compared by exact numeric value: + `1`, `1.0`, and `1.00` are equal. +* Floating-point values (`REAL`, `DOUBLE`) are equal to exact numerics only + when the floating-point value can be represented exactly as a variant decimal. + Example: `0.5` equals `DECIMAL '0.5'`, but `0.1` does not equal + `DECIMAL '0.1'`, because binary floating-point cannot represent `0.1` exactly. +* `+0.0` and `-0.0` are equal. +* `NaN` is not equal to any value, including itself. + +## Subscript operator + +Elements of a `VARIANT` value can be accessed using the SQL subscript operator +(`[]`). The result of a subscript operation is always a `VARIANT` value. + +### Objects + +When the underlying value is an object, use a `VARCHAR` key: + +```sql +variant_expression['key'] +``` + +If the specified key does not exist in the object, the result is SQL `NULL`. + +### Arrays + +When the underlying value is an array, use a `bigint` with one-based indexing: + +```sql +variant_expression[index] +``` + +The same SQL array indexing rules apply: + +* Indexes start at `1` +* Index `0` or negative indexes are invalid and result in an error +* An index greater than the array length results in an error + +## Functions + +:::{function} variant_is_null(variant) -> boolean +Returns `true` if the input value represents a *variant null*. + +This function distinguishes a variant null value from SQL `NULL`. + +* Returns `true` if the value is a variant null +* Returns `false` for all other variant values +* Returns SQL `NULL` if the input is SQL `NULL` + +Example: + +```sql +SELECT variant_is_null(CAST(JSON 'null' AS VARIANT)); -- true +SELECT variant_is_null(CAST(42 AS VARIANT)); -- false +SELECT variant_is_null(NULL); -- NULL +``` + +::: + +## Cast to VARIANT + +The following SQL types can be cast to `VARIANT`: + +### Scalar types + +* `BOOLEAN` +* `TINYINT` +* `SMALLINT` +* `INTEGER` +* `BIGINT` +* `REAL` +* `DOUBLE` +* `DECIMAL` +* `VARCHAR` +* `VARBINARY` +* `DATE` +* `TIME(p)` +* `TIMESTAMP(p)` +* `TIMESTAMP(p) WITH TIME ZONE` +* `UUID` +* `JSON` +* `VARIANT` + +### Container types + +* `ARRAY` +* `MAP` (with `VARCHAR` key type) +* `ROW` + +Container values may contain any supported scalar or container type, including +nested containers, `JSON`, and `VARIANT` values. + +## Cast from VARIANT + +A `VARIANT` value can be cast to the following SQL types when the underlying +value is compatible with the target type. + +Standard Trino cast coercions apply. For example, a `VARIANT` value containing +a string can be cast to a numeric type if the string represents a valid value +for the target type and fits within its range. + +### Scalar types + +* `BOOLEAN` +* `TINYINT` +* `SMALLINT` +* `INTEGER` +* `BIGINT` +* `REAL` +* `DOUBLE` +* `DECIMAL` +* `VARCHAR` +* `VARBINARY` +* `DATE` +* `TIME(p)` +* `TIMESTAMP(p)` +* `TIMESTAMP(p) WITH TIME ZONE` +* `UUID` +* `JSON` +* `VARIANT` + +### Container types + +* `ARRAY` +* `MAP` (with `VARCHAR` key type) +* `ROW` + +Casting to container types is supported when the structure of the target type +is compatible with the contents of the `VARIANT` value. If the underlying value +is incompatible with the requested type, the cast fails. diff --git a/docs/src/main/sphinx/language/types.md b/docs/src/main/sphinx/language/types.md index 9e50f8418401..07dcb4285aa5 100644 --- a/docs/src/main/sphinx/language/types.md +++ b/docs/src/main/sphinx/language/types.md @@ -235,11 +235,50 @@ Binary literals ignore any whitespace characters. For example, the literal Binary strings with length are not yet supported: `varbinary(n)` ::: +(json-data-type)= ### `JSON` JSON value type, which can be a JSON object, a JSON array, a JSON number, a JSON string, `true`, `false` or `null`. +(variant-data-type)= +### `VARIANT` + +A semi-structured value type. A `VARIANT` value can represent any of the following: + +- object (key-value structure) +- array +- string +- number (integer, decimal, and floating-point) +- boolean +- null +- date and time values + +`VARIANT` is designed for working with semi-structured data efficiently, and is +commonly used with connectors and file formats that support a native variant type. + +`VARIANT` differs from {ref}`json-data-type` in that it preserves the full +underlying value type, rather than reducing values to a limited set of JSON +types. + +Examples: +```sql +SELECT typeof(CAST(JSON '{"a": 1, "b": [true, null]}' AS VARIANT)); +-- variant + +SELECT CAST(CAST(JSON '123' AS VARIANT) AS BIGINT); +-- 123 +``` + +`VARIANT` follows the [Apache Iceberg Variant specification](https://github.com/apache/parquet-format/blob/master/VariantEncoding.md). +Trino implements this specification directly, including its type system, value +encoding, and semantics. + +This ensures consistent behavior when reading and writing variant values across +systems that support the same specification. + +See also {doc}`/functions/variant` + (date-time-data-types)= ## Date and time diff --git a/lib/trino-orc/src/main/java/io/trino/orc/reader/ColumnReaders.java b/lib/trino-orc/src/main/java/io/trino/orc/reader/ColumnReaders.java index cdf3ab60e415..4d929467c015 100644 --- a/lib/trino-orc/src/main/java/io/trino/orc/reader/ColumnReaders.java +++ b/lib/trino-orc/src/main/java/io/trino/orc/reader/ColumnReaders.java @@ -23,10 +23,11 @@ import io.trino.spi.type.TimeType; import io.trino.spi.type.Type; import io.trino.spi.type.UuidType; +import io.trino.spi.type.VariantType; -import static com.google.common.base.Preconditions.checkArgument; import static io.trino.orc.metadata.OrcType.OrcTypeKind.BINARY; import static io.trino.orc.metadata.OrcType.OrcTypeKind.LONG; +import static io.trino.orc.metadata.OrcType.OrcTypeKind.STRUCT; import static io.trino.orc.reader.ReaderUtils.invalidStreamType; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.TimeType.TIME_MICROS; @@ -35,6 +36,7 @@ public final class ColumnReaders { public static final String ICEBERG_BINARY_TYPE = "iceberg.binary-type"; public static final String ICEBERG_LONG_TYPE = "iceberg.long-type"; + public static final String ICEBERG_VARIANT_TYPE_KIND = "iceberg.variant-type"; private ColumnReaders() {} @@ -55,13 +57,17 @@ public static ColumnReader createColumnReader( return new TimeColumnReader(type, column, memoryContext.newLocalMemoryContext(ColumnReaders.class.getSimpleName())); } if (type instanceof UuidType) { - checkArgument(orcTypeKind == BINARY, "UUID type can only be read from BINARY column but got %s", column); - checkArgument( - "UUID".equals(column.getAttributes().get(ICEBERG_BINARY_TYPE)), - "Expected ORC column for UUID data to be annotated with %s=UUID: %s", - ICEBERG_BINARY_TYPE, column); + if (orcTypeKind != BINARY || !"UUID".equals(column.getAttributes().get(ICEBERG_BINARY_TYPE))) { + throw invalidStreamType(column, type); + } return new UuidColumnReader(column); } + if (type instanceof VariantType) { + if (orcTypeKind != STRUCT || !"true".equals(column.getAttributes().get(ICEBERG_VARIANT_TYPE_KIND))) { + throw invalidStreamType(column, type); + } + return new VariantColumnReader(column, memoryContext); + } return switch (orcTypeKind) { case BOOLEAN -> new BooleanColumnReader(type, column, memoryContext.newLocalMemoryContext(ColumnReaders.class.getSimpleName())); diff --git a/lib/trino-orc/src/main/java/io/trino/orc/reader/VariantColumnReader.java b/lib/trino-orc/src/main/java/io/trino/orc/reader/VariantColumnReader.java new file mode 100644 index 000000000000..22b3dd5dfd1e --- /dev/null +++ b/lib/trino-orc/src/main/java/io/trino/orc/reader/VariantColumnReader.java @@ -0,0 +1,213 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.orc.reader; + +import com.google.common.io.Closer; +import io.trino.memory.context.AggregatedMemoryContext; +import io.trino.orc.OrcColumn; +import io.trino.orc.OrcCorruptionException; +import io.trino.orc.metadata.ColumnEncoding; +import io.trino.orc.metadata.ColumnMetadata; +import io.trino.orc.stream.BooleanInputStream; +import io.trino.orc.stream.InputStreamSource; +import io.trino.orc.stream.InputStreamSources; +import io.trino.spi.block.Block; +import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.VariantBlock; +import jakarta.annotation.Nullable; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.time.ZoneId; +import java.util.List; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static io.airlift.slice.SizeOf.instanceSize; +import static io.trino.orc.metadata.Stream.StreamKind.PRESENT; +import static io.trino.orc.reader.ReaderUtils.toNotNullSupressedBlock; +import static io.trino.orc.stream.MissingInputStreamSource.missingStreamSource; +import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static java.util.Objects.requireNonNull; + +/** + * Reads a variant column from ORC. + * Variant is stored as a struct with two binary fields: metadata and value. + */ +public class VariantColumnReader + implements ColumnReader +{ + private static final int INSTANCE_SIZE = instanceSize(VariantColumnReader.class); + + private final OrcColumn column; + private final ColumnReader metadataReader; + private final ColumnReader valueReader; + + private int readOffset; + private int nextBatchSize; + + private InputStreamSource presentStreamSource = missingStreamSource(BooleanInputStream.class); + @Nullable + private BooleanInputStream presentStream; + + private boolean rowGroupOpen; + + public VariantColumnReader(OrcColumn column, AggregatedMemoryContext memoryContext) + throws OrcCorruptionException + { + this.column = requireNonNull(column, "column is null"); + + List nestedColumns = column.getNestedColumns(); + checkArgument(nestedColumns.size() == 2, "Variant column must have exactly 2 children (metadata, value), but found %s", nestedColumns.size()); + + // Fields are ordered: metadata, value + OrcColumn metadataColumn = nestedColumns.get(0); + OrcColumn valueColumn = nestedColumns.get(1); + + this.metadataReader = new SliceColumnReader(VARBINARY, metadataColumn, memoryContext); + this.valueReader = new SliceColumnReader(VARBINARY, valueColumn, memoryContext); + } + + @Override + public void prepareNextRead(int batchSize) + { + readOffset += nextBatchSize; + nextBatchSize = batchSize; + } + + @Override + public Block readBlock() + throws IOException + { + if (!rowGroupOpen) { + openRowGroup(); + } + + if (readOffset > 0) { + if (presentStream != null) { + // skip ahead the present bit reader, but count the set bits + // and use this as the skip size for the field readers + readOffset = presentStream.countBitsSet(readOffset); + } + metadataReader.prepareNextRead(readOffset); + valueReader.prepareNextRead(readOffset); + } + + boolean[] nullVector = null; + Block metadataBlock; + Block valueBlock; + + if (presentStream == null) { + metadataReader.prepareNextRead(nextBatchSize); + valueReader.prepareNextRead(nextBatchSize); + metadataBlock = metadataReader.readBlock(); + valueBlock = valueReader.readBlock(); + } + else { + nullVector = new boolean[nextBatchSize]; + int nullValues = presentStream.getUnsetBits(nextBatchSize, nullVector); + if (nullValues != nextBatchSize) { + int nonNullCount = nextBatchSize - nullValues; + metadataReader.prepareNextRead(nonNullCount); + valueReader.prepareNextRead(nonNullCount); + + Block rawMetadata = metadataReader.readBlock(); + Block rawValue = valueReader.readBlock(); + + metadataBlock = toNotNullSupressedBlock(nextBatchSize, nullVector, rawMetadata); + valueBlock = toNotNullSupressedBlock(nextBatchSize, nullVector, rawValue); + } + else { + // All values are null + metadataBlock = RunLengthEncodedBlock.create(VARBINARY.createBlockBuilder(null, 0).appendNull().build(), nextBatchSize); + valueBlock = RunLengthEncodedBlock.create(VARBINARY.createBlockBuilder(null, 0).appendNull().build(), nextBatchSize); + } + } + + VariantBlock variantBlock = VariantBlock.create(nextBatchSize, metadataBlock, valueBlock, Optional.ofNullable(nullVector)); + + readOffset = 0; + nextBatchSize = 0; + + return variantBlock; + } + + private void openRowGroup() + throws IOException + { + presentStream = presentStreamSource.openStream(); + rowGroupOpen = true; + } + + @Override + public void startStripe(ZoneId fileTimeZone, InputStreamSources dictionaryStreamSources, ColumnMetadata encoding) + throws IOException + { + presentStreamSource = missingStreamSource(BooleanInputStream.class); + + readOffset = 0; + nextBatchSize = 0; + + presentStream = null; + + rowGroupOpen = false; + + metadataReader.startStripe(fileTimeZone, dictionaryStreamSources, encoding); + valueReader.startStripe(fileTimeZone, dictionaryStreamSources, encoding); + } + + @Override + public void startRowGroup(InputStreamSources dataStreamSources) + throws IOException + { + presentStreamSource = dataStreamSources.getInputStreamSource(column, PRESENT, BooleanInputStream.class); + + readOffset = 0; + nextBatchSize = 0; + + presentStream = null; + + rowGroupOpen = false; + + metadataReader.startRowGroup(dataStreamSources); + valueReader.startRowGroup(dataStreamSources); + } + + @Override + public String toString() + { + return toStringHelper(this) + .addValue(column) + .toString(); + } + + @Override + public void close() + { + try (Closer closer = Closer.create()) { + closer.register(metadataReader::close); + closer.register(valueReader::close); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public long getRetainedSizeInBytes() + { + return INSTANCE_SIZE + metadataReader.getRetainedSizeInBytes() + valueReader.getRetainedSizeInBytes(); + } +} diff --git a/lib/trino-orc/src/main/java/io/trino/orc/writer/ColumnWriters.java b/lib/trino-orc/src/main/java/io/trino/orc/writer/ColumnWriters.java index 0e2884a1fb44..30d9dd2193f1 100644 --- a/lib/trino-orc/src/main/java/io/trino/orc/writer/ColumnWriters.java +++ b/lib/trino-orc/src/main/java/io/trino/orc/writer/ColumnWriters.java @@ -32,12 +32,16 @@ import io.trino.spi.type.RowType; import io.trino.spi.type.TimeType; import io.trino.spi.type.Type; +import io.trino.spi.type.VarbinaryType; +import io.trino.spi.type.VariantType; import java.util.function.Supplier; import static com.google.common.base.Preconditions.checkArgument; import static io.trino.orc.metadata.OrcType.OrcTypeKind.LONG; +import static io.trino.orc.metadata.OrcType.OrcTypeKind.STRUCT; import static io.trino.orc.reader.ColumnReaders.ICEBERG_LONG_TYPE; +import static io.trino.orc.reader.ColumnReaders.ICEBERG_VARIANT_TYPE_KIND; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; @@ -63,6 +67,18 @@ public static ColumnWriter createColumnWriter( checkArgument("TIME".equals(orcType.getAttributes().get(ICEBERG_LONG_TYPE)), "wrong attributes %s for type %s", orcType.getAttributes(), type); return new TimeColumnWriter(columnId, type, compression, bufferSize, () -> new TimeMicrosStatisticsBuilder(bloomFilterBuilder.get())); } + if (type instanceof VariantType) { + checkArgument(orcType.getOrcTypeKind() == STRUCT, "wrong ORC type %s for type %s", orcType, type); + checkArgument("true".equals(orcType.getAttributes().get(ICEBERG_VARIANT_TYPE_KIND)), "wrong attributes %s for type %s", orcType.getAttributes(), type); + checkArgument(orcType.getFieldCount() == 2, "Variant ORC struct must have 2 fields (metadata, value), but found %s", orcType.getFieldCount()); + checkArgument(orcTypes.get(orcType.getFieldTypeIndex(0)).getOrcTypeKind() == OrcType.OrcTypeKind.BINARY, "Variant ORC metadata field must be BINARY but found %s", orcTypes.get(orcType.getFieldTypeIndex(0))); + checkArgument(orcTypes.get(orcType.getFieldTypeIndex(1)).getOrcTypeKind() == OrcType.OrcTypeKind.BINARY, "Variant ORC value field must be BINARY but found %s", orcTypes.get(orcType.getFieldTypeIndex(1))); + + // Fields are ordered: metadata, value + ColumnWriter metadataWriter = createColumnWriter(orcType.getFieldTypeIndex(0), orcTypes, VarbinaryType.VARBINARY, compression, bufferSize, stringStatisticsLimit, bloomFilterBuilder, shouldCompactMinMax); + ColumnWriter valueWriter = createColumnWriter(orcType.getFieldTypeIndex(1), orcTypes, VarbinaryType.VARBINARY, compression, bufferSize, stringStatisticsLimit, bloomFilterBuilder, shouldCompactMinMax); + return new VariantColumnWriter(columnId, compression, bufferSize, metadataWriter, valueWriter); + } switch (orcType.getOrcTypeKind()) { case BOOLEAN: return new BooleanColumnWriter(columnId, type, compression, bufferSize); diff --git a/lib/trino-orc/src/main/java/io/trino/orc/writer/VariantColumnWriter.java b/lib/trino-orc/src/main/java/io/trino/orc/writer/VariantColumnWriter.java new file mode 100644 index 000000000000..d14268e4a73a --- /dev/null +++ b/lib/trino-orc/src/main/java/io/trino/orc/writer/VariantColumnWriter.java @@ -0,0 +1,252 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.orc.writer; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.airlift.slice.Slice; +import io.trino.orc.checkpoint.BooleanStreamCheckpoint; +import io.trino.orc.metadata.ColumnEncoding; +import io.trino.orc.metadata.CompressedMetadataWriter; +import io.trino.orc.metadata.CompressionKind; +import io.trino.orc.metadata.OrcColumnId; +import io.trino.orc.metadata.RowGroupIndex; +import io.trino.orc.metadata.Stream; +import io.trino.orc.metadata.Stream.StreamKind; +import io.trino.orc.metadata.statistics.ColumnStatistics; +import io.trino.orc.stream.PresentOutputStream; +import io.trino.orc.stream.StreamDataOutput; +import io.trino.spi.block.Block; +import io.trino.spi.block.VariantBlock; +import io.trino.spi.block.VariantBlock.VariantNestedBlocks; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static io.airlift.slice.SizeOf.instanceSize; +import static io.trino.orc.metadata.ColumnEncoding.ColumnEncodingKind.DIRECT; +import static io.trino.orc.metadata.CompressionKind.NONE; +import static java.util.Objects.requireNonNull; + +/** + * Writes a variant column to ORC. + * Variant is stored as a struct with two binary fields: metadata and value. + */ +public class VariantColumnWriter + implements ColumnWriter +{ + private static final int INSTANCE_SIZE = instanceSize(VariantColumnWriter.class); + private static final ColumnEncoding COLUMN_ENCODING = new ColumnEncoding(DIRECT, 0); + + private final OrcColumnId columnId; + private final boolean compressed; + private final PresentOutputStream presentStream; + private final ColumnWriter metadataWriter; + private final ColumnWriter valueWriter; + + private final List rowGroupColumnStatistics = new ArrayList<>(); + + private int nonNullValueCount; + + private boolean closed; + + public VariantColumnWriter(OrcColumnId columnId, CompressionKind compression, int bufferSize, ColumnWriter metadataWriter, ColumnWriter valueWriter) + { + this.columnId = columnId; + this.compressed = requireNonNull(compression, "compression is null") != NONE; + this.metadataWriter = requireNonNull(metadataWriter, "metadataWriter is null"); + this.valueWriter = requireNonNull(valueWriter, "valueWriter is null"); + this.presentStream = new PresentOutputStream(compression, bufferSize); + } + + @Override + public List getNestedColumnWriters() + { + return ImmutableList.builder() + .add(metadataWriter) + .addAll(metadataWriter.getNestedColumnWriters()) + .add(valueWriter) + .addAll(valueWriter.getNestedColumnWriters()) + .build(); + } + + @Override + public Map getColumnEncodings() + { + ImmutableMap.Builder encodings = ImmutableMap.builder(); + encodings.put(columnId, COLUMN_ENCODING); + encodings.putAll(metadataWriter.getColumnEncodings()); + encodings.putAll(valueWriter.getColumnEncodings()); + return encodings.buildOrThrow(); + } + + @Override + public void beginRowGroup() + { + presentStream.recordCheckpoint(); + metadataWriter.beginRowGroup(); + valueWriter.beginRowGroup(); + } + + @Override + public void writeBlock(Block block) + { + checkState(!closed); + checkArgument(block.getPositionCount() > 0, "Block is empty"); + + // record nulls + for (int position = 0; position < block.getPositionCount(); position++) { + boolean present = !block.isNull(position); + presentStream.writeBoolean(present); + if (present) { + nonNullValueCount++; + } + } + + // write null-suppressed field values + VariantNestedBlocks nested = VariantBlock.getNullSuppressedNestedFields(block); + Block metadataBlock = nested.metadataBlock(); + Block valueBlock = nested.valueBlock(); + + if (metadataBlock.getPositionCount() > 0) { + metadataWriter.writeBlock(metadataBlock); + valueWriter.writeBlock(valueBlock); + } + } + + @Override + public Map finishRowGroup() + { + checkState(!closed); + ColumnStatistics statistics = new ColumnStatistics((long) nonNullValueCount, 0, null, null, null, null, null, null, null, null, null, null); + rowGroupColumnStatistics.add(statistics); + nonNullValueCount = 0; + + ImmutableMap.Builder columnStatistics = ImmutableMap.builder(); + columnStatistics.put(columnId, statistics); + columnStatistics.putAll(metadataWriter.finishRowGroup()); + columnStatistics.putAll(valueWriter.finishRowGroup()); + return columnStatistics.buildOrThrow(); + } + + @Override + public void close() + { + closed = true; + metadataWriter.close(); + valueWriter.close(); + presentStream.close(); + } + + @Override + public Map getColumnStripeStatistics() + { + checkState(closed); + ImmutableMap.Builder columnStatistics = ImmutableMap.builder(); + columnStatistics.put(columnId, ColumnStatistics.mergeColumnStatistics(rowGroupColumnStatistics)); + columnStatistics.putAll(metadataWriter.getColumnStripeStatistics()); + columnStatistics.putAll(valueWriter.getColumnStripeStatistics()); + return columnStatistics.buildOrThrow(); + } + + @Override + public List getIndexStreams(CompressedMetadataWriter metadataWriter) + throws IOException + { + checkState(closed); + + ImmutableList.Builder rowGroupIndexes = ImmutableList.builder(); + + Optional> presentCheckpoints = presentStream.getCheckpoints(); + for (int i = 0; i < rowGroupColumnStatistics.size(); i++) { + int groupId = i; + ColumnStatistics columnStatistics = rowGroupColumnStatistics.get(groupId); + Optional presentCheckpoint = presentCheckpoints.map(checkpoints -> checkpoints.get(groupId)); + List positions = createVariantColumnPositionList(compressed, presentCheckpoint); + rowGroupIndexes.add(new RowGroupIndex(positions, columnStatistics)); + } + + Slice slice = metadataWriter.writeRowIndexes(rowGroupIndexes.build()); + Stream stream = new Stream(columnId, StreamKind.ROW_INDEX, slice.length(), false); + + ImmutableList.Builder indexStreams = ImmutableList.builder(); + indexStreams.add(new StreamDataOutput(slice, stream)); + indexStreams.addAll(this.metadataWriter.getIndexStreams(metadataWriter)); + indexStreams.addAll(this.metadataWriter.getBloomFilters(metadataWriter)); + indexStreams.addAll(this.valueWriter.getIndexStreams(metadataWriter)); + indexStreams.addAll(this.valueWriter.getBloomFilters(metadataWriter)); + return indexStreams.build(); + } + + @Override + public List getBloomFilters(CompressedMetadataWriter metadataWriter) + { + return ImmutableList.of(); + } + + private static List createVariantColumnPositionList( + boolean compressed, + Optional presentCheckpoint) + { + ImmutableList.Builder positionList = ImmutableList.builder(); + presentCheckpoint.ifPresent(booleanStreamCheckpoint -> positionList.addAll(booleanStreamCheckpoint.toPositionList(compressed))); + return positionList.build(); + } + + @Override + public List getDataStreams() + { + checkState(closed); + + ImmutableList.Builder outputDataStreams = ImmutableList.builder(); + presentStream.getStreamDataOutput(columnId).ifPresent(outputDataStreams::add); + outputDataStreams.addAll(metadataWriter.getDataStreams()); + outputDataStreams.addAll(valueWriter.getDataStreams()); + return outputDataStreams.build(); + } + + @Override + public long getBufferedBytes() + { + return presentStream.getBufferedBytes() + metadataWriter.getBufferedBytes() + valueWriter.getBufferedBytes(); + } + + @Override + public long getRetainedBytes() + { + long retainedBytes = INSTANCE_SIZE + presentStream.getRetainedBytes(); + retainedBytes += metadataWriter.getRetainedBytes(); + retainedBytes += valueWriter.getRetainedBytes(); + for (ColumnStatistics statistics : rowGroupColumnStatistics) { + retainedBytes += statistics.getRetainedSizeInBytes(); + } + return retainedBytes; + } + + @Override + public void reset() + { + closed = false; + presentStream.reset(); + metadataWriter.reset(); + valueWriter.reset(); + rowGroupColumnStatistics.clear(); + nonNullValueCount = 0; + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetTypeUtils.java b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetTypeUtils.java index b48e2da49c18..5c36ad68aa4c 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetTypeUtils.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/ParquetTypeUtils.java @@ -43,6 +43,7 @@ import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.type.StandardTypes.JSON; import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.spi.type.VariantType.VARIANT; import static java.lang.String.format; import static org.apache.parquet.schema.Type.Repetition.OPTIONAL; import static org.apache.parquet.schema.Type.Repetition.REPEATED; @@ -370,7 +371,7 @@ private static Optional constructField(Type type, ColumnIO columnIO, bool private static boolean isVariantType(Type type, ColumnIO columnIO) { - return type.getBaseName().equals(JSON) && + return (type == VARIANT || type.getBaseName().equals(JSON)) && columnIO instanceof GroupColumnIO groupColumnIo && groupColumnIo.getChildrenCount() == 2 && groupColumnIo.getChild("value") != null && diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java index 627fdc6864b5..5abb6544de33 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java @@ -48,6 +48,7 @@ import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.RowBlock; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.VariantBlock; import io.trino.spi.connector.SourcePage; import io.trino.spi.metrics.Metric; import io.trino.spi.metrics.Metrics; @@ -93,6 +94,7 @@ import static io.trino.parquet.reader.PageReader.createPageReader; import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.spi.type.VariantType.VARIANT; import static java.lang.Math.max; import static java.lang.Math.min; import static java.lang.Math.toIntExact; @@ -520,6 +522,37 @@ private ColumnChunk readVariant(VariantField field) throws IOException { ColumnChunk metadataChunk = readColumnChunk(field.getMetadata()); + ColumnChunk valueChunk = readColumnChunk(field.getValue()); + + // position count and nulls are derived from metadata def levels + int positionsCount = metadataChunk.getDefinitionLevels().length; + int variantDefLevel = field.getDefinitionLevel(); + boolean[] isNull = null; + for (int i = 0; i < positionsCount; i++) { + if (metadataChunk.getDefinitionLevels()[i] < variantDefLevel) { + if (isNull == null) { + isNull = new boolean[positionsCount]; + } + isNull[i] = true; + } + } + + // if isNull is present, we need to convert the blocks to not-null-suppressed blocks + Block metadataBlock = metadataChunk.getBlock(); + Block valueBlock = valueChunk.getBlock(); + if (isNull != null) { + metadataBlock = toNotNullSupressedBlock(positionsCount, isNull, metadataBlock); + valueBlock = toNotNullSupressedBlock(positionsCount, isNull, valueBlock); + } + + Block variantBlock = VariantBlock.create(positionsCount, metadataBlock, valueBlock, Optional.ofNullable(isNull)); + return new ColumnChunk(variantBlock, metadataChunk.getDefinitionLevels(), metadataChunk.getRepetitionLevels()); + } + + private ColumnChunk readVariantAsJson(VariantField field) + throws IOException + { + ColumnChunk metadataChunk = readColumnChunk(field.getMetadata()); int positionCount = metadataChunk.getBlock().getPositionCount(); BlockBuilder variantBlock = VARCHAR.createBlockBuilder(null, max(1, positionCount)); @@ -755,7 +788,13 @@ private ColumnChunk readColumnChunk(Field field) { ColumnChunk columnChunk; if (field instanceof VariantField variantField) { - columnChunk = readVariant(variantField); + if (variantField.getType() == VARIANT) { + // Directly read VARIANT as a single block + columnChunk = readVariant(variantField); + } + else { + columnChunk = readVariantAsJson(variantField); + } } else if (field.getType() instanceof RowType) { columnChunk = readStruct((GroupField) field); diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetTypeVisitor.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetTypeVisitor.java index a33010d4eab2..e9c9f125f918 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetTypeVisitor.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetTypeVisitor.java @@ -14,6 +14,7 @@ package io.trino.parquet.writer; import com.google.common.collect.Lists; +import io.trino.spi.variant.Header; import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.LogicalTypeAnnotation; import org.apache.parquet.schema.MessageType; @@ -104,6 +105,10 @@ else if (repeatedKeyValue.getFieldCount() == 1) { visitor.fieldNames.pop(); } } + if (LogicalTypeAnnotation.variantType(Header.VERSION).equals(annotation)) { + checkArgument(group.getFieldCount() == 2, "Invalid variant: expected 2 fields (metadata, value): %s", group); + return visitor.variant(group); + } return visitor.struct(group, visitFields(group, visitor)); } @@ -148,6 +153,11 @@ public T map(GroupType map, T key, T value) return null; } + public T variant(GroupType variant) + { + return null; + } + public T primitive(PrimitiveType primitive) { return null; diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriters.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriters.java index d6fd05a86f85..dcf363b4d8ec 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriters.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriters.java @@ -15,6 +15,7 @@ import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ObjectArrays; import io.trino.parquet.writer.valuewriter.BigintValueWriter; import io.trino.parquet.writer.valuewriter.BinaryValueWriter; import io.trino.parquet.writer.valuewriter.BooleanValueWriter; @@ -44,6 +45,7 @@ import io.trino.spi.type.UuidType; import io.trino.spi.type.VarbinaryType; import io.trino.spi.type.VarcharType; +import io.trino.spi.variant.Header; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.values.ValuesWriter; import org.apache.parquet.column.values.bloomfilter.AdaptiveBlockSplitBloomFilter; @@ -55,6 +57,8 @@ import org.apache.parquet.schema.LogicalTypeAnnotation.TimestampLogicalTypeAnnotation; import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName; +import org.apache.parquet.schema.Type.Repetition; import org.joda.time.DateTimeZone; import java.util.Iterator; @@ -84,6 +88,7 @@ import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS; import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_NANOS; import static io.trino.spi.type.TinyintType.TINYINT; +import static io.trino.spi.type.VarbinaryType.VARBINARY; import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT32; @@ -127,7 +132,7 @@ static PrimitiveValueWriter getValueWriter(ValuesWriter valuesWriter, Type type, return new TimeMicrosValueWriter(valuesWriter, parquetType); } if (type instanceof TimestampType) { - if (parquetType.getPrimitiveTypeName().equals(PrimitiveType.PrimitiveTypeName.INT96)) { + if (parquetType.getPrimitiveTypeName().equals(PrimitiveTypeName.INT96)) { checkArgument(parquetTimeZone.isPresent(), "parquetTimeZone must be provided for INT96 timestamps"); return new Int96TimestampValueWriter(valuesWriter, type, parquetType, parquetTimeZone.get()); } @@ -262,14 +267,58 @@ public ColumnWriter map(GroupType map, ColumnWriter key, ColumnWriter value) return new MapColumnWriter(key, value, fieldDefinitionLevel, fieldRepetitionLevel); } + @Override + public ColumnWriter variant(GroupType variant) + { + checkArgument( + LogicalTypeAnnotation.variantType(Header.VERSION).equals(variant.getLogicalTypeAnnotation()), + "VARIANT group must be annotated with VARIANT logical type: %s", variant); + checkArgument( + variant.getFieldCount() == 2, + "Unsupported VARIANT schema (expected exactly 2 fields: metadata, value): %s", variant); + + org.apache.parquet.schema.Type metadataType = variant.getType("metadata"); + org.apache.parquet.schema.Type valueType = variant.getType("value"); + + PrimitiveType metadataPrimitive = metadataType.asPrimitiveType(); + PrimitiveType valuePrimitive = valueType.asPrimitiveType(); + + checkArgument( + metadataPrimitive.getPrimitiveTypeName() == PrimitiveTypeName.BINARY, + "VARIANT metadata field must be binary: %s", metadataPrimitive); + checkArgument( + valuePrimitive.getPrimitiveTypeName() == PrimitiveTypeName.BINARY, + "VARIANT value field must be binary: %s", valuePrimitive); + + checkArgument( + metadataPrimitive.getRepetition() == Repetition.REQUIRED, + "VARIANT metadata field must be required: %s", metadataPrimitive); + + // For now, we only support the unshredded form: required value + checkArgument( + valuePrimitive.getRepetition() == Repetition.REQUIRED, + "VARIANT value field must be required (unshredded only supported): %s", valuePrimitive); + + String[] path = currentPath(); + ColumnWriter metadataColumnWriter = primitive(metadataPrimitive, ObjectArrays.concat(path, "metadata"), VARBINARY); + ColumnWriter valueColumnWriter = primitive(valuePrimitive, ObjectArrays.concat(path, "value"), VARBINARY); + int fieldDefinitionLevel = type.getMaxDefinitionLevel(path); + return new VariantColumnWriter(metadataColumnWriter, valueColumnWriter, fieldDefinitionLevel); + } + @Override public ColumnWriter primitive(PrimitiveType primitive) { String[] path = currentPath(); + Type trinoType = requireNonNull(trinoTypes.get(ImmutableList.copyOf(path)), "Trino type is null"); + return primitive(primitive, path, trinoType); + } + + private PrimitiveColumnWriter primitive(PrimitiveType primitive, String[] path, Type trinoType) + { int fieldDefinitionLevel = type.getMaxDefinitionLevel(path); int fieldRepetitionLevel = type.getMaxRepetitionLevel(path); ColumnDescriptor columnDescriptor = new ColumnDescriptor(path, primitive, fieldRepetitionLevel, fieldDefinitionLevel); - Type trinoType = requireNonNull(trinoTypes.get(ImmutableList.copyOf(path)), "Trino type is null"); Optional bloomFilter = createBloomFilter(bloomFilterColumns, maxBloomFilterSize, bloomFilterFpp, columnDescriptor, trinoType); return new PrimitiveColumnWriter( columnDescriptor, diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/VariantColumnWriter.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/VariantColumnWriter.java new file mode 100644 index 000000000000..4dd44debf6b4 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/VariantColumnWriter.java @@ -0,0 +1,107 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.writer; + +import com.google.common.collect.ImmutableList; +import io.trino.parquet.writer.repdef.DefLevelWriterProvider; +import io.trino.parquet.writer.repdef.DefLevelWriterProviders; +import io.trino.parquet.writer.repdef.RepLevelWriterProvider; +import io.trino.parquet.writer.repdef.RepLevelWriterProviders; +import io.trino.spi.block.Block; +import io.trino.spi.block.VariantBlock; + +import java.io.IOException; +import java.util.List; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.airlift.slice.SizeOf.instanceSize; +import static java.util.Objects.requireNonNull; + +public class VariantColumnWriter + implements ColumnWriter +{ + private static final int INSTANCE_SIZE = instanceSize(VariantColumnWriter.class); + + private final ColumnWriter metadataColumnWriter; + private final ColumnWriter valueColumnWriter; + private final int maxDefinitionLevel; + + public VariantColumnWriter(ColumnWriter metadataColumnWriter, ColumnWriter valueColumnWriter, int maxDefinitionLevel) + { + this.metadataColumnWriter = requireNonNull(metadataColumnWriter, "metadataColumnWriter is null"); + this.valueColumnWriter = requireNonNull(valueColumnWriter, "valueColumnWriter is null"); + this.maxDefinitionLevel = maxDefinitionLevel; + } + + @Override + public void writeBlock(ColumnChunk columnChunk) + throws IOException + { + Block block = columnChunk.getBlock(); + // This must be a VariantBlock + VariantBlock.VariantNestedBlocks nested = VariantBlock.getNullSuppressedNestedFields(block); + + Block metadataBlock = nested.metadataBlock(); + Block valueBlock = nested.valueBlock(); + + checkArgument( + metadataBlock.getPositionCount() == valueBlock.getPositionCount(), + "metadata and value blocks must have the same position count"); + + // IMPORTANT: we add a VARIANT-level def/rep provider here, + // just like StructColumnWriter does for RowBlock. + List defLevelWriterProviders = ImmutableList.builder() + .addAll(columnChunk.getDefLevelWriterProviders()) + .add(DefLevelWriterProviders.of(block, maxDefinitionLevel)) + .build(); + + List repLevelWriterProviders = ImmutableList.builder() + .addAll(columnChunk.getRepLevelWriterProviders()) + .add(RepLevelWriterProviders.of(block)) + .build(); + + // Push the two leaf blocks down with the augmented providers + metadataColumnWriter.writeBlock(new ColumnChunk(metadataBlock, defLevelWriterProviders, repLevelWriterProviders)); + valueColumnWriter.writeBlock(new ColumnChunk(valueBlock, defLevelWriterProviders, repLevelWriterProviders)); + } + + @Override + public void close() + { + metadataColumnWriter.close(); + valueColumnWriter.close(); + } + + @Override + public List getBuffer() + throws IOException + { + ImmutableList.Builder builder = ImmutableList.builder(); + builder.addAll(metadataColumnWriter.getBuffer()); + builder.addAll(valueColumnWriter.getBuffer()); + return builder.build(); + } + + @Override + public long getBufferedBytes() + { + return metadataColumnWriter.getBufferedBytes() + valueColumnWriter.getBufferedBytes(); + } + + @Override + public long getRetainedBytes() + { + return INSTANCE_SIZE + metadataColumnWriter.getRetainedBytes() + valueColumnWriter.getRetainedBytes(); + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/repdef/DefLevelWriterProviders.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/repdef/DefLevelWriterProviders.java index ffd630ad2f63..a3b0dda42731 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/repdef/DefLevelWriterProviders.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/repdef/DefLevelWriterProviders.java @@ -20,6 +20,7 @@ import io.trino.spi.block.ColumnarMap; import io.trino.spi.block.MapBlock; import io.trino.spi.block.RowBlock; +import io.trino.spi.block.VariantBlock; import java.util.Optional; @@ -27,15 +28,21 @@ import static java.lang.String.format; import static java.util.Objects.requireNonNull; -public class DefLevelWriterProviders +public final class DefLevelWriterProviders { private DefLevelWriterProviders() {} public static DefLevelWriterProvider of(Block block, int maxDefinitionLevel) { - if (block.getUnderlyingValueBlock() instanceof RowBlock) { + Block valueBlock = block.getUnderlyingValueBlock(); + + if (valueBlock instanceof RowBlock) { return new RowDefLevelWriterProvider(block, maxDefinitionLevel); } + if (valueBlock instanceof VariantBlock) { + // Treat VARIANT like a struct/row at the group level + return new VariantDefLevelWriterProvider(block, maxDefinitionLevel); + } return new PrimitiveDefLevelWriterProvider(block, maxDefinitionLevel); } @@ -335,6 +342,73 @@ public ValuesCount writeDefinitionLevels(int positionsCount) } } + static class VariantDefLevelWriterProvider + implements DefLevelWriterProvider + { + private final Block block; + private final int maxDefinitionLevel; + + VariantDefLevelWriterProvider(Block block, int maxDefinitionLevel) + { + this.block = requireNonNull(block, "block is null"); + this.maxDefinitionLevel = maxDefinitionLevel; + } + + @Override + public DefinitionLevelWriter getDefinitionLevelWriter(Optional nestedWriterOptional, ColumnDescriptorValuesWriter encoder) + { + checkArgument(nestedWriterOptional.isPresent(), "nestedWriter should be present for variant definition level writer"); + return new DefinitionLevelWriter() + { + private final DefinitionLevelWriter nestedWriter = nestedWriterOptional.orElseThrow(); + + private int offset; + + @Override + public ValuesCount writeDefinitionLevels() + { + return writeDefinitionLevels(block.getPositionCount()); + } + + @Override + public ValuesCount writeDefinitionLevels(int positionsCount) + { + checkValidPosition(offset, positionsCount, block.getPositionCount()); + if (!block.mayHaveNull()) { + // No null variants: just pass through to nested writer + offset += positionsCount; + return nestedWriter.writeDefinitionLevels(positionsCount); + } + + int maxDefinitionValuesCount = 0; + int totalValuesCount = 0; + for (int position = offset; position < offset + positionsCount; ) { + if (block.isNull(position)) { + // VARIANT group is null at this position + encoder.writeInteger(maxDefinitionLevel - 1); + totalValuesCount++; + position++; + } + else { + // Consecutive non-null variants: delegate to nested writer + int consecutiveNonNullsCount = 1; + position++; + while (position < offset + positionsCount && !block.isNull(position)) { + position++; + consecutiveNonNullsCount++; + } + ValuesCount valuesCount = nestedWriter.writeDefinitionLevels(consecutiveNonNullsCount); + maxDefinitionValuesCount += valuesCount.maxDefinitionLevelValuesCount(); + totalValuesCount += valuesCount.totalValuesCount(); + } + } + offset += positionsCount; + return new ValuesCount(totalValuesCount, maxDefinitionValuesCount); + } + }; + } + } + private static void checkValidPosition(int offset, int positionsCount, int totalPositionsCount) { if (offset < 0 || positionsCount < 0 || offset + positionsCount > totalPositionsCount) { diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/repdef/RepLevelWriterProviders.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/repdef/RepLevelWriterProviders.java index 7fdc0f48e4d2..0081c7656a27 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/repdef/RepLevelWriterProviders.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/repdef/RepLevelWriterProviders.java @@ -20,6 +20,7 @@ import io.trino.spi.block.ColumnarMap; import io.trino.spi.block.MapBlock; import io.trino.spi.block.RowBlock; +import io.trino.spi.block.VariantBlock; import java.util.Optional; @@ -33,9 +34,13 @@ private RepLevelWriterProviders() {} public static RepLevelWriterProvider of(Block block) { - if (block.getUnderlyingValueBlock() instanceof RowBlock) { + Block valueBlock = block.getUnderlyingValueBlock(); + if (valueBlock instanceof RowBlock) { return new RowRepLevelWriterProvider(block); } + if (valueBlock instanceof VariantBlock) { + return new VariantRepLevelWriterProvider(block); + } return new PrimitiveRepLevelWriterProvider(block); } @@ -145,6 +150,64 @@ public void writeRepetitionLevels(int parentLevel, int positionsCount) } } + static class VariantRepLevelWriterProvider + implements RepLevelWriterProvider + { + private final Block block; + + VariantRepLevelWriterProvider(Block block) + { + this.block = requireNonNull(block, "block is null"); + checkArgument(block.getUnderlyingValueBlock() instanceof VariantBlock, "block is not a variant block"); + } + + @Override + public RepetitionLevelWriter getRepetitionLevelWriter(Optional nestedWriterOptional, ColumnDescriptorValuesWriter encoder) + { + checkArgument(nestedWriterOptional.isPresent(), "nestedWriter should be present for variant repetition level writer"); + return new RepetitionLevelWriter() + { + private final RepetitionLevelWriter nestedWriter = nestedWriterOptional.orElseThrow(); + + private int offset; + + @Override + public void writeRepetitionLevels(int parentLevel) + { + writeRepetitionLevels(parentLevel, block.getPositionCount()); + } + + @Override + public void writeRepetitionLevels(int parentLevel, int positionsCount) + { + checkValidPosition(offset, positionsCount, block.getPositionCount()); + if (!block.mayHaveNull()) { + nestedWriter.writeRepetitionLevels(parentLevel, positionsCount); + offset += positionsCount; + return; + } + + for (int position = offset; position < offset + positionsCount; ) { + if (block.isNull(position)) { + encoder.writeInteger(parentLevel); + position++; + } + else { + int consecutiveNonNullsCount = 1; + position++; + while (position < offset + positionsCount && !block.isNull(position)) { + position++; + consecutiveNonNullsCount++; + } + nestedWriter.writeRepetitionLevels(parentLevel, consecutiveNonNullsCount); + } + } + offset += positionsCount; + } + }; + } + } + static class ColumnMapRepLevelWriterProvider implements RepLevelWriterProvider { diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/ExpressionConverter.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/ExpressionConverter.java index 5954aeb5b799..1d6fe24d4855 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/ExpressionConverter.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/ExpressionConverter.java @@ -65,6 +65,7 @@ import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_MICROSECOND; import static io.trino.spi.type.UuidType.UUID; import static io.trino.spi.type.UuidType.trinoUuidToJavaUuid; +import static io.trino.spi.type.VariantType.VARIANT; import static java.lang.Float.intBitsToFloat; import static java.lang.Math.toIntExact; import static java.lang.String.format; @@ -92,6 +93,11 @@ public static boolean isConvertibleToIcebergExpression(Domain domain) return false; } + if (domain.getType() == VARIANT) { + // Iceberg does not support filtering on VARIANT type, but simple checks always work + return domain.isOnlyNull() || domain.getValues().isAll(); + } + if (domain.getType() == UUID) { // Iceberg orders UUID values differently than Trino (perhaps due to https://bugs.openjdk.org/browse/JDK-7025832), so allow only IS NULL / IS NOT NULL checks return domain.isOnlyNull() || domain.getValues().isAll(); diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroDataConversion.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroDataConversion.java index 6369f518df16..52027e58d87e 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroDataConversion.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroDataConversion.java @@ -25,6 +25,7 @@ import io.trino.spi.block.RowBlockBuilder; import io.trino.spi.block.SqlMap; import io.trino.spi.block.SqlRow; +import io.trino.spi.block.VariantBlockBuilder; import io.trino.spi.type.ArrayType; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Decimals; @@ -36,11 +37,15 @@ import io.trino.spi.type.Type; import io.trino.spi.type.VarbinaryType; import io.trino.spi.type.VarcharType; +import io.trino.spi.type.VariantType; import jakarta.annotation.Nullable; import org.apache.iceberg.Schema; import org.apache.iceberg.data.GenericRecord; import org.apache.iceberg.data.Record; import org.apache.iceberg.types.Types; +import org.apache.iceberg.variants.Variant; +import org.apache.iceberg.variants.VariantMetadata; +import org.apache.iceberg.variants.VariantValue; import java.math.BigDecimal; import java.math.BigInteger; @@ -62,6 +67,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.airlift.slice.Slices.wrappedHeapBuffer; import static io.trino.plugin.iceberg.util.Timestamps.getTimestampTzMicros; import static io.trino.plugin.iceberg.util.Timestamps.timestampToNanos; import static io.trino.plugin.iceberg.util.Timestamps.timestampTzFromMicros; @@ -87,6 +93,7 @@ import static io.trino.spi.type.UuidType.trinoUuidToJavaUuid; import static io.trino.spi.type.VarbinaryType.VARBINARY; import static java.lang.Float.floatToRawIntBits; +import static java.nio.ByteOrder.LITTLE_ENDIAN; import static java.util.Objects.requireNonNull; import static org.apache.iceberg.types.Type.TypeID.FIXED; import static org.apache.iceberg.util.DateTimeUtil.microsFromTimestamp; @@ -265,6 +272,18 @@ public static Object toIcebergAvroObject(Type type, org.apache.iceberg.types.Typ return record; } + if (type instanceof VariantType variantType) { + // Iceberg's Avro DataWriter requires org.apache.iceberg.variants.Variant objects + // for variant type serialization. This is the only place we bridge to Iceberg's variant class. + io.trino.spi.variant.Variant variant = variantType.getObject(block, position); + + ByteBuffer metadataBuffer = variant.metadata().toSlice().toByteBuffer().order(LITTLE_ENDIAN); + ByteBuffer valueBuffer = variant.data().toByteBuffer().order(LITTLE_ENDIAN); + + VariantMetadata metadata = VariantMetadata.from(metadataBuffer); + VariantValue value = VariantValue.from(metadata, valueBuffer); + return Variant.of(metadata, value); + } throw new TrinoException(NOT_SUPPORTED, "unsupported type: " + type); } @@ -312,8 +331,9 @@ public static void serializeToTrinoBlock(Type type, org.apache.iceberg.types.Typ if (type instanceof VarbinaryType) { if (icebergType.typeId().equals(FIXED)) { VARBINARY.writeSlice(builder, Slices.wrappedBuffer((byte[]) object)); + return; } - VARBINARY.writeSlice(builder, Slices.wrappedHeapBuffer((ByteBuffer) object)); + VARBINARY.writeSlice(builder, wrappedHeapBuffer((ByteBuffer) object)); return; } if (type.equals(DATE)) { @@ -384,6 +404,24 @@ public static void serializeToTrinoBlock(Type type, org.apache.iceberg.types.Typ }); return; } + if (type instanceof VariantType) { + // Iceberg's Avro reader returns org.apache.iceberg.variants.Variant objects. + // This is the only place we bridge from Iceberg's variant class back to Trino. + Variant variant = (Variant) object; + VariantMetadata metadata = variant.metadata(); + VariantValue value = variant.value(); + + ByteBuffer metadataBuffer = ByteBuffer.allocate(metadata.sizeInBytes()); + metadata.writeTo(metadataBuffer, 0); + metadataBuffer.rewind(); + + ByteBuffer valueBuffer = ByteBuffer.allocate(value.sizeInBytes()); + value.writeTo(valueBuffer, 0); + valueBuffer.rewind(); + + ((VariantBlockBuilder) builder).writeEntry(wrappedHeapBuffer(metadataBuffer), wrappedHeapBuffer(valueBuffer)); + return; + } throw new TrinoException(NOT_SUPPORTED, "unsupported type: " + type); } } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java index 28e119a909b6..2b7a08cd66fc 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java @@ -440,6 +440,7 @@ import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_MILLISECOND; import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.spi.type.VariantType.VARIANT; import static java.lang.Boolean.parseBoolean; import static java.lang.Math.floorDiv; import static java.lang.Math.max; @@ -3424,6 +3425,7 @@ private TableStatisticsMetadata getStatisticsCollectionMetadata( io.trino.spi.type.Type type = column.getType(); return !(type instanceof MapType || type instanceof ArrayType || type instanceof RowType); // is scalar type }) + .filter(column -> column.getType() != VARIANT) // variant does not support NDV statistics .map(ColumnMetadata::getName) .collect(toImmutableSet()); diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergParquetColumnIOConverter.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergParquetColumnIOConverter.java index 30351c7ac192..c27ffffadca3 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergParquetColumnIOConverter.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergParquetColumnIOConverter.java @@ -17,6 +17,7 @@ import io.trino.parquet.Field; import io.trino.parquet.GroupField; import io.trino.parquet.PrimitiveField; +import io.trino.parquet.VariantField; import io.trino.spi.type.ArrayType; import io.trino.spi.type.MapType; import io.trino.spi.type.RowType; @@ -33,6 +34,8 @@ import static io.trino.parquet.ParquetTypeUtils.getArrayElementColumn; import static io.trino.parquet.ParquetTypeUtils.getMapKeyValueColumn; import static io.trino.parquet.ParquetTypeUtils.lookupColumnById; +import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.spi.type.VariantType.VARIANT; import static java.util.Objects.requireNonNull; import static org.apache.parquet.schema.Type.Repetition.OPTIONAL; @@ -96,10 +99,46 @@ public static Optional constructField(FieldContext context, ColumnIO colu Optional field = constructField(new FieldContext(arrayType.getElementType(), elementIdentity), getArrayElementColumn(groupColumnIO.getChild(0))); return Optional.of(new GroupField(type, repetitionLevel, definitionLevel, required, ImmutableList.of(field))); } + if (type == VARIANT) { + GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO; + + // Expect the Iceberg VARIANT Parquet shape: + // optional group variant (VARIANT) { + // required binary metadata; + // required binary value; + // } + if (groupColumnIO.getChildrenCount() != 2) { + throw new IllegalArgumentException("Invalid VARIANT column, expected exactly 2 children but found: " + groupColumnIO.getChildrenCount()); + } + + // Both should be primitive binary columns + PrimitiveColumnIO metadataPrimitive = getRequiredPrimitiveChild(groupColumnIO, "metadata"); + PrimitiveColumnIO valuePrimitive = getRequiredPrimitiveChild(groupColumnIO, "value"); + + // metadata and value are required in unshredded form + Field metadataField = new PrimitiveField(VARBINARY, true, metadataPrimitive.getColumnDescriptor(), metadataPrimitive.getId()); + Field valueField = new PrimitiveField(VARBINARY, true, valuePrimitive.getColumnDescriptor(), valuePrimitive.getId()); + + return Optional.of(new VariantField(type, repetitionLevel, definitionLevel, required, valueField, metadataField)); + } PrimitiveColumnIO primitiveColumnIO = (PrimitiveColumnIO) columnIO; return Optional.of(new PrimitiveField(type, required, primitiveColumnIO.getColumnDescriptor(), primitiveColumnIO.getId())); } + private static PrimitiveColumnIO getRequiredPrimitiveChild(GroupColumnIO groupColumnIO, String childName) + { + ColumnIO child = groupColumnIO.getChild(childName); + if (child == null) { + throw new IllegalArgumentException("Invalid VARIANT column, missing child '%s' in parent group '%s'" + .formatted(childName, groupColumnIO.getType().getName())); + } + if (!(child instanceof PrimitiveColumnIO primitiveChild)) { + throw new IllegalArgumentException("Invalid VARIANT column, child '%s' in parent group '%s' must be primitive but is %s" + .formatted(childName, groupColumnIO.getType().getName(), child.getClass().getSimpleName())); + } + return primitiveChild; + } + public record FieldContext(Type type, ColumnIdentity columnIdentity) { public FieldContext diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergUtil.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergUtil.java index 1d85f8fda563..40b871ee1f2b 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergUtil.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergUtil.java @@ -565,6 +565,10 @@ private static Stream> primitiveFieldTypes(NestedF return primitiveFieldTypes(fieldType.asNestedType().fields()); } + if (fieldType.isVariantType()) { + return Stream.empty(); + } + throw new IllegalStateException("Unsupported field type: " + nestedField); } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/TableStatisticsWriter.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/TableStatisticsWriter.java index 8e2588c029b0..39c132629519 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/TableStatisticsWriter.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/TableStatisticsWriter.java @@ -130,6 +130,9 @@ private GenericStatisticsFile writeStatisticsFile(ConnectorSession session, Tabl if (type instanceof Type.NestedType nestedType) { return nestedType.fields(); } + if (type instanceof Types.VariantType) { + return ImmutableList.of(); + } if (type instanceof Type.PrimitiveType) { return ImmutableList.of(); } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/TypeConverter.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/TypeConverter.java index b7df1f27ac87..1f14e88cd4ee 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/TypeConverter.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/TypeConverter.java @@ -55,6 +55,7 @@ import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MICROS; import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_NANOS; import static io.trino.spi.type.UuidType.UUID; +import static io.trino.spi.type.VariantType.VARIANT; import static java.lang.String.format; import static java.util.Locale.ENGLISH; @@ -107,8 +108,7 @@ public static Type toTrinoType(org.apache.iceberg.types.Type type, TypeManager t .map(field -> new RowType.Field(Optional.of(field.name()), toTrinoType(field.type(), typeManager))) .collect(toImmutableList())); case VARIANT: - // TODO https://github.com/trinodb/trino/issues/24538 Support variant type - break; + return VARIANT; case GEOMETRY: case GEOGRAPHY: case UNKNOWN: @@ -174,6 +174,9 @@ private static org.apache.iceberg.types.Type toIcebergTypeInternal(Type type, Op if (type.equals(UUID)) { return Types.UUIDType.get(); } + if (type.equals(VARIANT)) { + return Types.VariantType.get(); + } if (type instanceof RowType rowType) { return fromRow(rowType, columnIdentity, nextFieldId); } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/util/HiveSchemaUtil.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/util/HiveSchemaUtil.java index 603736a8cc41..0dca5351ff78 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/util/HiveSchemaUtil.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/util/HiveSchemaUtil.java @@ -46,8 +46,7 @@ private static String convertToTypeString(Type type) case FIXED, BINARY -> "binary"; case DECIMAL -> "decimal(%s,%s)".formatted(((DecimalType) type).precision(), ((DecimalType) type).scale()); case UNKNOWN, GEOMETRY, GEOGRAPHY -> throw new TrinoException(NOT_SUPPORTED, "Unsupported Iceberg type: " + type); - // TODO https://github.com/trinodb/trino/issues/24538 Support variant type - case VARIANT -> throw new TrinoException(NOT_SUPPORTED, "Unsupported Iceberg type: VARIANT"); + case VARIANT -> "struct"; case LIST -> "array<%s>".formatted(convert(type.asListType().elementType())); case MAP -> "map<%s,%s>".formatted(convert(type.asMapType().keyType()), convert(type.asMapType().valueType())); case STRUCT -> "struct<%s>".formatted(type.asStructType().fields().stream() diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/util/OrcMetrics.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/util/OrcMetrics.java index a9660de54aa5..b73f27de3814 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/util/OrcMetrics.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/util/OrcMetrics.java @@ -214,8 +214,11 @@ private static void populateExcludedColumns(ColumnMetadata orcColumns, } return; case STRUCT: + // Variant types are stored as STRUCT with iceberg.variant-type=true + // The children (metadata, value) are internal and should be excluded + boolean isVariantType = "true".equals(orcColumn.getAttributes().get(OrcTypeConverter.ICEBERG_VARIANT_TYPE_KIND)); for (OrcColumnId child : orcColumn.getFieldTypeIndexes()) { - populateExcludedColumns(orcColumns, child, exclude, excludedColumns); + populateExcludedColumns(orcColumns, child, exclude || isVariantType, excludedColumns); } return; default: diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/util/OrcTypeConverter.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/util/OrcTypeConverter.java index 4b6c39fbc6b5..4813ebaa5c3c 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/util/OrcTypeConverter.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/util/OrcTypeConverter.java @@ -46,6 +46,7 @@ public final class OrcTypeConverter public static final String ICEBERG_TIMESTAMP_UNIT = "iceberg.timestamp-unit"; public static final String ICEBERG_TIMESTAMP_UNIT_MICROS = "MICROS"; public static final String ICEBERG_TIMESTAMP_UNIT_NANOS = "NANOS"; + public static final String ICEBERG_VARIANT_TYPE_KIND = "iceberg.variant-type"; private OrcTypeConverter() {} @@ -99,7 +100,8 @@ private static List toOrcType(int nextFieldTypeIndex, Type type, Map throw new TrinoException(NOT_SUPPORTED, "Unsupported Iceberg type: " + type); + case VARIANT -> toOrcVariantType(nextFieldTypeIndex, attributes); + case GEOMETRY, GEOGRAPHY, UNKNOWN -> throw new TrinoException(NOT_SUPPORTED, "Unsupported Iceberg type: " + type); case STRUCT -> toOrcStructType(nextFieldTypeIndex, (StructType) type, attributes); case LIST -> toOrcListType(nextFieldTypeIndex, (ListType) type, attributes); case MAP -> toOrcMapType(nextFieldTypeIndex, (MapType) type, attributes); @@ -189,4 +191,47 @@ private static List toOrcMapType(int nextFieldTypeIndex, MapType mapTyp .addAll(valueTypes) .build(); } + + private static List toOrcVariantType(int nextFieldTypeIndex, Map attributes) + { + // Variant is stored as a struct with two binary fields: metadata and value + // The struct is marked with iceberg.variant-type=true to identify it as a variant + int metadataFieldIndex = nextFieldTypeIndex + 1; + int valueFieldIndex = nextFieldTypeIndex + 2; + + Map variantAttributes = ImmutableMap.builder() + .putAll(attributes) + .put(ICEBERG_VARIANT_TYPE_KIND, "true") + .buildOrThrow(); + + OrcType metadataType = new OrcType( + OrcTypeKind.BINARY, + ImmutableList.of(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableMap.of(ORC_ICEBERG_REQUIRED_KEY, "true")); + + OrcType valueType = new OrcType( + OrcTypeKind.BINARY, + ImmutableList.of(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableMap.of(ORC_ICEBERG_REQUIRED_KEY, "true")); + + return ImmutableList.of( + new OrcType( + OrcTypeKind.STRUCT, + ImmutableList.of(new OrcColumnId(metadataFieldIndex), new OrcColumnId(valueFieldIndex)), + ImmutableList.of("metadata", "value"), + Optional.empty(), + Optional.empty(), + Optional.empty(), + variantAttributes), + metadataType, + valueType); + } } diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java index fdc930664841..51c1283d2f7a 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java @@ -4913,10 +4913,12 @@ public void testAllAvailableTypes() " a_timestamp timestamp(6), " + " a_timestamptz timestamp(6) with time zone, " + " a_uuid uuid, " + + " a_variant variant, " + " a_row row(id integer, vc varchar), " + " an_array array(varchar), " + " a_map map(integer, varchar) " + - ")"); + ")" + + " WITH (FORMAT_VERSION = 3)"); String values = "VALUES (" + "true, " + @@ -4933,11 +4935,12 @@ public void testAllAvailableTypes() "TIMESTAMP '2021-07-24 03:43:57.987654'," + "TIMESTAMP '2021-07-24 04:43:57.987654 UTC', " + "UUID '20050910-1330-11e9-ffff-2a86e4085a59', " + + "CAST(42 as VARIANT), " + "CAST(ROW(42, 'this is a random value') AS ROW(id int, vc varchar)), " + "ARRAY[VARCHAR 'uno', 'dos', 'tres'], " + "map(ARRAY[1,2], ARRAY['ek', VARCHAR 'one'])) "; - String nullValues = nCopies(17, "NULL").stream() + String nullValues = nCopies(18, "NULL").stream() .collect(joining(", ", "VALUES (", ")")); assertUpdate("INSERT INTO test_all_types " + values, 1); @@ -4963,6 +4966,7 @@ public void testAllAvailableTypes() "AND a_timestamp = TIMESTAMP '2021-07-24 03:43:57.987654' " + "AND a_timestamptz = TIMESTAMP '2021-07-24 04:43:57.987654 UTC' " + "AND a_uuid = UUID '20050910-1330-11e9-ffff-2a86e4085a59' " + + "AND a_variant = CAST(42 as VARIANT) " + "AND a_row = CAST(ROW(42, 'this is a random value') AS ROW(id int, vc varchar)) " + "AND an_array = ARRAY[VARCHAR 'uno', 'dos', 'tres'] " + "AND a_map = map(ARRAY[1,2], ARRAY['ek', VARCHAR 'one']) " + @@ -4984,6 +4988,7 @@ public void testAllAvailableTypes() "AND a_timestamp IS NULL " + "AND a_timestamptz IS NULL " + "AND a_uuid IS NULL " + + "AND a_variant IS NULL " + "AND a_row IS NULL " + "AND an_array IS NULL " + "AND a_map IS NULL " + @@ -5010,6 +5015,7 @@ public void testAllAvailableTypes() " ('a_timestamp', NULL, 1e0, 0.5e0, NULL, " + (format == ORC ? "'2021-07-24 03:43:57.987000', '2021-07-24 03:43:57.987999'" : "'2021-07-24 03:43:57.987654', '2021-07-24 03:43:57.987654'") + "), " + " ('a_timestamptz', NULL, 1e0, 0.5e0, NULL, '2021-07-24 04:43:57.987 UTC', '2021-07-24 04:43:57.987 UTC'), " + " ('a_uuid', NULL, 1e0, 0.5e0, NULL, NULL, NULL), " + + " ('a_variant', NULL, NULL, " + (format == ORC ? "0.5e0" : "NULL") + ", NULL, NULL, NULL), " + " ('a_row', NULL, NULL, " + (format == ORC ? "0.5" : "NULL") + ", NULL, NULL, NULL), " + " ('an_array', NULL, NULL, " + (format == ORC ? "0.5" : "NULL") + ", NULL, NULL, NULL), " + " ('a_map', NULL, NULL, " + (format == ORC ? "0.5" : "NULL") + ", NULL, NULL, NULL), " + @@ -5033,6 +5039,7 @@ public void testAllAvailableTypes() " ('a_timestamp', NULL, 1e0, 0.1e0, NULL, NULL, NULL), " + " ('a_timestamptz', NULL, 1e0, 0.1e0, NULL, NULL, NULL), " + " ('a_uuid', NULL, 1e0, 0.1e0, NULL, NULL, NULL), " + + " ('a_variant', NULL, NULL, NULL, NULL, NULL, NULL), " + " ('a_row', NULL, NULL, NULL, NULL, NULL, NULL), " + " ('an_array', NULL, NULL, NULL, NULL, NULL, NULL), " + " ('a_map', NULL, NULL, NULL, NULL, NULL, NULL), " + @@ -5059,6 +5066,7 @@ public void testAllAvailableTypes() " ('a_timestamp', NULL, 1e0, 0.5e0, NULL, " + (format == ORC ? "'2021-07-24 03:43:57.987000', '2021-07-24 03:43:57.987999'" : "'2021-07-24 03:43:57.987654', '2021-07-24 03:43:57.987654'") + "), " + " ('a_timestamptz', NULL, 1e0, 0.5e0, NULL, '2021-07-24 04:43:57.987 UTC', '2021-07-24 04:43:57.987 UTC'), " + " ('a_uuid', NULL, 1e0, 0.5e0, NULL, NULL, NULL), " + + " ('a_variant', NULL, NULL, " + (format == ORC ? "0.5e0" : "NULL") + ", NULL, NULL, NULL), " + " ('a_row', NULL, NULL, " + (format == ORC ? "0.5" : "NULL") + ", NULL, NULL, NULL), " + " ('an_array', NULL, NULL, " + (format == ORC ? "0.5" : "NULL") + ", NULL, NULL, NULL), " + " ('a_map', NULL, NULL, " + (format == ORC ? "0.5" : "NULL") + ", NULL, NULL, NULL), " + @@ -5082,6 +5090,7 @@ public void testAllAvailableTypes() " ('a_timestamp', NULL, 1e0, 0.1e0, NULL, NULL, NULL), " + " ('a_timestamptz', NULL, 1e0, 0.1e0, NULL, NULL, NULL), " + " ('a_uuid', NULL, 1e0, 0.1e0, NULL, NULL, NULL), " + + " ('a_variant', NULL, NULL, NULL, NULL, NULL, NULL), " + " ('a_row', NULL, NULL, NULL, NULL, NULL, NULL), " + " ('an_array', NULL, NULL, NULL, NULL, NULL, NULL), " + " ('a_map', NULL, NULL, NULL, NULL, NULL, NULL), " + @@ -5184,6 +5193,69 @@ public void testAllAvailableTypes() assertUpdate("DROP TABLE test_all_types"); } + @Test + public void testVariantValueTypes() + { + try (TestTable table = newTrinoTable( + "test_variant_value_types", + "(v variant) WITH (FORMAT_VERSION = 3)")) { + String values = """ + VALUES + CAST(NULL AS VARIANT), + CAST(true AS VARIANT), + CAST(TINYINT '1' AS VARIANT), + CAST(SMALLINT '1' AS VARIANT), + CAST(INTEGER '1' AS VARIANT), + CAST(BIGINT '1' AS VARIANT), + CAST(REAL '1.0' AS VARIANT), + CAST(DOUBLE '1.0' AS VARIANT), + CAST(DECIMAL '1.23' AS VARIANT), + CAST('hello' AS VARIANT), + CAST(X'000102f0feff' AS VARIANT), + CAST(DATE '2021-07-24' AS VARIANT), + CAST(TIME '02:43:57.987654' AS VARIANT), + CAST(TIMESTAMP '2021-07-24 03:43:57.987654' AS VARIANT), + CAST(TIMESTAMP '2021-07-24 03:43:57.987654321' AS VARIANT), + CAST(TIMESTAMP '2021-07-24 04:43:57.987654 UTC' AS VARIANT), + CAST(TIMESTAMP '2021-07-24 04:43:57.987654321 UTC' AS VARIANT), + CAST(UUID '20050910-1330-11e9-ffff-2a86e4085a59' AS VARIANT), + CAST(ARRAY['uno', 'dos'] AS VARIANT), + CAST(MAP(ARRAY['a', 'b'], ARRAY[1, 2]) AS VARIANT), + CAST(CAST(ROW(42, 'x') AS ROW(id integer, vc varchar)) AS VARIANT) + """; + + assertUpdate("INSERT INTO " + table.getName() + " " + values, 21); + + assertThat(query("SELECT * FROM " + table.getName())) + .matches(values); + } + } + + @Test + public void testNestedVariant() + { + // Tests variant nested inside array and row types + try (TestTable table = newTrinoTable( + "test_nested_variant", + "(" + + "variant_array array(variant), " + + "variant_row row(v variant, i integer)) " + + "WITH (FORMAT_VERSION = 3)")) { + assertUpdate("INSERT INTO " + table.getName() + " VALUES (" + + "ARRAY[CAST(1 AS VARIANT), CAST('hello' AS VARIANT), CAST(NULL AS VARIANT)], " + + "CAST(ROW(42, 123) AS ROW(v variant, i integer)))", 1); + assertUpdate("INSERT INTO " + table.getName() + " VALUES (NULL, NULL)", 1); + + assertThat(query("SELECT * FROM " + table.getName())) + .matches(""" + VALUES ( + ARRAY[CAST(1 AS VARIANT), CAST('hello' AS VARIANT), CAST(NULL AS VARIANT)], + CAST(ROW(42, 123) AS ROW(v variant, i integer))), + (NULL, NULL) + """); + } + } + @Test public void testRepartitionDataOnCtas() { diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergAvroDataConversion.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergAvroDataConversion.java new file mode 100644 index 000000000000..7dc1e2d571cb --- /dev/null +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergAvroDataConversion.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.iceberg; + +import io.airlift.slice.Slices; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.RunLengthEncodedBlock; +import org.apache.iceberg.types.Types; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; + +import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.spi.type.VariantType.VARIANT; +import static org.assertj.core.api.Assertions.assertThat; + +class TestIcebergAvroDataConversion +{ + @Test + void testToIcebergAvroObjectWithDictionaryWrappedVariantBlock() + { + BlockBuilder blockBuilder = VARIANT.createBlockBuilder(null, 2); + VARIANT.writeObject(blockBuilder, io.trino.spi.variant.Variant.ofLong(41)); + VARIANT.writeObject(blockBuilder, io.trino.spi.variant.Variant.ofLong(42)); + Block dictionary = DictionaryBlock.create(3, blockBuilder.build(), new int[] {1, 0, 1}); + + Object converted = IcebergAvroDataConversion.toIcebergAvroObject(VARIANT, Types.VariantType.get(), dictionary, 0); + + assertThat(converted).isInstanceOf(org.apache.iceberg.variants.Variant.class); + assertThat(org.apache.iceberg.variants.Variant.toString((org.apache.iceberg.variants.Variant) converted)) + .contains("type=INT64") + .contains("value=42"); + } + + @Test + void testToIcebergAvroObjectWithRunLengthEncodedVariantBlock() + { + BlockBuilder blockBuilder = VARIANT.createBlockBuilder(null, 1); + VARIANT.writeObject(blockBuilder, io.trino.spi.variant.Variant.ofString("hello")); + Block rle = RunLengthEncodedBlock.create(blockBuilder.build(), 3); + + Object converted = IcebergAvroDataConversion.toIcebergAvroObject(VARIANT, Types.VariantType.get(), rle, 1); + + assertThat(converted).isInstanceOf(org.apache.iceberg.variants.Variant.class); + assertThat(org.apache.iceberg.variants.Variant.toString((org.apache.iceberg.variants.Variant) converted)) + .contains("type=STRING") + .contains("value=hello"); + } + + @Test + void testSerializeFixedBinaryToTrinoBlock() + { + BlockBuilder blockBuilder = VARBINARY.createBlockBuilder(null, 1); + + IcebergAvroDataConversion.serializeToTrinoBlock(VARBINARY, Types.FixedType.ofLength(3), blockBuilder, new byte[] {1, 2, 3}); + + assertThat(VARBINARY.getSlice(blockBuilder.build(), 0)).isEqualTo(Slices.wrappedBuffer(new byte[] {1, 2, 3})); + } + + @Test + void testSerializeBinaryToTrinoBlock() + { + BlockBuilder blockBuilder = VARBINARY.createBlockBuilder(null, 1); + + IcebergAvroDataConversion.serializeToTrinoBlock(VARBINARY, Types.BinaryType.get(), blockBuilder, ByteBuffer.wrap(new byte[] {4, 5, 6})); + + assertThat(VARBINARY.getSlice(blockBuilder.build(), 0)).isEqualTo(Slices.wrappedBuffer(new byte[] {4, 5, 6})); + } +} diff --git a/pom.xml b/pom.xml index 451ecbf5f0cf..73b73f072273 100644 --- a/pom.xml +++ b/pom.xml @@ -192,7 +192,7 @@ 1.12.797 4.17.0 8.1.1 - 123 + 126 1.24 2.0.0 v24.12.0 diff --git a/spark-variant-compatibility.md b/spark-variant-compatibility.md new file mode 100644 index 000000000000..3c9e24a3efb3 --- /dev/null +++ b/spark-variant-compatibility.md @@ -0,0 +1,229 @@ +# Spark Variant Compatibility + +This note records the current Spark/Iceberg compatibility story for `VARIANT` in this branch. + +## What Trino Implements + +Trino implements Iceberg `VARIANT` support in this branch. + +On the Trino side, we currently verify Spark interop through product tests in: + +- `testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java` +- `testing/trino-product-tests-launcher/src/main/java/io/trino/tests/product/launcher/env/environment/EnvSinglenodeSparkIcebergRest.java` + +Those tests currently cover: + +- `Trino writes -> Spark reads` for `AVRO` +- `Trino writes -> Spark reads` for `PARQUET` +- `Spark writes -> Trino reads` for `AVRO` +- `Spark writes -> Trino reads` for `PARQUET` + +Trino does not expose Iceberg Hadoop catalog as a first-class catalog type. The supported catalog types are defined in: + +- `plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/CatalogType.java` + +There is no `HADOOP` entry there. + +However, Trino can still interoperate with Hadoop-catalog Iceberg tables after registration by location. That path is +exercised in: + +- `plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergRegisterTableProcedure.java` + +Specifically, `testRegisterHadoopTableAndRead()` shows Trino registering and reading a Hadoop-created Iceberg table. + +## What Spark Implements + +Spark itself has native `variant` support in Spark 4: + +- [Spark `VariantType`](https://github.com/apache/spark/blob/master/sql/api/src/main/scala/org/apache/spark/sql/types/VariantType.scala) + +The main compatibility limitations here are not about Spark SQL syntax. They are about: + +- which Iceberg catalog path is being used +- which Iceberg REST server version is being used +- which file formats Spark Iceberg can actually read and write for `VARIANT` +- which Iceberg variant primitive encodings Spark's runtime understands + +## Catalog Support + +### Hive Metastore catalog + +Spark `VARIANT` table creation through the Iceberg Hive Metastore path is not supported. + +Reason: + +- upstream Iceberg Hive schema conversion still has no `VARIANT` mapping and throws unsupported-type errors + +Source: + +- [Iceberg `HiveSchemaUtil.java`](https://github.com/apache/iceberg/blob/main/hive-metastore/src/main/java/org/apache/iceberg/hive/HiveSchemaUtil.java) + +This is an upstream Iceberg Hive-catalog limitation, not a Trino syntax issue. + +### REST catalog + +Spark `VARIANT` works with the Iceberg REST catalog only if the REST server is new enough. + +Old REST server behavior: + +- older Iceberg schema parsing used `Types.fromPrimitiveString(...)` +- that rejects `variant` + +Source: + +- [Iceberg 1.6.0 `SchemaParser.java`](https://github.com/apache/iceberg/blob/apache-iceberg-1.6.0/core/src/main/java/org/apache/iceberg/SchemaParser.java) + +Newer REST server behavior: + +- newer Iceberg schema parsing uses `Types.fromTypeName(...)` +- that accepts `variant` + +Source: + +- [Iceberg 1.10.1 `SchemaParser.java`](https://github.com/apache/iceberg/blob/apache-iceberg-1.10.1/core/src/main/java/org/apache/iceberg/SchemaParser.java) + +In our test environment, the original REST setup was on the older `tabulario/iceberg-rest:1.5.0` line, which is why +Spark table creation failed there. We now use a newer local REST image in the product test environment. + +### Hadoop/storage-based catalog + +Spark supports `VARIANT` with Iceberg's Hadoop catalog. + +Here "storage-based catalog" means Iceberg's Hadoop catalog, where table metadata is stored directly in the warehouse +instead of going through Hive Metastore or REST. + +Upstream evidence: + +- Iceberg's own Spark variant test explicitly uses a Hadoop catalog to avoid Hive schema conversion + +Source: + +- [Iceberg `TestSparkVariantRead.java`](https://github.com/apache/iceberg/blob/main/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkVariantRead.java) + +## File Format Support + +### AVRO + +This is implemented in Trino and currently works in our Spark compatibility tests. + +### PARQUET + +This is implemented in Trino and currently works in our Spark compatibility tests. + +### ORC + +Spark does not currently support the `VARIANT` ORC path well enough for us to include it in the compatibility matrix. + +The important statement here is about Spark/Iceberg support, not Trino support: + +- Trino implements Iceberg `VARIANT` +- Spark Iceberg ORC read/write support for `VARIANT` is the part that is not working + +Why we consider it unsupported: + +- the Spark Iceberg ORC reader/writer code paths do not contain variant-aware handling analogous to the variant-specific + Spark tests +- in local product testing, Spark ORC variant interop failed in those code paths + +Sources: + +- [Iceberg `SparkOrcReader.java`](https://github.com/apache/iceberg/blob/main/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java) +- [Iceberg `SparkOrcWriter.java`](https://github.com/apache/iceberg/blob/main/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java) + +So the current state is: + +- `AVRO`: yes +- `PARQUET`: yes +- `ORC`: no, because Spark/Iceberg ORC support for `VARIANT` is not there yet in a usable form for interop + +## Variant Primitive Encodings + +Trino follows the Iceberg variant encoding set here. The compatibility gap is that Spark 4's variant runtime only +understands a subset of the Iceberg variant primitive encodings. + +### Currently verified in both directions + +These values are currently covered by product tests for both `Trino writes -> Spark reads` and +`Spark writes -> Trino reads`: + +- `null` +- variant `null` +- `boolean` +- `tinyint` +- `smallint` +- `integer` +- `bigint` +- `real` +- `double` +- `decimal` +- `varchar` +- `date` +- array values + +### Currently verified only for `Trino writes -> Spark reads` + +These values are already covered when Trino writes Iceberg `VARIANT` data and Spark reads it, but are not yet in the +verified Spark-written overlap: + +- `varbinary` +- map/object values +- row/object values +- `uuid` + +### Iceberg variant primitive encodings that Spark does not currently understand + +These Iceberg variant primitive encodings are not currently understood by the Spark 4.0 variant runtime shipped in our +test image: + +- `TIME_NTZ_MICROS` +- `TIMESTAMP_UTC_NANOS` +- `TIMESTAMP_NTZ_NANOS` + +Why: + +1. Iceberg 1.10.1 itself defines these physical types in its variant model: + - [Iceberg `PhysicalType.java`](https://github.com/apache/iceberg/blob/apache-iceberg-1.10.1/api/src/main/java/org/apache/iceberg/variants/PhysicalType.java) + - [Iceberg `Primitives.java`](https://github.com/apache/iceberg/blob/apache-iceberg-1.10.1/api/src/main/java/org/apache/iceberg/variants/Primitives.java) +2. Trino's variant header numbering matches those same physical encodings: + - `core/trino-spi/src/main/java/io/trino/spi/variant/Header.java` +3. Spark's runtime `VariantUtil` only recognizes the primitive IDs it knows about and throws + `UNKNOWN_PRIMITIVE_TYPE_IN_VARIANT` for the unknown ones: + - [Spark `VariantUtil.java`](https://github.com/apache/spark/blob/master/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java) + +In the Spark 4.0.0 runtime shipped in our test image, the unsupported IDs line up with Trino's numbering exactly: + +- `17` -> `TIME_NTZ_MICROS` +- `18` -> `TIMESTAMP_UTC_NANOS` +- `19` -> `TIMESTAMP_NTZ_NANOS` + +This is why some valid Iceberg/Trino-written `VARIANT` values still cannot be read by Spark even when the catalog and +file format are otherwise supported. + +## Summary + +The short version is: + +- Trino implements Iceberg `VARIANT` +- Spark also implements `variant` +- Hive Metastore catalog is still blocked upstream in Iceberg +- REST catalog works only with a new enough Iceberg REST server +- Hadoop/storage-based catalog works on the Spark side +- ORC is currently blocked by Spark/Iceberg support, not by Trino +- Spark still does not understand every Iceberg variant primitive encoding + +## References + +- [Iceberg `HiveSchemaUtil.java`](https://github.com/apache/iceberg/blob/main/hive-metastore/src/main/java/org/apache/iceberg/hive/HiveSchemaUtil.java) +- [Iceberg 1.6.0 `SchemaParser.java`](https://github.com/apache/iceberg/blob/apache-iceberg-1.6.0/core/src/main/java/org/apache/iceberg/SchemaParser.java) +- [Iceberg 1.10.1 `SchemaParser.java`](https://github.com/apache/iceberg/blob/apache-iceberg-1.10.1/core/src/main/java/org/apache/iceberg/SchemaParser.java) +- [Iceberg `TestSparkVariantRead.java`](https://github.com/apache/iceberg/blob/main/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkVariantRead.java) +- [Iceberg `SparkOrcReader.java`](https://github.com/apache/iceberg/blob/main/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java) +- [Iceberg `SparkOrcWriter.java`](https://github.com/apache/iceberg/blob/main/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java) +- [Iceberg `PhysicalType.java`](https://github.com/apache/iceberg/blob/apache-iceberg-1.10.1/api/src/main/java/org/apache/iceberg/variants/PhysicalType.java) +- [Iceberg `Primitives.java`](https://github.com/apache/iceberg/blob/apache-iceberg-1.10.1/api/src/main/java/org/apache/iceberg/variants/Primitives.java) +- [Spark `VariantUtil.java`](https://github.com/apache/spark/blob/master/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java) +- [Spark `VariantType.scala`](https://github.com/apache/spark/blob/master/sql/api/src/main/scala/org/apache/spark/sql/types/VariantType.scala) +- `core/trino-spi/src/main/java/io/trino/spi/variant/Header.java` +- `plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/CatalogType.java` +- `plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergRegisterTableProcedure.java` +- `testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java` diff --git a/testing/trino-product-tests-launcher/src/main/java/io/trino/tests/product/launcher/env/environment/EnvSinglenodeSparkIcebergRest.java b/testing/trino-product-tests-launcher/src/main/java/io/trino/tests/product/launcher/env/environment/EnvSinglenodeSparkIcebergRest.java index 2da33f6db6a9..5fd22b512262 100644 --- a/testing/trino-product-tests-launcher/src/main/java/io/trino/tests/product/launcher/env/environment/EnvSinglenodeSparkIcebergRest.java +++ b/testing/trino-product-tests-launcher/src/main/java/io/trino/tests/product/launcher/env/environment/EnvSinglenodeSparkIcebergRest.java @@ -48,7 +48,7 @@ public class EnvSinglenodeSparkIcebergRest private static final int REST_SERVER_PORT = 8181; private static final String SPARK_CONTAINER_NAME = "spark"; private static final String REST_CONTAINER_NAME = "iceberg-with-rest"; - private static final String REST_SERVER_IMAGE = "tabulario/iceberg-rest:1.5.0"; + private static final String REST_SERVER_IMAGE = "ghcr.io/trinodb/testing/iceberg-rest"; private static final String CATALOG_WAREHOUSE = "hdfs://hadoop-master:9000/user/hive/warehouse"; private final DockerFiles dockerFiles; @@ -80,10 +80,19 @@ public void extendEnvironment(Environment.Builder builder) @SuppressWarnings("resource") private DockerContainer createRESTContainer() { - DockerContainer container = new DockerContainer(REST_SERVER_IMAGE, REST_CONTAINER_NAME) + DockerContainer container = new DockerContainer(REST_SERVER_IMAGE + ":" + hadoopImagesVersion, REST_CONTAINER_NAME) .withEnv("CATALOG_WAREHOUSE", CATALOG_WAREHOUSE) .withEnv("REST_PORT", Integer.toString(REST_SERVER_PORT)) + .withEnv("HADOOP_CONF_DIR", "/etc/hadoop/conf") .withEnv("HADOOP_USER_NAME", "hive") + .withCopyFileToContainer( + forHostPath(dockerFiles.getDockerFilesHostPath( + "conf/environment/singlenode-spark-iceberg-rest/core-site.xml")), + "/etc/hadoop/conf/core-site.xml") + .withCopyFileToContainer( + forHostPath(dockerFiles.getDockerFilesHostPath( + "conf/environment/singlenode-spark-iceberg-rest/hdfs-site.xml")), + "/etc/hadoop/conf/hdfs-site.xml") .withStartupCheckStrategy(new IsRunningStartupCheckStrategy()) .waitingFor(forSelectedPorts(REST_SERVER_PORT)); diff --git a/testing/trino-product-tests-launcher/src/main/resources/docker/trino-product-tests/conf/environment/singlenode-spark-iceberg-rest/core-site.xml b/testing/trino-product-tests-launcher/src/main/resources/docker/trino-product-tests/conf/environment/singlenode-spark-iceberg-rest/core-site.xml new file mode 100644 index 000000000000..044c9921d1fb --- /dev/null +++ b/testing/trino-product-tests-launcher/src/main/resources/docker/trino-product-tests/conf/environment/singlenode-spark-iceberg-rest/core-site.xml @@ -0,0 +1,7 @@ + + + + fs.defaultFS + hdfs://hadoop-master:9000 + + diff --git a/testing/trino-product-tests-launcher/src/main/resources/docker/trino-product-tests/conf/environment/singlenode-spark-iceberg-rest/hdfs-site.xml b/testing/trino-product-tests-launcher/src/main/resources/docker/trino-product-tests/conf/environment/singlenode-spark-iceberg-rest/hdfs-site.xml new file mode 100644 index 000000000000..049ff0207b3d --- /dev/null +++ b/testing/trino-product-tests-launcher/src/main/resources/docker/trino-product-tests/conf/environment/singlenode-spark-iceberg-rest/hdfs-site.xml @@ -0,0 +1,12 @@ + + + + fs.viewfs.mounttable.hadoop-viewfs.link./default + hdfs://hadoop-master:9000/user/hive/warehouse + + + + dfs.safemode.threshold.pct + 0 + + diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java index 6d4286939c3a..57de3fdd8f61 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java @@ -2238,6 +2238,81 @@ public void testTrinoReadsSparkV3DeletionVectors(StorageFormat storageFormat) onSpark().executeQuery("DROP TABLE " + sparkTableName); } + @Test(groups = {ICEBERG_REST, PROFILE_SPECIFIC_TESTS}, dataProvider = "variantStorageFormats") + public void testSparkReadsTrinoVariantData(StorageFormat storageFormat) + { + String tableName = toLowerCase(format("test_spark_reads_trino_variant_%s_%s", storageFormat.name(), randomNameSuffix())); + String sparkTableName = sparkTableName(tableName); + String trinoTableName = trinoTableName(tableName); + + // Keep TIME_NTZ_MICROS, TIMESTAMP_UTC_NANOS, and TIMESTAMP_NTZ_NANOS out until Spark + // supports the full Iceberg VARIANT primitive set. + onTrino().executeQuery("CREATE TABLE " + trinoTableName + "(id INT, v VARIANT) " + + "WITH(format_version = 3, format = '" + storageFormat.name() + "')"); + onTrino().executeQuery( + """ + INSERT INTO %s VALUES + (1, CAST(NULL AS VARIANT)), + (2, CAST(JSON 'null' AS VARIANT)), + (3, CAST(true AS VARIANT)), + (4, CAST(TINYINT '1' AS VARIANT)), + (5, CAST(SMALLINT '1' AS VARIANT)), + (6, CAST(INTEGER '-2' AS VARIANT)), + (7, CAST(BIGINT '1234567890123' AS VARIANT)), + (8, CAST(REAL '1.5' AS VARIANT)), + (9, CAST(DOUBLE '2.5' AS VARIANT)), + (10, CAST(DECIMAL '123.45' AS VARIANT)), + (11, CAST('hello "variant"' AS VARIANT)), + (12, CAST(ARRAY[1, 2, 3] AS VARIANT)), + (13, CAST(MAP(ARRAY['a', 'b'], ARRAY[1, 2]) AS VARIANT)), + (14, CAST(CAST(ROW(42, 'x', true) AS ROW(id integer, vc varchar, flag boolean)) AS VARIANT)) + """.formatted(trinoTableName)); + + QueryResult trinoResult = onTrino().executeQuery("SELECT id, json_format(CAST(v AS JSON)) FROM " + trinoTableName + " ORDER BY id"); + QueryResult sparkResult = onSpark().executeQuery("SELECT id, to_json(v) FROM " + sparkTableName + " ORDER BY id"); + assertResultsEqual(trinoResult, sparkResult); + + onSpark().executeQuery("DROP TABLE " + sparkTableName); + } + + @Test(groups = {ICEBERG_REST, PROFILE_SPECIFIC_TESTS}, dataProvider = "variantStorageFormats") + public void testTrinoReadsSparkVariantData(StorageFormat storageFormat) + { + String tableName = toLowerCase(format("test_trino_reads_spark_variant_%s_%s", storageFormat.name(), randomNameSuffix())); + String sparkTableName = sparkTableName(tableName); + String trinoTableName = trinoTableName(tableName); + + // This is the currently verified Spark-written overlap. + // Keep TIME_NTZ_MICROS, TIMESTAMP_UTC_NANOS, and TIMESTAMP_NTZ_NANOS out until Spark + // implements the full Iceberg VARIANT primitive set. + onSpark().executeQuery("CREATE TABLE " + sparkTableName + "(id INT, v VARIANT) " + + "USING ICEBERG TBLPROPERTIES ('format-version'='3', 'write.format.default'='" + storageFormat.name() + "')"); + onSpark().executeQuery( + """ + INSERT INTO %s VALUES + (1, CAST(NULL AS VARIANT)), + (2, parse_json('null')), + (3, CAST(true AS VARIANT)), + (4, CAST(CAST(1 AS TINYINT) AS VARIANT)), + (5, CAST(CAST(1 AS SMALLINT) AS VARIANT)), + (6, CAST(CAST(-2 AS INT) AS VARIANT)), + (7, CAST(CAST(1234567890123 AS BIGINT) AS VARIANT)), + (8, CAST(CAST(1.5 AS FLOAT) AS VARIANT)), + (9, CAST(CAST(2.5 AS DOUBLE) AS VARIANT)), + (10, CAST(CAST(123.45 AS DECIMAL(5, 2)) AS VARIANT)), + (11, CAST('hello "variant"' AS VARIANT)), + (12, CAST(DATE '2021-07-24' AS VARIANT)), + (13, to_variant_object(array(1, 2, 3))), + (14, to_variant_object(named_struct('flag', true, 'id', CAST(42 AS TINYINT), 'vc', 'x'))) + """.formatted(sparkTableName)); + + QueryResult sparkResult = onSpark().executeQuery("SELECT id, to_json(v) FROM " + sparkTableName + " ORDER BY id"); + QueryResult trinoResult = onTrino().executeQuery("SELECT id, json_format(CAST(v AS JSON)) FROM " + trinoTableName + " ORDER BY id"); + assertResultsEqual(sparkResult, trinoResult); + + onSpark().executeQuery("DROP TABLE " + sparkTableName); + } + @Test(groups = {ICEBERG, PROFILE_SPECIFIC_TESTS}, dataProvider = "storageFormats") public void testDeleteAfterPartitionEvolution(StorageFormat storageFormat) { @@ -2393,6 +2468,17 @@ private io.trino.jdbc.Row.Builder rowBuilder() return io.trino.jdbc.Row.builder(); } + private static void assertResultsEqual(QueryResult first, QueryResult second) + { + assertThat(first).containsOnly(second.rows().stream() + .map(Row::new) + .collect(toImmutableList())); + + assertThat(second).containsOnly(first.rows().stream() + .map(Row::new) + .collect(toImmutableList())); + } + @DataProvider public static Object[][] specVersions() { @@ -2407,6 +2493,16 @@ public static Object[][] storageFormats() .toArray(Object[][]::new); } + @DataProvider + public static Object[][] variantStorageFormats() + { + // Spark/Iceberg VARIANT interoperability currently works for AVRO and PARQUET. + // ORC VARIANT is not supported in Spark's current Iceberg integration. + return Stream.of(StorageFormat.AVRO, StorageFormat.PARQUET) + .map(storageFormat -> new Object[] {storageFormat}) + .toArray(Object[][]::new); + } + // Provides each supported table formats paired with each delete file format. @DataProvider public static Object[][] tableFormatWithDeleteFormat() diff --git a/testing/trino-testing/src/main/java/io/trino/testing/TestingTrinoClient.java b/testing/trino-testing/src/main/java/io/trino/testing/TestingTrinoClient.java index 5476cc88a385..ea39d51e101c 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/TestingTrinoClient.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/TestingTrinoClient.java @@ -14,6 +14,7 @@ package io.trino.testing; import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slices; import io.trino.Session; import io.trino.client.IntervalDayTime; import io.trino.client.IntervalYearMonth; @@ -37,6 +38,7 @@ import io.trino.spi.type.VarcharType; import io.trino.type.SqlIntervalDayTime; import io.trino.type.SqlIntervalYearMonth; +import io.trino.util.variant.VariantWriter; import okhttp3.OkHttpClient; import java.math.BigDecimal; @@ -70,6 +72,7 @@ import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.UuidType.UUID; import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.spi.type.VariantType.VARIANT; import static io.trino.testing.MaterializedResult.DEFAULT_PRECISION; import static io.trino.type.IntervalDayTimeType.INTERVAL_DAY_TIME; import static io.trino.type.IntervalYearMonthType.INTERVAL_YEAR_MONTH; @@ -102,6 +105,7 @@ public class TestingTrinoClient .append(timestampFormat) .appendPattern(" VV") .toFormatter(); + private static final VariantWriter JSON_VARIANT_WRITER = VariantWriter.create(JSON); public TestingTrinoClient(TestingTrinoServer trinoServer, Session defaultSession) { @@ -330,6 +334,9 @@ private static Object convertToRowValue(Type type, Object value) //noinspection RedundantCast return (String) value; } + if (type == VARIANT) { + return JSON_VARIANT_WRITER.write(Slices.utf8Slice((String) value)); + } if (type instanceof ArrayType arrayType) { return ((List) value).stream() .map(element -> convertToRowValue(arrayType.getElementType(), element))