Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per our discussion, we should add a new ClientCapabilities flag, maybe RAW_VARIANT or BINARY_VARIANT, and send the client JSON if they don't support it.

For JDBC, getString() should return the JSON value (this is what Snowflake does). For getObject(), we can return our own TrinoVariant in the JDBC package, similar to TrinoIntervalDayTime.

Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public final class ClientStandardTypes
public static final String MAP = "map";
public static final String JSON = "json";
public static final String JSON_2016 = "json2016";
public static final String VARIANT = "variant";
Comment thread
electrum marked this conversation as resolved.
public static final String IPADDRESS = "ipaddress";
public static final String UUID = "uuid";
public static final String GEOMETRY = "Geometry";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
*/
package io.trino.client;

import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonToken;
import com.fasterxml.jackson.databind.json.JsonMapper;
import com.google.common.collect.ImmutableList;

import java.io.IOException;
import java.io.StringWriter;
import java.util.Base64;
import java.util.HashMap;
import java.util.LinkedList;
Expand Down Expand Up @@ -64,6 +66,8 @@
import static io.trino.client.ClientStandardTypes.UUID;
import static io.trino.client.ClientStandardTypes.VARBINARY;
import static io.trino.client.ClientStandardTypes.VARCHAR;
import static io.trino.client.ClientStandardTypes.VARIANT;
import static io.trino.client.JsonIterators.createJsonFactory;
import static java.lang.String.format;
import static java.util.Collections.unmodifiableList;
import static java.util.Collections.unmodifiableMap;
Expand All @@ -81,6 +85,7 @@ private JsonDecodingUtils() {}
private static final RealDecoder REAL_DECODER = new RealDecoder();
private static final BooleanDecoder BOOLEAN_DECODER = new BooleanDecoder();
private static final StringDecoder STRING_DECODER = new StringDecoder();
private static final VariantDecoder VARIANT_DECODER = new VariantDecoder();
private static final Base64Decoder BASE_64_DECODER = new Base64Decoder();
private static final ObjectDecoder OBJECT_DECODER = new ObjectDecoder();

Expand Down Expand Up @@ -117,6 +122,8 @@ private static TypeDecoder createTypeDecoder(ClientTypeSignature signature)
return REAL_DECODER;
case BOOLEAN:
return BOOLEAN_DECODER;
case VARIANT:
return VARIANT_DECODER;
case ARRAY:
return new ArrayDecoder(signature);
case MAP:
Expand Down Expand Up @@ -287,6 +294,21 @@ public Object decode(JsonParser parser)
}
}

private static class VariantDecoder
implements TypeDecoder
{
@Override
public Object decode(JsonParser parser)
throws IOException
{
StringWriter writer = new StringWriter();
try (JsonGenerator generator = createJsonFactory().createGenerator(writer)) {
generator.copyCurrentStructure(parser);
}
return writer.toString();
}
}

private static class Base64Decoder
implements TypeDecoder
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import io.trino.spi.block.RunLengthBlockEncoding;
import io.trino.spi.block.ShortArrayBlockEncoding;
import io.trino.spi.block.VariableWidthBlockEncoding;
import io.trino.spi.block.VariantBlockEncoding;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -56,6 +57,7 @@ public BlockEncodingManager(BlockEncodingSimdSupport blockEncodingSimdSupport)
addBlockEncoding(new LongArrayBlockEncoding(simdSupport.vectorizeNullBitPacking(), simdSupport.compressLong(), simdSupport.expandLong()));
addBlockEncoding(new Fixed12BlockEncoding(simdSupport.vectorizeNullBitPacking()));
addBlockEncoding(new Int128ArrayBlockEncoding(simdSupport.vectorizeNullBitPacking()));
addBlockEncoding(new VariantBlockEncoding(simdSupport.vectorizeNullBitPacking()));
addBlockEncoding(new DictionaryBlockEncoding());
addBlockEncoding(new ArrayBlockEncoding(simdSupport.vectorizeNullBitPacking()));
addBlockEncoding(new MapBlockEncoding(simdSupport.vectorizeNullBitPacking()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.connector.system.GlobalSystemConnector;
import io.trino.spi.TrinoException;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.CatalogSchemaFunctionName;
import io.trino.spi.function.FunctionId;
import io.trino.spi.function.LongVariableConstraint;
import io.trino.spi.function.Signature;
Expand All @@ -30,7 +32,6 @@
import io.trino.spi.type.TypeSignature;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.type.FunctionType;
import io.trino.type.JsonType;
import io.trino.type.TypeCoercion;
import io.trino.type.UnknownType;

Expand All @@ -47,10 +48,13 @@
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSortedMap.toImmutableSortedMap;
import static io.trino.metadata.GlobalFunctionCatalog.BUILTIN_SCHEMA;
import static io.trino.metadata.OperatorNameUtil.mangleOperatorName;
import static io.trino.metadata.SignatureBinder.RelationshipType.EXACT;
import static io.trino.metadata.SignatureBinder.RelationshipType.EXPLICIT_COERCION_FROM;
import static io.trino.metadata.SignatureBinder.RelationshipType.EXPLICIT_COERCION_TO;
import static io.trino.metadata.SignatureBinder.RelationshipType.IMPLICIT_COERCION;
import static io.trino.spi.function.OperatorType.CAST;
import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes;
import static io.trino.type.TypeCalculation.calculateLiteralValue;
import static io.trino.type.TypeCoercion.isCovariantTypeBase;
Expand Down Expand Up @@ -598,7 +602,8 @@ private boolean satisfiesCoercion(RelationshipType relationshipType, Type actual

private boolean canCast(Type fromType, Type toType)
{
if (toType instanceof UnknownType) {
// NULL can be cast to any type; avoid re-entering coercion cache.
if (fromType instanceof UnknownType || toType instanceof UnknownType) {
return true;
}
if (fromType instanceof RowType) {
Expand All @@ -615,14 +620,14 @@ private boolean canCast(Type fromType, Type toType)
}
return true;
}
if (toType instanceof JsonType) {
if (isRecursiveCastFromRow(toType)) {
return fromType.getTypeParameters().stream()
.allMatch(fromTypeParameter -> canCast(fromTypeParameter, toType));
}
return false;
}
if (fromType instanceof JsonType) {
if (toType instanceof RowType) {
if (toType instanceof RowType) {
if (isRecursiveCastToRow(fromType)) {
return toType.getTypeParameters().stream()
.allMatch(toTypeParameter -> canCast(fromType, toTypeParameter));
}
Expand All @@ -636,6 +641,61 @@ private boolean canCast(Type fromType, Type toType)
}
}

/// Check if there is a recursive variadic CAST from ROW.
/// This needs special handling because the cast is applied to each field of ROW individually.
private boolean isRecursiveCastFromRow(Type toType)
{
return metadata.getFunctions(null, new CatalogSchemaFunctionName(GlobalSystemConnector.NAME, BUILTIN_SCHEMA, mangleOperatorName(CAST))).stream()
.map(cast -> cast.functionMetadata().getSignature())
.anyMatch(signature -> isRecursiveCastFromRow(toType, signature));
}

private static boolean isRecursiveCastFromRow(Type toType, Signature signature)
{
// the return type must match toType
if (!toType.getTypeSignature().equals(signature.getReturnType())) {
return false;
}

// there must be exactly one type variable constraint and no long variable constraints
if (signature.getTypeVariableConstraints().size() != 1 || !signature.getLongVariableConstraints().isEmpty()) {
return false;
}
TypeVariableConstraint typeVariableConstraint = signature.getTypeVariableConstraints().getFirst();

// The argument type must be a type variable with variadic bound of "row"
return signature.getArgumentTypes().size() == 1 &&
signature.getArgumentTypes().getFirst().getBase().equals(typeVariableConstraint.getName()) &&
typeVariableConstraint.isRowType();
}

/// Check if there is a recursive variadic CAST to ROW.
/// This needs special handling because the cast is applied to each field of ROW individually.
private boolean isRecursiveCastToRow(Type fromType)
{
return metadata.getFunctions(null, new CatalogSchemaFunctionName(GlobalSystemConnector.NAME, BUILTIN_SCHEMA, mangleOperatorName(CAST))).stream()
.map(cast -> cast.functionMetadata().getSignature())
.anyMatch(signature -> isRecursiveCastToRow(fromType, signature));
}

private static boolean isRecursiveCastToRow(Type fromType, Signature signature)
{
// the argument type must match fromType
if (signature.getArgumentTypes().size() != 1 || !fromType.getTypeSignature().equals(signature.getArgumentTypes().getFirst())) {
return false;
}

// there must be exactly one type variable constraint and no long variable constraints
if (signature.getTypeVariableConstraints().size() != 1 || !signature.getLongVariableConstraints().isEmpty()) {
return false;
}
TypeVariableConstraint typeVariableConstraint = signature.getTypeVariableConstraints().getFirst();

// The return type must be a type variable with variadic bound of "row"
return signature.getReturnType().getBase().equals(typeVariableConstraint.getName()) &&
typeVariableConstraint.isRowType();
}

private static List<TypeSignature> getLambdaArgumentTypeSignatures(TypeSignature lambdaTypeSignature)
{
List<TypeParameter> parameters = lambdaTypeSignature.getParameters();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@
import io.trino.type.TinyintOperators;
import io.trino.type.UuidOperators;
import io.trino.type.VarcharOperators;
import io.trino.type.VariantFunctions;
import io.trino.type.VariantOperators;
import io.trino.type.setdigest.BuildSetDigestAggregation;
import io.trino.type.setdigest.MergeSetDigestAggregation;
import io.trino.type.setdigest.SetDigestFunctions;
Expand All @@ -289,6 +291,7 @@
import static io.trino.operator.scalar.ArraySubscriptOperator.ARRAY_SUBSCRIPT;
import static io.trino.operator.scalar.ArrayToElementConcatFunction.ARRAY_TO_ELEMENT_CONCAT_FUNCTION;
import static io.trino.operator.scalar.ArrayToJsonCast.ARRAY_TO_JSON;
import static io.trino.operator.scalar.ArrayToVariantCast.ARRAY_TO_VARIANT;
import static io.trino.operator.scalar.ArrayTransformFunction.ARRAY_TRANSFORM_FUNCTION;
import static io.trino.operator.scalar.CastFromUnknownOperator.CAST_FROM_UNKNOWN;
import static io.trino.operator.scalar.ConcatFunction.VARBINARY_CONCAT;
Expand All @@ -310,14 +313,19 @@
import static io.trino.operator.scalar.MapElementAtFunction.MAP_ELEMENT_AT;
import static io.trino.operator.scalar.MapFilterFunction.MAP_FILTER_FUNCTION;
import static io.trino.operator.scalar.MapToJsonCast.MAP_TO_JSON;
import static io.trino.operator.scalar.MapToVariantCast.MAP_TO_VARIANT;
import static io.trino.operator.scalar.MapTransformValuesFunction.MAP_TRANSFORM_VALUES_FUNCTION;
import static io.trino.operator.scalar.MapZipWithFunction.MAP_ZIP_WITH_FUNCTION;
import static io.trino.operator.scalar.MathFunctions.DECIMAL_MOD_FUNCTION;
import static io.trino.operator.scalar.Re2JCastToRegexpFunction.castCharToRe2JRegexp;
import static io.trino.operator.scalar.Re2JCastToRegexpFunction.castVarcharToRe2JRegexp;
import static io.trino.operator.scalar.RowToJsonCast.ROW_TO_JSON;
import static io.trino.operator.scalar.RowToRowCast.ROW_TO_ROW_CAST;
import static io.trino.operator.scalar.RowToVariantCast.ROW_TO_VARIANT;
import static io.trino.operator.scalar.TryCastFunction.TRY_CAST;
import static io.trino.operator.scalar.VariantToArrayCast.VARIANT_TO_ARRAY;
import static io.trino.operator.scalar.VariantToMapCast.VARIANT_TO_MAP;
import static io.trino.operator.scalar.VariantToRowCast.VARIANT_TO_ROW;
import static io.trino.operator.scalar.ZipFunction.ZIP_FUNCTIONS;
import static io.trino.operator.scalar.ZipWithFunction.ZIP_WITH_FUNCTION;
import static io.trino.operator.scalar.json.JsonArrayFunction.JSON_ARRAY_FUNCTION;
Expand All @@ -334,6 +342,7 @@
import static io.trino.type.DecimalCasts.DECIMAL_TO_SMALLINT_CAST;
import static io.trino.type.DecimalCasts.DECIMAL_TO_TINYINT_CAST;
import static io.trino.type.DecimalCasts.DECIMAL_TO_VARCHAR_CAST;
import static io.trino.type.DecimalCasts.DECIMAL_TO_VARIANT_CAST;
import static io.trino.type.DecimalCasts.DOUBLE_TO_DECIMAL_CAST;
import static io.trino.type.DecimalCasts.INTEGER_TO_DECIMAL_CAST;
import static io.trino.type.DecimalCasts.JSON_TO_DECIMAL_CAST;
Expand All @@ -342,6 +351,7 @@
import static io.trino.type.DecimalCasts.SMALLINT_TO_DECIMAL_CAST;
import static io.trino.type.DecimalCasts.TINYINT_TO_DECIMAL_CAST;
import static io.trino.type.DecimalCasts.VARCHAR_TO_DECIMAL_CAST;
import static io.trino.type.DecimalCasts.VARIANT_TO_DECIMAL_CAST;
import static io.trino.type.DecimalOperators.DECIMAL_ADD_OPERATOR;
import static io.trino.type.DecimalOperators.DECIMAL_DIVIDE_OPERATOR;
import static io.trino.type.DecimalOperators.DECIMAL_MODULUS_OPERATOR;
Expand Down Expand Up @@ -463,6 +473,17 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators
.scalars(JsonInputFunctions.class)
.scalars(JsonOutputFunctions.class)
.functions(JSON_OBJECT_FUNCTION, JSON_ARRAY_FUNCTION)
.scalars(VariantOperators.class)
.scalar(VariantOperators.VariantToTimeCast.class)
.scalar(VariantOperators.VariantFromTimeCast.class)
.scalar(VariantOperators.VariantToTimestampCast.class)
.scalar(VariantOperators.VariantFromTimestampCast.class)
.scalar(VariantOperators.VariantToTimestampWithTimeZoneCasts.class)
.scalar(VariantOperators.VariantFromTimestampWithTimeZoneCasts.class)
.functions(VARIANT_TO_ARRAY, ARRAY_TO_VARIANT)
.functions(VARIANT_TO_MAP, MAP_TO_VARIANT)
.functions(VARIANT_TO_ROW, ROW_TO_VARIANT)
.scalars(VariantFunctions.class)
.scalars(ColorFunctions.class)
.scalars(HyperLogLogFunctions.class)
.scalars(QuantileDigestFunctions.class)
Expand Down Expand Up @@ -561,6 +582,7 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators
.functions(VARCHAR_TO_DECIMAL_CAST, INTEGER_TO_DECIMAL_CAST, BIGINT_TO_DECIMAL_CAST, DOUBLE_TO_DECIMAL_CAST, REAL_TO_DECIMAL_CAST, BOOLEAN_TO_DECIMAL_CAST, TINYINT_TO_DECIMAL_CAST, SMALLINT_TO_DECIMAL_CAST)
.functions(NUMBER_TO_DECIMAL_CAST, DECIMAL_TO_NUMBER_CAST)
.functions(JSON_TO_DECIMAL_CAST, DECIMAL_TO_JSON_CAST)
.functions(VARIANT_TO_DECIMAL_CAST, DECIMAL_TO_VARIANT_CAST)
.functions(featuresConfig.isLegacyArithmeticDecimalOperators() ? LEGACY_DECIMAL_ADD_OPERATOR : DECIMAL_ADD_OPERATOR)
.functions(featuresConfig.isLegacyArithmeticDecimalOperators() ? LEGACY_DECIMAL_SUBTRACT_OPERATOR : DECIMAL_SUBTRACT_OPERATOR)
.functions(featuresConfig.isLegacyArithmeticDecimalOperators() ? LEGACY_DECIMAL_MULTIPLY_OPERATOR : DECIMAL_MULTIPLY_OPERATOR)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
import static io.trino.spi.type.TinyintType.TINYINT;
import static io.trino.spi.type.UuidType.UUID;
import static io.trino.spi.type.VarbinaryType.VARBINARY;
import static io.trino.spi.type.VariantType.VARIANT;
import static io.trino.sql.analyzer.TypeSignatureTranslator.toTypeSignature;
import static io.trino.type.ArrayParametricType.ARRAY;
import static io.trino.type.CodePointsType.CODE_POINTS;
Expand Down Expand Up @@ -152,6 +153,7 @@ public TypeRegistry(TypeOperators typeOperators, FeaturesConfig featuresConfig)
addType(JSON_2016);
addType(COLOR);
addType(JSON);
addType(VARIANT);
addType(CODE_POINTS);
addType(IPADDRESS);
addType(UUID);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.scalar;

import com.google.common.collect.ImmutableList;
import io.trino.metadata.SqlScalarFunction;
import io.trino.spi.block.Block;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionMetadata;
import io.trino.spi.function.Signature;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.TypeSignature;
import io.trino.spi.variant.Variant;
import io.trino.util.variant.VariantWriter;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;

import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.OperatorType.CAST;
import static io.trino.spi.type.TypeSignature.arrayType;
import static io.trino.spi.type.VariantType.VARIANT;
import static io.trino.util.Failures.checkCondition;
import static io.trino.util.variant.VariantUtil.canCastToVariant;
import static java.lang.invoke.MethodType.methodType;

public class ArrayToVariantCast
extends SqlScalarFunction
{
public static final ArrayToVariantCast ARRAY_TO_VARIANT = new ArrayToVariantCast();

private static final MethodHandle METHOD_HANDLE;

static {
try {
METHOD_HANDLE = MethodHandles.lookup().findStatic(ArrayToVariantCast.class, "toVariant", methodType(Variant.class, VariantWriter.class, Block.class));
}
catch (IllegalAccessException | NoSuchMethodException e) {
throw new ExceptionInInitializerError(e);
}
}

private ArrayToVariantCast()
{
super(FunctionMetadata.operatorBuilder(CAST)
.signature(Signature.builder()
.castableToTypeParameter("T", VARIANT.getTypeSignature())
.returnType(VARIANT)
.argumentType(arrayType(new TypeSignature("T")))
.build())
.build());
}

@Override
protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature)
{
ArrayType arrayType = (ArrayType) boundSignature.getArgumentTypes().getFirst();
checkCondition(canCastToVariant(arrayType), INVALID_CAST_ARGUMENT, "Cannot cast %s to VARIANT", arrayType);

VariantWriter writer = VariantWriter.create(arrayType);
MethodHandle methodHandle = METHOD_HANDLE.bindTo(writer);
return new ChoicesSpecializedSqlScalarFunction(
boundSignature,
FAIL_ON_NULL,
ImmutableList.of(NEVER_NULL),
methodHandle);
}

private static Variant toVariant(VariantWriter writer, Block block)
{
return writer.write(block);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ private JsonToRowCast()
// this is technically a recursive constraint for cast, but TypeRegistry.canCast has explicit handling for json to row cast
TypeVariableConstraint.builder("T")
.rowType()
.castableFrom(JSON)
.build())
.returnType(new TypeSignature("T"))
.argumentType(JSON)
Expand Down
Loading