diff --git a/src/main/scala/models/schemas/ArcaneSchema.scala b/src/main/scala/models/schemas/ArcaneSchema.scala index 619d0ab2..9c2ed74d 100644 --- a/src/main/scala/models/schemas/ArcaneSchema.scala +++ b/src/main/scala/models/schemas/ArcaneSchema.scala @@ -5,6 +5,7 @@ import ArcaneType.{StringType, StructType} import models.* import services.base.CanAdd +import scala.annotation.tailrec import scala.language.implicitConversions /** Types of fields in ArcaneSchema. @@ -27,12 +28,15 @@ enum ArcaneType: case ObjectType case StructType(schema: ArcaneSchema) - override def equals(obj: Any): Boolean = (this, obj) match { - case (t1: ListType, t2: ListType) => t1.elementType == t2.elementType + @tailrec + final def typeEquals(other: ArcaneType): Boolean = (this, other) match { + case (IntType, ShortType) => true + case (ShortType, IntType) => true + case (t1: ListType, t2: ListType) => t1.elementType.typeEquals(t2.elementType) case (ListType, _) => false case (t1: StructType, t2: StructType) => t1.schema.getMissingFields(t2.schema).isEmpty && t2.schema.getMissingFields(t1.schema).isEmpty - case _ => this.toString == obj.toString + case _ => this.toString == other.toString } /** A field in the schema definition that will require indexing when converting to Iceberg @@ -50,7 +54,7 @@ trait ArcaneSchemaField: other.name.toLowerCase() == name.toLowerCase() && thisType.schema .getMissingFields(otherType.schema) .isEmpty && otherType.schema.getMissingFields(thisType.schema).isEmpty - case _ => other.name.toLowerCase() == name.toLowerCase() && other.fieldType == fieldType + case _ => other.name.toLowerCase() == name.toLowerCase() && other.fieldType.typeEquals(fieldType) } /** A field in the schema definition that carries index information from the source that can be re-applied when @@ -63,14 +67,14 @@ trait IndexedArcaneSchemaField extends ArcaneSchemaField: */ final case class Field(name: String, fieldType: ArcaneType) extends ArcaneSchemaField: override def equals(obj: Any): Boolean = obj match - case Field(n, t) => n.toLowerCase() == name.toLowerCase() && t == fieldType + case Field(n, t) => n.toLowerCase() == name.toLowerCase() && t.typeEquals(fieldType) case _ => false /** Field is a case class that represents a field in ArcaneSchema */ final case class IndexedField(name: String, fieldType: ArcaneType, fieldId: Int) extends IndexedArcaneSchemaField: override def equals(obj: Any): Boolean = obj match - case IndexedField(n, t, id) => n.toLowerCase() == name.toLowerCase() && t == fieldType && id == fieldId + case IndexedField(n, t, id) => n.toLowerCase() == name.toLowerCase() && t.typeEquals(fieldType) && id == fieldId case _ => false /** MergeKeyField represents a field used for batch merges diff --git a/src/test/scala/tests/iceberg/SchemaConversionsTests.scala b/src/test/scala/tests/iceberg/SchemaConversionsTests.scala index 972ba937..bb4e23fd 100644 --- a/src/test/scala/tests/iceberg/SchemaConversionsTests.scala +++ b/src/test/scala/tests/iceberg/SchemaConversionsTests.scala @@ -1,8 +1,24 @@ package com.sneaksanddata.arcane.framework package tests.iceberg -import models.schemas.ArcaneType.{BigDecimalType, ListType, StringType, StructType} -import models.schemas.{ArcaneSchema, IndexedArcaneSchemaField, IndexedMergeKeyField, MergeKeyField} +import models.schemas.ArcaneType.{ + BigDecimalType, + BooleanType, + ByteArrayType, + DateTimeOffsetType, + DateType, + DoubleType, + FloatType, + IntType, + ListType, + LongType, + ShortType, + StringType, + StructType, + TimeType, + TimestampType +} +import models.schemas.{ArcaneSchema, IndexedArcaneSchemaField, IndexedField, IndexedMergeKeyField, MergeKeyField} import services.iceberg.{given_Conversion_ArcaneSchema_Schema, given_Conversion_Schema_ArcaneSchema, inferMergeKeyIndex} import org.apache.iceberg.Schema @@ -52,7 +68,7 @@ class SchemaConversionsTests extends AnyFlatSpec with Matchers { arcaneSchema.reverse.head should be(IndexedMergeKeyField(mergeKeyIndex)), arcaneSchema .find(f => f.name == "event_value_bigdecimal") - .map(f => f.fieldType == BigDecimalType(16, 4)) should be(Some(true)) + .map(f => f.fieldType.typeEquals(BigDecimalType(16, 4))) should be(Some(true)) ) } @@ -71,10 +87,14 @@ class SchemaConversionsTests extends AnyFlatSpec with Matchers { ( arcaneSchema.length should be(iceberg.columns().size() + 1), arcaneSchema.reverse.head should be(IndexedMergeKeyField(mergeKeyIndex)), - arcaneSchema.find(f => f.name == "call_stack_1").map(f => f.fieldType == ListType(StringType, 3)) should be( + arcaneSchema + .find(f => f.name == "call_stack_1") + .map(f => f.fieldType.typeEquals(ListType(StringType, 3))) should be( Some(true) ), - arcaneSchema.find(f => f.name == "call_stack_2").map(f => f.fieldType == ListType(StringType, 5)) should be( + arcaneSchema + .find(f => f.name == "call_stack_2") + .map(f => f.fieldType.typeEquals(ListType(StringType, 5))) should be( Some(true) ) ) @@ -140,4 +160,39 @@ class SchemaConversionsTests extends AnyFlatSpec with Matchers { ) } } + + it should "convert from ArcaneSchema to Iceberg and back" in { + forAll( + Seq( + ArcaneSchema( + List( + IndexedMergeKeyField(1), + IndexedField("col1", IntType, 2), + IndexedField("col2", LongType, 3), + IndexedField("col3", ByteArrayType, 4), + IndexedField("col4", BooleanType, 5), + IndexedField("col5", StringType, 6), + IndexedField("col6", DateType, 7), + IndexedField("col7", TimestampType, 8), + IndexedField("col8", DateTimeOffsetType, 9), + IndexedField("col9", BigDecimalType(16, 4), 10), + IndexedField("col10", DoubleType, 11), + IndexedField("col11", FloatType, 12), + IndexedField("col12", ShortType, 13), + IndexedField("col13", TimeType, 14) + ) + ), + ArcaneSchema( + List( + IndexedMergeKeyField(1) + ) + ) + ) + ) { arcaneSchema => + val iceberg1: Schema = implicitly(using arcaneSchema) + val arcaneSchema2: ArcaneSchema = implicitly(using iceberg1) + + arcaneSchema2.getMissingFields(arcaneSchema).isEmpty should be(true) + } + } } diff --git a/src/test/scala/tests/models/ArcaneSchemaTests.scala b/src/test/scala/tests/models/ArcaneSchemaTests.scala index 99477335..d8b194a9 100644 --- a/src/test/scala/tests/models/ArcaneSchemaTests.scala +++ b/src/test/scala/tests/models/ArcaneSchemaTests.scala @@ -50,7 +50,7 @@ class ArcaneSchemaTests extends AnyFlatSpec with Matchers { ) ) ) { case (typeA, typeB, expectedResult) => - (typeA == typeB) should be(expectedResult) + typeA.typeEquals(typeB) should be(expectedResult) } }