Skip to content

Commit 288e04e

Browse files
[bugfix] fix CI test failures for HSTUMatch
- Fix CPU CI: remove @parameterized.expand from test_hstu_match_export (single-case test) to fix skipIf+parameterized decorator interaction that prevented GPU unavailable skip - Fix GPU CI: change candidate group from JAGGED_SEQUENCE to DEEP with id_feature to fix type mismatch (string vs int64) in mock data join during integration test. Negative sampling with standard row-append works correctly with DEEP candidate group. - Update HSTUMatchItemTower to read from DEEP group key (not .sequence) - Update _build_batch to use NEG_DATA_GROUP for candidate items Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent ae5eb9c commit 288e04e

3 files changed

Lines changed: 23 additions & 27 deletions

File tree

tzrec/models/hstu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor:
207207
Returns:
208208
L2-normalized item embeddings of shape (sum_candidates, D).
209209
"""
210-
cand_emb = grouped_features[f"{self._group_name}.sequence"]
210+
cand_emb = grouped_features[self._group_name]
211211
item_emb = self._item_projection(cand_emb)
212212
return F.normalize(item_emb, p=2.0, dim=-1, eps=1e-6)
213213

@@ -268,7 +268,7 @@ def __init__(
268268
cand_features = self.get_features_in_feature_groups([cand_fg])
269269

270270
uih_dims = self.embedding_group.group_dims(tower_cfg.input + ".sequence")
271-
cand_dims = self.embedding_group.group_dims("candidate.sequence")
271+
cand_dims = self.embedding_group.group_dims("candidate")
272272

273273
# Optional contextual features
274274
contextual_feature_dim = 0

tzrec/models/hstu_test.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import unittest
1313

1414
import torch
15-
from parameterized import parameterized
1615
from torchrec import KeyedJaggedTensor
1716

1817
from tzrec.datasets.utils import BASE_DATA_GROUP, NEG_DATA_GROUP, Batch
@@ -43,8 +42,8 @@ def _build_model_config():
4342
),
4443
model_pb2.FeatureGroupConfig(
4544
group_name="candidate",
46-
feature_names=["candidate_ids"],
47-
group_type=model_pb2.FeatureGroupType.JAGGED_SEQUENCE,
45+
feature_names=["item_id"],
46+
group_type=model_pb2.FeatureGroupType.DEEP,
4847
),
4948
]
5049
return model_pb2.ModelConfig(
@@ -94,16 +93,14 @@ def _build_features():
9493
)
9594
),
9695
feature_pb2.FeatureConfig(
97-
sequence_id_feature=feature_pb2.IdFeature(
98-
feature_name="candidate_ids",
99-
sequence_length=10,
96+
id_feature=feature_pb2.IdFeature(
97+
feature_name="item_id",
10098
embedding_dim=48,
101-
num_buckets=3953,
102-
embedding_name="historical_ids",
99+
num_buckets=1000,
103100
)
104101
),
105102
]
106-
return create_features(feature_cfgs)
103+
return create_features(feature_cfgs, neg_fields=["item_id"])
107104

108105

109106
def _build_model(device):
@@ -126,17 +123,19 @@ def _build_batch(device):
126123
"""Build test batch with 2 users.
127124
128125
UIH: user1 has 3 items, user2 has 4 items.
129-
Candidates: 2 positive (1 per user) + 2 negative items.
126+
Candidates: 2 pos (1 per user) + 2 neg items.
130127
"""
128+
# BASE: UIH sequences + positive items
131129
sparse_feature = KeyedJaggedTensor.from_lengths_sync(
132-
keys=["historical_ids", "candidate_ids"],
133-
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13]),
134-
lengths=torch.tensor([3, 4, 2, 2]), # uih: [3,4], candidate: [2,2]
130+
keys=["historical_ids"],
131+
values=torch.tensor([1, 2, 3, 4, 5, 6, 7]),
132+
lengths=torch.tensor([3, 4]),
135133
)
134+
# NEG: positive items (first batch_size) + negative items
136135
neg_sparse_feature = KeyedJaggedTensor.from_lengths_sync(
137-
keys=["candidate_ids"],
138-
values=torch.tensor([20, 21, 22, 23]),
139-
lengths=torch.tensor([2, 2]),
136+
keys=["item_id"],
137+
values=torch.tensor([10, 11, 20, 21]),
138+
lengths=torch.tensor([1, 1, 1, 1]), # 2 pos + 2 neg, each 1 item
140139
)
141140
return Batch(
142141
sparse_features={
@@ -180,21 +179,20 @@ def test_hstu_match_eval(self) -> None:
180179
self.assertIn("recall@1", metric_result)
181180

182181
@unittest.skipIf(*gpu_unavailable)
183-
@parameterized.expand([[TestGraphType.FX_TRACE]])
184-
def test_hstu_match_export(self, graph_type) -> None:
182+
def test_hstu_match_export(self) -> None:
185183
"""Test HSTUMatch export: FX trace for serving."""
186184
device = torch.device("cuda")
187185
hstu = _build_model(device)
188186
batch = _build_batch(device)
189187

190188
hstu.eval()
191-
hstu = create_test_model(hstu, graph_type)
189+
hstu = create_test_model(hstu, TestGraphType.FX_TRACE)
192190
predictions = hstu(batch)
193191

194192
self.assertIn("similarity", predictions)
195193
sim = predictions["similarity"]
196194
self.assertEqual(sim.dim(), 2)
197-
self.assertEqual(sim.size(0), 2) # batch_size
195+
self.assertEqual(sim.size(0), 2)
198196

199197
@unittest.skipIf(*gpu_unavailable)
200198
def test_hstu_match_predict(self) -> None:
@@ -210,7 +208,7 @@ def test_hstu_match_predict(self) -> None:
210208
self.assertIn("similarity", predictions)
211209
sim = predictions["similarity"]
212210
self.assertEqual(sim.dim(), 2)
213-
self.assertEqual(sim.size(0), 2) # batch_size
211+
self.assertEqual(sim.size(0), 2)
214212
self.assertFalse(torch.isnan(sim).any())
215213

216214

tzrec/tests/configs/hstu_fg_mock.config

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,9 @@ feature_configs {
5959
}
6060
}
6161
feature_configs {
62-
sequence_id_feature {
62+
id_feature {
6363
feature_name: "item_id"
6464
expression: "item:item_id"
65-
sequence_length: 10
66-
sequence_delim: ";"
6765
num_buckets: 1000
6866
embedding_dim: 48
6967
}
@@ -78,7 +76,7 @@ model_config {
7876
feature_groups {
7977
group_name: "candidate"
8078
feature_names: "item_id"
81-
group_type: JAGGED_SEQUENCE
79+
group_type: DEEP
8280
}
8381
hstu_match {
8482
hstu_tower {

0 commit comments

Comments
 (0)