diff --git a/.gitignore b/.gitignore index 451d292d7..20e4540d9 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ *.iml # gradle build .gradle -build \ No newline at end of file +build +out \ No newline at end of file diff --git a/api/src/main/java/com/netflix/iceberg/FileFormat.java b/api/src/main/java/com/netflix/iceberg/FileFormat.java index 71d6d5e44..8814ada8d 100644 --- a/api/src/main/java/com/netflix/iceberg/FileFormat.java +++ b/api/src/main/java/com/netflix/iceberg/FileFormat.java @@ -22,6 +22,7 @@ * Enum of supported file formats. */ public enum FileFormat { + ORC("orc"), PARQUET("parquet"), AVRO("avro"); diff --git a/api/src/main/java/com/netflix/iceberg/UpdateProperties.java b/api/src/main/java/com/netflix/iceberg/UpdateProperties.java index e8159bba5..65cc9baa9 100644 --- a/api/src/main/java/com/netflix/iceberg/UpdateProperties.java +++ b/api/src/main/java/com/netflix/iceberg/UpdateProperties.java @@ -47,4 +47,10 @@ public interface UpdateProperties extends PendingUpdate> { */ UpdateProperties remove(String key); + /** + * Set the default file format for the table. + * @param format + * @return this + */ + UpdateProperties defaultFormat(FileFormat format); } diff --git a/build.gradle b/build.gradle index 16b6476b4..c091c2ff6 100644 --- a/build.gradle +++ b/build.gradle @@ -60,6 +60,7 @@ subprojects { ext { avroVersion = '1.8.2' + orcVersion = '1.4.2' parquetVersion = '1.9.1-SNAPSHOT' jacksonVersion = '2.6.7' @@ -114,6 +115,19 @@ project(':iceberg-core') { } } +project(':iceberg-orc') { + dependencies { + compile project(':iceberg-api') + compile project(':iceberg-core') + + compile "org.apache.orc:orc-core:$orcVersion:nohive" + + compileOnly('org.apache.hadoop:hadoop-client:2.7.3') { + exclude group: 'org.apache.avro', module: 'avro' + } + } +} + project(':iceberg-parquet') { dependencies { compile project(':iceberg-api') @@ -137,6 +151,7 @@ project(':iceberg-spark') { compile project(':iceberg-common') compile project(':iceberg-avro') compile project(':iceberg-core') + compile project(':iceberg-orc') compile project(':iceberg-parquet') compileOnly "org.apache.avro:avro:$avroVersion" @@ -174,10 +189,12 @@ project(':iceberg-runtime') { shadow project(':iceberg-common') shadow project(':iceberg-avro') shadow project(':iceberg-core') + shadow project(':iceberg-orc') shadow project(':iceberg-parquet') shadow project(':iceberg-spark') shadow "org.apache.avro:avro:$avroVersion" + shadow "org.apache.orc:orc-core:$orcVersion:nohive" shadow "org.apache.parquet:parquet-avro:$parquetVersion" } diff --git a/core/src/main/java/com/netflix/iceberg/PropertiesUpdate.java b/core/src/main/java/com/netflix/iceberg/PropertiesUpdate.java index 46295ee50..6044eaaa0 100644 --- a/core/src/main/java/com/netflix/iceberg/PropertiesUpdate.java +++ b/core/src/main/java/com/netflix/iceberg/PropertiesUpdate.java @@ -68,6 +68,12 @@ public UpdateProperties remove(String key) { return this; } + @Override + public UpdateProperties defaultFormat(FileFormat format) { + set(TableProperties.DEFAULT_FILE_FORMAT, format.name()); + return this; + } + @Override public Map apply() { this.base = ops.refresh(); diff --git a/orc/src/main/java/com/netflix/iceberg/orc/ColumnIdMap.java b/orc/src/main/java/com/netflix/iceberg/orc/ColumnIdMap.java new file mode 100644 index 000000000..fa266aaa4 --- /dev/null +++ b/orc/src/main/java/com/netflix/iceberg/orc/ColumnIdMap.java @@ -0,0 +1,126 @@ +/* + * Copyright 2018 Hortonworks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.netflix.iceberg.orc; + +import org.apache.orc.TypeDescription; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Collection; +import java.util.IdentityHashMap; +import java.util.Map; +import java.util.Set; + +/** + * The mapping from ORC's TypeDescription to the Iceberg column ids. + * + * Keep the API limited to Map rather than a concrete type so that we can + * change it later. + */ +public class ColumnIdMap implements Map { + + private final IdentityHashMap idMap = + new IdentityHashMap<>(); + + @Override + public int size() { + return idMap.size(); + } + + @Override + public boolean isEmpty() { + return idMap.isEmpty(); + } + + @Override + public boolean containsKey(Object key) { + return idMap.containsKey(key); + } + + @Override + public boolean containsValue(Object value) { + return idMap.containsValue(value); + } + + @Override + public Integer get(Object key) { + return idMap.get(key); + } + + @Override + public Integer put(TypeDescription key, Integer value) { + return idMap.put(key, value); + } + + @Override + public Integer remove(Object key) { + return idMap.remove(key); + } + + @Override + public void putAll(Map map) { + idMap.putAll(map); + } + + @Override + public void clear() { + idMap.clear(); + } + + @Override + public Set keySet() { + return idMap.keySet(); + } + + @Override + public Collection values() { + return idMap.values(); + } + + @Override + public Set> entrySet() { + return idMap.entrySet(); + } + + public ByteBuffer serialize() { + StringBuilder buffer = new StringBuilder(); + boolean needComma = false; + for(TypeDescription key: idMap.keySet()) { + if (needComma) { + buffer.append(','); + } else { + needComma = true; + } + buffer.append(key.getId()); + buffer.append(':'); + buffer.append(idMap.get(key).intValue()); + } + return ByteBuffer.wrap(buffer.toString().getBytes(StandardCharsets.UTF_8)); + } + + public static ColumnIdMap deserialize(TypeDescription schema, + ByteBuffer serial) { + ColumnIdMap result = new ColumnIdMap(); + String[] parts = StandardCharsets.UTF_8.decode(serial).toString().split(","); + for(int i = 0; i < parts.length; ++i) { + String[] subparts = parts[i].split(":"); + result.put(schema.findSubtype(Integer.parseInt(subparts[0])), + Integer.parseInt(subparts[1])); + } + return result; + } +} diff --git a/orc/src/main/java/com/netflix/iceberg/orc/ORC.java b/orc/src/main/java/com/netflix/iceberg/orc/ORC.java new file mode 100644 index 000000000..9de6e26e2 --- /dev/null +++ b/orc/src/main/java/com/netflix/iceberg/orc/ORC.java @@ -0,0 +1,146 @@ +/* + * Copyright 2018 Hortonworks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.netflix.iceberg.orc; + +import com.google.common.base.Preconditions; +import com.netflix.iceberg.Schema; +import com.netflix.iceberg.hadoop.HadoopInputFile; +import com.netflix.iceberg.hadoop.HadoopOutputFile; +import com.netflix.iceberg.io.InputFile; +import com.netflix.iceberg.io.OutputFile; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.orc.OrcFile; +import org.apache.orc.Reader; +import org.apache.orc.TypeDescription; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; + +public class ORC { + private ORC() { + } + + public static WriteBuilder write(OutputFile file) { + return new WriteBuilder(file); + } + + public static class WriteBuilder { + private final OutputFile file; + private final Configuration conf; + private Schema schema = null; + private Map metadata = new HashMap<>(); + + private WriteBuilder(OutputFile file) { + this.file = file; + if (file instanceof HadoopOutputFile) { + conf = new Configuration(((HadoopOutputFile) file).getConf()); + } else { + conf = new Configuration(); + } + } + + public WriteBuilder metadata(String property, String value) { + metadata.put(property, value.getBytes(StandardCharsets.UTF_8)); + return this; + } + + public WriteBuilder config(String property, String value) { + conf.set(property, value); + return this; + } + + public WriteBuilder schema(Schema schema) { + this.schema = schema; + return this; + } + + public OrcFileAppender build() { + OrcFile.WriterOptions options = + OrcFile.writerOptions(conf); + return new OrcFileAppender(schema, file, options, metadata); + } + } + + public static ReadBuilder read(InputFile file) { + return new ReadBuilder(file); + } + + public static class ReadBuilder { + private final InputFile file; + private final Configuration conf; + private com.netflix.iceberg.Schema schema = null; + private Long start = null; + private Long length = null; + + private ReadBuilder(InputFile file) { + Preconditions.checkNotNull(file, "Input file cannot be null"); + this.file = file; + if (file instanceof HadoopInputFile) { + conf = new Configuration(((HadoopInputFile) file).getConf()); + } else { + conf = new Configuration(); + } + } + + /** + * Restricts the read to the given range: [start, start + length). + * + * @param start the start position for this read + * @param length the length of the range this read should scan + * @return this builder for method chaining + */ + public ReadBuilder split(long start, long length) { + this.start = start; + this.length = length; + return this; + } + + public ReadBuilder schema(com.netflix.iceberg.Schema schema) { + this.schema = schema; + return this; + } + + public ReadBuilder config(String property, String value) { + conf.set(property, value); + return this; + } + + public OrcIterator build() { + Preconditions.checkNotNull(schema, "Schema is required"); + try { + Path path = new Path(file.location()); + Reader reader = OrcFile.createReader(path, OrcFile.readerOptions(conf)); + ColumnIdMap columnIds = new ColumnIdMap(); + TypeDescription orcSchema = TypeConversion.toOrc(schema, columnIds); + Reader.Options options = reader.options(); + if (start != null) { + options.range(start, length); + } + options.schema(orcSchema); + return new OrcIterator(path, orcSchema, reader.rows(options)); + } catch (IOException e) { + throw new RuntimeException("Can't open " + file.location(), e); + } + } + } +} diff --git a/orc/src/main/java/com/netflix/iceberg/orc/OrcFileAppender.java b/orc/src/main/java/com/netflix/iceberg/orc/OrcFileAppender.java new file mode 100644 index 000000000..6ad82ee4d --- /dev/null +++ b/orc/src/main/java/com/netflix/iceberg/orc/OrcFileAppender.java @@ -0,0 +1,108 @@ +/* + * Copyright 2018 Hortonworks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.netflix.iceberg.orc; + +import com.netflix.iceberg.Metrics; +import com.netflix.iceberg.Schema; +import com.netflix.iceberg.io.FileAppender; +import com.netflix.iceberg.io.OutputFile; +import org.apache.hadoop.fs.Path; +import org.apache.orc.ColumnStatistics; +import org.apache.orc.TypeDescription; +import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch; +import org.apache.orc.OrcFile; +import org.apache.orc.Writer; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; + +/** + * Create a file appender for ORC. + */ +public class OrcFileAppender implements FileAppender { + private final Writer writer; + private final TypeDescription orcSchema; + private final ColumnIdMap columnIds = new ColumnIdMap(); + private final Path path; + + public static final String COLUMN_NUMBERS_ATTRIBUTE = "iceberg.column.ids"; + + OrcFileAppender(Schema schema, + OutputFile file, + OrcFile.WriterOptions options, + Map metadata) { + orcSchema = TypeConversion.toOrc(schema, columnIds); + options.setSchema(orcSchema); + path = new Path(file.location()); + try { + writer = OrcFile.createWriter(path, options); + } catch (IOException e) { + throw new RuntimeException("Can't create file " + path, e); + } + writer.addUserMetadata(COLUMN_NUMBERS_ATTRIBUTE, columnIds.serialize()); + metadata.forEach( + (key,value) -> writer.addUserMetadata(key, ByteBuffer.wrap(value))); + } + + @Override + public void add(VectorizedRowBatch datum) { + try { + writer.addRowBatch(datum); + } catch (IOException e) { + throw new RuntimeException("Problem writing to ORC file " + path, e); + } + } + + @Override + public Metrics metrics() { + try { + long rows = writer.getNumberOfRows(); + ColumnStatistics[] stats = writer.getStatistics(); + // we don't currently have columnSizes or distinct counts. + Map valueCounts = new HashMap<>(); + Map nullCounts = new HashMap<>(); + Integer[] icebergIds = new Integer[orcSchema.getMaximumId() + 1]; + for(TypeDescription type: columnIds.keySet()) { + icebergIds[type.getId()] = columnIds.get(type); + } + for(int c=1; c < stats.length; ++c) { + if (icebergIds[c] != null) { + valueCounts.put(icebergIds[c], stats[c].getNumberOfValues()); + } + } + for(TypeDescription child: orcSchema.getChildren()) { + int c = child.getId(); + if (icebergIds[c] != null) { + nullCounts.put(icebergIds[c], rows - stats[c].getNumberOfValues()); + } + } + return new Metrics(rows, null, valueCounts, nullCounts); + } catch (IOException e) { + throw new RuntimeException("Can't get statistics " + path, e); + } + } + + @Override + public void close() throws IOException { + writer.close(); + } + + public TypeDescription getSchema() { + return orcSchema; + } +} diff --git a/orc/src/main/java/com/netflix/iceberg/orc/OrcIterator.java b/orc/src/main/java/com/netflix/iceberg/orc/OrcIterator.java new file mode 100644 index 000000000..3519b8ce9 --- /dev/null +++ b/orc/src/main/java/com/netflix/iceberg/orc/OrcIterator.java @@ -0,0 +1,75 @@ +/* + * Copyright 2018 Hortonworks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.netflix.iceberg.orc; + +import org.apache.hadoop.fs.Path; +import org.apache.orc.RecordReader; +import org.apache.orc.TypeDescription; +import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch; + +import java.io.Closeable; +import java.io.IOException; +import java.util.Iterator; + +/** + * An adaptor so that the ORC RecordReader can be used as an Iterator. + * Because the same VectorizedRowBatch is reused on each call to next, + * it gets changed when hasNext or next is called. + */ +public class OrcIterator implements Iterator, Closeable { + private final Path filename; + private final RecordReader rows; + private final VectorizedRowBatch batch; + private boolean advanced = false; + + OrcIterator(Path filename, TypeDescription schema, RecordReader rows) { + this.filename = filename; + this.rows = rows; + this.batch = schema.createRowBatch(); + } + + @Override + public void close() throws IOException { + rows.close(); + } + + private void advance() { + if (!advanced) { + try { + rows.nextBatch(batch); + } catch (IOException e) { + throw new RuntimeException("Problem reading ORC file " + filename, e); + } + advanced = true; + } + } + + @Override + public boolean hasNext() { + advance(); + return batch.size > 0; + } + + @Override + public VectorizedRowBatch next() { + // make sure we have the next batch + advance(); + // mark it as used + advanced = false; + return batch; + } +} diff --git a/orc/src/main/java/com/netflix/iceberg/orc/TypeConversion.java b/orc/src/main/java/com/netflix/iceberg/orc/TypeConversion.java new file mode 100644 index 000000000..41db8ee7b --- /dev/null +++ b/orc/src/main/java/com/netflix/iceberg/orc/TypeConversion.java @@ -0,0 +1,189 @@ +/* + * Copyright 2018 Hortonworks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.netflix.iceberg.orc; + +import com.netflix.iceberg.Schema; +import com.netflix.iceberg.types.Type; +import com.netflix.iceberg.types.Types; +import org.apache.orc.TypeDescription; + +import java.util.ArrayList; +import java.util.List; + +public class TypeConversion { + + /** + * Convert a given Iceberg schema to ORC. + * @param schema the Iceberg schema to convert + * @param columnIds an output with the column ids + * @return the ORC schema + */ + public static TypeDescription toOrc(Schema schema, + ColumnIdMap columnIds) { + return toOrc(null, schema.asStruct(), columnIds); + } + + static TypeDescription toOrc(Integer fieldId, + Type type, + ColumnIdMap columnIds) { + TypeDescription result; + switch (type.typeId()) { + case BOOLEAN: + result = TypeDescription.createBoolean(); + break; + case INTEGER: + result = TypeDescription.createInt(); + break; + case LONG: + result = TypeDescription.createLong(); + break; + case FLOAT: + result = TypeDescription.createFloat(); + break; + case DOUBLE: + result = TypeDescription.createDouble(); + break; + case DATE: + result = TypeDescription.createDate(); + break; + case TIME: + result = TypeDescription.createInt(); + break; + case TIMESTAMP: + result = TypeDescription.createTimestamp(); + break; + case STRING: + result = TypeDescription.createString(); + break; + case UUID: + result = TypeDescription.createBinary(); + break; + case FIXED: + result = TypeDescription.createBinary(); + break; + case BINARY: + result = TypeDescription.createBinary(); + break; + case DECIMAL: { + Types.DecimalType decimal = (Types.DecimalType) type; + result = TypeDescription.createDecimal() + .withScale(decimal.scale()) + .withPrecision(decimal.precision()); + break; + } + case STRUCT: { + result = TypeDescription.createStruct(); + for(Types.NestedField field: type.asStructType().fields()) { + result.addField(field.name(), toOrc(field.fieldId(), field.type(), columnIds)); + } + break; + } + case LIST: { + Types.ListType list = (Types.ListType) type; + result = TypeDescription.createList(toOrc(list.elementId(), list.elementType(), + columnIds)); + break; + } + case MAP: { + Types.MapType map = (Types.MapType) type; + TypeDescription key = toOrc(map.keyId(),map.keyType(), columnIds); + result = TypeDescription.createMap(key, + toOrc(map.valueId(), map.valueType(), columnIds)); + break; + } + default: + throw new IllegalArgumentException("Unhandled type " + type.typeId()); + } + if (fieldId != null) { + columnIds.put(result, fieldId); + } + return result; + } + + /** + * Convert an ORC schema to an Iceberg schema. + * @param schema the ORC schema + * @param columnIds the column ids + * @return the Iceberg schema + */ + public Schema fromOrc(TypeDescription schema, ColumnIdMap columnIds) { + return new Schema(convertOrcToType(schema, columnIds).asStructType().fields()); + } + + Type convertOrcToType(TypeDescription schema, ColumnIdMap columnIds) { + switch (schema.getCategory()) { + case BOOLEAN: + return Types.BooleanType.get(); + case BYTE: + case SHORT: + case INT: + return Types.IntegerType.get(); + case LONG: + return Types.LongType.get(); + case FLOAT: + return Types.FloatType.get(); + case DOUBLE: + return Types.DoubleType.get(); + case STRING: + case CHAR: + case VARCHAR: + return Types.StringType.get(); + case BINARY: + return Types.BinaryType.get(); + case DATE: + return Types.DateType.get(); + case TIMESTAMP: + return Types.TimestampType.withoutZone(); + case DECIMAL: + return Types.DecimalType.of(schema.getPrecision(), schema.getScale()); + case STRUCT: { + List fieldNames = schema.getFieldNames(); + List fieldTypes = schema.getChildren(); + List fields = new ArrayList<>(fieldNames.size()); + for (int c=0; c < fieldNames.size(); ++c) { + String name = fieldNames.get(c); + TypeDescription type = fieldTypes.get(c); + fields.add(Types.NestedField.optional(columnIds.get(type), name, + convertOrcToType(type, columnIds))); + } + return Types.StructType.of(fields); + } + case LIST: { + TypeDescription child = schema.getChildren().get(0); + return Types.ListType.ofOptional(columnIds.get(child), + convertOrcToType(child, columnIds)); + } + case MAP: { + TypeDescription key = schema.getChildren().get(0); + TypeDescription value = schema.getChildren().get(1); + switch (key.getCategory()) { + case STRING: + case CHAR: + case VARCHAR: + return Types.MapType.ofOptional(columnIds.get(key), + columnIds.get(value), convertOrcToType(value, columnIds)); + default: + throw new IllegalArgumentException("Can't handle maps with " + key + + " as key."); + } + } + default: + // We don't have an answer for union types. + throw new IllegalArgumentException("Can't handle " + schema); + } + } +} diff --git a/settings.gradle b/settings.gradle index 14d74f95f..42d463ce8 100644 --- a/settings.gradle +++ b/settings.gradle @@ -19,6 +19,7 @@ include 'api' include 'common' include 'avro' include 'core' +include 'orc' include 'parquet' include 'spark' include 'runtime' @@ -27,6 +28,7 @@ project(':api').name = 'iceberg-api' project(':common').name = 'iceberg-common' project(':avro').name = 'iceberg-avro' project(':core').name = 'iceberg-core' +project(':orc').name = 'iceberg-orc' project(':parquet').name = 'iceberg-parquet' project(':spark').name = 'iceberg-spark' project(':runtime').name = 'iceberg-runtime' diff --git a/spark/src/main/java/com/netflix/iceberg/spark/data/SparkOrcReader.java b/spark/src/main/java/com/netflix/iceberg/spark/data/SparkOrcReader.java new file mode 100644 index 000000000..d3d2ec1d0 --- /dev/null +++ b/spark/src/main/java/com/netflix/iceberg/spark/data/SparkOrcReader.java @@ -0,0 +1,871 @@ +/* + * Copyright 2018 Hortonworks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.netflix.iceberg.spark.data; + +import com.netflix.iceberg.FileScanTask; +import com.netflix.iceberg.Schema; +import com.netflix.iceberg.io.InputFile; +import com.netflix.iceberg.orc.ColumnIdMap; +import com.netflix.iceberg.orc.ORC; +import com.netflix.iceberg.orc.OrcIterator; +import com.netflix.iceberg.orc.TypeConversion; +import org.apache.orc.TypeDescription; +import org.apache.orc.storage.common.type.FastHiveDecimal; +import org.apache.orc.storage.ql.exec.vector.BytesColumnVector; +import org.apache.orc.storage.ql.exec.vector.ColumnVector; +import org.apache.orc.storage.ql.exec.vector.DecimalColumnVector; +import org.apache.orc.storage.ql.exec.vector.DoubleColumnVector; +import org.apache.orc.storage.ql.exec.vector.ListColumnVector; +import org.apache.orc.storage.ql.exec.vector.LongColumnVector; +import org.apache.orc.storage.ql.exec.vector.MapColumnVector; +import org.apache.orc.storage.ql.exec.vector.StructColumnVector; +import org.apache.orc.storage.ql.exec.vector.TimestampColumnVector; +import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch; +import org.apache.orc.storage.serde2.io.DateWritable; +import org.apache.orc.storage.serde2.io.HiveDecimalWritable; +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter; +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.util.SerializableConfiguration; + +import java.io.Closeable; +import java.io.IOException; +import java.math.BigDecimal; +import java.sql.Timestamp; +import java.util.Iterator; +import java.util.List; + +/** + * Converts the OrcInterator, which returns ORC's VectorizedRowBatch to a + * set of Spark's UnsafeRows. + * + * It minimizes allocations by reusing most of the objects in the implementation. + */ +public class SparkOrcReader implements Iterator, Closeable { + private final static int INITIAL_SIZE = 128 * 1024; + private final OrcIterator reader; + private final TypeDescription orcSchema; + private final UnsafeRow row; + private final BufferHolder holder; + private final UnsafeRowWriter writer; + private int nextRow = 0; + private VectorizedRowBatch current = null; + private Converter[] converter; + + public SparkOrcReader(InputFile location, + FileScanTask task, + Schema readSchema, + SerializableConfiguration conf) { + ColumnIdMap columnIds = new ColumnIdMap(); + orcSchema = TypeConversion.toOrc(readSchema, columnIds); + reader = ORC.read(location) + .split(task.start(), task.length()) + .schema(readSchema) + .build(); + int numFields = readSchema.columns().size(); + row = new UnsafeRow(numFields); + holder = new BufferHolder(row, INITIAL_SIZE); + writer = new UnsafeRowWriter(holder, numFields); + converter = new Converter[numFields]; + for(int c=0; c < numFields; ++c) { + converter[c] = buildConverter(holder, orcSchema.getChildren().get(c)); + } + } + + @Override + public boolean hasNext() { + return (current != null && nextRow < current.size) || reader.hasNext(); + } + + @Override + public UnsafeRow next() { + if (current == null || nextRow >= current.size) { + current = reader.next(); + nextRow = 0; + } + // Reset the holder to start the buffer over again. + // BufferHolder.reset does the wrong thing... + holder.cursor = Platform.BYTE_ARRAY_OFFSET; + writer.reset(); + for(int c=0; c < current.cols.length; ++c) { + converter[c].convert(writer, c, current.cols[c], nextRow); + } + nextRow++; + return row; + } + + @Override + public void close() throws IOException { + reader.close(); + } + + static void printRow(SpecializedGetters row, TypeDescription schema) { + List children = schema.getChildren(); + System.out.print("{"); + for(int c = 0; c < children.size(); ++c) { + System.out.print("\"" + schema.getFieldNames().get(c) + "\": "); + printRow(row, c, children.get(c)); + } + System.out.print("}"); + } + + static void printRow(SpecializedGetters row, int ord, TypeDescription schema) { + switch (schema.getCategory()) { + case BOOLEAN: + System.out.print(row.getBoolean(ord)); + break; + case BYTE: + System.out.print(row.getByte(ord)); + break; + case SHORT: + System.out.print(row.getShort(ord)); + break; + case INT: + System.out.print(row.getInt(ord)); + break; + case LONG: + System.out.print(row.getLong(ord)); + break; + case FLOAT: + System.out.print(row.getFloat(ord)); + break; + case DOUBLE: + System.out.print(row.getDouble(ord)); + break; + case CHAR: + case VARCHAR: + case STRING: + System.out.print("\"" + row.getUTF8String(ord) + "\""); + break; + case BINARY: { + byte[] bin = row.getBinary(ord); + if (bin == null) { + System.out.print("null"); + } else { + System.out.print("["); + for (int i = 0; i < bin.length; ++i) { + if (i != 0) { + System.out.print(", "); + } + int v = bin[i] & 0xff; + if (v < 16) { + System.out.print("0" + Integer.toHexString(v)); + } else { + System.out.print(Integer.toHexString(v)); + } + } + System.out.print("]"); + } + break; + } + case DECIMAL: + System.out.print(row.getDecimal(ord, schema.getPrecision(), schema.getScale())); + break; + case DATE: + System.out.print("\"" + new DateWritable(row.getInt(ord)) + "\""); + break; + case TIMESTAMP: + System.out.print("\"" + new Timestamp(row.getLong(ord)) + "\""); + break; + case STRUCT: + printRow(row.getStruct(ord, schema.getChildren().size()), schema); + break; + case LIST: { + TypeDescription child = schema.getChildren().get(0); + System.out.print("["); + ArrayData list = row.getArray(ord); + for(int e=0; e < list.numElements(); ++e) { + if (e != 0) { + System.out.print(", "); + } + printRow(list, e, child); + } + System.out.print("]"); + break; + } + case MAP: { + TypeDescription keyType = schema.getChildren().get(0); + TypeDescription valueType = schema.getChildren().get(1); + MapData map = row.getMap(ord); + ArrayData keys = map.keyArray(); + ArrayData values = map.valueArray(); + System.out.print("["); + for(int e=0; e < map.numElements(); ++e) { + if (e != 0) { + System.out.print(", "); + } + printRow(keys, e, keyType); + System.out.print(": "); + printRow(values, e, valueType); + } + System.out.print("]"); + break; + } + default: + throw new IllegalArgumentException("Unhandled type " + schema); + } + } + static int getArrayElementSize(TypeDescription type) { + switch (type.getCategory()) { + case BOOLEAN: + case BYTE: + return 1; + case SHORT: + return 2; + case INT: + case FLOAT: + return 4; + default: + return 8; + } + } + + /** + * The common interface for converting from a ORC ColumnVector to a Spark + * UnsafeRow. UnsafeRows need two different interfaces for writers and thus + * we have two methods the first is for structs (UnsafeRowWriter) and the + * second is for lists and maps (UnsafeArrayWriter). If Spark adds a common + * interface similar to SpecializedGetters we could that and a single set of + * methods. + */ + interface Converter { + void convert(UnsafeRowWriter writer, int column, ColumnVector vector, int row); + void convert(UnsafeArrayWriter writer, int element, ColumnVector vector, + int row); + } + + private static class BooleanConverter implements Converter { + @Override + public void convert(UnsafeRowWriter writer, int column, ColumnVector vector, int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNullAt(column); + } else { + writer.write(column, ((LongColumnVector) vector).vector[row] != 0); + } + } + + @Override + public void convert(UnsafeArrayWriter writer, int element, + ColumnVector vector, int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNull(element); + } else { + writer.write(element, ((LongColumnVector) vector).vector[row] != 0); + } + } + } + + private static class ByteConverter implements Converter { + @Override + public void convert(UnsafeRowWriter writer, int column, ColumnVector vector, + int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNullAt(column); + } else { + writer.write(column, (byte) ((LongColumnVector) vector).vector[row]); + } + } + + @Override + public void convert(UnsafeArrayWriter writer, int element, + ColumnVector vector, int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNull(element); + } else { + writer.write(element, (byte) ((LongColumnVector) vector).vector[row]); + } + } + } + + private static class ShortConverter implements Converter { + @Override + public void convert(UnsafeRowWriter writer, int column, ColumnVector vector, + int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNullAt(column); + } else { + writer.write(column, (short) ((LongColumnVector) vector).vector[row]); + } + } + + @Override + public void convert(UnsafeArrayWriter writer, int element, + ColumnVector vector, int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNull(element); + } else { + writer.write(element, (short) ((LongColumnVector) vector).vector[row]); + } + } + } + + private static class IntConverter implements Converter { + @Override + public void convert(UnsafeRowWriter writer, int column, ColumnVector vector, + int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNullAt(column); + } else { + writer.write(column, (int) ((LongColumnVector) vector).vector[row]); + } + } + + @Override + public void convert(UnsafeArrayWriter writer, int element, + ColumnVector vector, int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNull(element); + } else { + writer.write(element, (int) ((LongColumnVector) vector).vector[row]); + } + } + } + + private static class LongConverter implements Converter { + @Override + public void convert(UnsafeRowWriter writer, int column, ColumnVector vector, + int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNullAt(column); + } else { + writer.write(column, ((LongColumnVector) vector).vector[row]); + } + } + + @Override + public void convert(UnsafeArrayWriter writer, int element, + ColumnVector vector, int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNull(element); + } else { + writer.write(element, ((LongColumnVector) vector).vector[row]); + } + } + } + + private static class FloatConverter implements Converter { + @Override + public void convert(UnsafeRowWriter writer, int column, ColumnVector vector, + int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNullAt(column); + } else { + writer.write(column, (float) ((DoubleColumnVector) vector).vector[row]); + } + } + + @Override + public void convert(UnsafeArrayWriter writer, int element, + ColumnVector vector, int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNull(element); + } else { + writer.write(element, (float) ((DoubleColumnVector) vector).vector[row]); + } + } + } + + private static class DoubleConverter implements Converter { + @Override + public void convert(UnsafeRowWriter writer, int column, ColumnVector vector, + int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNullAt(column); + } else { + writer.write(column, ((DoubleColumnVector) vector).vector[row]); + } + } + + @Override + public void convert(UnsafeArrayWriter writer, int element, + ColumnVector vector, int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNull(element); + } else { + writer.write(element, ((DoubleColumnVector) vector).vector[row]); + } + } + } + + private static class TimestampConverter implements Converter { + + private long convert(TimestampColumnVector vector, int row) { + // compute microseconds past 1970. + long micros = (vector.time[row]/1000) * 1_000_000 + vector.nanos[row] / 1000; + return micros; + } + + @Override + public void convert(UnsafeRowWriter writer, int column, ColumnVector vector, + int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNullAt(column); + } else { + writer.write(column, convert((TimestampColumnVector) vector, row)); + } + } + + @Override + public void convert(UnsafeArrayWriter writer, int element, + ColumnVector vector, int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNull(element); + } else { + writer.write(element, convert((TimestampColumnVector) vector, row)); + } + } + } + + /** + * UnsafeArrayWriter doesn't have a binary form that lets the user pass an + * offset and length, so I've added one here. It is the minor tweak of the + * UnsafeArrayWriter.write(int, byte[]) method. + * @param holder the BufferHolder where the bytes are being written + * @param writer the UnsafeArrayWriter + * @param ordinal the element that we are writing into + * @param input the input bytes + * @param offset the first byte from input + * @param length the number of bytes to write + */ + static void write(BufferHolder holder, UnsafeArrayWriter writer, int ordinal, + byte[] input, int offset, int length) { + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(length); + + // grow the global buffer before writing data. + holder.grow(roundedSize); + + if ((length & 0x07) > 0) { + Platform.putLong(holder.buffer, holder.cursor + ((length >> 3) << 3), 0L); + } + + // Write the bytes to the variable length portion. + Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET + offset, + holder.buffer, holder.cursor, length); + + writer.setOffsetAndSize(ordinal, holder.cursor, length); + + // move the cursor forward. + holder.cursor += roundedSize; + } + + private static class BinaryConverter implements Converter { + private final BufferHolder holder; + + BinaryConverter(BufferHolder holder) { + this.holder = holder; + } + + @Override + public void convert(UnsafeRowWriter writer, int column, ColumnVector vector, + int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNullAt(column); + } else { + BytesColumnVector v = (BytesColumnVector) vector; + writer.write(column, v.vector[row], v.start[row], v.length[row]); + } + } + + @Override + public void convert(UnsafeArrayWriter writer, int element, + ColumnVector vector, int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNull(element); + } else { + BytesColumnVector v = (BytesColumnVector) vector; + write(holder, writer, element, v.vector[row], v.start[row], + v.length[row]); + } + } + } + + /** + * This hack is to get the unscaled value (for precision <= 18) quickly. + * This can be replaced when we upgrade to storage-api 2.5.0. + */ + static class DecimalHack extends FastHiveDecimal { + long unscaledLong(FastHiveDecimal value) { + fastSet(value); + return fastSignum * fast1 * 10_000_000_000_000_000L + fast0; + } + } + + private static class Decimal18Converter implements Converter { + final DecimalHack hack = new DecimalHack(); + final int precision; + final int scale; + + Decimal18Converter(int precision, int scale) { + this.precision = precision; + this.scale = scale; + } + + @Override + public void convert(UnsafeRowWriter writer, int column, ColumnVector vector, + int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNullAt(column); + } else { + HiveDecimalWritable v = ((DecimalColumnVector) vector).vector[row]; + writer.write(column, + new Decimal().set(hack.unscaledLong(v), precision, v.scale()), + precision, scale); + } + } + + @Override + public void convert(UnsafeArrayWriter writer, int element, + ColumnVector vector, int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNull(element); + } else { + HiveDecimalWritable v = ((DecimalColumnVector) vector).vector[row]; + writer.write(element, + new Decimal().set(hack.unscaledLong(v), precision, v.scale()), + precision, scale); + } + } + } + + private static class Decimal38Converter implements Converter { + final int precision; + final int scale; + + Decimal38Converter(int precision, int scale) { + this.precision = precision; + this.scale = scale; + } + + @Override + public void convert(UnsafeRowWriter writer, int column, ColumnVector vector, + int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNullAt(column); + } else { + BigDecimal v = ((DecimalColumnVector) vector).vector[row] + .getHiveDecimal().bigDecimalValue(); + writer.write(column, + new Decimal().set(new scala.math.BigDecimal(v), precision, scale), + precision, scale); + } + } + + @Override + public void convert(UnsafeArrayWriter writer, int element, + ColumnVector vector, int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNull(element); + } else { + BigDecimal v = ((DecimalColumnVector) vector).vector[row] + .getHiveDecimal().bigDecimalValue(); + writer.write(element, + new Decimal().set(new scala.math.BigDecimal(v), precision, scale), + precision, scale); + } + } + } + + private static class StructConverter implements Converter { + private final BufferHolder holder; + private final Converter[] children; + private final UnsafeRowWriter childWriter; + + StructConverter(BufferHolder holder, TypeDescription schema) { + this.holder = holder; + children = new Converter[schema.getChildren().size()]; + for(int c=0; c < children.length; ++c) { + children[c] = buildConverter(holder, schema.getChildren().get(c)); + } + childWriter = new UnsafeRowWriter(holder, children.length); + } + + int writeStruct(StructColumnVector vector, int row) { + int start = holder.cursor; + childWriter.reset(); + for(int c=0; c < children.length; ++c) { + children[c].convert(childWriter, c, vector.fields[c], row); + } + return start; + } + + @Override + public void convert(UnsafeRowWriter writer, int column, ColumnVector vector, int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNullAt(column); + } else { + int start = writeStruct((StructColumnVector) vector, row); + writer.setOffsetAndSize(column, start, holder.cursor - start); + } + } + + @Override + public void convert(UnsafeArrayWriter writer, int element, + ColumnVector vector, int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNull(element); + } else { + int start = writeStruct((StructColumnVector) vector, row); + writer.setOffsetAndSize(element, start, holder.cursor - start); + } + } + } + + private static class ListConverter implements Converter { + private final BufferHolder holder; + private final Converter children; + private final UnsafeArrayWriter childWriter; + private final int elementSize; + + ListConverter(BufferHolder holder, TypeDescription schema) { + this.holder = holder; + TypeDescription child = schema.getChildren().get(0); + children = buildConverter(holder, child); + childWriter = new UnsafeArrayWriter(); + elementSize = getArrayElementSize(child); + } + + int writeList(ListColumnVector v, int row) { + int offset = (int) v.offsets[row]; + int length = (int) v.lengths[row]; + int start = holder.cursor; + childWriter.initialize(holder, length, elementSize); + for(int c=0; c < length; ++c) { + children.convert(childWriter, c, v.child, offset + c); + } + return start; + } + + @Override + public void convert(UnsafeRowWriter writer, int column, ColumnVector vector, + int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNullAt(column); + } else { + int start = writeList((ListColumnVector) vector, row); + writer.setOffsetAndSize(column, start, holder.cursor - start); + } + } + + @Override + public void convert(UnsafeArrayWriter writer, int element, + ColumnVector vector, int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNull(element); + } else { + int start = writeList((ListColumnVector) vector, row); + writer.setOffsetAndSize(element, start, holder.cursor - start); + } + } + } + + private static class MapConverter implements Converter { + private final BufferHolder holder; + private final Converter keyConvert; + private final Converter valueConvert; + private final UnsafeArrayWriter childWriter; + private final int keySize; + private final int valueSize; + + MapConverter(BufferHolder holder, TypeDescription schema) { + this.holder = holder; + TypeDescription keyType = schema.getChildren().get(0); + TypeDescription valueType = schema.getChildren().get(1); + keyConvert = buildConverter(holder, keyType); + keySize = getArrayElementSize(keyType); + valueConvert = buildConverter(holder, valueType); + valueSize = getArrayElementSize(valueType); + childWriter = new UnsafeArrayWriter(); + } + + int writeMap(MapColumnVector v, int row) { + int offset = (int) v.offsets[row]; + int length = (int) v.lengths[row]; + int start = holder.cursor; + // save room for the key size + final int KEY_SIZE_BYTES = 8; + holder.grow(KEY_SIZE_BYTES); + holder.cursor += KEY_SIZE_BYTES; + // serialize the keys + childWriter.initialize(holder, length, keySize); + for(int c=0; c < length; ++c) { + keyConvert.convert(childWriter, c, v.keys, offset + c); + } + // store the serialized size of the keys + Platform.putLong(holder.buffer, start, holder.cursor - start - KEY_SIZE_BYTES); + // serialize the values + childWriter.initialize(holder, length, valueSize); + for(int c=0; c < length; ++c) { + valueConvert.convert(childWriter, c, v.values, offset + c); + } + return start; + } + + @Override + public void convert(UnsafeRowWriter writer, int column, ColumnVector vector, + int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNullAt(column); + } else { + int start = writeMap((MapColumnVector) vector, row); + writer.setOffsetAndSize(column, start, holder.cursor - start); + } + } + + @Override + public void convert(UnsafeArrayWriter writer, int element, + ColumnVector vector, int row) { + if (vector.isRepeating) { + row = 0; + } + if (!vector.noNulls && vector.isNull[row]) { + writer.setNull(element); + } else { + int start = writeMap((MapColumnVector) vector, row); + writer.setOffsetAndSize(element, start, holder.cursor - start); + } + } + } + + static Converter buildConverter(BufferHolder holder, TypeDescription schema) { + switch (schema.getCategory()) { + case BOOLEAN: + return new BooleanConverter(); + case BYTE: + return new ByteConverter(); + case SHORT: + return new ShortConverter(); + case DATE: + case INT: + return new IntConverter(); + case LONG: + return new LongConverter(); + case FLOAT: + return new FloatConverter(); + case DOUBLE: + return new DoubleConverter(); + case TIMESTAMP: + return new TimestampConverter(); + case DECIMAL: + if (schema.getPrecision() <= Decimal.MAX_LONG_DIGITS()) { + return new Decimal18Converter(schema.getPrecision(), schema.getScale()); + } else { + return new Decimal38Converter(schema.getPrecision(), schema.getScale()); + } + case BINARY: + case STRING: + case CHAR: + case VARCHAR: + return new BinaryConverter(holder); + case STRUCT: + return new StructConverter(holder, schema); + case LIST: + return new ListConverter(holder, schema); + case MAP: + return new MapConverter(holder, schema); + default: + throw new IllegalArgumentException("Unhandled type " + schema); + } + } +} diff --git a/spark/src/main/java/com/netflix/iceberg/spark/data/SparkOrcWriter.java b/spark/src/main/java/com/netflix/iceberg/spark/data/SparkOrcWriter.java new file mode 100644 index 000000000..e8eb4da96 --- /dev/null +++ b/spark/src/main/java/com/netflix/iceberg/spark/data/SparkOrcWriter.java @@ -0,0 +1,444 @@ +/* + * Copyright 2018 Hortonworks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.netflix.iceberg.spark.data; + +import com.netflix.iceberg.Metrics; +import com.netflix.iceberg.io.FileAppender; +import com.netflix.iceberg.orc.OrcFileAppender; +import org.apache.orc.TypeDescription; +import org.apache.orc.storage.common.type.HiveDecimal; +import org.apache.orc.storage.ql.exec.vector.BytesColumnVector; +import org.apache.orc.storage.ql.exec.vector.ColumnVector; +import org.apache.orc.storage.ql.exec.vector.DecimalColumnVector; +import org.apache.orc.storage.ql.exec.vector.DoubleColumnVector; +import org.apache.orc.storage.ql.exec.vector.ListColumnVector; +import org.apache.orc.storage.ql.exec.vector.LongColumnVector; +import org.apache.orc.storage.ql.exec.vector.MapColumnVector; +import org.apache.orc.storage.ql.exec.vector.StructColumnVector; +import org.apache.orc.storage.ql.exec.vector.TimestampColumnVector; +import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.unsafe.types.UTF8String; + +import java.io.IOException; +import java.sql.Timestamp; +import java.util.List; + +/** + * This class acts as an adaptor from an OrcFileAppender to a + * FileAppender<InternalRow>. + */ +public class SparkOrcWriter implements FileAppender { + private final static int BATCH_SIZE = 1024; + private final VectorizedRowBatch batch; + private final OrcFileAppender writer; + private final Converter[] converters; + + public SparkOrcWriter(OrcFileAppender writer) { + TypeDescription schema = writer.getSchema(); + batch = schema.createRowBatch(BATCH_SIZE); + this.writer = writer; + converters = buildConverters(schema); + } + + /** + * The interface for the conversion from Spark's SpecializedGetters to + * ORC's ColumnVectors. + */ + interface Converter { + /** + * Take a value from the Spark data value and add it to the ORC output. + * @param rowId the row in the ColumnVector + * @param column either the column number or element number + * @param data either an InternalRow or ArrayData + * @param output the ColumnVector to put the value into + */ + void addValue(int rowId, int column, SpecializedGetters data, + ColumnVector output); + } + + static class BooleanConverter implements Converter { + public void addValue(int rowId, int column, SpecializedGetters data, + ColumnVector output) { + if (data.isNullAt(column)) { + output.noNulls = false; + output.isNull[rowId] = true; + } else { + output.isNull[rowId] = false; + ((LongColumnVector) output).vector[rowId] = data.getBoolean(column) ? 1 : 0; + } + } + } + + static class ByteConverter implements Converter { + public void addValue(int rowId, int column, SpecializedGetters data, + ColumnVector output) { + if (data.isNullAt(column)) { + output.noNulls = false; + output.isNull[rowId] = true; + } else { + output.isNull[rowId] = false; + ((LongColumnVector) output).vector[rowId] = data.getByte(column); + } + } + } + + static class ShortConverter implements Converter { + public void addValue(int rowId, int column, SpecializedGetters data, + ColumnVector output) { + if (data.isNullAt(column)) { + output.noNulls = false; + output.isNull[rowId] = true; + } else { + output.isNull[rowId] = false; + ((LongColumnVector) output).vector[rowId] = data.getShort(column); + } + } + } + + static class IntConverter implements Converter { + public void addValue(int rowId, int column, SpecializedGetters data, + ColumnVector output) { + if (data.isNullAt(column)) { + output.noNulls = false; + output.isNull[rowId] = true; + } else { + output.isNull[rowId] = false; + ((LongColumnVector) output).vector[rowId] = data.getInt(column); + } + } + } + + static class LongConverter implements Converter { + public void addValue(int rowId, int column, SpecializedGetters data, + ColumnVector output) { + if (data.isNullAt(column)) { + output.noNulls = false; + output.isNull[rowId] = true; + } else { + output.isNull[rowId] = false; + ((LongColumnVector) output).vector[rowId] = data.getLong(column); + } + } + } + + static class FloatConverter implements Converter { + public void addValue(int rowId, int column, SpecializedGetters data, + ColumnVector output) { + if (data.isNullAt(column)) { + output.noNulls = false; + output.isNull[rowId] = true; + } else { + output.isNull[rowId] = false; + ((DoubleColumnVector) output).vector[rowId] = data.getFloat(column); + } + } + } + + static class DoubleConverter implements Converter { + public void addValue(int rowId, int column, SpecializedGetters data, + ColumnVector output) { + if (data.isNullAt(column)) { + output.noNulls = false; + output.isNull[rowId] = true; + } else { + output.isNull[rowId] = false; + ((DoubleColumnVector) output).vector[rowId] = data.getDouble(column); + } + } + } + + static class StringConverter implements Converter { + public void addValue(int rowId, int column, SpecializedGetters data, + ColumnVector output) { + if (data.isNullAt(column)) { + output.noNulls = false; + output.isNull[rowId] = true; + } else { + output.isNull[rowId] = false; + byte[] value = data.getUTF8String(column).getBytes(); + ((BytesColumnVector) output).setRef(rowId, value, 0, value.length); + } + } + } + + static class BytesConverter implements Converter { + public void addValue(int rowId, int column, SpecializedGetters data, + ColumnVector output) { + if (data.isNullAt(column)) { + output.noNulls = false; + output.isNull[rowId] = true; + } else { + output.isNull[rowId] = false; + // getBinary always makes a copy, so we don't need to worry about it + // being changed behind our back. + byte[] value = data.getBinary(column); + ((BytesColumnVector) output).setRef(rowId, value, 0, value.length); + } + } + } + + static class TimestampConverter implements Converter { + // The JDK has a bug where timestamps before 1970 with times like: + // HH:MM:SS.000XXX with non-zero XXX are off by 1 second. + private static final boolean NO_TIMESTAMP_BUG; + static { + Timestamp ts1 = Timestamp.valueOf("1969-12-25 12:34:56.000234"); + Timestamp ts2 = Timestamp.valueOf("1969-12-25 12:34:56.001234"); + NO_TIMESTAMP_BUG = ts1.getTime()/1000 == ts2.getTime()/1000; + } + + public void addValue(int rowId, int column, SpecializedGetters data, + ColumnVector output) { + if (data.isNullAt(column)) { + output.noNulls = false; + output.isNull[rowId] = true; + } else { + output.isNull[rowId] = false; + TimestampColumnVector cv = (TimestampColumnVector) output; + long micros = data.getLong(column); + cv.time[rowId] = (micros / 1_000_000) * 1000; + int nanos = (int) (micros % 1_000_000) * 1000; + if (nanos < 0) { + nanos += 1_000_000_000; + if (NO_TIMESTAMP_BUG || nanos >= 1_000_000) { + cv.time[rowId] -= 1000; + } + } + cv.nanos[rowId] = nanos; + } + } + } + + static class Decimal18Converter implements Converter { + private final int precision; + private final int scale; + + Decimal18Converter(TypeDescription schema) { + precision = schema.getPrecision(); + scale = schema.getScale(); + } + + public void addValue(int rowId, int column, SpecializedGetters data, + ColumnVector output) { + if (data.isNullAt(column)) { + output.noNulls = false; + output.isNull[rowId] = true; + } else { + output.isNull[rowId] = false; + ((DecimalColumnVector) output).vector[rowId].setFromLongAndScale( + data.getDecimal(column, precision, scale).toUnscaledLong(), scale); + } + } + } + + static class Decimal38Converter implements Converter { + private final int precision; + private final int scale; + + Decimal38Converter(TypeDescription schema) { + precision = schema.getPrecision(); + scale = schema.getScale(); + } + public void addValue(int rowId, int column, SpecializedGetters data, + ColumnVector output) { + if (data.isNullAt(column)) { + output.noNulls = false; + output.isNull[rowId] = true; + } else { + output.isNull[rowId] = false; + ((DecimalColumnVector) output).vector[rowId].set( + HiveDecimal.create(data.getDecimal(column, precision, scale) + .toJavaBigDecimal())); + } + } + } + + static class StructConverter implements Converter { + private final Converter[] children; + + StructConverter(TypeDescription schema) { + children = new Converter[schema.getChildren().size()]; + for(int c=0; c < children.length; ++c) { + children[c] = buildConverter(schema.getChildren().get(c)); + } + } + + public void addValue(int rowId, int column, SpecializedGetters data, + ColumnVector output) { + if (data.isNullAt(column)) { + output.noNulls = false; + output.isNull[rowId] = true; + } else { + output.isNull[rowId] = false; + InternalRow value = data.getStruct(column, children.length); + StructColumnVector cv = (StructColumnVector) output; + for(int c=0; c < children.length; ++c) { + children[c].addValue(rowId, c, value, cv.fields[c]); + } + } + } + } + + static class ListConverter implements Converter { + private final Converter children; + + ListConverter(TypeDescription schema) { + children = buildConverter(schema.getChildren().get(0)); + } + + public void addValue(int rowId, int column, SpecializedGetters data, + ColumnVector output) { + if (data.isNullAt(column)) { + output.noNulls = false; + output.isNull[rowId] = true; + } else { + output.isNull[rowId] = false; + ArrayData value = data.getArray(column); + ListColumnVector cv = (ListColumnVector) output; + // record the length and start of the list elements + cv.lengths[rowId] = value.numElements(); + cv.offsets[rowId] = cv.childCount; + cv.childCount += cv.lengths[rowId]; + // make sure the child is big enough + cv.child.ensureSize(cv.childCount, true); + // Add each element + for(int e=0; e < cv.lengths[rowId]; ++e) { + children.addValue((int) (e + cv.offsets[rowId]), e, value, cv.child); + } + } + } + } + + static class MapConverter implements Converter { + private final Converter keyConverter; + private final Converter valueConverter; + + MapConverter(TypeDescription schema) { + keyConverter = buildConverter(schema.getChildren().get(0)); + valueConverter = buildConverter(schema.getChildren().get(1)); + } + + public void addValue(int rowId, int column, SpecializedGetters data, + ColumnVector output) { + if (data.isNullAt(column)) { + output.noNulls = false; + output.isNull[rowId] = true; + } else { + output.isNull[rowId] = false; + MapData map = data.getMap(column); + ArrayData key = map.keyArray(); + ArrayData value = map.valueArray(); + MapColumnVector cv = (MapColumnVector) output; + // record the length and start of the list elements + cv.lengths[rowId] = value.numElements(); + cv.offsets[rowId] = cv.childCount; + cv.childCount += cv.lengths[rowId]; + // make sure the child is big enough + cv.keys.ensureSize(cv.childCount, true); + cv.values.ensureSize(cv.childCount, true); + // Add each element + for(int e=0; e < cv.lengths[rowId]; ++e) { + int pos = (int)(e + cv.offsets[rowId]); + keyConverter.addValue(pos, e, key, cv.keys); + valueConverter.addValue(pos, e, value, cv.values); + } + } + } + } + + private static Converter buildConverter(TypeDescription schema) { + switch (schema.getCategory()) { + case BOOLEAN: + return new BooleanConverter(); + case BYTE: + return new ByteConverter(); + case SHORT: + return new ShortConverter(); + case DATE: + case INT: + return new IntConverter(); + case LONG: + return new LongConverter(); + case FLOAT: + return new FloatConverter(); + case DOUBLE: + return new DoubleConverter(); + case BINARY: + return new BytesConverter(); + case STRING: + case CHAR: + case VARCHAR: + return new StringConverter(); + case DECIMAL: + return schema.getPrecision() <= 18 + ? new Decimal18Converter(schema) + : new Decimal38Converter(schema); + case TIMESTAMP: + return new TimestampConverter(); + case STRUCT: + return new StructConverter(schema); + case LIST: + return new ListConverter(schema); + case MAP: + return new MapConverter(schema); + } + throw new IllegalArgumentException("Unhandled type " + schema); + } + + private static Converter[] buildConverters(TypeDescription schema) { + if (schema.getCategory() != TypeDescription.Category.STRUCT) { + throw new IllegalArgumentException("Top level must be a struct " + schema); + } + List children = schema.getChildren(); + Converter[] result = new Converter[children.size()]; + for(int c=0; c < children.size(); ++c) { + result[c] = buildConverter(children.get(c)); + } + return result; + } + + @Override + public void add(InternalRow datum) { + int row = batch.size++; + for(int c=0; c < converters.length; ++c) { + converters[c].addValue(row, c, datum, batch.cols[c]); + } + if (batch.size == BATCH_SIZE) { + writer.add(batch); + batch.reset(); + } + } + + @Override + public Metrics metrics() { + return writer.metrics(); + } + + @Override + public void close() throws IOException { + if (batch.size > 0) { + writer.add(batch); + batch.reset(); + } + writer.close(); + } +} diff --git a/spark/src/main/java/com/netflix/iceberg/spark/source/Reader.java b/spark/src/main/java/com/netflix/iceberg/spark/source/Reader.java index 5624446f0..fa6cd8c8b 100644 --- a/spark/src/main/java/com/netflix/iceberg/spark/source/Reader.java +++ b/spark/src/main/java/com/netflix/iceberg/spark/source/Reader.java @@ -41,6 +41,7 @@ import com.netflix.iceberg.spark.SparkFilters; import com.netflix.iceberg.spark.SparkSchemaUtil; import com.netflix.iceberg.spark.data.SparkAvroReader; +import com.netflix.iceberg.spark.data.SparkOrcReader; import com.netflix.iceberg.types.TypeUtil; import com.netflix.iceberg.types.Types; import org.apache.hadoop.conf.Configuration; @@ -342,6 +343,10 @@ public DataReader createDataReader() { APPLY_PROJECTION.bind(projection(finalSchema, iterSchema))::invoke); break; + case ORC: + unsafeRowIterator = new SparkOrcReader(location, task, finalSchema, conf); + break; + default: throw new UnsupportedOperationException("Cannot read unknown format: " + file.format()); } diff --git a/spark/src/main/java/com/netflix/iceberg/spark/source/Writer.java b/spark/src/main/java/com/netflix/iceberg/spark/source/Writer.java index 941cd89f4..588e2b5c4 100644 --- a/spark/src/main/java/com/netflix/iceberg/spark/source/Writer.java +++ b/spark/src/main/java/com/netflix/iceberg/spark/source/Writer.java @@ -36,8 +36,10 @@ import com.netflix.iceberg.io.FileAppender; import com.netflix.iceberg.io.InputFile; import com.netflix.iceberg.io.OutputFile; +import com.netflix.iceberg.orc.ORC; import com.netflix.iceberg.parquet.Parquet; import com.netflix.iceberg.spark.data.SparkAvroWriter; +import com.netflix.iceberg.spark.data.SparkOrcWriter; import com.netflix.iceberg.util.Tasks; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; @@ -242,6 +244,13 @@ public FileAppender newAppender(OutputFile file, FileFormat format) { .named("table") .build(); + case ORC: { + @SuppressWarnings("unchecked") + SparkOrcWriter writer = new SparkOrcWriter(ORC.write(file) + .schema(schema) + .build()); + return (FileAppender) writer; + } default: throw new UnsupportedOperationException("Cannot write unknown format: " + format); } diff --git a/spark/src/test/java/com/netflix/iceberg/spark/data/RandomData.java b/spark/src/test/java/com/netflix/iceberg/spark/data/RandomData.java index fde6a22fe..8cfd1a70a 100644 --- a/spark/src/test/java/com/netflix/iceberg/spark/data/RandomData.java +++ b/spark/src/test/java/com/netflix/iceberg/spark/data/RandomData.java @@ -25,9 +25,19 @@ import com.netflix.iceberg.types.Types; import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericData.Record; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; + import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Random; @@ -103,7 +113,7 @@ public Object map(Types.MapType map, Supplier valueResult) { Map result = Maps.newLinkedHashMap(); for (int i = 0; i < numEntries; i += 1) { - String key = randomString(random) + i; // add i to ensure no collisions + String key = randomString(random).toString() + i; // add i to ensure no collisions // return null 5% of the time when the value is optional if (map.isValueOptional() && random.nextInt(20) == 1) { result.put(key, null); @@ -117,138 +127,162 @@ public Object map(Types.MapType map, Supplier valueResult) { @Override public Object primitive(Type.PrimitiveType primitive) { - int choice = random.nextInt(20); - + Object result = generatePrimitive(primitive, random); + // For the primitives that Avro needs a different type than Spark, fix + // them here. switch (primitive.typeId()) { - case BOOLEAN: - return choice < 10; - - case INTEGER: - switch (choice) { - case 1: - return Integer.MIN_VALUE; - case 2: - return Integer.MAX_VALUE; - case 3: - return 0; - default: - return random.nextInt(); - } - - case LONG: - switch (choice) { - case 1: - return Long.MIN_VALUE; - case 2: - return Long.MAX_VALUE; - case 3: - return 0L; - default: - return random.nextLong(); - } - - case FLOAT: - switch (choice) { - case 1: - return Float.MIN_VALUE; - case 2: - return -Float.MIN_VALUE; - case 3: - return Float.MAX_VALUE; - case 4: - return -Float.MAX_VALUE; - case 5: - return Float.NEGATIVE_INFINITY; - case 6: - return Float.POSITIVE_INFINITY; - case 7: - return 0.0F; - case 8: - return Float.NaN; - default: - return random.nextFloat(); - } - - case DOUBLE: - switch (choice) { - case 1: - return Double.MIN_VALUE; - case 2: - return -Double.MIN_VALUE; - case 3: - return Double.MAX_VALUE; - case 4: - return -Double.MAX_VALUE; - case 5: - return Double.NEGATIVE_INFINITY; - case 6: - return Double.POSITIVE_INFINITY; - case 7: - return 0.0D; - case 8: - return Double.NaN; - default: - return random.nextDouble(); - } - - case DATE: - // this will include negative values (dates before 1970-01-01) - return random.nextInt() % ABOUT_380_YEARS_IN_DAYS; - - case TIME: - return (random.nextLong() & Integer.MAX_VALUE) % ONE_DAY_IN_MICROS; - - case TIMESTAMP: - return random.nextLong(); - case STRING: - return randomString(random); - - case UUID: - byte[] uuidBytes = new byte[16]; - random.nextBytes(uuidBytes); - // this will hash the uuidBytes - return UUID.nameUUIDFromBytes(uuidBytes); - + return ((UTF8String) result).toString(); case FIXED: - byte[] fixed = new byte[((Types.FixedType) primitive).length()]; - random.nextBytes(fixed); - return new GenericData.Fixed(typeToSchema.get(primitive), fixed); - + return new GenericData.Fixed(typeToSchema.get(primitive), + (byte[]) result); case BINARY: - int length = random.nextInt(50); - ByteBuffer buffer = ByteBuffer.allocate(length); - random.nextBytes(buffer.array()); - return buffer; - + return ByteBuffer.wrap((byte[]) result); + case UUID: + return UUID.nameUUIDFromBytes((byte[]) result); case DECIMAL: - Types.DecimalType decimal = (Types.DecimalType) primitive; - return new BigDecimal(randomUnscaled(decimal.precision(), random), decimal.scale()); - + return ((Decimal) result).toJavaBigDecimal(); default: - throw new IllegalArgumentException( - "Cannot generate random value for unknown type: " + primitive); + return result; } } } - private static int ABOUT_380_YEARS_IN_DAYS = 380 * 365; - private static long ONE_DAY_IN_MICROS = 24 * 60 * 60 * 1_000_000; - private static String CHARS = + public static Object generatePrimitive(Type.PrimitiveType primitive, + Random random) { + int choice = random.nextInt(20); + + switch (primitive.typeId()) { + case BOOLEAN: + return choice < 10; + + case INTEGER: + switch (choice) { + case 1: + return Integer.MIN_VALUE; + case 2: + return Integer.MAX_VALUE; + case 3: + return 0; + default: + return random.nextInt(); + } + + case LONG: + switch (choice) { + case 1: + return Long.MIN_VALUE; + case 2: + return Long.MAX_VALUE; + case 3: + return 0L; + default: + return random.nextLong(); + } + + case FLOAT: + switch (choice) { + case 1: + return Float.MIN_VALUE; + case 2: + return -Float.MIN_VALUE; + case 3: + return Float.MAX_VALUE; + case 4: + return -Float.MAX_VALUE; + case 5: + return Float.NEGATIVE_INFINITY; + case 6: + return Float.POSITIVE_INFINITY; + case 7: + return 0.0F; + case 8: + return Float.NaN; + default: + return random.nextFloat(); + } + + case DOUBLE: + switch (choice) { + case 1: + return Double.MIN_VALUE; + case 2: + return -Double.MIN_VALUE; + case 3: + return Double.MAX_VALUE; + case 4: + return -Double.MAX_VALUE; + case 5: + return Double.NEGATIVE_INFINITY; + case 6: + return Double.POSITIVE_INFINITY; + case 7: + return 0.0D; + case 8: + return Double.NaN; + default: + return random.nextDouble(); + } + + case DATE: + // this will include negative values (dates before 1970-01-01) + return random.nextInt() % ABOUT_380_YEARS_IN_DAYS; + + case TIME: + return (random.nextLong() & Integer.MAX_VALUE) % ONE_DAY_IN_MICROS; + + case TIMESTAMP: + return random.nextLong() % FIFTY_YEARS_IN_MICROS; + + case STRING: + return randomString(random); + + case UUID: + byte[] uuidBytes = new byte[16]; + random.nextBytes(uuidBytes); + // this will hash the uuidBytes + return uuidBytes; + + case FIXED: + byte[] fixed = new byte[((Types.FixedType) primitive).length()]; + random.nextBytes(fixed); + return fixed; + + case BINARY: + byte[] binary = new byte[random.nextInt(50)]; + random.nextBytes(binary); + return binary; + + case DECIMAL: + Types.DecimalType type = (Types.DecimalType) primitive; + BigInteger unscaled = randomUnscaled(type.precision(), random); + return Decimal.apply(new BigDecimal(unscaled, type.scale())); + + default: + throw new IllegalArgumentException( + "Cannot generate random value for unknown type: " + primitive); + } + } + + private static final long FIFTY_YEARS_IN_MICROS = + (50L * (365 * 3 + 366) * 24 * 60 * 60 * 1_000_000) / 4; + private static final int ABOUT_380_YEARS_IN_DAYS = 380 * 365; + private static final long ONE_DAY_IN_MICROS = 24 * 60 * 60 * 1_000_000L; + private static final String CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-.!?"; - private static String randomString(Random random) { + private static UTF8String randomString(Random random) { int length = random.nextInt(50); - StringBuilder sb = new StringBuilder(); + byte[] buffer = new byte[length]; for (int i = 0; i < length; i += 1) { - sb.append(CHARS.charAt(random.nextInt(CHARS.length()))); + buffer[i] = (byte) CHARS.charAt(random.nextInt(CHARS.length())); } - return sb.toString(); + return UTF8String.fromBytes(buffer); } - private static String DIGITS = "0123456789"; + private static final String DIGITS = "0123456789"; private static BigInteger randomUnscaled(int precision, Random random) { int length = random.nextInt(precision); if (length == 0) { @@ -262,4 +296,176 @@ private static BigInteger randomUnscaled(int precision, Random random) { return new BigInteger(sb.toString()); } + + public static Iterator generateSpark(Schema schema, + int rows, + long seed) { + return new Iterator(){ + private int rowsLeft = rows; + SparkGenerator generator = buildGenerator(schema.asStruct(), + new Random(seed), false); + + @Override + public boolean hasNext() { + return rowsLeft > 0; + } + + @Override + public InternalRow next() { + rowsLeft -= 1; + return (InternalRow) generator.next(); + } + }; + } + + interface SparkGenerator { + /** + * Generate the next object for Spark. + * @return InternalRow, MapData, ArrayData, etc. + */ + Object next(); + } + + /** + * A filter that generates a null 5% of the time. + */ + static class OptionalSparkGenerator implements SparkGenerator { + private final SparkGenerator child; + private final Random random; + + OptionalSparkGenerator(SparkGenerator child, Random random) { + this.child = child; + this.random = random; + } + + @Override + public Object next() { + return random.nextInt(100) < 5 ? null : child.next(); + } + } + + static class PrimitiveSparkGenerator implements SparkGenerator { + private final Type.PrimitiveType type; + private final Random random; + + PrimitiveSparkGenerator(Type type, Random random) { + this.type = (Type.PrimitiveType) type; + this.random = random; + } + + @Override + public Object next() { + return generatePrimitive(type, random); + } + } + + static class StructSparkGenerator implements SparkGenerator { + private final SparkGenerator[] children; + + StructSparkGenerator(Type type, Random random) { + Types.StructType t = (Types.StructType) type; + List fields = t.fields(); + children = new SparkGenerator[fields.size()]; + for(int c=0; c < children.length; ++c) { + Types.NestedField field = fields.get(c); + children[c] = buildGenerator(field.type(), random, field.isOptional()); + } + } + + @Override + public Object next() { + GenericInternalRow row = new GenericInternalRow(children.length); + for(int c=0; c < children.length; ++c) { + row.update(c, children[c].next()); + } + return row; + } + } + + static class ListSparkGenerator implements SparkGenerator { + private final Random random; + private final SparkGenerator child; + + ListSparkGenerator(Type type, Random random) { + this.random = random; + Types.ListType t = (Types.ListType) type; + child = buildGenerator(t.elementType(), random, t.isElementOptional()); + } + + @Override + public Object next() { + int len = random.nextInt(20); + GenericArrayData result = new GenericArrayData(new Object[len]); + for(int e=0; e < len; ++e) { + result.update(e, child.next()); + } + return result; + } + } + + static class MapSparkGenerator implements SparkGenerator { + private final Random random; + private final SparkGenerator keyGenerator; + private final SparkGenerator valueGenerator; + + MapSparkGenerator(Type type, Random random) { + this.random = random; + Types.MapType t = (Types.MapType) type; + keyGenerator = buildGenerator(t.keyType(), random, false); + valueGenerator = buildGenerator(t.valueType(), random, t.isValueOptional()); + } + + @Override + public Object next() { + int len = random.nextInt(20); + GenericArrayData keys = new GenericArrayData(new Object[len]); + GenericArrayData values = new GenericArrayData(new Object[len]); + ArrayBasedMapData result = new ArrayBasedMapData(keys, values); + List alreadyUsed = new ArrayList(len); + for(int e=0; e < len; ++e) { + Object key; + do { + key = keyGenerator.next(); + } while (alreadyUsed.contains(key)); + alreadyUsed.add(key); + keys.update(e, key); + values.update(e, valueGenerator.next()); + } + return result; + } + } + + static SparkGenerator buildGenerator(Type type, Random random, + boolean isOptional) { + SparkGenerator result; + switch (type.typeId()) { + case BOOLEAN: + case INTEGER: + case LONG: + case FLOAT: + case DOUBLE: + case DATE: + case TIME: + case TIMESTAMP: + case STRING: + case UUID: + case FIXED: + case BINARY: + case DECIMAL: + result = new PrimitiveSparkGenerator(type, random); + break; + case STRUCT: + result = new StructSparkGenerator(type, random); + break; + case LIST: + result = new ListSparkGenerator(type, random); + break; + case MAP: + result = new MapSparkGenerator(type, random); + break; + default: + throw new IllegalArgumentException("Unhandled type " + type); + } + return isOptional ? new OptionalSparkGenerator(result, random) : result; + } } diff --git a/spark/src/test/java/com/netflix/iceberg/spark/data/TestHelpers.java b/spark/src/test/java/com/netflix/iceberg/spark/data/TestHelpers.java index 14d22d590..a0a1dec8a 100644 --- a/spark/src/test/java/com/netflix/iceberg/spark/data/TestHelpers.java +++ b/spark/src/test/java/com/netflix/iceberg/spark/data/TestHelpers.java @@ -23,10 +23,14 @@ import com.netflix.iceberg.types.Types; import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericData.Record; +import org.apache.orc.storage.serde2.io.DateWritable; import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters; import org.apache.spark.sql.catalyst.util.ArrayBasedMapData; import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.types.UTF8String; import org.junit.Assert; @@ -280,4 +284,290 @@ private static void assertEqualsUnsafe(Type type, Object expected, Object actual throw new IllegalArgumentException("Not a supported type: " + type); } } + + /** + * Check that the given InternalRow is equivalent to the Row. + * @param prefix context for error messages + * @param type the type of the row + * @param expected the expected value of the row + * @param actual the actual value of the row + */ + public static void assertEquals(String prefix, Types.StructType type, + InternalRow expected, Row actual) { + if (expected == null || actual == null) { + Assert.assertEquals(prefix, expected, actual); + } else { + List fields = type.fields(); + for (int c = 0; c < fields.size(); ++c) { + String fieldName = fields.get(c).name(); + Type childType = fields.get(c).type(); + switch (childType.typeId()) { + case BOOLEAN: + case INTEGER: + case LONG: + case STRING: + case DECIMAL: + case DATE: + case TIMESTAMP: + Assert.assertEquals(prefix + "." + fieldName + " - " + childType, + getPrimativeValue(expected, c, childType), + getPrimativeValue(actual, c, childType)); + break; + case FLOAT: + case DOUBLE: + assertEqualsDouble(prefix + "." + fieldName, + (Double) getPrimativeValue(expected, c, childType), + (Double) getPrimativeValue(actual, c, childType)); + break; + case UUID: + case FIXED: + case BINARY: + assertEqualBytes(prefix + "." + fieldName, + (byte[]) getPrimativeValue(expected, c, childType), + (byte[]) actual.get(c)); + break; + case STRUCT: { + Types.StructType st = (Types.StructType) childType; + assertEquals(prefix + "." + fieldName, st, + expected.getStruct(c, st.fields().size()), actual.getStruct(c)); + break; + } + case LIST: + assertEqualsLists(prefix + "." + fieldName, childType.asListType(), + expected.getArray(c), + toList((Seq) actual.get(c))); + break; + case MAP: + assertEqualsMaps(prefix + "." + fieldName, childType.asMapType(), expected.getMap(c), + actual.getMap(c)); + break; + default: + throw new IllegalArgumentException("Unhandled type " + childType); + } + } + } + } + + private static void assertEqualsLists(String prefix, Types.ListType type, + ArrayData expected, List actual) { + if (expected == null || actual == null) { + Assert.assertEquals(prefix, expected, actual); + } else { + Assert.assertEquals(prefix + " length", expected.numElements(), actual.size()); + Type childType = type.elementType(); + for (int e = 0; e < expected.numElements(); ++e) { + switch (childType.typeId()) { + case BOOLEAN: + case INTEGER: + case LONG: + case STRING: + case DECIMAL: + case DATE: + case TIMESTAMP: + Assert.assertEquals(prefix + ".elem " + e + " - " + childType, + getPrimativeValue(expected, e, childType), + actual.get(e)); + break; + case FLOAT: + case DOUBLE: + assertEqualsDouble(prefix + ".elem " + e, + (Double) getPrimativeValue(expected, e, childType), + toDouble(actual.get(e))); + break; + case UUID: + case FIXED: + case BINARY: + assertEqualBytes(prefix + ".elem " + e, + (byte[]) getPrimativeValue(expected, e, childType), + (byte[]) actual.get(e)); + break; + case STRUCT: { + Types.StructType st = (Types.StructType) childType; + assertEquals(prefix + ".elem " + e, st, + expected.getStruct(e, st.fields().size()), (Row) actual.get(e)); + break; + } + case LIST: + assertEqualsLists(prefix + ".elem " + e, childType.asListType(), + expected.getArray(e), + toList((Seq) actual.get(e))); + break; + case MAP: + assertEqualsMaps(prefix + ".elem " + e, childType.asMapType(), + expected.getMap(e), (scala.collection.Map) actual.get(e)); + break; + default: + throw new IllegalArgumentException("Unhandled type " + childType); + } + } + } + } + + private static void assertEqualsMaps(String prefix, Types.MapType type, + MapData expected, scala.collection.Map actual) { + if (expected == null || actual == null) { + Assert.assertEquals(prefix, expected, actual); + } else { + Map javaMap = mapAsJavaMapConverter((scala.collection.Map) actual).asJava(); + Type keyType = type.keyType(); + Type valueType = type.valueType(); + ArrayData expectedKeyArray = expected.keyArray(); + ArrayData expectedValueArray = expected.valueArray(); + Assert.assertEquals(prefix + " length", expected.numElements(), javaMap.size()); + for (int e = 0; e < expected.numElements(); ++e) { + Object expectedKey = getPrimativeValue(expectedKeyArray, e, keyType); + Object actualValue = javaMap.get(expectedKey); + if (actualValue == null) { + Assert.assertEquals(prefix + ".key=" + expectedKey + " has null", true, + expected.valueArray().isNullAt(e)); + } else { + switch (valueType.typeId()) { + case BOOLEAN: + case INTEGER: + case LONG: + case STRING: + case DECIMAL: + case DATE: + case TIMESTAMP: + Assert.assertEquals(prefix + ".key=" + expectedKey + " - " + valueType, + getPrimativeValue(expectedValueArray, e, valueType), + javaMap.get(expectedKey)); + break; + case FLOAT: + case DOUBLE: + assertEqualsDouble(prefix + ".key=" + expectedKey, + (Double) getPrimativeValue(expectedValueArray, e, valueType), + (Double) javaMap.get(expectedKey)); + break; + case UUID: + case FIXED: + case BINARY: + assertEqualBytes(prefix + ".key=" + expectedKey, + (byte[]) getPrimativeValue(expectedValueArray, e, valueType), + (byte[]) javaMap.get(expectedKey)); + break; + case STRUCT: { + Types.StructType st = (Types.StructType) valueType; + assertEquals(prefix + ".key=" + expectedKey, st, + expectedValueArray.getStruct(e, st.fields().size()), + (Row) javaMap.get(expectedKey)); + break; + } + case LIST: + assertEqualsLists(prefix + ".key=" + expectedKey, + valueType.asListType(), + expectedValueArray.getArray(e), + toList((Seq) javaMap.get(expectedKey))); + break; + case MAP: + assertEqualsMaps(prefix + ".key=" + expectedKey, valueType.asMapType(), + expectedValueArray.getMap(e), + (scala.collection.Map) javaMap.get(expectedKey)); + break; + default: + throw new IllegalArgumentException("Unhandled type " + valueType); + } + } + } + } + } + + private static void assertEqualsDouble(String context, Double expect, + Double actual) { + if (expect == null && actual == null) { + return; + } + Assert.assertEquals(context, expect, actual, 0.00001); + } + + private static Object getPrimativeValue(SpecializedGetters container, int ord, + Type type) { + if (container.isNullAt(ord)) { + return null; + } + switch (type.typeId()) { + case BOOLEAN: + return container.getBoolean(ord); + case INTEGER: + return container.getInt(ord); + case LONG: + return container.getLong(ord); + case FLOAT: + return (double) container.getFloat(ord); + case DOUBLE: + return container.getDouble(ord); + case STRING: + return container.getUTF8String(ord).toString(); + case BINARY: + case FIXED: + case UUID: + return container.getBinary(ord); + case DATE: + return new DateWritable(container.getInt(ord)).get(); + case TIMESTAMP: + return DateTimeUtils.toJavaTimestamp(container.getLong(ord)); + case DECIMAL: { + Types.DecimalType dt = (Types.DecimalType) type; + return container.getDecimal(ord, dt.precision(), dt.scale()).toJavaBigDecimal(); + } + default: + throw new IllegalArgumentException("Unhandled type " + type); + } + } + + private static Object getPrimativeValue(Row row, int ord, Type type) { + if (row.isNullAt(ord)) { + return null; + } + switch (type.typeId()) { + case BOOLEAN: + return row.getBoolean(ord); + case INTEGER: + return row.getInt(ord); + case LONG: + return row.getLong(ord); + case FLOAT: + return (double) row.getFloat(ord); + case DOUBLE: + return row.getDouble(ord); + case STRING: + return row.getString(ord); + case BINARY: + case FIXED: + case UUID: + return row.get(ord); + case DATE: + return row.getDate(ord); + case TIMESTAMP: + return row.getTimestamp(ord); + case DECIMAL: + return row.getDecimal(ord); + default: + throw new IllegalArgumentException("Unhandled type " + type); + } + } + + private static Double toDouble(Object val) { + if (val == null) { + return null; + } else if (val instanceof Float) { + return (double) ((Float)val); + } else if (val instanceof Double) { + return (Double) val; + } + throw new IllegalArgumentException("Can't convert " + val + " to Double."); + } + + private static List toList(Seq val) { + return val == null ? null : seqAsJavaListConverter(val).asJava(); + } + + private static void assertEqualBytes(String context, byte[] expected, + byte[] actual) { + if (expected == null || actual == null) { + Assert.assertEquals(context, expected, actual); + } else { + Assert.assertArrayEquals(context, expected, actual); + } + } } diff --git a/spark/src/test/java/com/netflix/iceberg/spark/source/SimpleRecord.java b/spark/src/test/java/com/netflix/iceberg/spark/source/SimpleRecord.java new file mode 100644 index 000000000..fb401c72d --- /dev/null +++ b/spark/src/test/java/com/netflix/iceberg/spark/source/SimpleRecord.java @@ -0,0 +1,77 @@ +/* + * Copyright 2018 Hortonworks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.netflix.iceberg.spark.source; + +import com.google.common.base.Objects; + +public class SimpleRecord { + private Integer id; + private String data; + + public SimpleRecord() { + } + + SimpleRecord(Integer id, String data) { + this.id = id; + this.data = data; + } + + public Integer getId() { + return id; + } + + public void setId(Integer id) { + this.id = id; + } + + public String getData() { + return data; + } + + public void setData(String data) { + this.data = data; + } + + @Override + public boolean equals(Object o) { + if (this == o){ + return true; + } + if (o == null || getClass() != o.getClass()){ + return false; + } + + SimpleRecord record = (SimpleRecord) o; + return Objects.equal(id, record.id) && Objects.equal(data, record.data); + } + + @Override + public int hashCode() { + return Objects.hashCode(id, data); + } + + @Override + public String toString() { + StringBuilder buffer = new StringBuilder(); + buffer.append("{\"id\"="); + buffer.append(id); + buffer.append(",\"data\"=\""); + buffer.append(data); + buffer.append("\"}"); + return buffer.toString(); + } +} diff --git a/spark/src/test/java/com/netflix/iceberg/spark/source/TestDataFrameWrites.java b/spark/src/test/java/com/netflix/iceberg/spark/source/TestDataFrameWrites.java index 9ddd03f21..5454d01ba 100644 --- a/spark/src/test/java/com/netflix/iceberg/spark/source/TestDataFrameWrites.java +++ b/spark/src/test/java/com/netflix/iceberg/spark/source/TestDataFrameWrites.java @@ -60,6 +60,7 @@ public class TestDataFrameWrites extends AvroDataTest { public static Object[][] parameters() { return new Object[][] { new Object[] { "parquet" }, + new Object[] { "orc" }, new Object[] { "avro" } }; } diff --git a/spark/src/test/java/com/netflix/iceberg/spark/source/TestOrcScan.java b/spark/src/test/java/com/netflix/iceberg/spark/source/TestOrcScan.java new file mode 100644 index 000000000..c7c551323 --- /dev/null +++ b/spark/src/test/java/com/netflix/iceberg/spark/source/TestOrcScan.java @@ -0,0 +1,134 @@ +/* + * Copyright 2018 Hortonworks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.netflix.iceberg.spark.source; + +import com.netflix.iceberg.DataFile; +import com.netflix.iceberg.DataFiles; +import com.netflix.iceberg.FileFormat; +import com.netflix.iceberg.Metrics; +import com.netflix.iceberg.PartitionSpec; +import com.netflix.iceberg.Schema; +import com.netflix.iceberg.Table; +import com.netflix.iceberg.hadoop.HadoopTables; +import com.netflix.iceberg.io.FileAppender; +import com.netflix.iceberg.orc.ORC; +import com.netflix.iceberg.orc.OrcFileAppender; +import com.netflix.iceberg.spark.data.AvroDataTest; +import com.netflix.iceberg.spark.data.RandomData; +import com.netflix.iceberg.spark.data.SparkOrcWriter; +import com.netflix.iceberg.spark.data.TestHelpers; +import com.netflix.iceberg.types.Type; +import com.netflix.iceberg.types.Types; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.serde2.io.TimestampWritable; +import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch; +import org.apache.orc.storage.serde2.io.DateWritable; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.io.IOException; +import java.sql.Date; +import java.sql.Timestamp; +import java.util.Iterator; +import java.util.List; +import java.util.UUID; + +import static com.netflix.iceberg.Files.localOutput; + +public class TestOrcScan extends AvroDataTest { + private static final Configuration CONF = new Configuration(); + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private static SparkSession spark = null; + + @BeforeClass + public static void startSpark() { + TestOrcScan.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterClass + public static void stopSpark() { + SparkSession spark = TestOrcScan.spark; + TestOrcScan.spark = null; + spark.stop(); + } + + @Override + protected void writeAndValidate(Schema schema) throws IOException { + System.out.println("Starting ORC test with " + schema); + final int ROW_COUNT = 100; + final long SEED = 1; + File parent = temp.newFolder("orc"); + File location = new File(parent, "test"); + File dataFolder = new File(location, "data"); + dataFolder.mkdirs(); + + File orcFile = new File(dataFolder, + FileFormat.ORC.addExtension(UUID.randomUUID().toString())); + + HadoopTables tables = new HadoopTables(CONF); + Table table = tables.create(schema, PartitionSpec.unpartitioned(), + location.toString()); + + // Important: use the table's schema for the rest of the test + // When tables are created, the column ids are reassigned. + Schema tableSchema = table.schema(); + + Metrics metrics; + try (SparkOrcWriter writer = + new SparkOrcWriter(ORC.write(localOutput(orcFile)) + .schema(tableSchema) + .build())) { + writer.addAll(RandomData.generateSpark(tableSchema, ROW_COUNT, SEED)); + metrics = writer.metrics(); + } + + DataFile file = DataFiles.builder(PartitionSpec.unpartitioned()) + .withFileSizeInBytes(orcFile.length()) + .withPath(orcFile.toString()) + .withMetrics(metrics) + .build(); + + table.newAppend().appendFile(file).commit(); + + Dataset df = spark.read() + .format("iceberg") + .load(location.toString()); + + List rows = df.collectAsList(); + Assert.assertEquals("Wrong number of rows", ROW_COUNT, rows.size()); + Iterator expected = RandomData.generateSpark(tableSchema, + ROW_COUNT, SEED); + for(int i=0; i < ROW_COUNT; ++i) { + TestHelpers.assertEquals("row " + i, schema.asStruct(), expected.next(), + rows.get(i)); + } + } +} diff --git a/spark/src/test/java/com/netflix/iceberg/spark/source/TestOrcWrite.java b/spark/src/test/java/com/netflix/iceberg/spark/source/TestOrcWrite.java new file mode 100644 index 000000000..bc126709a --- /dev/null +++ b/spark/src/test/java/com/netflix/iceberg/spark/source/TestOrcWrite.java @@ -0,0 +1,110 @@ +/* + * Copyright 2018 Hortonworks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.netflix.iceberg.spark.source; + +import com.google.common.collect.Lists; +import com.netflix.iceberg.FileFormat; +import com.netflix.iceberg.PartitionSpec; +import com.netflix.iceberg.Schema; +import com.netflix.iceberg.Table; +import com.netflix.iceberg.hadoop.HadoopTables; +import com.netflix.iceberg.types.Types; +import org.apache.hadoop.conf.Configuration; +import org.apache.orc.CompressionKind; +import org.apache.orc.OrcConf; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.io.IOException; +import java.util.List; + +import static com.netflix.iceberg.types.Types.NestedField.optional; + +public class TestOrcWrite { + private static final Configuration CONF = new Configuration(); + private static final Schema SCHEMA = new Schema( + optional(1, "id", Types.IntegerType.get()), + optional(2, "data", Types.StringType.get()) + ); + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private static SparkSession spark = null; + + @BeforeClass + public static void startSpark() { + TestOrcWrite.spark = SparkSession.builder().master("local[2]").getOrCreate(); + } + + @AfterClass + public static void stopSpark() { + SparkSession spark = TestOrcWrite.spark; + TestOrcWrite.spark = null; + spark.stop(); + } + + @Test + public void testBasicWrite() throws IOException { + File parent = temp.newFolder("orc"); + File location = new File(parent, "test"); + location.mkdirs(); + + HadoopTables tables = new HadoopTables(CONF); + PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); + Table table = tables.create(SCHEMA, spec, location.toString()); + table.updateProperties() + .defaultFormat(FileFormat.ORC) + .set(OrcConf.COMPRESS.getAttribute(), CompressionKind.NONE.name()) + .commit(); + + List expected = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c") + ); + + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); + + // TODO: incoming columns must be ordered according to the table's schema + df.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(location.toString()); + + table.refresh(); + + Dataset result = spark.read() + .format("iceberg") + .load(location.toString()); + + List actual = result.orderBy("id").as( + Encoders.bean(SimpleRecord.class)).collectAsList(); + + Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); + Assert.assertEquals("Result rows should match", expected, actual); + } +} diff --git a/spark/src/test/java/com/netflix/iceberg/spark/source/TestParquetWrite.java b/spark/src/test/java/com/netflix/iceberg/spark/source/TestParquetWrite.java index 4c0986ccd..c8cb93f30 100644 --- a/spark/src/test/java/com/netflix/iceberg/spark/source/TestParquetWrite.java +++ b/spark/src/test/java/com/netflix/iceberg/spark/source/TestParquetWrite.java @@ -16,7 +16,6 @@ package com.netflix.iceberg.spark.source; -import com.google.common.base.Objects; import com.google.common.collect.Lists; import com.netflix.iceberg.PartitionSpec; import com.netflix.iceberg.Schema; @@ -65,53 +64,6 @@ public static void stopSpark() { spark.stop(); } - public static class Record { - private Integer id; - private String data; - - public Record() { - } - - private Record(Integer id, String data) { - this.id = id; - this.data = data; - } - - public Integer getId() { - return id; - } - - public void setId(Integer id) { - this.id = id; - } - - public String getData() { - return data; - } - - public void setData(String data) { - this.data = data; - } - - @Override - public boolean equals(Object o) { - if (this == o){ - return true; - } - if (o == null || getClass() != o.getClass()){ - return false; - } - - Record record = (Record) o; - return Objects.equal(id, record.id) && Objects.equal(data, record.data); - } - - @Override - public int hashCode() { - return Objects.hashCode(id, data); - } - } - @Test public void testBasicWrite() throws IOException { File parent = temp.newFolder("parquet"); @@ -122,13 +74,13 @@ public void testBasicWrite() throws IOException { PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); Table table = tables.create(SCHEMA, spec, location.toString()); - List expected = Lists.newArrayList( - new Record(1, "a"), - new Record(2, "b"), - new Record(3, "c") + List expected = Lists.newArrayList( + new SimpleRecord(1, "a"), + new SimpleRecord(2, "b"), + new SimpleRecord(3, "c") ); - Dataset df = spark.createDataFrame(expected, Record.class); + Dataset df = spark.createDataFrame(expected, SimpleRecord.class); // TODO: incoming columns must be ordered according to the table's schema df.select("id", "data").write() @@ -142,7 +94,7 @@ public void testBasicWrite() throws IOException { .format("iceberg") .load(location.toString()); - List actual = result.orderBy("id").as(Encoders.bean(Record.class)).collectAsList(); + List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); Assert.assertEquals("Result rows should match", expected, actual);