diff --git a/client/trino-client/src/main/java/io/trino/client/JsonDecodingUtils.java b/client/trino-client/src/main/java/io/trino/client/JsonDecodingUtils.java index 72f86dc9a6b8..c2e78e4dc1b1 100644 --- a/client/trino-client/src/main/java/io/trino/client/JsonDecodingUtils.java +++ b/client/trino-client/src/main/java/io/trino/client/JsonDecodingUtils.java @@ -75,6 +75,8 @@ public final class JsonDecodingUtils { + static final JsonMapper JSON_MAPPER = new JsonMapper(createJsonFactory()); + private JsonDecodingUtils() {} private static final BigIntegerDecoder BIG_INTEGER_DECODER = new BigIntegerDecoder(); @@ -302,7 +304,7 @@ public Object decode(JsonParser parser) throws IOException { StringWriter writer = new StringWriter(); - try (JsonGenerator generator = createJsonFactory().createGenerator(writer)) { + try (JsonGenerator generator = JSON_MAPPER.createGenerator(writer)) { generator.copyCurrentStructure(parser); } return writer.toString(); diff --git a/core/trino-main/src/main/java/io/trino/util/variant/VariantUtil.java b/core/trino-main/src/main/java/io/trino/util/variant/VariantUtil.java index 63e636e521f3..0f65bc3b2024 100644 --- a/core/trino-main/src/main/java/io/trino/util/variant/VariantUtil.java +++ b/core/trino-main/src/main/java/io/trino/util/variant/VariantUtil.java @@ -58,7 +58,6 @@ import io.trino.type.BigintOperators; import io.trino.type.BooleanOperators; import io.trino.type.DateOperators; -import io.trino.type.DateTimes; import io.trino.type.DoubleOperators; import io.trino.type.IntegerOperators; import io.trino.type.JsonType; @@ -106,14 +105,18 @@ import static io.trino.spi.type.VariantType.VARIANT; import static io.trino.spi.variant.Header.BasicType.PRIMITIVE; import static io.trino.spi.variant.Header.PrimitiveType.BINARY; +import static io.trino.type.BooleanOperators.castToVarchar; import static io.trino.type.DateTimes.MICROSECONDS_PER_DAY; import static io.trino.type.DateTimes.NANOSECONDS_PER_DAY; import static io.trino.type.DateTimes.PICOSECONDS_PER_DAY; +import static io.trino.type.DateTimes.formatTimestamp; +import static io.trino.type.DateTimes.formatTimestampWithTimeZone; import static io.trino.type.DateTimes.round; import static io.trino.type.JsonType.JSON; import static io.trino.util.JsonUtil.createJsonGenerator; import static java.lang.Float.floatToRawIntBits; -import static java.lang.Math.toIntExact; +import static java.lang.Math.floorDiv; +import static java.lang.Math.floorMod; import static java.lang.String.format; import static java.math.RoundingMode.HALF_UP; import static java.time.ZoneOffset.UTC; @@ -124,6 +127,7 @@ public final class VariantUtil .disable(CANONICALIZE_FIELD_NAMES) // prevents characters outside BMP (e.g., emoji) from being escaped as surrogate pairs .enable(COMBINE_UNICODE_SURROGATES_IN_UTF8) + .disable(ESCAPE_NON_ASCII) .build()); private VariantUtil() {} @@ -157,8 +161,8 @@ public static boolean canCastToVariant(Type type) return mapType.getKeyType() instanceof VarcharType && canCastToVariant(mapType.getValueType()); } - if (type instanceof RowType) { - return type.getTypeParameters().stream().allMatch(VariantUtil::canCastToVariant); + if (type instanceof RowType rowType) { + return rowType.getTypeParameters().stream().allMatch(VariantUtil::canCastToVariant); } return false; } @@ -191,8 +195,8 @@ public static boolean canCastFromVariant(Type type) if (type instanceof MapType mapType) { return mapType.getKeyType() instanceof VarcharType && canCastFromVariant(mapType.getValueType()); } - if (type instanceof RowType) { - return type.getTypeParameters().stream().allMatch(VariantUtil::canCastFromVariant); + if (type instanceof RowType rowType) { + return rowType.getTypeParameters().stream().allMatch(VariantUtil::canCastFromVariant); } return false; } @@ -204,8 +208,8 @@ public static Slice asVarchar(Variant variant) case PRIMITIVE -> switch (variant.primitiveType()) { case NULL -> null; case STRING -> variant.getString(); - case BOOLEAN_TRUE -> BooleanOperators.castToVarchar(UNBOUNDED_LENGTH, true); - case BOOLEAN_FALSE -> BooleanOperators.castToVarchar(UNBOUNDED_LENGTH, false); + case BOOLEAN_TRUE -> castToVarchar(UNBOUNDED_LENGTH, true); + case BOOLEAN_FALSE -> castToVarchar(UNBOUNDED_LENGTH, false); case INT8 -> utf8Slice(String.valueOf(variant.getByte())); case INT16 -> utf8Slice(String.valueOf(variant.getShort())); case INT32 -> utf8Slice(String.valueOf(variant.getInt())); @@ -217,22 +221,22 @@ public static Slice asVarchar(Variant variant) case TIME_NTZ_MICROS -> TimeOperators.castToVarchar(UNBOUNDED_LENGTH, 6, variant.getTimeMicros() * 1_000_000L); case TIMESTAMP_UTC_MICROS -> { long micros = variant.getTimestampMicros(); - long epochMillis = Math.floorDiv(micros, 1_000L); - int picosOfMilli = toIntExact(Math.floorMod(micros, 1_000L) * 1_000_000L); - yield utf8Slice(DateTimes.formatTimestampWithTimeZone(6, epochMillis, picosOfMilli, UTC_KEY.getZoneId())); + long epochMillis = floorDiv(micros, 1_000L); + int picosOfMilli = floorMod(micros, 1_000) * 1_000_000; + yield utf8Slice(formatTimestampWithTimeZone(6, epochMillis, picosOfMilli, UTC_KEY.getZoneId())); } - case TIMESTAMP_NTZ_MICROS -> utf8Slice(DateTimes.formatTimestamp(6, variant.getTimestampMicros(), 0, UTC)); + case TIMESTAMP_NTZ_MICROS -> utf8Slice(formatTimestamp(6, variant.getTimestampMicros(), 0, UTC)); case TIMESTAMP_UTC_NANOS -> { long nanos = variant.getTimestampNanos(); - long epochMillis = Math.floorDiv(nanos, 1_000_000L); - int picosOfMilli = toIntExact(Math.floorMod(nanos, 1_000_000L) * 1_000L); - yield utf8Slice(DateTimes.formatTimestampWithTimeZone(9, epochMillis, picosOfMilli, UTC_KEY.getZoneId())); + long epochMillis = floorDiv(nanos, 1_000_000L); + int picosOfMilli = floorMod(nanos, 1_000_000) * 1_000; + yield utf8Slice(formatTimestampWithTimeZone(9, epochMillis, picosOfMilli, UTC_KEY.getZoneId())); } case TIMESTAMP_NTZ_NANOS -> { long nanos = variant.getTimestampNanos(); - long epochMicros = Math.floorDiv(nanos, 1_000L); - int picosOfMicros = toIntExact(Math.floorMod(nanos, 1_000L) * 1_000L); - yield utf8Slice(DateTimes.formatTimestamp(9, epochMicros, picosOfMicros, UTC)); + long epochMicros = floorDiv(nanos, 1_000L); + int picosOfMicros = floorMod(nanos, 1_000) * 1_000; + yield utf8Slice(formatTimestamp(9, epochMicros, picosOfMicros, UTC)); } case UUID -> utf8Slice(variant.getUuid().toString()); default -> throw new VariantCastException("Unsupported VARIANT primitive type for cast to VARCHAR: " + variant.primitiveType()); @@ -490,8 +494,8 @@ public static Long asDate(Variant variant) case DATE -> (long) variant.getDate(); case TIMESTAMP_UTC_MICROS, TIMESTAMP_NTZ_MICROS -> { long micros = variant.getTimestampMicros(); - long epochSeconds = Math.floorDiv(micros, 1_000_000L); - int nanoAdjustment = (int) Math.floorMod(micros, 1_000_000L) * 1_000; + long epochSeconds = floorDiv(micros, 1_000_000L); + int nanoAdjustment = floorMod(micros, 1_000_000) * 1_000; yield Instant.ofEpochSecond(epochSeconds, nanoAdjustment) .atZone(UTC) .toLocalDate() @@ -499,8 +503,8 @@ public static Long asDate(Variant variant) } case TIMESTAMP_UTC_NANOS, TIMESTAMP_NTZ_NANOS -> { long nanos = variant.getTimestampNanos(); - long epochSeconds = Math.floorDiv(nanos, 1_000_000_000L); - int nanoAdjustment = (int) Math.floorMod(nanos, 1_000_000_000L); + long epochSeconds = floorDiv(nanos, 1_000_000_000L); + int nanoAdjustment = floorMod(nanos, 1_000_000_000); yield Instant.ofEpochSecond(epochSeconds, nanoAdjustment) .atZone(UTC) .toLocalDate() @@ -525,13 +529,13 @@ public static Long asTime(Variant variant, int precision) yield round(timePicos, MAX_PRECISION - precision) % PICOSECONDS_PER_DAY; } case TIMESTAMP_UTC_MICROS, TIMESTAMP_NTZ_MICROS -> { - long micros = Math.floorMod(variant.getTimestampMicros(), MICROSECONDS_PER_DAY); + long micros = floorMod(variant.getTimestampMicros(), MICROSECONDS_PER_DAY); long timePicos = micros * 1_000_000L; // round can round up to a value equal to 24h, so we need to compute module 24h yield round(timePicos, MAX_PRECISION - precision) % PICOSECONDS_PER_DAY; } case TIMESTAMP_UTC_NANOS, TIMESTAMP_NTZ_NANOS -> { - long nanos = Math.floorMod(variant.getTimestampNanos(), NANOSECONDS_PER_DAY); + long nanos = floorMod(variant.getTimestampNanos(), NANOSECONDS_PER_DAY); long timePicos = nanos * 1_000L; // round can round up to a value equal to 24h, so we need to compute module 24h yield round(timePicos, MAX_PRECISION - precision) % PICOSECONDS_PER_DAY; @@ -591,8 +595,8 @@ public static LongTimestamp asLongTimestamp(Variant variant, int precision) if (precision < 9) { nanos = round(nanos, 9 - precision); } - long micros = Math.floorDiv(nanos, 1_000L); - int picosOfMicro = toIntExact(Math.floorMod(nanos, 1_000L) * 1_000L); + long micros = floorDiv(nanos, 1_000L); + int picosOfMicro = floorMod(nanos, 1_000) * 1_000; yield new LongTimestamp(micros, picosOfMicro); } case STRING -> VarcharToTimestampCast.castToLongTimestamp(precision, variant.getString().toStringUtf8()); @@ -644,8 +648,8 @@ public static LongTimestampWithTimeZone asLongTimestampWithTimeZone(Variant vari if (precision < 6) { micros = round(micros, 6 - precision); } - long millis = Math.floorDiv(micros, 1_000L); - int picosOfMillis = toIntExact(Math.floorMod(micros, 1_000L) * 1_000_000L); + long millis = floorDiv(micros, 1_000L); + int picosOfMillis = floorMod(micros, 1_000) * 1_000_000; yield fromEpochMillisAndFraction(millis, picosOfMillis, UTC_KEY); } case TIMESTAMP_UTC_NANOS, TIMESTAMP_NTZ_NANOS -> { @@ -653,8 +657,8 @@ public static LongTimestampWithTimeZone asLongTimestampWithTimeZone(Variant vari if (precision < 9) { nanos = round(nanos, 9 - precision); } - long millis = Math.floorDiv(nanos, 1_000_000L); - int picosOfMillis = toIntExact(Math.floorMod(nanos, 1_000_000L) * 1_000L); + long millis = floorDiv(nanos, 1_000_000L); + int picosOfMillis = floorMod(nanos, 1_000_000) * 1_000; yield fromEpochMillisAndFraction(millis, picosOfMillis, UTC_KEY); } case STRING -> asLongTimestampWithTimeZone(variant.getString(), precision); @@ -1211,12 +1215,9 @@ private static void parseVariantToSingleRowBlock( public static Slice asJson(Variant variant) { - try { - SliceOutput output = new DynamicSliceOutput(40); - try (JsonGenerator jsonGenerator = createJsonGenerator(JSON_MAPPER, output)) { - jsonGenerator.configure(ESCAPE_NON_ASCII.mappedFeature(), false); - toJsonValue(jsonGenerator, variant); - } + try (SliceOutput output = new DynamicSliceOutput(40); JsonGenerator jsonGenerator = createJsonGenerator(JSON_MAPPER, output)) { + toJsonValue(jsonGenerator, variant); + jsonGenerator.flush(); return output.slice(); } catch (IOException e) { @@ -1224,78 +1225,80 @@ public static Slice asJson(Variant variant) } } - private static void toJsonValue(JsonGenerator jsonGenerator, Variant variant) + private static void toJsonValue(JsonGenerator generator, Variant variant) throws IOException { switch (variant.basicType()) { case PRIMITIVE -> { switch (variant.primitiveType()) { - case NULL -> jsonGenerator.writeNull(); + case NULL -> generator.writeNull(); case BINARY -> { Slice binary = variant.getBinary(); - jsonGenerator.writeBinary(binary.byteArray(), binary.byteArrayOffset(), binary.length()); + generator.writeBinary(binary.byteArray(), binary.byteArrayOffset(), binary.length()); } - case STRING -> jsonGenerator.writeString(variant.getString().toStringUtf8()); - case BOOLEAN_TRUE -> jsonGenerator.writeBoolean(true); - case BOOLEAN_FALSE -> jsonGenerator.writeBoolean(false); - case INT8 -> jsonGenerator.writeNumber(variant.getByte()); - case INT16 -> jsonGenerator.writeNumber(variant.getShort()); - case INT32 -> jsonGenerator.writeNumber(variant.getInt()); - case INT64 -> jsonGenerator.writeNumber(variant.getLong()); - case DECIMAL4, DECIMAL8, DECIMAL16 -> jsonGenerator.writeNumber(variant.getDecimal()); - case FLOAT -> jsonGenerator.writeNumber(variant.getFloat()); - case DOUBLE -> jsonGenerator.writeNumber(variant.getDouble()); - case DATE -> jsonGenerator.writeString(DateOperators.castToVarchar(UNBOUNDED_LENGTH, variant.getDate()).toStringUtf8()); - case TIME_NTZ_MICROS -> jsonGenerator.writeString(TimeOperators.castToVarchar(UNBOUNDED_LENGTH, 6, variant.getTimeMicros() * 1_000_000L).toStringUtf8()); + case STRING -> generator.writeString(variant.getString().toStringUtf8()); + case BOOLEAN_TRUE -> generator.writeBoolean(true); + case BOOLEAN_FALSE -> generator.writeBoolean(false); + case INT8 -> generator.writeNumber(variant.getByte()); + case INT16 -> generator.writeNumber(variant.getShort()); + case INT32 -> generator.writeNumber(variant.getInt()); + case INT64 -> generator.writeNumber(variant.getLong()); + case DECIMAL4, DECIMAL8, DECIMAL16 -> generator.writeNumber(variant.getDecimal()); + case FLOAT -> generator.writeNumber(variant.getFloat()); + case DOUBLE -> generator.writeNumber(variant.getDouble()); + case DATE -> generator.writeString(DateOperators.castToVarchar(UNBOUNDED_LENGTH, variant.getDate()).toStringUtf8()); + case TIME_NTZ_MICROS -> generator.writeString(TimeOperators.castToVarchar(UNBOUNDED_LENGTH, 6, variant.getTimeMicros() * 1_000_000L).toStringUtf8()); case TIMESTAMP_UTC_MICROS -> { long micros = variant.getTimestampMicros(); - long epochMillis = Math.floorDiv(micros, 1_000L); - int picosOfMilli = toIntExact(Math.floorMod(micros, 1_000L) * 1_000_000L); - jsonGenerator.writeString(DateTimes.formatTimestampWithTimeZone(6, epochMillis, picosOfMilli, UTC_KEY.getZoneId())); + long epochMillis = floorDiv(micros, 1_000L); + int picosOfMilli = floorMod(micros, 1_000) * 1_000_000; + generator.writeString(formatTimestampWithTimeZone(6, epochMillis, picosOfMilli, UTC_KEY.getZoneId())); } - case TIMESTAMP_NTZ_MICROS -> jsonGenerator.writeString(DateTimes.formatTimestamp(6, variant.getTimestampMicros(), 0, UTC)); + case TIMESTAMP_NTZ_MICROS -> generator.writeString(formatTimestamp(6, variant.getTimestampMicros(), 0, UTC)); case TIMESTAMP_UTC_NANOS -> { long nanos = variant.getTimestampNanos(); - long epochMillis = Math.floorDiv(nanos, 1_000_000L); - int picosOfMilli = toIntExact(Math.floorMod(nanos, 1_000_000L) * 1_000L); - jsonGenerator.writeString(DateTimes.formatTimestampWithTimeZone(9, epochMillis, picosOfMilli, UTC_KEY.getZoneId())); + long epochMillis = floorDiv(nanos, 1_000_000L); + + int picosOfMilli = floorMod(nanos, 1_000_000) * 1_000; + generator.writeString(formatTimestampWithTimeZone(9, epochMillis, picosOfMilli, UTC_KEY.getZoneId())); } case TIMESTAMP_NTZ_NANOS -> { long nanos = variant.getTimestampNanos(); - long epochMicros = Math.floorDiv(nanos, 1_000L); - int picosOfMicros = toIntExact(Math.floorMod(nanos, 1_000L) * 1_000L); - jsonGenerator.writeString(DateTimes.formatTimestamp(9, epochMicros, picosOfMicros, UTC)); + long epochMicros = floorDiv(nanos, 1_000L); + + int picosOfMicros = floorMod(nanos, 1_000) * 1_000; + generator.writeString(formatTimestamp(9, epochMicros, picosOfMicros, UTC)); } - case UUID -> jsonGenerator.writeString(variant.getUuid().toString()); + case UUID -> generator.writeString(variant.getUuid().toString()); } } - case SHORT_STRING -> jsonGenerator.writeString(variant.getString().toStringUtf8()); + case SHORT_STRING -> generator.writeString(variant.getString().toStringUtf8()); case ARRAY -> { - jsonGenerator.writeStartArray(); + generator.writeStartArray(); variant.arrayElements().forEach(element -> { try { - toJsonValue(jsonGenerator, element); + toJsonValue(generator, element); } catch (IOException e) { throw new UncheckedIOException(e); } }); - jsonGenerator.writeEndArray(); + generator.writeEndArray(); } case OBJECT -> { Metadata metadata = variant.metadata(); - jsonGenerator.writeStartObject(); + generator.writeStartObject(); variant.objectFields().forEach(fieldIdValue -> { try { String fieldName = metadata.get(fieldIdValue.fieldId()).toStringUtf8(); - jsonGenerator.writeFieldName(fieldName); - toJsonValue(jsonGenerator, fieldIdValue.value()); + generator.writeFieldName(fieldName); + toJsonValue(generator, fieldIdValue.value()); } catch (IOException e) { throw new UncheckedIOException(e); } }); - jsonGenerator.writeEndObject(); + generator.writeEndObject(); } } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/variant/Variant.java b/core/trino-spi/src/main/java/io/trino/spi/variant/Variant.java index ed72e22a0fee..dedc572f4dee 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/variant/Variant.java +++ b/core/trino-spi/src/main/java/io/trino/spi/variant/Variant.java @@ -41,6 +41,7 @@ import java.util.stream.Stream; import static io.airlift.slice.Slices.utf8Slice; +import static io.airlift.slice.Slices.wrappedBuffer; import static io.trino.spi.variant.Header.BasicType.OBJECT; import static io.trino.spi.variant.Header.BasicType.PRIMITIVE; import static io.trino.spi.variant.Header.arrayFieldOffsetSize; @@ -96,25 +97,22 @@ import static io.trino.spi.variant.VariantUtils.findFieldIndex; import static io.trino.spi.variant.VariantUtils.readOffset; import static io.trino.spi.variant.VariantUtils.verify; +import static java.lang.Math.floorDiv; +import static java.lang.Math.floorMod; import static java.lang.Math.multiplyExact; -import static java.lang.Math.toIntExact; import static java.util.Collections.unmodifiableList; import static java.util.Collections.unmodifiableMap; import static java.util.Objects.requireNonNull; import static java.util.Objects.requireNonNullElse; -public final class Variant +public record Variant(Slice data, Metadata metadata, BasicType basicType, PrimitiveType primitiveType) { - public static final Variant NULL_VALUE = from(EMPTY_METADATA, Slices.wrappedBuffer(primitiveHeader(PrimitiveType.NULL))); + public static final Variant NULL_VALUE = from(EMPTY_METADATA, wrappedBuffer(primitiveHeader(PrimitiveType.NULL))); public static final Variant EMPTY_ARRAY; public static final Variant EMPTY_OBJECT; - private final Slice data; - private final Metadata metadata; - private final BasicType basicType; - private final PrimitiveType primitiveType; static { - IntUnaryOperator emptyIndexedOperator = index -> { + IntUnaryOperator emptyIndexedOperator = _ -> { throw new IndexOutOfBoundsException(); }; @@ -127,42 +125,18 @@ public final class Variant EMPTY_OBJECT = from(EMPTY_METADATA, emptyObjectValue); } - public Variant(Slice data, Metadata metadata, BasicType basicType, PrimitiveType primitiveType) + public Variant { requireNonNull(data, "data is null"); requireNonNull(metadata, "metadata is null"); requireNonNull(basicType, "basicType is null"); - checkArgument(basicType == PRIMITIVE || primitiveType == null, "primitiveType must be null for non-primitive basicType"); - checkArgument(basicType != PRIMITIVE || primitiveType != null, "primitiveType must be non-null for primitive basicType"); + + checkArgument(basicType == PRIMITIVE == (primitiveType != null), "primitiveType must be non-null if and only if basicType is PRIMITIVE"); // not need to retain metadata for non-container types if (!basicType.isContainer()) { metadata = EMPTY_METADATA; } - this.data = data; - this.metadata = metadata; - this.basicType = basicType; - this.primitiveType = primitiveType; - } - - public Slice data() - { - return data; - } - - public Metadata metadata() - { - return metadata; - } - - public BasicType basicType() - { - return basicType; - } - - public PrimitiveType primitiveType() - { - return primitiveType; } public static Variant from(Metadata metadata, Slice data) @@ -560,13 +534,13 @@ public Instant getInstant() int nanoOfSecond; if (primitiveType == PrimitiveType.TIMESTAMP_UTC_MICROS) { long micros = getTimestampMicros(); - seconds = Math.floorDiv(micros, 1_000_000); - nanoOfSecond = toIntExact(Math.floorMod(micros, 1_000_000) * 1_000L); + seconds = floorDiv(micros, 1_000_000); + nanoOfSecond = floorMod(micros, 1_000_000) * 1_000; } else if (primitiveType == PrimitiveType.TIMESTAMP_UTC_NANOS) { long nanos = getTimestampNanos(); - seconds = Math.floorDiv(nanos, 1_000_000_000L); - nanoOfSecond = (int) Math.floorMod(nanos, 1_000_000_000L); + seconds = floorDiv(nanos, 1_000_000_000L); + nanoOfSecond = floorMod(nanos, 1_000_000_000); } else { throw new IllegalStateException("Expected primitive TIMESTAMP but got " + primitiveType); @@ -580,13 +554,13 @@ public LocalDateTime getLocalDateTime() int nanoOfSecond; if (primitiveType == PrimitiveType.TIMESTAMP_NTZ_MICROS) { long micros = getTimestampMicros(); - seconds = Math.floorDiv(micros, 1_000_000); - nanoOfSecond = toIntExact(Math.floorMod(micros, 1_000_000) * 1_000L); + seconds = floorDiv(micros, 1_000_000); + nanoOfSecond = floorMod(micros, 1_000_000) * 1_000; } else if (primitiveType == PrimitiveType.TIMESTAMP_NTZ_NANOS) { long nanos = getTimestampNanos(); - seconds = Math.floorDiv(nanos, 1_000_000_000L); - nanoOfSecond = (int) Math.floorMod(nanos, 1_000_000_000L); + seconds = floorDiv(nanos, 1_000_000_000L); + nanoOfSecond = floorMod(nanos, 1_000_000_000); } else { throw new IllegalStateException("Expected primitive TIMESTAMP but got " + primitiveType); @@ -636,7 +610,7 @@ public UUID getUuid() public int getArrayLength() { verifyType(BasicType.ARRAY); - int count = arrayIsLarge(data.getByte(0)) ? data.getInt(1) : (data.getByte(1) & 0xFF); + int count = arrayIsLarge(data.getByte(0)) ? data.getInt(1) : data.getByte(1) & 0xFF; checkState(count >= 0, () -> "Corrupt array count: " + count); return count; } @@ -648,7 +622,7 @@ public Variant getArrayElement(int index) boolean large = arrayIsLarge(header); int offSize = arrayFieldOffsetSize(header); - int count = large ? data.getInt(1) : (data.getByte(1) & 0xFF); + int count = large ? data.getInt(1) : data.getByte(1) & 0xFF; Objects.checkIndex(index, count); int offsetsStart = 1 + (large ? 4 : 1); @@ -668,7 +642,7 @@ public Stream arrayElements() boolean large = arrayIsLarge(header); int offsetSize = arrayFieldOffsetSize(header); - int count = large ? data.getInt(1) : (data.getByte(1) & 0xFF); + int count = large ? data.getInt(1) : data.getByte(1) & 0xFF; int offsetsStart = 1 + (large ? 4 : 1); int valuesStart = offsetsStart + (count + 1) * offsetSize; @@ -685,7 +659,7 @@ public Stream arrayElements() public int getObjectFieldCount() { verifyType(OBJECT); - int count = objectIsLarge(data.getByte(0)) ? data.getInt(1) : (data.getByte(1) & 0xFF); + int count = objectIsLarge(data.getByte(0)) ? data.getInt(1) : data.getByte(1) & 0xFF; checkState(count >= 0, () -> "Corrupt object field count: " + count); return count; } @@ -703,7 +677,7 @@ public Optional getObjectField(int fieldId) boolean large = objectIsLarge(header); int idSize = objectFieldIdSize(header); int offsetSize = objectFieldOffsetSize(header); - int count = large ? data.getInt(1) : (data.getByte(1) & 0xFF); + int count = large ? data.getInt(1) : data.getByte(1) & 0xFF; checkState(count >= 0, () -> "Corrupt object field count: " + count); int idsStart = 1 + (large ? 4 : 1); int offsetsStart = idsStart + count * idSize; @@ -728,7 +702,7 @@ public Optional getObjectField(Slice fieldName) boolean large = objectIsLarge(header); int idSize = objectFieldIdSize(header); int offsetSize = objectFieldOffsetSize(header); - int count = large ? data.getInt(1) : (data.getByte(1) & 0xFF); + int count = large ? data.getInt(1) : data.getByte(1) & 0xFF; checkState(count >= 0, () -> "Corrupt object field count: " + count); int idsStart = 1 + (large ? 4 : 1); int offsetsStart = idsStart + count * idSize; @@ -751,7 +725,7 @@ public Stream objectFieldNames() boolean large = objectIsLarge(header); int idSize = objectFieldIdSize(header); - int count = large ? data.getInt(1) : (data.getByte(1) & 0xFF); + int count = large ? data.getInt(1) : data.getByte(1) & 0xFF; int idsStart = 1 + (large ? 4 : 1); @@ -1004,29 +978,29 @@ private static int computeEncodedSize( case byte[] bytes -> encodedBinarySize(bytes.length); case String s -> encodedStringSize(utf8Slice(s).length()); case List list -> containerSizeCache.computeIfAbsent(list, _ -> { - int totalElementsLength = 0; - for (Object element : list) { - totalElementsLength += computeEncodedSize(element, fieldIdByName, fieldRemappers, containerSizeCache); - } - return encodedArraySize(list.size(), totalElementsLength); - }); + int totalElementsLength = 0; + for (Object element : list) { + totalElementsLength += computeEncodedSize(element, fieldIdByName, fieldRemappers, containerSizeCache); + } + return encodedArraySize(list.size(), totalElementsLength); + }); case Map map -> containerSizeCache.computeIfAbsent(map, _ -> { - if (map.isEmpty()) { - return ENCODED_EMPTY_OBJECT_SIZE; - } - - int maxFieldId = -1; - for (Object key : map.keySet()) { - maxFieldId = Math.max(maxFieldId, requireNonNull(fieldIdByName.get(castMapKey(key)))); - } - - int totalValuesLength = 0; - for (Object entry : map.values()) { - totalValuesLength += computeEncodedSize(entry, fieldIdByName, fieldRemappers, containerSizeCache); - } - - return encodedObjectSize(maxFieldId, map.size(), totalValuesLength); - }); + if (map.isEmpty()) { + return ENCODED_EMPTY_OBJECT_SIZE; + } + + int maxFieldId = -1; + for (Object key : map.keySet()) { + maxFieldId = Math.max(maxFieldId, requireNonNull(fieldIdByName.get(castMapKey(key)))); + } + + int totalValuesLength = 0; + for (Object entry : map.values()) { + totalValuesLength += computeEncodedSize(entry, fieldIdByName, fieldRemappers, containerSizeCache); + } + + return encodedObjectSize(maxFieldId, map.size(), totalValuesLength); + }); default -> throw new IllegalArgumentException("Unsupported object type for VARIANT: " + value.getClass().getName()); }; } @@ -1056,7 +1030,7 @@ private static int writeEncoded( case LocalDateTime dateTime -> encodeTimestampNanosNtz(dateTime, out, offset); case UUID uuid -> encodeUuid(uuid, out, offset); case Slice slice -> encodeBinary(slice, out, offset); - case byte[] bytes -> encodeBinary(Slices.wrappedBuffer(bytes), out, offset); + case byte[] bytes -> encodeBinary(wrappedBuffer(bytes), out, offset); case String string -> encodeString(utf8Slice(string), out, offset); case List list -> { int written = encodeArrayHeading( diff --git a/core/trino-spi/src/main/java/io/trino/spi/variant/VariantEncoder.java b/core/trino-spi/src/main/java/io/trino/spi/variant/VariantEncoder.java index e92fb14ccf99..cd360c6567c7 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/variant/VariantEncoder.java +++ b/core/trino-spi/src/main/java/io/trino/spi/variant/VariantEncoder.java @@ -58,7 +58,9 @@ import static io.trino.spi.variant.VariantUtils.getOffsetSize; import static io.trino.spi.variant.VariantUtils.verify; import static io.trino.spi.variant.VariantUtils.writeOffset; +import static java.lang.Math.addExact; import static java.lang.Math.max; +import static java.lang.Math.multiplyExact; public final class VariantEncoder { @@ -286,11 +288,11 @@ public static int encodeTimestampNanosUtc(Instant value, Slice variant, int offs // For negative timestamps with a positive nano adjustment, shift one second into nanos first // so multiplyExact can still represent the full long nanoseconds domain (including Long.MIN_VALUE). if (epochSecond < 0 && nanoOfSecond > 0) { - epochSecond = Math.addExact(epochSecond, 1); + epochSecond = addExact(epochSecond, 1); nanoOfSecond -= 1_000_000_000; } - long epochNanos = Math.addExact(Math.multiplyExact(epochSecond, 1_000_000_000L), nanoOfSecond); + long epochNanos = addExact(multiplyExact(epochSecond, 1_000_000_000L), nanoOfSecond); return encodeTimestampNanosUtc(epochNanos, variant, offset); }