diff --git a/internal/dbtest/orm_test.go b/internal/dbtest/orm_test.go index 407dd2459..7802d2e31 100644 --- a/internal/dbtest/orm_test.go +++ b/internal/dbtest/orm_test.go @@ -37,6 +37,7 @@ func TestORM(t *testing.T) { {testCompositeM2M}, {testHasOneRelationWithOpts}, {testHasManyRelationWithOpts}, + {testM2MRelationOnEmbeddedBaseModel}, } testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { @@ -667,6 +668,73 @@ func testHasManyRelationWithOpts(t *testing.T, db *bun.DB) { }, outUsers2) } +func testM2MRelationOnEmbeddedBaseModel(t *testing.T, db *bun.DB) { + type Item struct { + ID int64 `bun:",pk"` + } + + type Order struct { + bun.BaseModel `bun:"table:orders"` + + ID int64 `bun:",pk"` + Items []Item `bun:"m2m:order_to_items,join:Order=Item"` + } + + type OrderToItem struct { + bun.BaseModel `bun:"table:order_to_items"` + + OrderID int64 `bun:",pk"` + Order *Order `bun:"rel:belongs-to,join:order_id=id"` + ItemID int64 `bun:",pk"` + Item *Item `bun:"rel:belongs-to,join:item_id=id"` + } + + type OrderWrap struct { + bun.BaseModel `bun:"table:orders,alias:orders"` + *Order + + Extra bool `bun:"extra,scanonly"` + } + + db.RegisterModel((*OrderToItem)(nil)) + mustResetModel(t, ctx, db, (*Item)(nil), (*Order)(nil), (*OrderToItem)(nil)) + + items := []Item{ + {ID: 1}, + {ID: 2}, + } + _, err := db.NewInsert().Model(&items).Exec(ctx) + require.NoError(t, err) + + orders := []Order{ + {ID: 10}, + {ID: 11}, + } + _, err = db.NewInsert().Model(&orders).Exec(ctx) + require.NoError(t, err) + + orderItems := []OrderToItem{ + {OrderID: 10, ItemID: 1}, + {OrderID: 10, ItemID: 2}, + {OrderID: 11, ItemID: 2}, + } + _, err = db.NewInsert().Model(&orderItems).Exec(ctx) + require.NoError(t, err) + + var bare []Order + err = db.NewSelect().Model(&bare).Relation("Items").Order("id").Scan(ctx) + require.NoError(t, err) + require.Len(t, bare, 2) + require.Len(t, bare[0].Items, 2) + require.Len(t, bare[1].Items, 1) + + var wrapped []OrderWrap + err = db.NewSelect().Model(&wrapped).Relation("Items").Order("id").Scan(ctx) + require.NoError(t, err) + require.Len(t, wrapped, 2) + require.Len(t, wrapped[1].Items, 1) +} + type Genre struct { ID int `bun:",pk"` Name string diff --git a/schema/table.go b/schema/table.go index 8af8dd541..736d01f32 100644 --- a/schema/table.go +++ b/schema/table.go @@ -935,7 +935,7 @@ func (t *Table) m2mRelation(field *Field) *Relation { } leftRel := m2mTable.belongsToRelation(leftField) - rel.BasePKs = leftRel.JoinPKs + rel.BasePKs = baseTablePKs(t, leftRel.JoinPKs) rel.M2MBasePKs = leftRel.BasePKs rightRel := m2mTable.belongsToRelation(rightField) @@ -1057,6 +1057,21 @@ func parseRelationJoin(join []string) ([]string, []string) { return baseColumns, joinColumns } +func baseTablePKs(t *Table, pks []*Field) []*Field { + out := make([]*Field, len(pks)) + + for i, f := range pks { + if local, ok := t.FieldMap[f.Name]; ok { + out[i] = local + continue + } + + out[i] = f + } + + return out +} + //------------------------------------------------------------------------------ func softDeleteFieldUpdater(field *Field) func(fv reflect.Value, tm time.Time) error { diff --git a/schema/table_test.go b/schema/table_test.go index c298a2856..455b8348b 100644 --- a/schema/table_test.go +++ b/schema/table_test.go @@ -335,4 +335,48 @@ func TestTable(t *testing.T) { require.True(t, counter.AutoIncrement, "autoincrement") require.True(t, counter.NotNull, "not null") }) + + t.Run("m2m on embedded base", func(t *testing.T) { + type Item struct { + ID int64 `bun:",pk"` + } + + type Order struct { + BaseModel `bun:"orders"` + + ID int64 `bun:",pk"` + Items []Item `bun:"m2m:order_to_items,join:Order=Item"` + } + + type OrderToItem struct { + BaseModel `bun:"order_to_items"` + + OrderID int64 `bun:",pk"` + Order *Order `bun:"rel:belongs-to,join:order_id=id"` + ItemID int64 `bun:",pk"` + Item *Item `bun:"rel:belongs-to,join:item_id=id"` + } + + type OrderWrap struct { + BaseModel `bun:"orders,alias:orders"` + *Order + + Extra bool `bun:"extra"` + } + + dialect := newNopDialect() + dialect.Tables().Register((*OrderToItem)(nil)) + + outer := dialect.Tables().Get(reflect.TypeOf((*OrderWrap)(nil)).Elem()) + + id, ok := outer.FieldMap["id"] + require.True(t, ok) + require.Equal(t, []int{1, 1}, id.Index) + + rel, ok := outer.Relations["Items"] + require.True(t, ok) + require.Equal(t, ManyToManyRelation, rel.Type) + require.Len(t, rel.BasePKs, 1) + require.Same(t, id, rel.BasePKs[0]) + }) }