Skip to content

Commit a35b239

Browse files
committed
Add support to NUMBER to Python UDFs
1 parent e791b2c commit a35b239

4 files changed

Lines changed: 85 additions & 1 deletion

File tree

plugin/trino-functions-python/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
<dependency>
7676
<groupId>io.trino</groupId>
7777
<artifactId>trino-wasm-python</artifactId>
78-
<version>3.13-5</version>
78+
<version>3.13-6</version>
7979
</dependency>
8080

8181
<dependency>

plugin/trino-functions-python/src/main/java/io/trino/plugin/functions/python/TrinoType.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ enum TrinoType
3838
JSON(20),
3939
UUID(21),
4040
IPADDRESS(22),
41+
NUMBER(23),
4142
/**/;
4243

4344
private final int id;

plugin/trino-functions-python/src/main/java/io/trino/plugin/functions/python/TrinoTypes.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,19 @@
3535
import io.trino.spi.type.LongTimestamp;
3636
import io.trino.spi.type.LongTimestampWithTimeZone;
3737
import io.trino.spi.type.MapType;
38+
import io.trino.spi.type.NumberType;
3839
import io.trino.spi.type.RealType;
3940
import io.trino.spi.type.RowType;
4041
import io.trino.spi.type.SmallintType;
42+
import io.trino.spi.type.SqlNumber;
4143
import io.trino.spi.type.StandardTypes;
4244
import io.trino.spi.type.TimeType;
4345
import io.trino.spi.type.TimeWithTimeZoneType;
4446
import io.trino.spi.type.TimeZoneKey;
4547
import io.trino.spi.type.TimestampType;
4648
import io.trino.spi.type.TimestampWithTimeZoneType;
4749
import io.trino.spi.type.TinyintType;
50+
import io.trino.spi.type.TrinoNumber;
4851
import io.trino.spi.type.Type;
4952
import io.trino.spi.type.VarcharType;
5053

@@ -170,6 +173,7 @@ private static TrinoType singletonType(Type type)
170173
case StandardTypes.JSON -> TrinoType.JSON;
171174
case StandardTypes.UUID -> TrinoType.UUID;
172175
case StandardTypes.IPADDRESS -> TrinoType.IPADDRESS;
176+
case StandardTypes.NUMBER -> TrinoType.NUMBER;
173177
default -> throw new TrinoException(NOT_SUPPORTED, "Unsupported type: " + type);
174178
};
175179
}
@@ -202,6 +206,15 @@ private static void javaToBinary(Type type, Object value, SliceOutput output)
202206
: Decimals.toString((Int128) value, decimalType.getScale());
203207
writeVariableSlice(utf8Slice(decimalString), output);
204208
}
209+
case NumberType _ -> {
210+
TrinoNumber number = (TrinoNumber) value;
211+
String stringValue = switch (number.toBigDecimal()) {
212+
case TrinoNumber.BigDecimalValue(BigDecimal bigDecimal) -> bigDecimal.toString();
213+
case TrinoNumber.Infinity(boolean negative) -> negative ? "-Infinity" : "+Infinity";
214+
case TrinoNumber.NotANumber() -> "NaN";
215+
};
216+
writeVariableSlice(utf8Slice(stringValue), output);
217+
}
205218
case TimeWithTimeZoneType timeType -> {
206219
if (timeType.isShort()) {
207220
long time = (long) value;
@@ -284,6 +297,15 @@ private static void blockToBinary(Type type, Block block, int position, SliceOut
284297
: Decimals.toString((Int128) decimalType.getObject(block, position), decimalType.getScale());
285298
writeVariableSlice(utf8Slice(decimalString), output);
286299
}
300+
case NumberType numberType -> {
301+
SqlNumber value = (SqlNumber) numberType.getObjectValue(block, position);
302+
String stringValue = switch (value.value()) {
303+
case TrinoNumber.BigDecimalValue(BigDecimal bigDecimal) -> bigDecimal.toString();
304+
case TrinoNumber.Infinity(boolean negative) -> negative ? "-Infinity" : "+Infinity";
305+
case TrinoNumber.NotANumber() -> "NaN";
306+
};
307+
writeVariableSlice(utf8Slice(stringValue), output);
308+
}
287309
case DateType dateType -> output.writeInt(dateType.getInt(block, position));
288310
case TimeType timeType -> output.writeLong(picosToMicros(timeType.getLong(block, position)));
289311
case TimeWithTimeZoneType timeType -> {
@@ -394,6 +416,16 @@ public static Object binaryToJava(Type type, SliceInput input)
394416
e);
395417
}
396418
}
419+
case NumberType _ -> {
420+
String stringUtf8 = input.readSlice(input.readInt()).toStringUtf8();
421+
TrinoNumber.AsBigDecimal number = switch (stringUtf8) {
422+
case "NaN" -> new TrinoNumber.NotANumber();
423+
case "+Infinity" -> new TrinoNumber.Infinity(false);
424+
case "-Infinity" -> new TrinoNumber.Infinity(true);
425+
default -> new TrinoNumber.BigDecimalValue(new BigDecimal(stringUtf8));
426+
};
427+
yield TrinoNumber.from(number);
428+
}
397429
case TimeType timeType -> {
398430
long micros = roundMicros(input.readLong(), timeType.getPrecision()) % MICROSECONDS_PER_DAY;
399431
yield micros * PICOSECONDS_PER_MICROSECOND;

plugin/trino-functions-python/src/test/java/io/trino/plugin/functions/python/TestPythonFunctions.java

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import static io.trino.testing.TestingHandles.TEST_CATALOG_NAME;
3636
import static io.trino.testing.TestingSession.testSessionBuilder;
3737
import static org.assertj.core.api.Assertions.assertThat;
38+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
3839
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS;
3940
import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT;
4041

@@ -751,6 +752,56 @@ SELECT bad_bigint_return()
751752
"TypeError: 'str' object cannot be interpreted as an integer");
752753
}
753754

755+
@Test
756+
public void testTypeNumber()
757+
{
758+
String query =
759+
"""
760+
WITH FUNCTION multiply(x number, y number)
761+
RETURNS number
762+
LANGUAGE PYTHON
763+
WITH (handler = 'multiply')
764+
AS $$
765+
from decimal import Decimal
766+
def multiply(x, y):
767+
return x * y
768+
$$
769+
""";
770+
771+
assertThat(assertions.query(
772+
query + "SELECT multiply(NUMBER '1.12345000000000123456789', NUMBER '2.5432100000000000000000000000000000000000001')"))
773+
.matches("VALUES NUMBER '2.857169274500003139765403527'");
774+
775+
assertThat(assertions.query(
776+
query + "SELECT multiply(NUMBER 'NaN', NUMBER '3.14')"))
777+
.matches("VALUES NUMBER 'NaN'");
778+
779+
assertThat(assertions.query(
780+
query + "SELECT multiply(NUMBER '-Infinity', NUMBER '3.14')"))
781+
.matches("VALUES NUMBER '-Infinity'");
782+
783+
assertThat(assertions.query(
784+
query + "SELECT multiply(NUMBER '+Infinity', NUMBER '3.14')"))
785+
.matches("VALUES NUMBER '+Infinity'");
786+
787+
assertThat(assertions.query(
788+
query + "SELECT multiply(NUMBER '+Infinity', NUMBER '-Infinity')"))
789+
.matches("VALUES NUMBER '-Infinity'");
790+
791+
assertThat(assertions.query(
792+
query + "SELECT multiply(NUMBER '-Infinity', NUMBER '-Infinity')"))
793+
.matches("VALUES NUMBER '+Infinity'");
794+
795+
assertThat(assertions.query(
796+
query + "SELECT multiply(NUMBER '-Infinity', NUMBER 'NaN')"))
797+
.matches("VALUES NUMBER 'NaN'");
798+
799+
assertThatThrownBy(() -> assertThat(assertions.query(
800+
query + "SELECT multiply(NULL, NUMBER '2.54321')"))
801+
.matches("VALUES NUMBER 'NaN'"))
802+
.hasMessageContaining("TypeError: unsupported operand type(s) for *: 'NoneType' and 'decimal.Decimal'");
803+
}
804+
754805
@Test
755806
public void testTypeInteger()
756807
{

0 commit comments

Comments
 (0)