Skip to content
16 changes: 10 additions & 6 deletions src/main/scala/models/schemas/ArcaneSchema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import ArcaneType.{StringType, StructType}
import models.*
import services.base.CanAdd

import scala.annotation.tailrec
import scala.language.implicitConversions

/** Types of fields in ArcaneSchema.
Expand All @@ -27,12 +28,15 @@ enum ArcaneType:
case ObjectType
case StructType(schema: ArcaneSchema)

override def equals(obj: Any): Boolean = (this, obj) match {
case (t1: ListType, t2: ListType) => t1.elementType == t2.elementType
@tailrec
final def typeEquals(other: ArcaneType): Boolean = (this, other) match {
case (IntType, ShortType) => true
case (ShortType, IntType) => true
case (t1: ListType, t2: ListType) => t1.elementType.typeEquals(t2.elementType)
case (ListType, _) => false
case (t1: StructType, t2: StructType) =>
t1.schema.getMissingFields(t2.schema).isEmpty && t2.schema.getMissingFields(t1.schema).isEmpty
case _ => this.toString == obj.toString
case _ => this.toString == other.toString
}

/** A field in the schema definition that will require indexing when converting to Iceberg
Expand All @@ -50,7 +54,7 @@ trait ArcaneSchemaField:
other.name.toLowerCase() == name.toLowerCase() && thisType.schema
.getMissingFields(otherType.schema)
.isEmpty && otherType.schema.getMissingFields(thisType.schema).isEmpty
case _ => other.name.toLowerCase() == name.toLowerCase() && other.fieldType == fieldType
case _ => other.name.toLowerCase() == name.toLowerCase() && other.fieldType.typeEquals(fieldType)
}

/** A field in the schema definition that carries index information from the source that can be re-applied when
Expand All @@ -63,14 +67,14 @@ trait IndexedArcaneSchemaField extends ArcaneSchemaField:
*/
final case class Field(name: String, fieldType: ArcaneType) extends ArcaneSchemaField:
override def equals(obj: Any): Boolean = obj match
case Field(n, t) => n.toLowerCase() == name.toLowerCase() && t == fieldType
case Field(n, t) => n.toLowerCase() == name.toLowerCase() && t.typeEquals(fieldType)
case _ => false

/** Field is a case class that represents a field in ArcaneSchema
*/
final case class IndexedField(name: String, fieldType: ArcaneType, fieldId: Int) extends IndexedArcaneSchemaField:
override def equals(obj: Any): Boolean = obj match
case IndexedField(n, t, id) => n.toLowerCase() == name.toLowerCase() && t == fieldType && id == fieldId
case IndexedField(n, t, id) => n.toLowerCase() == name.toLowerCase() && t.typeEquals(fieldType) && id == fieldId
case _ => false

/** MergeKeyField represents a field used for batch merges
Expand Down
65 changes: 60 additions & 5 deletions src/test/scala/tests/iceberg/SchemaConversionsTests.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
package com.sneaksanddata.arcane.framework
package tests.iceberg

import models.schemas.ArcaneType.{BigDecimalType, ListType, StringType, StructType}
import models.schemas.{ArcaneSchema, IndexedArcaneSchemaField, IndexedMergeKeyField, MergeKeyField}
import models.schemas.ArcaneType.{
BigDecimalType,
BooleanType,
ByteArrayType,
DateTimeOffsetType,
DateType,
DoubleType,
FloatType,
IntType,
ListType,
LongType,
ShortType,
StringType,
StructType,
TimeType,
TimestampType
}
import models.schemas.{ArcaneSchema, IndexedArcaneSchemaField, IndexedField, IndexedMergeKeyField, MergeKeyField}
import services.iceberg.{given_Conversion_ArcaneSchema_Schema, given_Conversion_Schema_ArcaneSchema, inferMergeKeyIndex}

import org.apache.iceberg.Schema
Expand Down Expand Up @@ -52,7 +68,7 @@ class SchemaConversionsTests extends AnyFlatSpec with Matchers {
arcaneSchema.reverse.head should be(IndexedMergeKeyField(mergeKeyIndex)),
arcaneSchema
.find(f => f.name == "event_value_bigdecimal")
.map(f => f.fieldType == BigDecimalType(16, 4)) should be(Some(true))
.map(f => f.fieldType.typeEquals(BigDecimalType(16, 4))) should be(Some(true))
)
}

Expand All @@ -71,10 +87,14 @@ class SchemaConversionsTests extends AnyFlatSpec with Matchers {
(
arcaneSchema.length should be(iceberg.columns().size() + 1),
arcaneSchema.reverse.head should be(IndexedMergeKeyField(mergeKeyIndex)),
arcaneSchema.find(f => f.name == "call_stack_1").map(f => f.fieldType == ListType(StringType, 3)) should be(
arcaneSchema
.find(f => f.name == "call_stack_1")
.map(f => f.fieldType.typeEquals(ListType(StringType, 3))) should be(
Some(true)
),
arcaneSchema.find(f => f.name == "call_stack_2").map(f => f.fieldType == ListType(StringType, 5)) should be(
arcaneSchema
.find(f => f.name == "call_stack_2")
.map(f => f.fieldType.typeEquals(ListType(StringType, 5))) should be(
Some(true)
)
)
Expand Down Expand Up @@ -140,4 +160,39 @@ class SchemaConversionsTests extends AnyFlatSpec with Matchers {
)
}
}

it should "convert from ArcaneSchema to Iceberg and back" in {
forAll(
Seq(
ArcaneSchema(
List(
IndexedMergeKeyField(1),
IndexedField("col1", IntType, 2),
IndexedField("col2", LongType, 3),
IndexedField("col3", ByteArrayType, 4),
IndexedField("col4", BooleanType, 5),
IndexedField("col5", StringType, 6),
IndexedField("col6", DateType, 7),
IndexedField("col7", TimestampType, 8),
IndexedField("col8", DateTimeOffsetType, 9),
IndexedField("col9", BigDecimalType(16, 4), 10),
IndexedField("col10", DoubleType, 11),
IndexedField("col11", FloatType, 12),
IndexedField("col12", ShortType, 13),
IndexedField("col13", TimeType, 14)
)
),
ArcaneSchema(
List(
IndexedMergeKeyField(1)
)
)
)
) { arcaneSchema =>
val iceberg1: Schema = implicitly(using arcaneSchema)
val arcaneSchema2: ArcaneSchema = implicitly(using iceberg1)

arcaneSchema2.getMissingFields(arcaneSchema).isEmpty should be(true)
}
}
}
2 changes: 1 addition & 1 deletion src/test/scala/tests/models/ArcaneSchemaTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class ArcaneSchemaTests extends AnyFlatSpec with Matchers {
)
)
) { case (typeA, typeB, expectedResult) =>
(typeA == typeB) should be(expectedResult)
typeA.typeEquals(typeB) should be(expectedResult)
}
}

Expand Down
Loading