Skip to content

Commit 33da2b1

Browse files
committed
Add arrayType, rowType, mapType and varbinary support to format()
1 parent 8f376af commit 33da2b1

3 files changed

Lines changed: 178 additions & 6 deletions

File tree

core/trino-main/src/main/java/io/trino/operator/scalar/FormatFunction.java

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
*/
1414
package io.trino.operator.scalar;
1515

16+
import com.fasterxml.jackson.core.io.JsonStringEncoder;
1617
import com.google.common.collect.ImmutableList;
1718
import io.airlift.slice.Slice;
1819
import io.trino.annotation.UsedByGeneratedCode;
1920
import io.trino.metadata.SqlScalarFunction;
2021
import io.trino.spi.TrinoException;
2122
import io.trino.spi.block.Block;
23+
import io.trino.spi.block.SqlMap;
2224
import io.trino.spi.block.SqlRow;
2325
import io.trino.spi.connector.ConnectorSession;
2426
import io.trino.spi.function.BoundSignature;
@@ -28,21 +30,26 @@
2830
import io.trino.spi.function.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder;
2931
import io.trino.spi.function.FunctionMetadata;
3032
import io.trino.spi.function.Signature;
33+
import io.trino.spi.type.ArrayType;
3134
import io.trino.spi.type.CharType;
3235
import io.trino.spi.type.DecimalType;
3336
import io.trino.spi.type.Int128;
37+
import io.trino.spi.type.MapType;
3438
import io.trino.spi.type.RowType;
3539
import io.trino.spi.type.TimeType;
3640
import io.trino.spi.type.TimestampType;
3741
import io.trino.spi.type.TimestampWithTimeZoneType;
3842
import io.trino.spi.type.Type;
3943
import io.trino.spi.type.TypeSignature;
44+
import io.trino.spi.type.UuidType;
45+
import io.trino.spi.type.VarbinaryType;
4046
import io.trino.spi.type.VarcharType;
4147

4248
import java.lang.invoke.MethodHandle;
4349
import java.math.BigDecimal;
4450
import java.time.LocalDate;
4551
import java.time.LocalTime;
52+
import java.util.Base64;
4653
import java.util.IllegalFormatException;
4754
import java.util.List;
4855
import java.util.function.BiFunction;
@@ -124,7 +131,21 @@ private static void addDependencies(FunctionDependencyDeclarationBuilder builder
124131
type instanceof TimeType ||
125132
type instanceof DecimalType ||
126133
type instanceof VarcharType ||
127-
type instanceof CharType) {
134+
type instanceof CharType ||
135+
type instanceof VarbinaryType) {
136+
return;
137+
}
138+
if (type instanceof ArrayType arrayType) {
139+
addDependencies(builder, arrayType.getElementType());
140+
return;
141+
}
142+
if (type instanceof MapType mapType) {
143+
addDependencies(builder, mapType.getKeyType());
144+
addDependencies(builder, mapType.getValueType());
145+
return;
146+
}
147+
if (type instanceof RowType rowType) {
148+
rowType.getTypeParameters().forEach(t -> addDependencies(builder, t));
128149
return;
129150
}
130151

@@ -236,6 +257,18 @@ private static BiFunction<Block, Integer, Object> valueConverter(FunctionDepende
236257
if (type instanceof CharType charType) {
237258
return (block, position) -> padSpaces(charType.getSlice(block, position), charType).toStringUtf8();
238259
}
260+
if (type instanceof RowType rowType) {
261+
return (block, position) -> rowToString(functionDependencies, rowType, (SqlRow) rowType.getObject(block, position));
262+
}
263+
if (type instanceof MapType mapType) {
264+
return (block, position) -> mapToString(functionDependencies, mapType, (SqlMap) mapType.getObject(block, position));
265+
}
266+
if (type instanceof ArrayType arrayType) {
267+
return (block, position) -> arrayToString(functionDependencies, arrayType, (Block) arrayType.getObject(block, position));
268+
}
269+
if (type instanceof VarbinaryType varbinaryType) {
270+
return (block, position) -> Base64.getEncoder().encodeToString(varbinaryType.getSlice(block, position).getBytes());
271+
}
239272

240273
BiFunction<Block, Integer, Object> function;
241274
if (type.getJavaType() == long.class) {
@@ -258,6 +291,63 @@ else if (type.getJavaType() == Slice.class) {
258291
return (block, position) -> convertToString(handle, function.apply(block, position));
259292
}
260293

294+
private static Object quotedValue(FunctionDependencies functionDependencies, Type type, Block block, int position)
295+
{
296+
Object value = FormatFunction.converter(functionDependencies, type).apply(block, position);
297+
if (value != null && (
298+
type instanceof VarcharType ||
299+
type instanceof CharType ||
300+
type instanceof VarbinaryType ||
301+
type instanceof UuidType)) {
302+
return String.format("\"%s\"", new String(JsonStringEncoder.getInstance().quoteAsString((String) value)));
303+
}
304+
return value;
305+
}
306+
307+
private static String rowToString(FunctionDependencies functionDependencies, RowType rowType, SqlRow row)
308+
{
309+
List<RowType.Field> fields = rowType.getFields();
310+
boolean hasAllFieldNames = fields.stream().allMatch(field -> field.getName().isPresent());
311+
StringBuilder builder = new StringBuilder(hasAllFieldNames ? "{" : "[");
312+
int rawIndex = row.getRawIndex();
313+
for (int i = 0; i < fields.size(); i++) {
314+
builder.append(i == 0 ? "" : ", ");
315+
if (hasAllFieldNames) {
316+
String fieldName = fields.get(i).getName().get();
317+
builder.append('"').append(new String(JsonStringEncoder.getInstance().quoteAsString(fieldName))).append("\": ");
318+
}
319+
builder.append(quotedValue(functionDependencies, fields.get(i).getType(), row.getRawFieldBlock(i), rawIndex));
320+
}
321+
return builder.append(hasAllFieldNames ? '}' : ']').toString();
322+
}
323+
324+
private static String mapToString(FunctionDependencies functionDependencies, MapType mapType, SqlMap sqlMap)
325+
{
326+
StringBuilder builder = new StringBuilder("{");
327+
Block keys = sqlMap.getRawKeyBlock();
328+
Block values = sqlMap.getRawValueBlock();
329+
int rawOffset = sqlMap.getRawOffset();
330+
for (int i = 0; i < sqlMap.getSize(); i++) {
331+
builder
332+
.append(i == 0 ? "" : ", ")
333+
.append(quotedValue(functionDependencies, mapType.getKeyType(), keys, rawOffset + i))
334+
.append(": ")
335+
.append(quotedValue(functionDependencies, mapType.getValueType(), values, rawOffset + i));
336+
}
337+
return builder.append('}').toString();
338+
}
339+
340+
private static String arrayToString(FunctionDependencies functionDependencies, ArrayType arrayType, Block elementBlock)
341+
{
342+
StringBuilder builder = new StringBuilder("[");
343+
for (int i = 0; i < elementBlock.getPositionCount(); i++) {
344+
builder
345+
.append(i == 0 ? "" : ", ")
346+
.append(quotedValue(functionDependencies, arrayType.getElementType(), elementBlock, i));
347+
}
348+
return builder.append(']').toString();
349+
}
350+
261351
private static LocalTime toLocalTime(long value)
262352
{
263353
long nanoOfDay = roundDiv(value, PICOSECONDS_PER_NANOSECOND);

core/trino-main/src/test/java/io/trino/operator/scalar/TestFormatFunction.java

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
import java.util.Arrays;
2626

27-
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
2827
import static io.trino.spi.StandardErrorCode.TYPE_MISMATCH;
2928
import static io.trino.testing.assertions.TrinoExceptionAssert.assertTrinoExceptionThrownBy;
3029
import static org.assertj.core.api.Assertions.assertThat;
@@ -149,6 +148,78 @@ public void testFormat()
149148
assertThat(format("%s", "cast('test' AS char(5))"))
150149
.isEqualTo("test ");
151150

151+
assertThat(format("%s", "cast(row('hello', 'world') AS row(greeting varchar, planet varchar))"))
152+
.isEqualTo("{\"greeting\": \"hello\", \"planet\": \"world\"}");
153+
154+
assertThat(format("%s", "row('hello', 'world')"))
155+
.isEqualTo("[\"hello\", \"world\"]");
156+
157+
assertThat(format("%s", "cast(row('hello', array['world']) AS row(greeting varchar, planet array(varchar)))"))
158+
.isEqualTo("{\"greeting\": \"hello\", \"planet\": [\"world\"]}");
159+
160+
assertThat(format("%s", "cast(row('hello', 1337) AS row(greeting varchar, planet integer))"))
161+
.isEqualTo("{\"greeting\": \"hello\", \"planet\": 1337}");
162+
163+
assertThat(format("%s", "cast(row('hello', from_base64('d29ybGQ=')) AS row(greeting varchar, planet varbinary))"))
164+
.isEqualTo("{\"greeting\": \"hello\", \"planet\": \"d29ybGQ=\"}");
165+
166+
assertThat(format("%s", "ARRAY['hello', 'world']"))
167+
.isEqualTo("[\"hello\", \"world\"]");
168+
169+
assertThat(format("%s", "ARRAY['hel\"l\\o\nworld']"))
170+
.isEqualTo("[\"hel\\\"l\\\\o\\nworld\"]");
171+
172+
assertThat(format("%s", "ARRAY[1, 2, 3]"))
173+
.isEqualTo("[1, 2, 3]");
174+
175+
assertThat(format("%s", "ARRAY[TRUE, FALSE]"))
176+
.isEqualTo("[true, false]");
177+
178+
assertThat(format("%s", "from_base64('d29ybGQ=')"))
179+
.isEqualTo("d29ybGQ=");
180+
181+
assertThat(format("%s", "map(ARRAY['greeting', 'planet'], ARRAY['hello', 'world'])"))
182+
.isEqualTo("{\"greeting\": \"hello\", \"planet\": \"world\"}");
183+
184+
assertThat(format("%s", "map(ARRAY['greeting', 'planet'], ARRAY[1, 2])"))
185+
.isEqualTo("{\"greeting\": 1, \"planet\": 2}");
186+
187+
assertThat(format("%s", "ARRAY[null, 'world']"))
188+
.isEqualTo("[null, \"world\"]");
189+
190+
assertThat(format("%s", "cast(row(null, 'world') AS row(greeting varchar, planet varchar))"))
191+
.isEqualTo("{\"greeting\": null, \"planet\": \"world\"}");
192+
193+
assertThat(format("%s", "map(ARRAY['greeting'], ARRAY[null])"))
194+
.isEqualTo("{\"greeting\": null}");
195+
196+
assertThat(format("%s", "ARRAY[cast('hello' AS char(5))]"))
197+
.isEqualTo("[\"hello\"]");
198+
199+
assertThat(format("%s", "cast(row(cast('hi' AS char(3))) AS row(greeting char(3)))"))
200+
.isEqualTo("{\"greeting\": \"hi \"}");
201+
202+
assertThat(format("%s", "ARRAY[ARRAY['a', 'b'], ARRAY['c']]"))
203+
.isEqualTo("[[\"a\", \"b\"], [\"c\"]]");
204+
205+
assertThat(format("%s", "map(ARRAY['nums'], ARRAY[ARRAY[1, 2]])"))
206+
.isEqualTo("{\"nums\": [1, 2]}");
207+
208+
assertThat(format("%s", "map(ARRAY[1, 2], ARRAY['hello', 'world'])"))
209+
.isEqualTo("{1: \"hello\", 2: \"world\"}");
210+
211+
assertThat(format("%s", "ARRAY[uuid '03780fd9-76cf-4366-b720-0cfc6b957e8f']"))
212+
.isEqualTo("[\"03780fd9-76cf-4366-b720-0cfc6b957e8f\"]");
213+
214+
assertThat(format("%s", "row(uuid '03780fd9-76cf-4366-b720-0cfc6b957e8f')"))
215+
.isEqualTo("[\"03780fd9-76cf-4366-b720-0cfc6b957e8f\"]");
216+
217+
assertThat(format("%s", "cast(row(uuid '03780fd9-76cf-4366-b720-0cfc6b957e8f') AS row(id uuid))"))
218+
.isEqualTo("{\"id\": \"03780fd9-76cf-4366-b720-0cfc6b957e8f\"}");
219+
220+
assertThat(format("%s", "map(ARRAY['id'], ARRAY[uuid '03780fd9-76cf-4366-b720-0cfc6b957e8f'])"))
221+
.isEqualTo("{\"id\": \"03780fd9-76cf-4366-b720-0cfc6b957e8f\"}");
222+
152223
assertTrinoExceptionThrownBy(format("%.4d", "8")::evaluate)
153224
.hasMessage("Invalid format string: %.4d (IllegalFormatPrecision: 4)");
154225

@@ -176,10 +247,6 @@ public void testFormat()
176247
assertTrinoExceptionThrownBy(format("%tT", "current_time")::evaluate)
177248
.hasMessage("Invalid format string: %tT (IllegalFormatConversion: T != java.lang.String)");
178249

179-
assertTrinoExceptionThrownBy(format("%s", "array[8]")::evaluate)
180-
.hasErrorCode(NOT_SUPPORTED)
181-
.hasMessage("line 1:20: Type not supported for formatting: array(integer)");
182-
183250
assertTrinoExceptionThrownBy(assertions.function("format", "5", "8")::evaluate)
184251
.hasErrorCode(TYPE_MISMATCH)
185252
.hasMessage("line 1:17: Type of first argument to format() must be VARCHAR (actual: integer)");

docs/src/main/sphinx/functions/conversion.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,21 @@ SELECT format('%2$s %3$s %1$s', 'a', 'b', 'c');
4646
4747
SELECT format('%1$tA, %1$tB %1$te, %1$tY', date '2006-07-04');
4848
-- 'Tuesday, July 4, 2006'
49+
50+
SELECT format('%s', cast(row('hello', 'world') AS row(greeting varchar, planet varchar)));
51+
-- '{"greeting": "hello", "planet": "world"}'
52+
53+
SELECT format('%s', row('hello', 'world'));
54+
-- '["hello", "world"]'
55+
56+
SELECT format('%s', ARRAY['hello', 'world']);
57+
-- '["hello", "world"]'
58+
59+
SELECT format('%s', map(ARRAY['greeting', 'planet'], ARRAY['hello', 'world']));
60+
-- '{"greeting": "hello", "planet": "world"}'
61+
62+
SELECT format('%s', from_base64('d29ybGQ='));
63+
-- 'd29ybGQ='
4964
```
5065
:::
5166

0 commit comments

Comments
 (0)