Skip to content

Commit 24c8d5f

Browse files
[bugfix] fix CI test failures for HSTUMatch
- Fix CPU CI: remove @parameterized.expand from test_hstu_match_export (single-case test) to fix skipIf+parameterized decorator interaction that prevented GPU unavailable skip - Fix GPU CI: change candidate group from JAGGED_SEQUENCE to DEEP with id_feature to fix type mismatch (string vs int64) in mock data join during integration test. Negative sampling with standard row-append works correctly with DEEP candidate group. - Update HSTUMatchItemTower to read from DEEP group key (not .sequence) - Update _build_batch to use NEG_DATA_GROUP for candidate items Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent ae5eb9c commit 24c8d5f

File tree

4 files changed

+32
-32
lines changed

4 files changed

+32
-32
lines changed

tzrec/models/hstu.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(
7575
contextual_group_name: str = "contextual",
7676
) -> None:
7777
super().__init__(tower_config, output_dim, similarity, feature_group, features)
78+
self._pass_grouped_features = True
7879
hstu_cfg = tower_config.hstu
7980
uih_dim = sum(feature_group_dims)
8081
stu_dim = hstu_cfg.stu.embedding_dim
@@ -192,6 +193,7 @@ def __init__(
192193
# Override _group_name: parent sets it from tower_config.input ("uih"),
193194
# but item tower needs to read from the candidate feature group.
194195
self._group_name = feature_group.group_name
196+
self._pass_grouped_features = True
195197
cand_dim = sum(feature_group_dims)
196198
self._item_projection: torch.nn.Module = torch.nn.Sequential(
197199
torch.nn.Linear(cand_dim, output_dim),
@@ -207,7 +209,7 @@ def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor:
207209
Returns:
208210
L2-normalized item embeddings of shape (sum_candidates, D).
209211
"""
210-
cand_emb = grouped_features[f"{self._group_name}.sequence"]
212+
cand_emb = grouped_features[self._group_name]
211213
item_emb = self._item_projection(cand_emb)
212214
return F.normalize(item_emb, p=2.0, dim=-1, eps=1e-6)
213215

@@ -268,7 +270,7 @@ def __init__(
268270
cand_features = self.get_features_in_feature_groups([cand_fg])
269271

270272
uih_dims = self.embedding_group.group_dims(tower_cfg.input + ".sequence")
271-
cand_dims = self.embedding_group.group_dims("candidate.sequence")
273+
cand_dims = self.embedding_group.group_dims("candidate")
272274

273275
# Optional contextual features
274276
contextual_feature_dim = 0

tzrec/models/hstu_test.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import unittest
1313

1414
import torch
15-
from parameterized import parameterized
1615
from torchrec import KeyedJaggedTensor
1716

1817
from tzrec.datasets.utils import BASE_DATA_GROUP, NEG_DATA_GROUP, Batch
@@ -43,8 +42,8 @@ def _build_model_config():
4342
),
4443
model_pb2.FeatureGroupConfig(
4544
group_name="candidate",
46-
feature_names=["candidate_ids"],
47-
group_type=model_pb2.FeatureGroupType.JAGGED_SEQUENCE,
45+
feature_names=["item_id"],
46+
group_type=model_pb2.FeatureGroupType.DEEP,
4847
),
4948
]
5049
return model_pb2.ModelConfig(
@@ -94,16 +93,14 @@ def _build_features():
9493
)
9594
),
9695
feature_pb2.FeatureConfig(
97-
sequence_id_feature=feature_pb2.IdFeature(
98-
feature_name="candidate_ids",
99-
sequence_length=10,
96+
id_feature=feature_pb2.IdFeature(
97+
feature_name="item_id",
10098
embedding_dim=48,
101-
num_buckets=3953,
102-
embedding_name="historical_ids",
99+
num_buckets=1000,
103100
)
104101
),
105102
]
106-
return create_features(feature_cfgs)
103+
return create_features(feature_cfgs, neg_fields=["item_id"])
107104

108105

109106
def _build_model(device):
@@ -126,17 +123,19 @@ def _build_batch(device):
126123
"""Build test batch with 2 users.
127124
128125
UIH: user1 has 3 items, user2 has 4 items.
129-
Candidates: 2 positive (1 per user) + 2 negative items.
126+
Candidates: 2 pos (1 per user) + 2 neg items.
130127
"""
128+
# BASE: UIH sequences + positive items
131129
sparse_feature = KeyedJaggedTensor.from_lengths_sync(
132-
keys=["historical_ids", "candidate_ids"],
133-
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13]),
134-
lengths=torch.tensor([3, 4, 2, 2]), # uih: [3,4], candidate: [2,2]
130+
keys=["historical_ids"],
131+
values=torch.tensor([1, 2, 3, 4, 5, 6, 7]),
132+
lengths=torch.tensor([3, 4]),
135133
)
134+
# NEG: positive items (first batch_size) + negative items
136135
neg_sparse_feature = KeyedJaggedTensor.from_lengths_sync(
137-
keys=["candidate_ids"],
138-
values=torch.tensor([20, 21, 22, 23]),
139-
lengths=torch.tensor([2, 2]),
136+
keys=["item_id"],
137+
values=torch.tensor([10, 11, 20, 21]),
138+
lengths=torch.tensor([1, 1, 1, 1]), # 2 pos + 2 neg, each 1 item
140139
)
141140
return Batch(
142141
sparse_features={
@@ -180,21 +179,20 @@ def test_hstu_match_eval(self) -> None:
180179
self.assertIn("recall@1", metric_result)
181180

182181
@unittest.skipIf(*gpu_unavailable)
183-
@parameterized.expand([[TestGraphType.FX_TRACE]])
184-
def test_hstu_match_export(self, graph_type) -> None:
182+
def test_hstu_match_export(self) -> None:
185183
"""Test HSTUMatch export: FX trace for serving."""
186184
device = torch.device("cuda")
187185
hstu = _build_model(device)
188186
batch = _build_batch(device)
189187

190188
hstu.eval()
191-
hstu = create_test_model(hstu, graph_type)
189+
hstu = create_test_model(hstu, TestGraphType.FX_TRACE)
192190
predictions = hstu(batch)
193191

194192
self.assertIn("similarity", predictions)
195193
sim = predictions["similarity"]
196194
self.assertEqual(sim.dim(), 2)
197-
self.assertEqual(sim.size(0), 2) # batch_size
195+
self.assertEqual(sim.size(0), 2)
198196

199197
@unittest.skipIf(*gpu_unavailable)
200198
def test_hstu_match_predict(self) -> None:
@@ -210,7 +208,7 @@ def test_hstu_match_predict(self) -> None:
210208
self.assertIn("similarity", predictions)
211209
sim = predictions["similarity"]
212210
self.assertEqual(sim.dim(), 2)
213-
self.assertEqual(sim.size(0), 2) # batch_size
211+
self.assertEqual(sim.size(0), 2)
214212
self.assertFalse(torch.isnan(sim).any())
215213

216214

tzrec/models/match_model.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def __init__(
220220
self._similarity = similarity
221221
self._feature_group = feature_group
222222
self._features = features
223+
self._pass_grouped_features = False
223224

224225

225226
class MatchModel(BaseModel):
@@ -492,8 +493,9 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]:
492493
embedding (dict): tower output embedding.
493494
"""
494495
grouped_features = self.embedding_group(batch)
495-
return {
496-
f"{self._tower_name}_emb": getattr(self, self._tower_name)(
497-
grouped_features[self._group_name]
498-
)
499-
}
496+
tower = getattr(self, self._tower_name)
497+
if tower._pass_grouped_features:
498+
tower_input = grouped_features
499+
else:
500+
tower_input = grouped_features[self._group_name]
501+
return {f"{self._tower_name}_emb": tower(tower_input)}

tzrec/tests/configs/hstu_fg_mock.config

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,9 @@ feature_configs {
5959
}
6060
}
6161
feature_configs {
62-
sequence_id_feature {
62+
id_feature {
6363
feature_name: "item_id"
6464
expression: "item:item_id"
65-
sequence_length: 10
66-
sequence_delim: ";"
6765
num_buckets: 1000
6866
embedding_dim: 48
6967
}
@@ -78,7 +76,7 @@ model_config {
7876
feature_groups {
7977
group_name: "candidate"
8078
feature_names: "item_id"
81-
group_type: JAGGED_SEQUENCE
79+
group_type: DEEP
8280
}
8381
hstu_match {
8482
hstu_tower {

0 commit comments

Comments
 (0)