diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayVectorFunctions.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayVectorFunctions.java index 44068fefa0f1..9190c761c1a8 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayVectorFunctions.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayVectorFunctions.java @@ -16,6 +16,7 @@ import io.trino.spi.block.Block; import io.trino.spi.function.Description; import io.trino.spi.function.ScalarFunction; +import io.trino.spi.function.SqlNullable; import io.trino.spi.function.SqlType; import io.trino.spi.type.StandardTypes; @@ -59,10 +60,15 @@ public static double dotProduct(@SqlType("array(double)") Block first, @SqlType( @Description("Calculates the cosine similarity between two vectors") @ScalarFunction @SqlType(StandardTypes.DOUBLE) - public static double cosineSimilarity(@SqlType("array(double)") Block first, @SqlType("array(double)") Block second) + @SqlNullable + public static Double cosineSimilarity(@SqlType("array(double)") Block first, @SqlType("array(double)") Block second) { checkCondition(first.getPositionCount() == second.getPositionCount(), INVALID_FUNCTION_ARGUMENT, "The arguments must have the same length"); + if (first.hasNull() || second.hasNull()) { + return null; + } + double firstMagnitude = 0.0; double secondMagnitude = 0.0; double dotProduct = 0.0; @@ -81,8 +87,13 @@ public static double cosineSimilarity(@SqlType("array(double)") Block first, @Sq @Description("Calculates the cosine distance between two vectors") @ScalarFunction @SqlType(StandardTypes.DOUBLE) - public static double cosineDistance(@SqlType("array(double)") Block first, @SqlType("array(double)") Block second) + @SqlNullable + public static Double cosineDistance(@SqlType("array(double)") Block first, @SqlType("array(double)") Block second) { - return 1.0 - cosineSimilarity(first, second); + Double cosineSimilarity = cosineSimilarity(first, second); + if (cosineSimilarity == null) { + return null; + } + return 1.0 - cosineSimilarity; } } diff --git a/core/trino-main/src/test/java/io/trino/block/AbstractTestBlock.java b/core/trino-main/src/test/java/io/trino/block/AbstractTestBlock.java index 4a5cfc007517..409ccffeaa7f 100644 --- a/core/trino-main/src/test/java/io/trino/block/AbstractTestBlock.java +++ b/core/trino-main/src/test/java/io/trino/block/AbstractTestBlock.java @@ -36,6 +36,7 @@ import java.lang.reflect.Modifier; import java.util.Arrays; import java.util.List; +import java.util.Objects; import java.util.Optional; import java.util.stream.IntStream; @@ -198,6 +199,13 @@ private void assertBlockPositions(Block block, T[] expectedValues) for (int position = 0; position < block.getPositionCount(); position++) { assertBlockPosition(block, position, expectedValues[position]); } + if (Arrays.stream(expectedValues).anyMatch(Objects::isNull)) { + assertThat(block.hasNull()).isTrue(); + assertThat(block.mayHaveNull()).isTrue(); + } + else { + assertThat(block.hasNull()).isFalse(); + } } protected static List splitBlock(Block block, int count) diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayVectorFunctions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayVectorFunctions.java index 6f73d5262323..7cebe1b743a5 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayVectorFunctions.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayVectorFunctions.java @@ -330,6 +330,11 @@ void testCosineDistance() .hasType(DOUBLE) .isEqualTo(NaN); + assertThat(assertions.function("cosine_distance", "ARRAY[1, 2]", "ARRAY[3, null]")) + .isNull(DOUBLE); + assertThat(assertions.function("cosine_distance", "ARRAY[1, null]", "ARRAY[3, 4]")) + .isNull(DOUBLE); + assertTrinoExceptionThrownBy(assertions.function("cosine_distance", "ARRAY[]", "ARRAY[]")::evaluate) .hasMessage("Vector magnitude cannot be zero"); assertTrinoExceptionThrownBy(assertions.function("cosine_distance", "ARRAY[]", "ARRAY[1]")::evaluate) @@ -341,4 +346,109 @@ void testCosineDistance() assertTrinoExceptionThrownBy(assertions.function("cosine_distance", "ARRAY[1, 2]", "ARRAY[1]")::evaluate) .hasMessage("The arguments must have the same length"); } + + @Test + void testCosineSimilarity() + { + assertThat(assertions.function("cosine_similarity", "ARRAY[1]", "ARRAY[2]")) + .hasType(DOUBLE) + .isEqualTo(1.0); + assertThat(assertions.function("cosine_similarity", "ARRAY[1, 2]", "ARRAY[3, 4]")) + .hasType(DOUBLE) + .isEqualTo(1.0 - 0.01613008990009257); + assertThat(assertions.function("cosine_similarity", "ARRAY[4, 5, 6]", "ARRAY[4, 5, 6]")) + .hasType(DOUBLE) + .isEqualTo(1.0); + assertThat(assertions.function("cosine_similarity", "ARRAY[REAL '1.1', REAL '2.2', REAL '3.3']", "ARRAY[REAL '4.4', REAL '5.5', REAL '6.6']")) + .hasType(DOUBLE) + .isEqualTo(1.0 - 0.025368154060122383); + assertThat(assertions.function("cosine_similarity", "ARRAY[DOUBLE '1.1', DOUBLE '2.2', DOUBLE '3.3']", "ARRAY[DOUBLE '4.4', DOUBLE '5.5', DOUBLE '6.6']")) + .hasType(DOUBLE) + .isEqualTo(1.0 - 0.025368153802923676); + assertThat(assertions.function("cosine_similarity", "ARRAY[1.1, 2.2, 3.3]", "ARRAY[4.4, 5.5, 6.6]")) + .hasType(DOUBLE) + .isEqualTo(1.0 - 0.025368153802923676); + + // real type's min and max + assertThat(assertions.function("cosine_similarity", "ARRAY[REAL '3.4028235e+38f']", "ARRAY[REAL '3.4028235e+38f']")) + .hasType(DOUBLE) + .isEqualTo(1.0); + assertThat(assertions.function("cosine_similarity", "ARRAY[REAL '-3.4028235e+38f']", "ARRAY[REAL '-3.4028235e+38f']")) + .hasType(DOUBLE) + .isEqualTo(1.0); + assertThat(assertions.function("cosine_similarity", "ARRAY[REAL '3.4028235e+38f']", "ARRAY[REAL '-3.4028235e+38f']")) + .hasType(DOUBLE) + .isEqualTo(1.0 - 2.0); + assertThat(assertions.function("cosine_similarity", "ARRAY[REAL '-3.4028235e+38f']", "ARRAY[REAL '3.4028235e+38f']")) + .hasType(DOUBLE) + .isEqualTo(1.0 - 2.0); + assertThat(assertions.function("cosine_similarity", "ARRAY[REAL '1.4E-45']", "ARRAY[REAL '1.4E-45']")) + .hasType(DOUBLE) + .isEqualTo(1.0); + assertThat(assertions.function("cosine_similarity", "ARRAY[REAL '-1.4E-45']", "ARRAY[REAL '-1.4E-45']")) + .hasType(DOUBLE) + .isEqualTo(1.0); + assertThat(assertions.function("cosine_similarity", "ARRAY[REAL '1.4E-45']", "ARRAY[REAL '-1.4E-45']")) + .hasType(DOUBLE) + .isEqualTo(1.0 - 2.0); + assertThat(assertions.function("cosine_similarity", "ARRAY[REAL '-1.4E-45']", "ARRAY[REAL '1.4E-45']")) + .hasType(DOUBLE) + .isEqualTo(1.0 - 2.0); + assertThat(assertions.function("cosine_similarity", "ARRAY[REAL '3.4028235e+38f']", "ARRAY[REAL '1.4E-45']")) + .hasType(DOUBLE) + .isEqualTo(1.0); + assertThat(assertions.function("cosine_similarity", "ARRAY[REAL '1.4E-45']", "ARRAY[REAL '3.4028235e+38f']")) + .hasType(DOUBLE) + .isEqualTo(1.0); + + // double type's min and max + assertThat(assertions.function("cosine_similarity", "ARRAY[DOUBLE '1.7976931348623157E+309']", "ARRAY[DOUBLE '1.7976931348623157E+309']")) + .hasType(DOUBLE) + .isEqualTo(NaN); + assertThat(assertions.function("cosine_similarity", "ARRAY[DOUBLE '-1.7976931348623157E+308']", "ARRAY[DOUBLE '-1.7976931348623157E+308']")) + .hasType(DOUBLE) + .isEqualTo(NaN); + assertThat(assertions.function("cosine_similarity", "ARRAY[DOUBLE '1.7976931348623157E+309']", "ARRAY[DOUBLE '-1.7976931348623157E+308']")) + .hasType(DOUBLE) + .isEqualTo(NaN); + assertThat(assertions.function("cosine_similarity", "ARRAY[DOUBLE '-1.7976931348623157E+308']", "ARRAY[DOUBLE '1.7976931348623157E+309']")) + .hasType(DOUBLE) + .isEqualTo(NaN); + + // NaN and infinity + assertThat(assertions.function("cosine_similarity", "ARRAY[1]", "ARRAY[nan()]")) + .hasType(DOUBLE) + .isEqualTo(NaN); + assertThat(assertions.function("cosine_similarity", "ARRAY[nan()]", "ARRAY[1]")) + .hasType(DOUBLE) + .isEqualTo(NaN); + assertThat(assertions.function("cosine_similarity", "ARRAY[1]", "ARRAY[-infinity()]")) + .hasType(DOUBLE) + .isEqualTo(NaN); + assertThat(assertions.function("cosine_similarity", "ARRAY[-infinity()]", "ARRAY[1]")) + .hasType(DOUBLE) + .isEqualTo(NaN); + assertThat(assertions.function("cosine_similarity", "ARRAY[1]", "ARRAY[infinity()]")) + .hasType(DOUBLE) + .isEqualTo(NaN); + assertThat(assertions.function("cosine_similarity", "ARRAY[infinity()]", "ARRAY[1]")) + .hasType(DOUBLE) + .isEqualTo(NaN); + + assertThat(assertions.function("cosine_similarity", "ARRAY[1, 2]", "ARRAY[3, null]")) + .isNull(DOUBLE); + assertThat(assertions.function("cosine_similarity", "ARRAY[1, null]", "ARRAY[3, 4]")) + .isNull(DOUBLE); + + assertTrinoExceptionThrownBy(assertions.function("cosine_similarity", "ARRAY[]", "ARRAY[]")::evaluate) + .hasMessage("Vector magnitude cannot be zero"); + assertTrinoExceptionThrownBy(assertions.function("cosine_similarity", "ARRAY[]", "ARRAY[1]")::evaluate) + .hasMessage("The arguments must have the same length"); + assertTrinoExceptionThrownBy(assertions.function("cosine_similarity", "ARRAY[1]", "ARRAY[]")::evaluate) + .hasMessage("The arguments must have the same length"); + assertTrinoExceptionThrownBy(assertions.function("cosine_similarity", "ARRAY[1]", "ARRAY[1, 2]")::evaluate) + .hasMessage("The arguments must have the same length"); + assertTrinoExceptionThrownBy(assertions.function("cosine_similarity", "ARRAY[1, 2]", "ARRAY[1]")::evaluate) + .hasMessage("The arguments must have the same length"); + } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlock.java index 5e079cf5c402..10530ad73ed5 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlock.java @@ -198,6 +198,20 @@ public boolean mayHaveNull() return valueIsNull != null; } + @Override + public boolean hasNull() + { + if (valueIsNull == null) { + return false; + } + for (int i = 0; i < positionCount; i++) { + if (valueIsNull[i + arrayOffset]) { + return true; + } + } + return false; + } + @Override public String toString() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/Block.java b/core/trino-spi/src/main/java/io/trino/spi/block/Block.java index ea55eb882aac..28accd96d239 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/Block.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/Block.java @@ -149,6 +149,19 @@ default boolean mayHaveNull() return true; } + /** + * Does this block have a null value? This method is expected to be O(N). + */ + default boolean hasNull() + { + for (int i = 0; i < getPositionCount(); i++) { + if (isNull(i)) { + return true; + } + } + return false; + } + /** * Is the specified position null? * diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlock.java index e7dba0d9277c..4cf60baa2e84 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlock.java @@ -153,6 +153,20 @@ public boolean mayHaveNull() return valueIsNull != null; } + @Override + public boolean hasNull() + { + if (valueIsNull == null) { + return false; + } + for (int i = 0; i < positionCount; i++) { + if (valueIsNull[i + arrayOffset]) { + return true; + } + } + return false; + } + @Override public boolean isNull(int position) { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/DictionaryBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/DictionaryBlock.java index e5f2a09402c6..775092971241 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/DictionaryBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/DictionaryBlock.java @@ -374,6 +374,12 @@ public boolean mayHaveNull() return mayHaveNull && dictionary.mayHaveNull(); } + @Override + public boolean hasNull() + { + return mayHaveNull && dictionary.hasNull(); + } + @Override public boolean isNull(int position) { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12Block.java b/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12Block.java index 7ddbde46838c..c5583043d77b 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12Block.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12Block.java @@ -156,6 +156,20 @@ public boolean mayHaveNull() return valueIsNull != null; } + @Override + public boolean hasNull() + { + if (valueIsNull == null) { + return false; + } + for (int i = 0; i < positionCount; i++) { + if (valueIsNull[i + positionOffset]) { + return true; + } + } + return false; + } + @Override public boolean isNull(int position) { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlock.java index 990b4449ab27..05317b595880 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlock.java @@ -150,6 +150,20 @@ public boolean mayHaveNull() return valueIsNull != null; } + @Override + public boolean hasNull() + { + if (valueIsNull == null) { + return false; + } + for (int i = 0; i < positionCount; i++) { + if (valueIsNull[i + positionOffset]) { + return true; + } + } + return false; + } + @Override public boolean isNull(int position) { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlock.java index 5712e0d4d6c2..52205a3ef7b2 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlock.java @@ -136,6 +136,20 @@ public boolean mayHaveNull() return valueIsNull != null; } + @Override + public boolean hasNull() + { + if (valueIsNull == null) { + return false; + } + for (int i = 0; i < positionCount; i++) { + if (valueIsNull[i + arrayOffset]) { + return true; + } + } + return false; + } + @Override public boolean isNull(int position) { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/LazyBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/LazyBlock.java index a0159b064a5c..6e52514ac2b4 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/LazyBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/LazyBlock.java @@ -163,6 +163,12 @@ public boolean mayHaveNull() return getBlock().mayHaveNull(); } + @Override + public boolean hasNull() + { + return getBlock().hasNull(); + } + public Block getBlock() { return lazyData.getTopLevelBlock(); diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlock.java index a181ace9ba5e..ea11602d8bed 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlock.java @@ -135,6 +135,20 @@ public boolean mayHaveNull() return valueIsNull != null; } + @Override + public boolean hasNull() + { + if (valueIsNull == null) { + return false; + } + for (int i = 0; i < positionCount; i++) { + if (valueIsNull[i + arrayOffset]) { + return true; + } + } + return false; + } + @Override public boolean isNull(int position) { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/MapBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/MapBlock.java index d9c5d6658eff..6fc50f7400bc 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/MapBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/MapBlock.java @@ -231,6 +231,20 @@ public boolean mayHaveNull() return mapIsNull != null; } + @Override + public boolean hasNull() + { + if (mapIsNull == null) { + return false; + } + for (int i = 0; i < positionCount; i++) { + if (mapIsNull[i + startOffset]) { + return true; + } + } + return false; + } + @Override public int getPositionCount() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/RowBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/RowBlock.java index bc7bab51c1e2..ec909b80172d 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/RowBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/RowBlock.java @@ -152,6 +152,20 @@ public boolean mayHaveNull() return rowIsNull != null; } + @Override + public boolean hasNull() + { + if (rowIsNull == null) { + return false; + } + for (int i = 0; i < positionCount; i++) { + if (rowIsNull[i]) { + return true; + } + } + return false; + } + boolean[] getRawRowIsNull() { return rowIsNull; diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/RunLengthEncodedBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/RunLengthEncodedBlock.java index cd03ac2a46dc..a10b94eea451 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/RunLengthEncodedBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/RunLengthEncodedBlock.java @@ -196,6 +196,12 @@ public boolean mayHaveNull() return positionCount > 0 && value.isNull(0); } + @Override + public boolean hasNull() + { + return mayHaveNull(); + } + @Override public boolean isNull(int position) { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlock.java index 811572f931fc..e83aae77df8e 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlock.java @@ -135,6 +135,20 @@ public boolean mayHaveNull() return valueIsNull != null; } + @Override + public boolean hasNull() + { + if (valueIsNull == null) { + return false; + } + for (int i = 0; i < positionCount; i++) { + if (valueIsNull[i + arrayOffset]) { + return true; + } + } + return false; + } + @Override public boolean isNull(int position) { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlock.java index 6cd81a848927..c5511bbf8a13 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlock.java @@ -191,6 +191,20 @@ public boolean mayHaveNull() return valueIsNull != null; } + @Override + public boolean hasNull() + { + if (valueIsNull == null) { + return false; + } + for (int i = 0; i < positionCount; i++) { + if (valueIsNull[i + arrayOffset]) { + return true; + } + } + return false; + } + @Override public boolean isNull(int position) { diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlVectorType.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlVectorType.java index 3b0b21ef5778..b11aa808e5e6 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlVectorType.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlVectorType.java @@ -354,8 +354,8 @@ void testCosineDistanceUnsupportedPushdown() .isNotFullyPushedDown(ProjectNode.class); assertThat(query("SELECT id FROM " + table.getName() + " ORDER BY cosine_distance(v, ARRAY[REAL 'NaN']) LIMIT 1")) .isNotFullyPushedDown(ProjectNode.class); - assertQueryFails("SELECT id FROM " + table.getName() + " ORDER BY cosine_distance(v, ARRAY[CAST(NULL AS REAL)]) LIMIT 1", - "Vector magnitude cannot be zero"); + assertThat(query("SELECT id FROM " + table.getName() + " ORDER BY cosine_distance(v, ARRAY[CAST(NULL AS REAL)]) LIMIT 1")) + .isNotFullyPushedDown(ProjectNode.class); assertThat(query("SELECT id FROM " + table.getName() + " ORDER BY cosine_distance(v, NULL) LIMIT 1")) .isNotFullyPushedDown(ProjectNode.class); }