Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions internal/dbtest/orm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func TestORM(t *testing.T) {
{testCompositeM2M},
{testHasOneRelationWithOpts},
{testHasManyRelationWithOpts},
{testM2MRelationOnEmbeddedBaseModel},
}

testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
Expand Down Expand Up @@ -667,6 +668,73 @@ func testHasManyRelationWithOpts(t *testing.T, db *bun.DB) {
}, outUsers2)
}

func testM2MRelationOnEmbeddedBaseModel(t *testing.T, db *bun.DB) {
type Item struct {
ID int64 `bun:",pk"`
}

type Order struct {
bun.BaseModel `bun:"table:orders"`

ID int64 `bun:",pk"`
Items []Item `bun:"m2m:order_to_items,join:Order=Item"`
}

type OrderToItem struct {
bun.BaseModel `bun:"table:order_to_items"`

OrderID int64 `bun:",pk"`
Order *Order `bun:"rel:belongs-to,join:order_id=id"`
ItemID int64 `bun:",pk"`
Item *Item `bun:"rel:belongs-to,join:item_id=id"`
}

type OrderWrap struct {
bun.BaseModel `bun:"table:orders,alias:orders"`
*Order

Extra bool `bun:"extra,scanonly"`
}

db.RegisterModel((*OrderToItem)(nil))
mustResetModel(t, ctx, db, (*Item)(nil), (*Order)(nil), (*OrderToItem)(nil))

items := []Item{
{ID: 1},
{ID: 2},
}
_, err := db.NewInsert().Model(&items).Exec(ctx)
require.NoError(t, err)

orders := []Order{
{ID: 10},
{ID: 11},
}
_, err = db.NewInsert().Model(&orders).Exec(ctx)
require.NoError(t, err)

orderItems := []OrderToItem{
{OrderID: 10, ItemID: 1},
{OrderID: 10, ItemID: 2},
{OrderID: 11, ItemID: 2},
}
_, err = db.NewInsert().Model(&orderItems).Exec(ctx)
require.NoError(t, err)

var bare []Order
err = db.NewSelect().Model(&bare).Relation("Items").Order("id").Scan(ctx)
require.NoError(t, err)
require.Len(t, bare, 2)
require.Len(t, bare[0].Items, 2)
require.Len(t, bare[1].Items, 1)

var wrapped []OrderWrap
err = db.NewSelect().Model(&wrapped).Relation("Items").Order("id").Scan(ctx)
require.NoError(t, err)
require.Len(t, wrapped, 2)
require.Len(t, wrapped[1].Items, 1)
}

type Genre struct {
ID int `bun:",pk"`
Name string
Expand Down
17 changes: 16 additions & 1 deletion schema/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ func (t *Table) m2mRelation(field *Field) *Relation {
}

leftRel := m2mTable.belongsToRelation(leftField)
rel.BasePKs = leftRel.JoinPKs
rel.BasePKs = baseTablePKs(t, leftRel.JoinPKs)
rel.M2MBasePKs = leftRel.BasePKs

rightRel := m2mTable.belongsToRelation(rightField)
Expand Down Expand Up @@ -1057,6 +1057,21 @@ func parseRelationJoin(join []string) ([]string, []string) {
return baseColumns, joinColumns
}

func baseTablePKs(t *Table, pks []*Field) []*Field {
out := make([]*Field, len(pks))

for i, f := range pks {
if local, ok := t.FieldMap[f.Name]; ok {
out[i] = local
continue
}

out[i] = f
}

return out
}

//------------------------------------------------------------------------------

func softDeleteFieldUpdater(field *Field) func(fv reflect.Value, tm time.Time) error {
Expand Down
44 changes: 44 additions & 0 deletions schema/table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,4 +335,48 @@ func TestTable(t *testing.T) {
require.True(t, counter.AutoIncrement, "autoincrement")
require.True(t, counter.NotNull, "not null")
})

t.Run("m2m on embedded base", func(t *testing.T) {
type Item struct {
ID int64 `bun:",pk"`
}

type Order struct {
BaseModel `bun:"orders"`

ID int64 `bun:",pk"`
Items []Item `bun:"m2m:order_to_items,join:Order=Item"`
}

type OrderToItem struct {
BaseModel `bun:"order_to_items"`

OrderID int64 `bun:",pk"`
Order *Order `bun:"rel:belongs-to,join:order_id=id"`
ItemID int64 `bun:",pk"`
Item *Item `bun:"rel:belongs-to,join:item_id=id"`
}

type OrderWrap struct {
BaseModel `bun:"orders,alias:orders"`
*Order

Extra bool `bun:"extra"`
}

dialect := newNopDialect()
dialect.Tables().Register((*OrderToItem)(nil))

outer := dialect.Tables().Get(reflect.TypeOf((*OrderWrap)(nil)).Elem())

id, ok := outer.FieldMap["id"]
require.True(t, ok)
require.Equal(t, []int{1, 1}, id.Index)

rel, ok := outer.Relations["Items"]
require.True(t, ok)
require.Equal(t, ManyToManyRelation, rel.Type)
require.Len(t, rel.BasePKs, 1)
require.Same(t, id, rel.BasePKs[0])
})
}
Loading