From 92526cf0ca3580df1eca45dcc5f86998de841a36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 1 Apr 2026 16:41:01 +0800 Subject: [PATCH 01/12] [refactor] refactor HSTUMatch to use modern STU module with jagged sequences - Replace old HSTUEncoder/SequentialTransductionUnitJagged with STUStack/STULayer - Add SequencePreprocessor for match models (UIH-only, no candidate concat) with optional contextual features, timestamps, and actions - Use JAGGED_SEQUENCE feature groups for native jagged tensor support - Separate UIH and candidate feature groups (two-tower architecture) - Update HSTUMatchTower proto to use HSTU config instead of HSTUEncoder - Add GRSequencePreprocessor to GRInputPreprocessor oneof - Fix WhichOneof field name bug in create_output_postprocessor - Remove unused HSTUEncoder class and tests from sequence.py Co-Authored-By: Claude Opus 4.6 (1M context) --- tzrec/models/hstu.py | 339 +++++++++++++++++--------- tzrec/models/hstu_test.py | 105 +++++--- tzrec/modules/gr/postprocessors.py | 2 +- tzrec/modules/gr/preprocessors.py | 223 +++++++++++++++++ tzrec/modules/sequence.py | 212 +--------------- tzrec/modules/sequence_test.py | 49 ---- tzrec/protos/models/match_model.proto | 2 - tzrec/protos/module.proto | 9 + tzrec/protos/tower.proto | 9 +- 9 files changed, 525 insertions(+), 425 deletions(-) diff --git a/tzrec/models/hstu.py b/tzrec/models/hstu.py index 1ef8877a..7efc4607 100644 --- a/tzrec/models/hstu.py +++ b/tzrec/models/hstu.py @@ -13,38 +13,52 @@ import torch import torch.nn.functional as F -from torch._tensor import Tensor -from tzrec.datasets.utils import NEG_DATA_GROUP, Batch +from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature from tzrec.models.match_model import MatchModel, MatchTowerWoEG from tzrec.modules.embedding import EmbeddingGroup -from tzrec.modules.sequence import HSTUEncoder +from tzrec.modules.gr.positional_encoder import HSTUPositionalEncoder +from tzrec.modules.gr.postprocessors import ( + OutputPostprocessor, + create_output_postprocessor, +) +from tzrec.modules.gr.preprocessors import ( + InputPreprocessor, + create_input_preprocessor, +) +from tzrec.modules.gr.stu import STU, STULayer, STUStack +from tzrec.modules.norm import LayerNorm +from tzrec.modules.utils import init_linear_xavier_weights_zero_bias +from tzrec.ops.utils import set_static_max_seq_lens from tzrec.protos import model_pb2, simi_pb2, tower_pb2 -from tzrec.utils import config_util +from tzrec.protos.model_pb2 import ModelConfig +from tzrec.protos.models import match_model_pb2 +from tzrec.utils.config_util import config_to_kwargs +from tzrec.utils.fx_util import fx_int_item - -@torch.fx.wrap -def _update_dict_tensor( - tensor_dict: Dict[str, torch.Tensor], - new_tensor_dict: Optional[Dict[str, Optional[torch.Tensor]]], -) -> None: - if new_tensor_dict: - for k, v in new_tensor_dict.items(): - if v is not None: - tensor_dict[k] = v +torch.fx.wrap(fx_int_item) class HSTUMatchUserTower(MatchTowerWoEG): - """HSTU Match model user tower. + """HSTU Match model user tower using modern STU module. + + Processes UIH (User Interaction History) sequences through SequencePreprocessor, + HSTUPositionalEncoder, and STUStack to produce user embeddings. During training, + returns one embedding per UIH position (autoregressive). During inference, returns + the last position embedding per user for ANN retrieval. Args: - tower_config (Tower): user tower config. - output_dim (int): user output embedding dimension. - similarity (Similarity): when use COSINE similarity, - will norm the output embedding. - feature_group (FeatureGroupConfig): feature group config. + tower_config (HSTUMatchTower): user tower config with HSTU settings. + output_dim (int): user output embedding dimension (stu.embedding_dim). + similarity (Similarity): similarity method config. + feature_group (FeatureGroupConfig): uih feature group config. + feature_group_dims (list): per-feature embedding dims in the group. features (list): list of features. + model_config (ModelConfig): full model config. + contextual_feature_dim (int): dimension of each contextual feature. + max_contextual_seq_len (int): number of contextual features. + contextual_group_name (str): contextual group name in grouped features. """ def __init__( @@ -55,177 +69,262 @@ def __init__( feature_group: model_pb2.FeatureGroupConfig, feature_group_dims: List[int], features: List[BaseFeature], - model_config: model_pb2.ModelConfig, + model_config: ModelConfig, + contextual_feature_dim: int = 0, + max_contextual_seq_len: int = 0, + contextual_group_name: str = "contextual", ) -> None: super().__init__(tower_config, output_dim, similarity, feature_group, features) - self.tower_config = tower_config - encoder_config = tower_config.hstu_encoder - seq_config_dict = config_util.config_to_kwargs(encoder_config) - sequence_dim = sum(feature_group_dims) - seq_config_dict["sequence_dim"] = sequence_dim - self.seq_encoder = HSTUEncoder(**seq_config_dict) + hstu_cfg = tower_config.hstu + uih_dim = sum(feature_group_dims) + stu_dim = hstu_cfg.stu.embedding_dim + + # Preprocessor: projects UIH, handles optional contextual/actions + self._input_preprocessor: InputPreprocessor = create_input_preprocessor( + hstu_cfg.input_preprocessor, + uih_embedding_dim=uih_dim, + output_embedding_dim=stu_dim, + contextual_feature_dim=contextual_feature_dim, + max_contextual_seq_len=max_contextual_seq_len, + contextual_group_name=contextual_group_name, + ) + + # Positional encoder + pos_kwargs = config_to_kwargs(hstu_cfg.positional_encoder) + self._positional_encoder: HSTUPositionalEncoder = HSTUPositionalEncoder( + embedding_dim=stu_dim, + contextual_seq_len=self._input_preprocessor.contextual_seq_len(), + **pos_kwargs, + ) + + # STU stack + stu_kwargs = config_to_kwargs(hstu_cfg.stu) + self._stu_module: STU = STUStack( + stu_list=[STULayer(**stu_kwargs) for _ in range(hstu_cfg.attn_num_layers)], + ) + + # Output postprocessor (L2 norm or layer norm) + self._output_postprocessor: OutputPostprocessor = create_output_postprocessor( + hstu_cfg.output_postprocessor, + embedding_dim=stu_dim, + ) + + self._input_dropout_ratio: float = hstu_cfg.input_dropout_ratio def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor: - """Forward the tower. + """Forward the user tower. Args: - grouped_features: Dictionary containing grouped feature tensors + grouped_features: dictionary of embedded features from EmbeddingGroup. Returns: - torch.Tensor: The output tensor from the tower + user embeddings. Training: (sum_uih_len, D) jagged, one per position. + Inference: (B, D), last position embedding per user. """ - output = self.seq_encoder(grouped_features) + # 1. Preprocess: project UIH + optional contextual/actions + ( + max_seq_len, + total_uih_len, + _, + seq_lengths, + seq_offsets, + seq_timestamps, + seq_embeddings, + num_targets, + ) = self._input_preprocessor(grouped_features) + + # 2. Positional encoding + seq_embeddings = self._positional_encoder( + max_seq_len=max_seq_len, + seq_lengths=seq_lengths, + seq_offsets=seq_offsets, + seq_timestamps=seq_timestamps, + seq_embeddings=seq_embeddings, + num_targets=num_targets, + ) + + # 3. Input dropout + STU + seq_embeddings = F.dropout( + seq_embeddings, p=self._input_dropout_ratio, training=self.training + ) + seq_embeddings = self._stu_module( + x=seq_embeddings, + x_offsets=seq_offsets, + max_seq_len=max_seq_len, + num_targets=num_targets, + ) + + # 4. Output postprocessor + user_emb = self._output_postprocessor( + seq_embeddings=seq_embeddings, + seq_timestamps=seq_timestamps, + ) - return output + if not self.training: + # Inference: extract last position embedding per user → (B, D) + user_emb = user_emb[seq_offsets[1:] - 1] + + return user_emb class HSTUMatchItemTower(MatchTowerWoEG): """HSTU Match model item tower. + Projects candidate embeddings to STU embedding dimension and L2 normalizes. + Args: - tower_config (Tower): item tower config. - output_dim (int): item output embedding dimension. - similarity (Similarity): when use COSINE similarity, - will norm the output embedding. - feature_group (FeatureGroupConfig): feature group config. + tower_config (HSTUMatchTower): tower config. + output_dim (int): item output embedding dimension (stu.embedding_dim). + similarity (Similarity): similarity method config. + feature_group (FeatureGroupConfig): candidate feature group config. + feature_group_dims (list): per-feature embedding dims in the group. features (list): list of features. """ def __init__( self, - tower_config: tower_pb2.Tower, + tower_config: tower_pb2.HSTUMatchTower, output_dim: int, similarity: simi_pb2.Similarity, feature_group: model_pb2.FeatureGroupConfig, + feature_group_dims: List[int], features: List[BaseFeature], ) -> None: super().__init__(tower_config, output_dim, similarity, feature_group, features) - self.tower_config = tower_config + cand_dim = sum(feature_group_dims) + self._item_projection: torch.nn.Module = torch.nn.Sequential( + torch.nn.Linear(cand_dim, output_dim), + LayerNorm(output_dim), + ).apply(init_linear_xavier_weights_zero_bias) def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor: - """Forward the tower. + """Forward the item tower. Args: - grouped_features: Dictionary containing grouped feature tensors + grouped_features: dictionary of embedded features from EmbeddingGroup. Returns: - torch.Tensor: The output tensor from the tower + L2-normalized item embeddings of shape (sum_candidates, D). """ - output = grouped_features[f"{self._group_name}.sequence"] - output = F.normalize(output, p=2.0, dim=1, eps=1e-6) - - return output + cand_emb = grouped_features[f"{self._group_name}.sequence"] + item_emb = self._item_projection(cand_emb) + return F.normalize(item_emb, p=2.0, dim=-1, eps=1e-6) class HSTUMatch(MatchModel): - """HSTU Match model. + """HSTU Match model for two-tower retrieval. + + Uses modern STUStack for user sequence encoding with native jagged sequences. + User tower processes UIH through SequencePreprocessor + STU. Item tower + projects and normalizes candidate embeddings. Similarity via dot product. + + Feature groups: + - "uih" (JAGGED_SEQUENCE): user interaction history + - "candidate" (JAGGED_SEQUENCE): candidate items (pos + neg) + - "contextual" (optional, DEEP/SEQUENCE): user contextual features + + Training: autoregressive — each UIH position produces a user embedding, + compared against candidates via dot product similarity. + Inference: last UIH position → user embedding for ANN retrieval. Args: model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( self, - model_config: model_pb2.ModelConfig, + model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: Optional[List[str]] = None, **kwargs: Any, ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) - assert len(model_config.feature_groups) == 1 + assert model_config.WhichOneof("model") == "hstu_match", ( + "invalid model config: %s" % self._base_model_config.WhichOneof("model") + ) + assert isinstance(self._model_config, match_model_pb2.HSTUMatch) + + tower_cfg = self._model_config.hstu_tower + set_static_max_seq_lens([tower_cfg.max_seq_len]) + self.embedding_group = EmbeddingGroup( features, list(model_config.feature_groups) ) - name_to_feature_group = {x.group_name: x for x in model_config.feature_groups} - feature_group = name_to_feature_group[self._model_config.hstu_tower.input] - - used_features = self.get_features_in_feature_groups([feature_group]) + stu_dim = tower_cfg.hstu.stu.embedding_dim + + # Resolve feature groups + name_to_fg = {x.group_name: x for x in model_config.feature_groups} + uih_fg = name_to_fg[tower_cfg.input] + cand_fg = name_to_fg.get("candidate") + assert cand_fg is not None, "HSTUMatch requires a 'candidate' feature group." + + uih_features = self.get_features_in_feature_groups([uih_fg]) + cand_features = self.get_features_in_feature_groups([cand_fg]) + + uih_dims = self.embedding_group.group_dims(tower_cfg.input + ".sequence") + cand_dims = self.embedding_group.group_dims("candidate.sequence") + + # Optional contextual features + contextual_feature_dim = 0 + max_contextual_seq_len = 0 + contextual_group_name = "contextual" + if "contextual" in name_to_fg: + ctx_group_type = self.embedding_group.group_type("contextual") + if ctx_group_type == model_pb2.SEQUENCE: + contextual_group_name = "contextual.query" + elif ctx_group_type == model_pb2.DEEP: + contextual_group_name = "contextual" + ctx_dims = self.embedding_group.group_dims(contextual_group_name) + if len(set(ctx_dims)) > 1: + raise ValueError( + "output_dim of features in contextual features_group " + f"must be same, but now {set(ctx_dims)}." + ) + contextual_feature_dim = ctx_dims[0] + max_contextual_seq_len = len(ctx_dims) self.user_tower = HSTUMatchUserTower( - self._model_config.hstu_tower, - self._model_config.output_dim, - self._model_config.similarity, - feature_group, - self.embedding_group.group_dims( - self._model_config.hstu_tower.input + ".sequence" - ), - used_features, - model_config, + tower_config=tower_cfg, + output_dim=stu_dim, + similarity=self._model_config.similarity, + feature_group=uih_fg, + feature_group_dims=uih_dims, + features=uih_features, + model_config=model_config, + contextual_feature_dim=contextual_feature_dim, + max_contextual_seq_len=max_contextual_seq_len, + contextual_group_name=contextual_group_name, ) self.item_tower = HSTUMatchItemTower( - self._model_config.hstu_tower, - self._model_config.output_dim, - self._model_config.similarity, - feature_group, - used_features, + tower_config=tower_cfg, + output_dim=stu_dim, + similarity=self._model_config.similarity, + feature_group=cand_fg, + feature_group_dims=cand_dims, + features=cand_features, ) - self.seq_tower_input = self._model_config.hstu_tower.input + self._temperature = self._model_config.temperature - def predict(self, batch: Batch) -> Dict[str, Tensor]: + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Forward the model. Args: batch (Batch): input batch data. Return: - predictions (dict): a dict of predicted result. + predictions (dict): a dict of predicted result with 'similarity' key. """ - batch_sparse_features = batch.sparse_features[NEG_DATA_GROUP] - # Get batch_size and neg_sample_size from batch_sparse_features - batch_size = batch.labels[self._label_name].shape[0] - neg_sample_size = batch_sparse_features.lengths()[batch_size] - 1 grouped_features = self.embedding_group(batch) - item_group_features = { - self.seq_tower_input + ".sequence": grouped_features[ - self.seq_tower_input + ".sequence" - ][batch_size:, : neg_sample_size + 1], - } - item_tower_emb = self.item_tower(item_group_features) - user_group_features = { - self.seq_tower_input + ".sequence": grouped_features[ - self.seq_tower_input + ".sequence" - ][:batch_size], - self.seq_tower_input + ".sequence_length": grouped_features[ - self.seq_tower_input + ".sequence_length" - ][:batch_size], - } - user_tower_emb = self.user_tower(user_group_features) - ui_sim = ( - self.simi(user_tower_emb, item_tower_emb, neg_for_each_sample=True) - / self._model_config.temperature - ) - return {"similarity": ui_sim} + user_emb = self.user_tower(grouped_features) + item_emb = self.item_tower(grouped_features) - def simi( - self, - user_emb: torch.Tensor, - item_emb: torch.Tensor, - neg_for_each_sample: bool = False, - ) -> torch.Tensor: - """Override the sim method in MatchModel to calculate similarity.""" - if self._in_batch_negative: - return torch.mm(user_emb, item_emb.T) - else: - batch_size = user_emb.size(0) - pos_item_emb = item_emb[:, 0] - neg_item_emb = item_emb[:, 1:].reshape(-1, item_emb.shape[-1]) - pos_ui_sim = torch.sum( - torch.multiply(user_emb, pos_item_emb), dim=-1, keepdim=True - ) - neg_ui_sim = None - if not neg_for_each_sample: - neg_ui_sim = torch.matmul(user_emb, neg_item_emb.transpose(0, 1)) - else: - num_neg_per_user = neg_item_emb.size(0) // batch_size - neg_size = batch_size * num_neg_per_user - neg_item_emb = neg_item_emb[:neg_size] - neg_item_emb = neg_item_emb.view(batch_size, num_neg_per_user, -1) - neg_ui_sim = torch.sum(user_emb.unsqueeze(1) * neg_item_emb, dim=-1) - return torch.cat([pos_ui_sim, neg_ui_sim], dim=-1) + ui_sim = self.sim(user_emb, item_emb) / self._temperature + return {"similarity": ui_sim} diff --git a/tzrec/models/hstu_test.py b/tzrec/models/hstu_test.py index d74a1865..d620ec19 100644 --- a/tzrec/models/hstu_test.py +++ b/tzrec/models/hstu_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Alibaba Group; +# Copyright (c) 2024-2025, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,24 +15,31 @@ from parameterized import parameterized from torchrec import KeyedJaggedTensor -from tzrec.datasets.utils import BASE_DATA_GROUP, NEG_DATA_GROUP, Batch +from tzrec.datasets.utils import BASE_DATA_GROUP, Batch from tzrec.features.feature import create_features from tzrec.models.hstu import HSTUMatch +from tzrec.modules.utils import Kernel from tzrec.protos import ( feature_pb2, loss_pb2, model_pb2, - seq_encoder_pb2, + module_pb2, tower_pb2, ) from tzrec.protos.models import match_model_pb2 from tzrec.utils.state_dict_util import init_parameters -from tzrec.utils.test_util import TestGraphType, create_test_model +from tzrec.utils.test_util import TestGraphType, create_test_model, gpu_unavailable -class HSTUTest(unittest.TestCase): +class HSTUMatchTest(unittest.TestCase): + """Tests for the refactored HSTUMatch model with STU and jagged sequences.""" + + @unittest.skipIf(*gpu_unavailable) @parameterized.expand([[TestGraphType.NORMAL], [TestGraphType.FX_TRACE]]) - def test_hstu(self, graph_type) -> None: + def test_hstu_match(self, graph_type) -> None: + """Test HSTUMatch with separate uih/candidate JAGGED_SEQUENCE groups.""" + device = torch.device("cuda") + feature_cfgs = [ feature_pb2.FeatureConfig( sequence_id_feature=feature_pb2.IdFeature( @@ -43,39 +50,55 @@ def test_hstu(self, graph_type) -> None: ) ), feature_pb2.FeatureConfig( - id_feature=feature_pb2.IdFeature( - feature_name="item_id", + sequence_id_feature=feature_pb2.IdFeature( + feature_name="candidate_ids", + sequence_length=10, embedding_dim=48, - num_buckets=1000, - embedding_name="item_id", + num_buckets=3953, + embedding_name="historical_ids", ) ), ] features = create_features(feature_cfgs) + feature_groups = [ model_pb2.FeatureGroupConfig( - group_name="sequence", + group_name="uih", feature_names=["historical_ids"], - group_type=model_pb2.FeatureGroupType.SEQUENCE, + group_type=model_pb2.FeatureGroupType.JAGGED_SEQUENCE, + ), + model_pb2.FeatureGroupConfig( + group_name="candidate", + feature_names=["candidate_ids"], + group_type=model_pb2.FeatureGroupType.JAGGED_SEQUENCE, ), ] + model_config = model_pb2.ModelConfig( feature_groups=feature_groups, hstu_match=match_model_pb2.HSTUMatch( hstu_tower=tower_pb2.HSTUMatchTower( - input="sequence", - hstu_encoder=seq_encoder_pb2.HSTUEncoder( - sequence_dim=48, - attn_dim=48, - linear_dim=48, - input="sequence", - max_seq_length=210, - num_blocks=2, - num_heads=1, - linear_activation="silu", - linear_config="uvqk", - max_output_len=0, + input="uih", + hstu=module_pb2.HSTU( + stu=module_pb2.STU( + embedding_dim=48, + num_heads=1, + hidden_dim=48, + attention_dim=48, + output_dropout_ratio=0.2, + ), + attn_num_layers=2, + positional_encoder=module_pb2.GRPositionalEncoder( + num_position_buckets=512, + ), + input_preprocessor=module_pb2.GRInputPreprocessor( + sequence_preprocessor=(module_pb2.GRSequencePreprocessor()), + ), + output_postprocessor=module_pb2.GROutputPostprocessor( + l2norm_postprocessor=(module_pb2.GRL2NormPostprocessor()), + ), ), + max_seq_len=210, ), temperature=0.05, ), @@ -85,34 +108,40 @@ def test_hstu(self, graph_type) -> None: ) ], ) + hstu = HSTUMatch( model_config=model_config, features=features, labels=["label"], - sampler_type="negative_sampler", ) - init_parameters(hstu, device=torch.device("cpu")) + init_parameters(hstu, device=device) + hstu.to(device) + hstu.set_kernel(Kernel.PYTORCH) + hstu.eval() hstu = create_test_model(hstu, graph_type) - # Create test sequence data + # Build test batch: 2 users + # UIH: user1 has 3 history items, user2 has 4 + # Candidates: user1 has 2 candidates (1 pos + 1 neg), + # user2 has 2 candidates sparse_feature = KeyedJaggedTensor.from_lengths_sync( - keys=["historical_ids"], - values=torch.tensor([1, 2, 3, 4, 5, 2, 7, 8, 9, 4, 11, 5, 13, 14, 15]), - # sequence length is - # 2, 3, 2 (neg_seq, first is pos), 2 (neg_seq, first is pos)... - lengths=torch.tensor([2, 3, 2, 2, 2, 2, 2]), + keys=["historical_ids", "candidate_ids"], + values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13]), + lengths=torch.tensor( + [3, 4, 2, 2] # uih: [3,4], candidate: [2,2] + ), ) batch = Batch( - sparse_features={ - NEG_DATA_GROUP: sparse_feature, - BASE_DATA_GROUP: sparse_feature, - }, + sparse_features={BASE_DATA_GROUP: sparse_feature}, labels={"label": torch.tensor([1, 1])}, - ) + ).to(device) predictions = hstu(batch) - self.assertEqual(predictions["similarity"].size(), (5, 2)) + self.assertIn("similarity", predictions) + sim = predictions["similarity"] + self.assertEqual(sim.dim(), 2) + self.assertEqual(sim.size(0), 2) # batch_size if __name__ == "__main__": diff --git a/tzrec/modules/gr/postprocessors.py b/tzrec/modules/gr/postprocessors.py index 9c7891bf..5d6c6bca 100644 --- a/tzrec/modules/gr/postprocessors.py +++ b/tzrec/modules/gr/postprocessors.py @@ -229,7 +229,7 @@ def create_output_postprocessor( ) -> OutputPostprocessor: """Create OutputPostprocessor.""" if isinstance(postprocessor_cfg, module_pb2.GROutputPostprocessor): - postprocessor_type = postprocessor_cfg.WhichOneof("input_preprocessor") + postprocessor_type = postprocessor_cfg.WhichOneof("output_postprocessor") config_dict = config_to_kwargs(getattr(postprocessor_cfg, postprocessor_type)) else: assert len(postprocessor_cfg) == 1, ( diff --git a/tzrec/modules/gr/preprocessors.py b/tzrec/modules/gr/preprocessors.py index d00af706..aaa2928e 100644 --- a/tzrec/modules/gr/preprocessors.py +++ b/tzrec/modules/gr/preprocessors.py @@ -486,6 +486,227 @@ def contextual_seq_len(self) -> int: return self._max_contextual_seq_len +class SequencePreprocessor(InputPreprocessor): + """Preprocessor for sequence-only models without candidate concatenation. + + Processes UIH (User Interaction History) sequences with optional contextual + features, timestamps, and actions. Projects UIH embeddings to the STU + embedding dimension. Suitable for two-tower match/retrieval models where + user and item towers are independent. + + Args: + uih_embedding_dim (int): dimension of UIH input embeddings. + output_embedding_dim (int): dimension of output embeddings (STU dim). + contextual_feature_dim (int): dimension of each contextual feature. + Inferred from embedding_group at model init time. + max_contextual_seq_len (int): number of contextual features. + Inferred from embedding_group at model init time. + contextual_group_name (str): name of contextual group in grouped features. + action_encoder (Dict[str, Any]): optional ActionEncoder config. + action_mlp (Dict[str, Any]): optional action MLP config. + is_inference (bool): whether to run in inference mode. + """ + + def __init__( + self, + uih_embedding_dim: int, + output_embedding_dim: int, + contextual_feature_dim: int = 0, + max_contextual_seq_len: int = 0, + contextual_group_name: str = "contextual", + action_encoder: Optional[Dict[str, Any]] = None, + action_mlp: Optional[Dict[str, Any]] = None, + is_inference: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(is_inference=is_inference) + self._uih_embedding_dim: int = uih_embedding_dim + self._output_embedding_dim: int = output_embedding_dim + self._contextual_feature_dim: int = contextual_feature_dim + self._max_contextual_seq_len: int = max_contextual_seq_len + self._contextual_group_name: str = contextual_group_name + + # Project UIH to STU embedding dim + self._input_projection: torch.nn.Module = torch.nn.Linear( + uih_embedding_dim, output_embedding_dim + ) + + # Optional contextual feature projection + if max_contextual_seq_len > 0: + std = 1.0 * sqrt(2.0 / float(contextual_feature_dim + output_embedding_dim)) + self._batched_contextual_linear_weights = torch.nn.Parameter( + torch.empty( + ( + max_contextual_seq_len, + contextual_feature_dim, + output_embedding_dim, + ) + ).normal_(0.0, std) + ) + self._batched_contextual_linear_bias = torch.nn.Parameter( + torch.empty((max_contextual_seq_len, 1, output_embedding_dim)).fill_( + 0.0 + ) + ) + + # Optional action encoder + self._action_encoder_cfg = action_encoder + if action_encoder is not None: + contextual_embedding_dim: int = ( + max_contextual_seq_len * contextual_feature_dim + ) + self._action_encoder: ActionEncoder = create_action_encoder( + action_encoder, + is_inference=is_inference, + ) + assert action_mlp is not None, ( + "action_mlp must be set when action_encoder is set." + ) + self._action_embedding_mlp: ContextualizedMLP = create_contextualized_mlp( + action_mlp, + contextual_embedding_dim=contextual_embedding_dim, + sequential_input_dim=self._action_encoder.output_dim, + sequential_output_dim=output_embedding_dim, + is_inference=is_inference, + ) + + def forward( + self, grouped_features: Dict[str, torch.Tensor] + ) -> Tuple[ + int, + int, + int, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + """Forward the module. + + Args: + grouped_features (Dict[str, torch.Tensor]): embedding group features. + + Returns: + output_max_seq_len (int): output maximum sequence length. + output_total_uih_len (int): output total user history sequence length. + output_total_targets (int): always 0 (no candidates). + output_seq_lengths (torch.Tensor): output sequence lengths. + output_seq_offsets (torch.Tensor): output sequence offsets. + output_seq_timestamps (torch.Tensor): output sequence timestamps. + output_seq_embeddings (torch.Tensor): output sequence embeddings. + output_num_targets (torch.Tensor): always zeros (no candidates). + """ + uih_embeddings = grouped_features["uih.sequence"] + uih_seq_lengths = grouped_features["uih.sequence_length"] + max_uih_len = fx_int_item(uih_seq_lengths.max()) + total_uih_len = fx_int_item(uih_seq_lengths.sum()) + uih_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(uih_seq_lengths) + + # Project UIH to output dim + seq_embeddings = self._input_projection(uih_embeddings) + + # Optional: contextual features + contextual_input_embeddings: Optional[torch.Tensor] = None + contextual_embeddings: Optional[torch.Tensor] = None + if self._max_contextual_seq_len > 0: + contextual_input_embeddings = grouped_features[self._contextual_group_name] + contextual_embeddings = torch.baddbmm( + self._batched_contextual_linear_bias.to( + contextual_input_embeddings.dtype + ), + contextual_input_embeddings.view( + -1, + self._max_contextual_seq_len, + self._contextual_feature_dim, + ).transpose(0, 1), + self._batched_contextual_linear_weights.to( + contextual_input_embeddings.dtype + ), + ).transpose(0, 1) + + # Optional: action embeddings + if self._action_encoder_cfg is not None: + action_embeddings = self._action_encoder( + seq_actions=grouped_features["uih_action.sequence"].to(torch.int64), + max_uih_len=max_uih_len, + max_targets=0, + uih_offsets=uih_offsets, + target_offsets=uih_offsets, + total_uih_len=total_uih_len, + total_targets=0, + ).to(uih_embeddings.dtype) + action_embeddings = self._action_embedding_mlp( + seq_embeddings=action_embeddings, + seq_offsets=uih_offsets, + max_seq_len=max_uih_len, + contextual_embeddings=contextual_input_embeddings, + ) + seq_embeddings = seq_embeddings + action_embeddings + + # Timestamps: always use zeros (timestamps are optional for match models) + seq_timestamps = torch.zeros( + total_uih_len, dtype=torch.int64, device=uih_embeddings.device + ) + + output_max_seq_len = max_uih_len + output_seq_lengths = uih_seq_lengths + output_seq_offsets = uih_offsets + output_total_uih_len = total_uih_len + + # Concat contextual embeddings if present + if self._max_contextual_seq_len > 0: + seq_embeddings = concat_2D_jagged( + values_left=fx_unwrap_optional_tensor(contextual_embeddings).reshape( + -1, self._output_embedding_dim + ), + values_right=seq_embeddings, + max_len_left=self._max_contextual_seq_len, + max_len_right=max_uih_len, + offsets_left=None, + offsets_right=uih_offsets, + kernel=self.kernel(), + ) + seq_timestamps = concat_2D_jagged( + values_left=_fx_timestamp_contextual_zeros( + seq_timestamps, + uih_seq_lengths, + self._max_contextual_seq_len, + ), + values_right=seq_timestamps.unsqueeze(-1), + max_len_left=self._max_contextual_seq_len, + max_len_right=max_uih_len, + offsets_left=None, + offsets_right=uih_offsets, + kernel=self.kernel(), + ).squeeze(-1) + output_max_seq_len = max_uih_len + self._max_contextual_seq_len + output_total_uih_len = ( + total_uih_len + self._max_contextual_seq_len * uih_seq_lengths.size(0) + ) + output_seq_lengths = uih_seq_lengths + self._max_contextual_seq_len + output_seq_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + output_seq_lengths + ) + + num_targets = torch.zeros_like(output_seq_lengths) + + return ( + output_max_seq_len, + output_total_uih_len, + 0, + output_seq_lengths, + output_seq_offsets, + seq_timestamps, + seq_embeddings, + num_targets, + ) + + def contextual_seq_len(self) -> int: + """Contextual feature sequence length.""" + return self._max_contextual_seq_len + + def create_input_preprocessor( preprocessor_cfg: Union[module_pb2.GRInputPreprocessor, Dict[str, Any]], **kwargs: Any, @@ -508,5 +729,7 @@ def create_input_preprocessor( ) elif preprocessor_type == "contextual_interleave_preprocessor": return ContextualInterleavePreprocessor(**config_dict) + elif preprocessor_type == "sequence_preprocessor": + return SequencePreprocessor(**config_dict) else: raise RuntimeError(f"Unknown preprocessor type: {preprocessor_type}") diff --git a/tzrec/modules/sequence.py b/tzrec/modules/sequence.py index 880402fb..fc316521 100644 --- a/tzrec/modules/sequence.py +++ b/tzrec/modules/sequence.py @@ -10,18 +10,13 @@ # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional import numpy as np import torch from torch import nn from torch.nn import functional as F -from tzrec.modules.hstu import ( - HSTUCacheState, - RelativeBucketedTimeAndPositionBasedBias, - SequentialTransductionUnitJagged, -) from tzrec.modules.mlp import MLP from tzrec.protos.seq_encoder_pb2 import SeqEncoderConfig from tzrec.utils import config_util @@ -372,211 +367,6 @@ def forward(self, sequence_embedded: Dict[str, torch.Tensor]) -> torch.Tensor: ) # [B, (L+1)*C] -class HSTUEncoder(SequenceEncoder): - """HSTU sequence encoder. - - Args: - sequence_dim (int): sequence tensor channel dimension. - query_dim (int): query tensor channel dimension. - input(str): input feature group name. - attn_mlp (dict): target attention MLP module parameters. - """ - - def __init__( - self, - sequence_dim: int, - attn_dim: int, - linear_dim: int, - input: str, - max_seq_length: int, - pos_dropout_rate: float = 0.2, - linear_dropout_rate: float = 0.2, - attn_dropout_rate: float = 0.0, - normalization: str = "rel_bias", - linear_activation: str = "silu", - linear_config: str = "uvqk", - num_heads: int = 1, - num_blocks: int = 2, - max_output_len: int = 10, - time_bucket_size: int = 128, - **kwargs: Optional[Dict[str, Any]], - ) -> None: - super().__init__(input) - self._sequence_dim = sequence_dim - self._attn_dim = attn_dim - self._linear_dim = linear_dim - self._max_seq_length = max_seq_length - self._query_name = f"{input}.query" - self._sequence_name = f"{input}.sequence" - self._sequence_length_name = f"{input}.sequence_length" - max_output_len = max_output_len + 1 # for target - self.position_embed = nn.Embedding( - self._max_seq_length + max_output_len, self._sequence_dim - ) - self.dropout_rate = pos_dropout_rate - self.enable_relative_attention_bias = True - self.autocast_dtype = None - self._attention_layers: nn.ModuleList = nn.ModuleList( - modules=[ - SequentialTransductionUnitJagged( - embedding_dim=self._sequence_dim, - linear_hidden_dim=self._linear_dim, - attention_dim=self._attn_dim, - normalization=normalization, - linear_config=linear_config, - linear_activation=linear_activation, - num_heads=num_heads, - relative_attention_bias_module=( - RelativeBucketedTimeAndPositionBasedBias( - max_seq_len=max_seq_length + max_output_len, - num_buckets=time_bucket_size, - bucketization_fn=lambda x: ( - torch.log(torch.abs(x).clamp(min=1)) / 0.301 - ).long(), - ) - if self.enable_relative_attention_bias - else None - ), - dropout_ratio=linear_dropout_rate, - attn_dropout_ratio=attn_dropout_rate, - concat_ua=False, - ) - for _ in range(num_blocks) - ] - ) - self.register_buffer( - "_attn_mask", - torch.triu( - torch.ones( - ( - self._max_seq_length + max_output_len, - self._max_seq_length + max_output_len, - ), - dtype=torch.bool, - ), - diagonal=1, - ), - ) - self._autocast_dtype = None - - def output_dim(self) -> int: - """Output dimension of the module.""" - return self._sequence_dim - - def forward(self, sequence_embedded: Dict[str, torch.Tensor]) -> torch.Tensor: - """Forward the module.""" - sequence = sequence_embedded[self._sequence_name] # B, N, E - sequence_length = sequence_embedded[self._sequence_length_name] # N - # max_seq_length = sequence.size(1) - float_dtype = sequence.dtype - - # Add positional embeddings and apply dropout - positions = ( - fx_arange(sequence.size(1), device=sequence.device) - .unsqueeze(0) - .expand(sequence.size(0), -1) - ) - sequence = sequence * (self._sequence_dim**0.5) + self.position_embed(positions) - sequence = F.dropout(sequence, p=self.dropout_rate, training=self.training) - sequence_mask = fx_arange( - sequence.size(1), device=sequence_length.device - ).unsqueeze(0) < sequence_length.unsqueeze(1) - sequence = sequence * sequence_mask.unsqueeze(-1).to(float_dtype) - - invalid_attn_mask = 1.0 - self._attn_mask.to(float_dtype) - sequence_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( - sequence_length - ) - sequence = torch.ops.fbgemm.dense_to_jagged(sequence, [sequence_offsets])[0] - - all_timestamps = None - jagged_x, cache_states = self.jagged_forward( - x=sequence, - x_offsets=sequence_offsets, - all_timestamps=all_timestamps, - invalid_attn_mask=invalid_attn_mask, - delta_x_offsets=None, - cache=None, - return_cache_states=False, - ) - # post processing: L2 Normalization - output_embeddings = jagged_x - output_embeddings = output_embeddings[..., : self._sequence_dim] - output_embeddings = output_embeddings / torch.clamp( - torch.linalg.norm(output_embeddings, ord=None, dim=-1, keepdim=True), - min=1e-6, - ) - if not self.training: - output_embeddings = self.get_current_embeddings( - sequence_length, output_embeddings - ) - return output_embeddings - - def jagged_forward( - self, - x: torch.Tensor, - x_offsets: torch.Tensor, - all_timestamps: Optional[torch.Tensor], - invalid_attn_mask: torch.Tensor, - delta_x_offsets: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - cache: Optional[List[HSTUCacheState]] = None, - return_cache_states: bool = False, - ) -> Tuple[torch.Tensor, List[HSTUCacheState]]: - r"""Jagged forward. - - Args: - x: (\sum_i N_i, D) x float - x_offsets: (B + 1) x int32 - all_timestamps: (B, 1 + N) x int64 - invalid_attn_mask: (B, N, N) x float, each element in {0, 1} - delta_x_offsets: offsets for x - cache: cache contents - return_cache_states: bool. True if we should return cache states. - - Returns: - x' = f(x), (\sum_i N_i, D) x float - """ - cache_states: List[HSTUCacheState] = [] - - with torch.autocast( - "cuda", - enabled=self._autocast_dtype is not None, - dtype=self._autocast_dtype or torch.float16, - ): - for i, layer in enumerate(self._attention_layers): - x, cache_states_i = layer( - x=x, - x_offsets=x_offsets, - all_timestamps=all_timestamps, - invalid_attn_mask=invalid_attn_mask, - delta_x_offsets=delta_x_offsets, - cache=cache[i] if cache is not None else None, - return_cache_states=return_cache_states, - ) - if return_cache_states: - cache_states.append(cache_states_i) - - return x, cache_states - - def get_current_embeddings( - self, - lengths: torch.Tensor, - encoded_embeddings: torch.Tensor, - ) -> torch.Tensor: - """Get the embeddings of the last past_id as the current embeds. - - Args: - lengths: (B,) x int - encoded_embeddings: (B, N, D,) x float - - Returns: - (B, D,) x float, where [i, :] == encoded_embeddings[i, lengths[i] - 1, :] - """ - offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) - indices = offsets[1:] - 1 - return encoded_embeddings[indices] - - def create_seq_encoder( seq_encoder_config: SeqEncoderConfig, group_total_dim: Dict[str, int] ) -> SequenceEncoder: diff --git a/tzrec/modules/sequence_test.py b/tzrec/modules/sequence_test.py index 96559bc1..767e145e 100644 --- a/tzrec/modules/sequence_test.py +++ b/tzrec/modules/sequence_test.py @@ -16,7 +16,6 @@ from tzrec.modules.sequence import ( DINEncoder, - HSTUEncoder, MultiWindowDINEncoder, PoolingEncoder, SelfAttentionEncoder, @@ -102,54 +101,6 @@ def test_din_encoder_padding(self, graph_type) -> None: self.assertEqual(result.size(), (4, 16)) -class HSTUEncoderTest(unittest.TestCase): - @parameterized.expand( - [[TestGraphType.NORMAL], [TestGraphType.FX_TRACE], [TestGraphType.JIT_SCRIPT]] - ) - def test_hstu_encoder(self, graph_type) -> None: - hstu = HSTUEncoder( - sequence_dim=16, - input="click_seq", - max_seq_length=10, - attn_dim=16, - linear_dim=16, - ) - self.assertEqual(hstu.output_dim(), 16) - hstu = create_test_module(hstu, graph_type) - embedded = { - "click_seq.sequence": torch.randn(4, 10, 16), - "click_seq.sequence_length": torch.tensor([2, 3, 4, 5]), - } - result = hstu(embedded) - if hstu.training: - self.assertEqual(result.size(), (14, 16)) - else: - self.assertEqual(result.size(), (4, 16)) - - @parameterized.expand( - [[TestGraphType.NORMAL], [TestGraphType.FX_TRACE], [TestGraphType.JIT_SCRIPT]] - ) - def test_hstu_encoder_padding(self, graph_type) -> None: - hstu = HSTUEncoder( - sequence_dim=16, - input="click_seq", - max_seq_length=10, - attn_dim=16, - linear_dim=16, - ) - self.assertEqual(hstu.output_dim(), 16) - hstu = create_test_module(hstu, graph_type) - embedded = { - "click_seq.sequence": torch.randn(4, 10, 16), - "click_seq.sequence_length": torch.tensor([2, 3, 4, 5]), - } - result = hstu(embedded) - if hstu.training: - self.assertEqual(result.size(), (14, 16)) - else: - self.assertEqual(result.size(), (4, 16)) - - class SimpleAttentionTest(unittest.TestCase): @parameterized.expand( [ diff --git a/tzrec/protos/models/match_model.proto b/tzrec/protos/models/match_model.proto index 11024318..95846c1f 100644 --- a/tzrec/protos/models/match_model.proto +++ b/tzrec/protos/models/match_model.proto @@ -21,8 +21,6 @@ message DSSM { message HSTUMatch { required HSTUMatchTower hstu_tower = 1; - // user and item tower output dimension - required int32 output_dim = 2; // similarity method optional Similarity similarity = 3 [default=INNER_PRODUCT]; // similarity scaling factor diff --git a/tzrec/protos/module.proto b/tzrec/protos/module.proto index 7a1a5982..08601141 100644 --- a/tzrec/protos/module.proto +++ b/tzrec/protos/module.proto @@ -195,12 +195,21 @@ message GRContextualInterleavePreprocessor { required GRContextualizedMLP content_mlp = 7; } +message GRSequencePreprocessor { + // action encoder config (optional - for models with action info) + optional GRActionEncoder action_encoder = 1; + // action embedding mlp config (required if action_encoder is set) + optional GRContextualizedMLP action_mlp = 2; +} + message GRInputPreprocessor { oneof input_preprocessor { // input preprocessor with contextual features GRContextualPreprocessor contextual_preprocessor = 20; // input preprocessor with interleave targets GRContextualInterleavePreprocessor contextual_interleave_preprocessor = 21; + // input preprocessor for sequence-only models (no candidate concat) + GRSequencePreprocessor sequence_preprocessor = 22; } } diff --git a/tzrec/protos/tower.proto b/tzrec/protos/tower.proto index 024cbb71..07a668a0 100644 --- a/tzrec/protos/tower.proto +++ b/tzrec/protos/tower.proto @@ -4,7 +4,6 @@ package tzrec.protos; import "tzrec/protos/module.proto"; import "tzrec/protos/loss.proto"; import "tzrec/protos/metric.proto"; -import "tzrec/protos/seq_encoder.proto"; message Tower { // input feature group name required string input = 1; @@ -13,10 +12,12 @@ message Tower { }; message HSTUMatchTower { - // input feature group name + // input feature group name (uih group) required string input = 1; - // hstu_encoder config - required HSTUEncoder hstu_encoder = 2; + // HSTU config (STU, positional_encoder, input_preprocessor, output_postprocessor) + required HSTU hstu = 2; + // max sequence length + optional uint32 max_seq_len = 3 [default = 100]; } message DINTower { From 4143b7285e5d2ec803499a232f570781f7fd89e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 1 Apr 2026 16:43:42 +0800 Subject: [PATCH 02/12] [refactor] refactor HSTUMatch to use modern STU module with jagged sequences - Replace old HSTUEncoder/SequentialTransductionUnitJagged with STUStack/STULayer - Add SequencePreprocessor for match models (UIH-only, no candidate concat) with optional contextual features, timestamps, and actions - Use JAGGED_SEQUENCE feature groups for native jagged tensor support - Separate UIH and candidate feature groups (two-tower architecture) - Update HSTUMatchTower proto to use HSTU config instead of HSTUEncoder - Add GRSequencePreprocessor to GRInputPreprocessor oneof - Fix WhichOneof field name bug in create_output_postprocessor - Fix item tower reading wrong feature group (_group_name override) - Remove unused HSTUEncoder class and tests from sequence.py Co-Authored-By: Claude Opus 4.6 (1M context) --- tzrec/models/hstu.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tzrec/models/hstu.py b/tzrec/models/hstu.py index 7efc4607..f018769a 100644 --- a/tzrec/models/hstu.py +++ b/tzrec/models/hstu.py @@ -191,6 +191,9 @@ def __init__( features: List[BaseFeature], ) -> None: super().__init__(tower_config, output_dim, similarity, feature_group, features) + # Override _group_name: parent sets it from tower_config.input ("uih"), + # but item tower needs to read from the candidate feature group. + self._group_name = feature_group.group_name cand_dim = sum(feature_group_dims) self._item_projection: torch.nn.Module = torch.nn.Sequential( torch.nn.Linear(cand_dim, output_dim), @@ -244,7 +247,7 @@ def __init__( ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) assert model_config.WhichOneof("model") == "hstu_match", ( - "invalid model config: %s" % self._base_model_config.WhichOneof("model") + "invalid model config: %s" % model_config.WhichOneof("model") ) assert isinstance(self._model_config, match_model_pb2.HSTUMatch) From 9e0f4a79f3c38a5af75fa5211a4871fdff08a3ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 1 Apr 2026 17:56:32 +0800 Subject: [PATCH 03/12] [test] refine HSTUMatch tests with train/eval/export/predict coverage - User tower always returns (B, D) last-position embeddings (not jagged) - Add unit tests: test_hstu_match_train (forward + loss + backward), test_hstu_match_eval (metrics), test_hstu_match_export (FX trace), test_hstu_match_predict (inference) - Update hstu_fg_mock.config with new proto format (HSTU, STU, JAGGED_SEQUENCE, SequencePreprocessor) - Enable integration test_hstu_with_fg_train_eval_export (was skipped) - Add predict step to integration test Co-Authored-By: Claude Opus 4.6 (1M context) --- tzrec/models/hstu.py | 8 +- tzrec/models/hstu_test.py | 266 +++++++++++++++--------- tzrec/tests/configs/hstu_fg_mock.config | 49 +++-- tzrec/tests/match_integration_test.py | 11 +- 4 files changed, 211 insertions(+), 123 deletions(-) diff --git a/tzrec/models/hstu.py b/tzrec/models/hstu.py index f018769a..5b6ea364 100644 --- a/tzrec/models/hstu.py +++ b/tzrec/models/hstu.py @@ -118,8 +118,7 @@ def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor: grouped_features: dictionary of embedded features from EmbeddingGroup. Returns: - user embeddings. Training: (sum_uih_len, D) jagged, one per position. - Inference: (B, D), last position embedding per user. + user embeddings of shape (B, D), last position embedding per user. """ # 1. Preprocess: project UIH + optional contextual/actions ( @@ -160,9 +159,8 @@ def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor: seq_timestamps=seq_timestamps, ) - if not self.training: - # Inference: extract last position embedding per user → (B, D) - user_emb = user_emb[seq_offsets[1:] - 1] + # Extract last position embedding per user → (B, D) + user_emb = user_emb[seq_offsets[1:] - 1] return user_emb diff --git a/tzrec/models/hstu_test.py b/tzrec/models/hstu_test.py index d620ec19..f5e48a80 100644 --- a/tzrec/models/hstu_test.py +++ b/tzrec/models/hstu_test.py @@ -15,13 +15,15 @@ from parameterized import parameterized from torchrec import KeyedJaggedTensor -from tzrec.datasets.utils import BASE_DATA_GROUP, Batch +from tzrec.datasets.utils import BASE_DATA_GROUP, NEG_DATA_GROUP, Batch from tzrec.features.feature import create_features from tzrec.models.hstu import HSTUMatch +from tzrec.models.model import TrainWrapper from tzrec.modules.utils import Kernel from tzrec.protos import ( feature_pb2, loss_pb2, + metric_pb2, model_pb2, module_pb2, tower_pb2, @@ -31,117 +33,185 @@ from tzrec.utils.test_util import TestGraphType, create_test_model, gpu_unavailable +def _build_model_config(): + """Build HSTUMatch model config for tests.""" + feature_groups = [ + model_pb2.FeatureGroupConfig( + group_name="uih", + feature_names=["historical_ids"], + group_type=model_pb2.FeatureGroupType.JAGGED_SEQUENCE, + ), + model_pb2.FeatureGroupConfig( + group_name="candidate", + feature_names=["candidate_ids"], + group_type=model_pb2.FeatureGroupType.JAGGED_SEQUENCE, + ), + ] + return model_pb2.ModelConfig( + feature_groups=feature_groups, + hstu_match=match_model_pb2.HSTUMatch( + hstu_tower=tower_pb2.HSTUMatchTower( + input="uih", + hstu=module_pb2.HSTU( + stu=module_pb2.STU( + embedding_dim=48, + num_heads=1, + hidden_dim=48, + attention_dim=48, + output_dropout_ratio=0.2, + ), + attn_num_layers=2, + positional_encoder=module_pb2.GRPositionalEncoder( + num_position_buckets=512, + ), + input_preprocessor=module_pb2.GRInputPreprocessor( + sequence_preprocessor=module_pb2.GRSequencePreprocessor(), + ), + output_postprocessor=module_pb2.GROutputPostprocessor( + l2norm_postprocessor=module_pb2.GRL2NormPostprocessor(), + ), + ), + max_seq_len=210, + ), + temperature=0.05, + ), + losses=[ + loss_pb2.LossConfig(softmax_cross_entropy=loss_pb2.SoftmaxCrossEntropy()) + ], + metrics=[metric_pb2.MetricConfig(recall_at_k=metric_pb2.RecallAtK(top_k=1))], + ) + + +def _build_features(): + """Build features for HSTUMatch tests.""" + feature_cfgs = [ + feature_pb2.FeatureConfig( + sequence_id_feature=feature_pb2.IdFeature( + feature_name="historical_ids", + sequence_length=210, + embedding_dim=48, + num_buckets=3953, + ) + ), + feature_pb2.FeatureConfig( + sequence_id_feature=feature_pb2.IdFeature( + feature_name="candidate_ids", + sequence_length=10, + embedding_dim=48, + num_buckets=3953, + embedding_name="historical_ids", + ) + ), + ] + return create_features(feature_cfgs) + + +def _build_model(device): + """Build HSTUMatch model on device.""" + model_config = _build_model_config() + features = _build_features() + hstu = HSTUMatch( + model_config=model_config, + features=features, + labels=["label"], + sampler_type="negative_sampler", + ) + init_parameters(hstu, device=device) + hstu.to(device) + hstu.set_kernel(Kernel.PYTORCH) + return hstu + + +def _build_batch(device): + """Build test batch with 2 users. + + UIH: user1 has 3 items, user2 has 4 items. + Candidates: 2 positive (1 per user) + 2 negative items. + """ + sparse_feature = KeyedJaggedTensor.from_lengths_sync( + keys=["historical_ids", "candidate_ids"], + values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13]), + lengths=torch.tensor([3, 4, 2, 2]), # uih: [3,4], candidate: [2,2] + ) + neg_sparse_feature = KeyedJaggedTensor.from_lengths_sync( + keys=["candidate_ids"], + values=torch.tensor([20, 21, 22, 23]), + lengths=torch.tensor([2, 2]), + ) + return Batch( + sparse_features={ + BASE_DATA_GROUP: sparse_feature, + NEG_DATA_GROUP: neg_sparse_feature, + }, + labels={"label": torch.tensor([1, 1])}, + ).to(device) + + class HSTUMatchTest(unittest.TestCase): - """Tests for the refactored HSTUMatch model with STU and jagged sequences.""" + """Tests for HSTUMatch model with STU and jagged sequences.""" @unittest.skipIf(*gpu_unavailable) - @parameterized.expand([[TestGraphType.NORMAL], [TestGraphType.FX_TRACE]]) - def test_hstu_match(self, graph_type) -> None: - """Test HSTUMatch with separate uih/candidate JAGGED_SEQUENCE groups.""" + def test_hstu_match_train(self) -> None: + """Test HSTUMatch training: forward + loss + backward.""" device = torch.device("cuda") + hstu = _build_model(device) + batch = _build_batch(device) + + train_model = TrainWrapper(hstu, device=device).to(device) + total_loss, (losses, predictions, batch) = train_model(batch) + + self.assertIn("similarity", predictions) + self.assertIn("softmax_cross_entropy", losses) + self.assertTrue(total_loss.requires_grad) + self.assertFalse(torch.isnan(total_loss)) + + @unittest.skipIf(*gpu_unavailable) + def test_hstu_match_eval(self) -> None: + """Test HSTUMatch evaluation: forward + metrics.""" + device = torch.device("cuda") + hstu = _build_model(device) + batch = _build_batch(device) + + train_model = TrainWrapper(hstu, device=device).to(device) + _, (_, predictions, batch) = train_model(batch) + + hstu.update_metric(predictions, batch) + metric_result = hstu.compute_metric() + self.assertIn("recall@1", metric_result) + + @unittest.skipIf(*gpu_unavailable) + @parameterized.expand([[TestGraphType.FX_TRACE]]) + def test_hstu_match_export(self, graph_type) -> None: + """Test HSTUMatch export: FX trace for serving.""" + device = torch.device("cuda") + hstu = _build_model(device) + batch = _build_batch(device) - feature_cfgs = [ - feature_pb2.FeatureConfig( - sequence_id_feature=feature_pb2.IdFeature( - feature_name="historical_ids", - sequence_length=210, - embedding_dim=48, - num_buckets=3953, - ) - ), - feature_pb2.FeatureConfig( - sequence_id_feature=feature_pb2.IdFeature( - feature_name="candidate_ids", - sequence_length=10, - embedding_dim=48, - num_buckets=3953, - embedding_name="historical_ids", - ) - ), - ] - features = create_features(feature_cfgs) - - feature_groups = [ - model_pb2.FeatureGroupConfig( - group_name="uih", - feature_names=["historical_ids"], - group_type=model_pb2.FeatureGroupType.JAGGED_SEQUENCE, - ), - model_pb2.FeatureGroupConfig( - group_name="candidate", - feature_names=["candidate_ids"], - group_type=model_pb2.FeatureGroupType.JAGGED_SEQUENCE, - ), - ] - - model_config = model_pb2.ModelConfig( - feature_groups=feature_groups, - hstu_match=match_model_pb2.HSTUMatch( - hstu_tower=tower_pb2.HSTUMatchTower( - input="uih", - hstu=module_pb2.HSTU( - stu=module_pb2.STU( - embedding_dim=48, - num_heads=1, - hidden_dim=48, - attention_dim=48, - output_dropout_ratio=0.2, - ), - attn_num_layers=2, - positional_encoder=module_pb2.GRPositionalEncoder( - num_position_buckets=512, - ), - input_preprocessor=module_pb2.GRInputPreprocessor( - sequence_preprocessor=(module_pb2.GRSequencePreprocessor()), - ), - output_postprocessor=module_pb2.GROutputPostprocessor( - l2norm_postprocessor=(module_pb2.GRL2NormPostprocessor()), - ), - ), - max_seq_len=210, - ), - temperature=0.05, - ), - losses=[ - loss_pb2.LossConfig( - softmax_cross_entropy=loss_pb2.SoftmaxCrossEntropy() - ) - ], - ) - - hstu = HSTUMatch( - model_config=model_config, - features=features, - labels=["label"], - ) - init_parameters(hstu, device=device) - hstu.to(device) - hstu.set_kernel(Kernel.PYTORCH) hstu.eval() hstu = create_test_model(hstu, graph_type) + predictions = hstu(batch) - # Build test batch: 2 users - # UIH: user1 has 3 history items, user2 has 4 - # Candidates: user1 has 2 candidates (1 pos + 1 neg), - # user2 has 2 candidates - sparse_feature = KeyedJaggedTensor.from_lengths_sync( - keys=["historical_ids", "candidate_ids"], - values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13]), - lengths=torch.tensor( - [3, 4, 2, 2] # uih: [3,4], candidate: [2,2] - ), - ) + self.assertIn("similarity", predictions) + sim = predictions["similarity"] + self.assertEqual(sim.dim(), 2) + self.assertEqual(sim.size(0), 2) # batch_size - batch = Batch( - sparse_features={BASE_DATA_GROUP: sparse_feature}, - labels={"label": torch.tensor([1, 1])}, - ).to(device) + @unittest.skipIf(*gpu_unavailable) + def test_hstu_match_predict(self) -> None: + """Test HSTUMatch predict: inference mode forward pass.""" + device = torch.device("cuda") + hstu = _build_model(device) + batch = _build_batch(device) + + hstu.eval() + with torch.no_grad(): + predictions = hstu.predict(batch) - predictions = hstu(batch) self.assertIn("similarity", predictions) sim = predictions["similarity"] self.assertEqual(sim.dim(), 2) self.assertEqual(sim.size(0), 2) # batch_size + self.assertFalse(torch.isnan(sim).any()) if __name__ == "__main__": diff --git a/tzrec/tests/configs/hstu_fg_mock.config b/tzrec/tests/configs/hstu_fg_mock.config index 3eb55d18..f23b82c3 100644 --- a/tzrec/tests/configs/hstu_fg_mock.config +++ b/tzrec/tests/configs/hstu_fg_mock.config @@ -31,15 +31,13 @@ data_config { dataset_type: ParquetDataset fg_mode: FG_DAG label_fields: "clk" - enable_hstu: true num_workers: 8 negative_sampler { input_path: "odps://{PROJECT}/tables/taobao_ad_feature_gl_bucketized_v1" num_sample: 128 - attr_fields: "historical_ids" - item_id_field: "historical_ids" + attr_fields: "item_id" + item_id_field: "item_id" attr_delimiter: "\t" - item_id_delim: ';' } } feature_configs { @@ -72,25 +70,40 @@ feature_configs { model_config { feature_groups { - group_name: "sequence" + group_name: "uih" feature_names: "historical_ids" - group_type: SEQUENCE + group_type: JAGGED_SEQUENCE + } + feature_groups { + group_name: "candidate" + feature_names: "item_id" + group_type: JAGGED_SEQUENCE } hstu_match { hstu_tower { - input: 'sequence' - hstu_encoder { - sequence_dim: 48 - attn_dim: 48 - linear_dim: 48 - input: "sequence" - max_seq_length: 210 - num_blocks: 2 - num_heads: 1 - linear_activation: "silu" - linear_config: "uvqk" - max_output_len: 0 + input: "uih" + hstu { + stu { + embedding_dim: 48 + num_heads: 1 + hidden_dim: 48 + attention_dim: 48 + output_dropout_ratio: 0.2 + } + attn_num_layers: 2 + positional_encoder { + num_position_buckets: 512 + } + input_preprocessor { + sequence_preprocessor { + } + } + output_postprocessor { + l2norm_postprocessor { + } + } } + max_seq_len: 210 } temperature: 0.05 } diff --git a/tzrec/tests/match_integration_test.py b/tzrec/tests/match_integration_test.py index 9333ee41..320764f0 100644 --- a/tzrec/tests/match_integration_test.py +++ b/tzrec/tests/match_integration_test.py @@ -431,14 +431,12 @@ def test_mind_train_eval_export(self): os.path.exists(os.path.join(self.test_dir, "export/item/scripted_model.pt")) ) - @unittest.skip("skip hstu match test") def test_hstu_with_fg_train_eval_export(self): self.success = utils.test_train_eval( "tzrec/tests/configs/hstu_fg_mock.config", self.test_dir, user_id="user_id", item_id="item_id", - is_hstu=True, ) if self.success: self.success = utils.test_eval( @@ -448,6 +446,15 @@ def test_hstu_with_fg_train_eval_export(self): self.success = utils.test_export( os.path.join(self.test_dir, "pipeline.config"), self.test_dir ) + if self.success: + self.success = utils.test_predict( + scripted_model_path=os.path.join(self.test_dir, "export/item"), + predict_input_path=os.path.join(self.test_dir, r"eval_data/\*.parquet"), + predict_output_path=os.path.join(self.test_dir, "predict_result"), + reserved_columns="item_id", + output_columns="item_tower_emb", + test_dir=self.test_dir, + ) self.assertTrue(self.success) self.assertTrue( os.path.exists(os.path.join(self.test_dir, "export/user/scripted_model.pt")) From ae5eb9c6ea48ed98bc3d1056a8272dfd262e2648 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Thu, 2 Apr 2026 13:11:18 +0800 Subject: [PATCH 04/12] [cleanup] candidate as item list, remove enable_hstu/is_hstu legacy code - Change candidate feature from id_feature to sequence_id_feature (list of items per user: pos + negs) in hstu_fg_mock.config - Add combine_neg_as_candidate_sequence() to auto-detect sequence fields in the negative sampling path and combine pos+neg into per-user candidate sequence strings, replacing the old enable_hstu branch - Remove enable_hstu field from data.proto - Remove enable_hstu branch and process_hstu_seq_data/process_hstu_neg_sample from dataset.py and utils.py - Remove is_hstu parameter and HSTUIdMockInput from tests/utils.py - Add @unittest.skipIf(*gpu_unavailable) to HSTU integration test - Add tests for combine_neg_as_candidate_sequence in utils_test.py Co-Authored-By: Claude Opus 4.6 (1M context) --- tzrec/datasets/dataset.py | 61 ++++------- tzrec/datasets/utils.py | 138 +++++------------------- tzrec/datasets/utils_test.py | 91 ++++------------ tzrec/protos/data.proto | 3 - tzrec/tests/configs/hstu_fg_mock.config | 5 +- tzrec/tests/match_integration_test.py | 2 + tzrec/tests/utils.py | 67 ++---------- 7 files changed, 84 insertions(+), 283 deletions(-) diff --git a/tzrec/datasets/dataset.py b/tzrec/datasets/dataset.py index dd3cd8aa..604b89aa 100644 --- a/tzrec/datasets/dataset.py +++ b/tzrec/datasets/dataset.py @@ -30,8 +30,7 @@ CKPT_SOURCE_ID, Batch, RecordBatchTensor, - process_hstu_neg_sample, - process_hstu_seq_data, + combine_neg_as_candidate_sequence, remove_nullable, ) from tzrec.features.feature import BaseFeature @@ -177,7 +176,6 @@ def __init__( self._reserved_columns = reserved_columns or [] self._mode = mode self._debug_level = debug_level - self._enable_hstu = data_config.enable_hstu self.sampler_type = ( self._data_config.WhichOneof("sampler") if self._data_config.HasField("sampler") @@ -238,6 +236,14 @@ def __init__( self._sampler = None self._sampler_inited = False + # Build mapping of field_name → sequence_delim for candidate sequence + # auto-detection during negative sampling. + self._seq_field_delims: Dict[str, str] = {} + for feature in features: + if hasattr(feature, "sequence_delim") and feature.sequence_delim: + for input_name in feature.inputs: + self._seq_field_delims[input_name] = feature.sequence_delim + self._reader = None def launch_sampler_cluster( @@ -367,49 +373,20 @@ def _build_batch(self, input_data: Dict[str, pa.Array]) -> Batch: input_data = _expand_tdm_sample( input_data, pos_sampled, neg_sampled, self._data_config ) - elif self._enable_hstu: - seq_attr = self._sampler._item_id_field - - ( - input_data_k_split, - input_data_k_split_slice, - pre_seq_filter_reshaped_joined, - ) = process_hstu_seq_data( - input_data=input_data, - seq_attr=seq_attr, - seq_str_delim=self._sampler.item_id_delim, - ) - if self._mode == Mode.TRAIN: - # Training using all possible target items - input_data[seq_attr] = input_data_k_split_slice - elif self._mode == Mode.EVAL: - # Evaluation using the last item for previous sequence - input_data[seq_attr] = input_data_k_split.values.take( - pa.array(input_data_k_split.offsets.to_numpy()[1:] - 1) - ) - sampled = self._sampler.get(input_data) - # To keep consistent with other process, use two functions - for k, v in sampled.items(): - if k in input_data: - combined = process_hstu_neg_sample( - input_data, - v, - self._sampler._num_sample, - self._sampler.item_id_delim, - seq_attr, - ) - # Combine here to make embddings of both user sequence - # and target item are the same - input_data[k] = pa.concat_arrays( - [pre_seq_filter_reshaped_joined, combined] - ) - else: - input_data[k] = v else: sampled = self._sampler.get(input_data) for k, v in sampled.items(): if k in input_data: - input_data[k] = pa.concat_arrays([input_data[k], v]) + seq_delim = self._seq_field_delims.get(k) + if seq_delim is not None: + input_data[k] = combine_neg_as_candidate_sequence( + input_data[k], + v, + self._sampler._num_sample, + seq_delim, + ) + else: + input_data[k] = pa.concat_arrays([input_data[k], v]) else: input_data[k] = v diff --git a/tzrec/datasets/utils.py b/tzrec/datasets/utils.py index 9009c762..0ad5bff5 100644 --- a/tzrec/datasets/utils.py +++ b/tzrec/datasets/utils.py @@ -503,129 +503,47 @@ def to_dict( return tensor_dict -def process_hstu_seq_data( - input_data: Dict[str, pa.Array], - seq_attr: str, - seq_str_delim: str, -) -> Tuple[pa.Array, pa.Array, pa.Array]: - """Process sequence data for HSTU match model. - - Args: - input_data: Dictionary containing input arrays - seq_attr: Name of the sequence attribute field - seq_str_delim: Delimiter used to separate sequence items - - Returns: - Tuple containing: - - input_data_k_split: pa.Array, Original sequence items - - input_data_k_split_slice: pa.Array, Target items for autoregressive training - - pre_seq_filter_reshaped_joined: pa.Array, - Training sequence for autoregressive training - """ - # default sequence data is string - if pa.types.is_string(input_data[seq_attr].type): - input_data_k_split = pc.split_pattern(input_data[seq_attr], seq_str_delim) - # Get target items for training for autoregressive training - # Example: [1,2,3,4,5] -> [2,3,4,5] - input_data_k_split_slice = pc.list_flatten( - pc.list_slice(input_data_k_split, start=1) - ) - - # Directly extract the training sequence for autoregressive training - # Operation target example: [1,2,3,4,5] -> [1,2,3,4] - # (corresponding target: [2,3,4,5]) - # As this can not be achieved by pyarrow.compute, and for loop is costly - # we need to do this using pa.ListArray.from_arrays using offsets - # 1. transfer to numpy and filter out the last item - pre_seq = pc.list_flatten(input_data_k_split).to_numpy(zero_copy_only=False) - # Mark last items of each seq with '-1' - pre_seq[input_data_k_split.offsets.to_numpy()[1:] - 1] = "-1" - # Filter out -1 marker elements - mask = pre_seq != "-1" - pre_seq_filter = pre_seq[mask] - # 2. create offsets for reshaping filtered sequence - # The offsets should be created extract the training sequence - # Example: if the original offsets are [0,2,5,9], after filter, - # for the offsets should be [0, 1, 3, 6] - # that is, [0] + [2-1, 5-2, 9-3] - pre_seq_filter_offsets = pa.array( - np.concatenate( - [ - np.array([0]), - input_data_k_split.offsets[1:].to_numpy(zero_copy_only=False) - - np.arange( - 1, - len( - input_data_k_split.offsets[1:].to_numpy( - zero_copy_only=False - ) - ) - + 1, - ), - ] - ) - ) - pre_seq_filter_reshaped = pa.ListArray.from_arrays( - pre_seq_filter_offsets, pre_seq_filter - ) - # Join filtered sequence with delimiter - pre_seq_filter_reshaped_joined = pc.binary_join( - pre_seq_filter_reshaped, seq_str_delim - ) - - return ( - input_data_k_split, - input_data_k_split_slice, - pre_seq_filter_reshaped_joined, - ) - - -def process_hstu_neg_sample( - input_data: Dict[str, pa.Array], - v: pa.Array, +def combine_neg_as_candidate_sequence( + pos_data: pa.Array, + neg_data: pa.Array, neg_sample_num: int, - seq_str_delim: str, - seq_attr: str, + seq_delim: str, ) -> pa.Array: - """Process negative samples for HSTU match model. + """Combine positive and negative items into candidate sequences. + + For each sample, joins the positive item with its negative items using the + sequence delimiter. Used when candidate features are sequence_id_features + in a JAGGED_SEQUENCE group. Args: - input_data: Dict[str, pa.Array], Dictionary containing input arrays - v: pa.Array, negative samples. - neg_sample_num: int, number of negative samples. - seq_str_delim: str, delimiter for sequence string. - seq_attr: str, attribute name of sequence. + pos_data: positive item IDs, one per sample. Shape: (B,). + neg_data: negative item IDs. Shape: (B * neg_sample_num,). + neg_sample_num: number of negative samples per positive. + seq_delim: delimiter for joining items into a sequence string. Returns: - pa.Array: Processed negative samples + pa.Array of strings, each containing "pos;neg1;neg2;..." per sample. + + Example: + pos_data = ["1", "2"] + neg_data = ["3", "4", "5", "6"] + neg_sample_num = 2 + seq_delim = ";" + result = ["1;3;4", "2;5;6"] """ - # The goal is to make neg samples concat to the training sequence - # Example: - # input_data[seq_attr] = ["1;2;3"] - # neg_sample_num = 2 - # v = [4,5,6,7,8,9] - # then the output should be [[1,4,5], [2,6,7], [3,8,9]] - v_str = v.cast(pa.string()) - filtered_v_offsets = pa.array( + neg_str = neg_data.cast(pa.string()) + neg_offsets = pa.array( np.concatenate( [ np.array([0]), - np.arange(neg_sample_num, len(v_str) + 1, neg_sample_num), + np.arange(neg_sample_num, len(neg_str) + 1, neg_sample_num), ] ) ) - # Reshape v for each input_data[seq_attr] - # Example:[4,5,6,7,8,9] -> [[4,5], [6,7], [8,9]] - filtered_v_palist = pa.ListArray.from_arrays(filtered_v_offsets, v_str) - # Using string for join, as not found operation for ListArray achieving this - # Example: [[4,5], [6,7], [8,9]] -> ["4;5", "6;7", "8;9"] - sampled_joined = pc.binary_join(filtered_v_palist, seq_str_delim) - # Combine training sequence and target items - # Example: ["1;2;3"] + ["4;5", "6;7", "8;9"] - # -> ["1;4;5", "2;6;7", "3;8;9"] - return pc.binary_join_element_wise( - input_data[seq_attr], sampled_joined, seq_str_delim - ) + neg_lists = pa.ListArray.from_arrays(neg_offsets, neg_str) + neg_joined = pc.binary_join(neg_lists, seq_delim) + pos_str = pos_data.cast(pa.string()) + return pc.binary_join_element_wise(pos_str, neg_joined, seq_delim) def calc_slice_position( diff --git a/tzrec/datasets/utils_test.py b/tzrec/datasets/utils_test.py index 46ddb8b0..91125135 100644 --- a/tzrec/datasets/utils_test.py +++ b/tzrec/datasets/utils_test.py @@ -14,16 +14,14 @@ import numpy as np import pyarrow as pa -import pyarrow.compute as pc from tzrec.datasets.utils import ( _normalize_type_str, calc_remaining_intervals, calc_slice_intervals, calc_slice_position, + combine_neg_as_candidate_sequence, get_input_fields_proto, - process_hstu_neg_sample, - process_hstu_seq_data, ) from tzrec.protos import data_pb2 from tzrec.protos.data_pb2 import FieldType @@ -206,80 +204,33 @@ def test_calc_slice_intervals_topology_change(self): self.assertEqual(total_rows, 400) # All remaining rows accounted for - def test_process_hstu_seq_data(self): - """Test processing sequence data for HSTU match model.""" - input_data = {"sequence": pa.array(["1;2;3;4", "5;6;7;8", "9;10;11;12"])} + def test_combine_neg_as_candidate_sequence(self): + """Test combining positive and negative items into candidate sequences.""" + pos_data = pa.array(["1", "2", "3"]) + neg_data = pa.array(["101", "102", "103", "104", "105", "106"]) - split, slice_result, training_seq = process_hstu_seq_data( - input_data=input_data, seq_attr="sequence", seq_str_delim=";" - ) - - # Verify results - # Test original split sequences - expected_split_values = [ - "1", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "9", - "10", - "11", - "12", - ] - self.assertEqual(pc.list_flatten(split).to_pylist(), expected_split_values) - - # Test sliced sequences (target items) - expected_slice_values = ["2", "3", "4", "6", "7", "8", "10", "11", "12"] - self.assertEqual(slice_result.to_pylist(), expected_slice_values) - - # Test training sequences - expected_training_seqs = ["1;2;3", "5;6;7", "9;10;11"] - self.assertEqual(training_seq.to_pylist(), expected_training_seqs) - - def test_process_hstu_neg_sample(self): - """Test processing negative samples for HSTU match model.""" - input_data = {"sequence": pa.array(["1", "2", "3"])} - neg_samples = pa.array(["101", "102", "103", "104", "105", "106"]) - - result = process_hstu_neg_sample( - input_data=input_data, - v=neg_samples, + result = combine_neg_as_candidate_sequence( + pos_data=pos_data, + neg_data=neg_data, neg_sample_num=2, - seq_str_delim=";", - seq_attr="sequence", + seq_delim=";", ) + expected = ["1;101;102", "2;103;104", "3;105;106"] + self.assertEqual(result.to_pylist(), expected) - expected_results = [ - "1;101;102", - "2;103;104", - "3;105;106", - ] - self.assertEqual(result.to_pylist(), expected_results) - - def test_process_hstu_neg_sample_with_different_delim(self): - """Test negative sampling with different delimiter.""" - input_data = {"sequence": pa.array(["1", "2", "3"])} + def test_combine_neg_as_candidate_sequence_different_delim(self): + """Test candidate sequence combination with different delimiter.""" + pos_data = pa.array(["1", "2"]) + neg_data = pa.array(["10", "20", "30", "40"]) - neg_samples = pa.array(["101", "102", "103", "104", "105", "106"]) - - result = process_hstu_neg_sample( - input_data=input_data, - v=neg_samples, + result = combine_neg_as_candidate_sequence( + pos_data=pos_data, + neg_data=neg_data, neg_sample_num=2, - seq_str_delim="|", - seq_attr="sequence", + seq_delim="|", ) - - expected_results = [ - "1|101|102", - "2|103|104", - "3|105|106", - ] - self.assertEqual(result.to_pylist(), expected_results) + expected = ["1|10|20", "2|30|40"] + self.assertEqual(result.to_pylist(), expected) def test_normalize_type_str_basic_types(self): """Test normalizing basic types.""" diff --git a/tzrec/protos/data.proto b/tzrec/protos/data.proto index 6012e5c2..1db03d62 100644 --- a/tzrec/protos/data.proto +++ b/tzrec/protos/data.proto @@ -132,9 +132,6 @@ message DataConfig { // fg run mode. optional FgMode fg_mode = 20 [default = FG_NONE]; - // hstu enable - optional bool enable_hstu = 21 [default = false]; - // whether to shuffle data optional bool shuffle = 22 [default = false]; diff --git a/tzrec/tests/configs/hstu_fg_mock.config b/tzrec/tests/configs/hstu_fg_mock.config index f23b82c3..a755f673 100644 --- a/tzrec/tests/configs/hstu_fg_mock.config +++ b/tzrec/tests/configs/hstu_fg_mock.config @@ -59,12 +59,13 @@ feature_configs { } } feature_configs { - id_feature { + sequence_id_feature { feature_name: "item_id" expression: "item:item_id" + sequence_length: 10 + sequence_delim: ";" num_buckets: 1000 embedding_dim: 48 - embedding_name: "item_id" } } diff --git a/tzrec/tests/match_integration_test.py b/tzrec/tests/match_integration_test.py index 320764f0..31138e2c 100644 --- a/tzrec/tests/match_integration_test.py +++ b/tzrec/tests/match_integration_test.py @@ -15,6 +15,7 @@ import unittest from tzrec.tests import utils +from tzrec.utils.test_util import gpu_unavailable class MatchIntegrationTest(unittest.TestCase): @@ -431,6 +432,7 @@ def test_mind_train_eval_export(self): os.path.exists(os.path.join(self.test_dir, "export/item/scripted_model.pt")) ) + @unittest.skipIf(*gpu_unavailable) def test_hstu_with_fg_train_eval_export(self): self.success = utils.test_train_eval( "tzrec/tests/configs/hstu_fg_mock.config", diff --git a/tzrec/tests/utils.py b/tzrec/tests/utils.py index 77022083..7b5cabd3 100644 --- a/tzrec/tests/utils.py +++ b/tzrec/tests/utils.py @@ -128,36 +128,6 @@ def create_data(self, num_rows: int, has_null: bool = True) -> pa.Array: return pa.array(data) -class HSTUIdMockInput(MockInput): - """Mock sparse id input data class.""" - - def __init__( - self, - name: str, - is_multi: bool = False, - num_ids: Optional[int] = None, - vocab_list: Optional[List[str]] = None, - multival_sep: str = chr(3), - ) -> None: - super().__init__(name) - self.is_multi = is_multi - self.num_ids = num_ids - self.vocab_list = vocab_list - self.multival_sep = multival_sep - - def create_data(self, num_rows: int, has_null: bool = True) -> pa.Array: - """Create mock data.""" - # string - # num_multi_rows = random.randint(num_rows // 3, 2 * num_rows // 3) - num_multi_id = 3 - data_multi = _create_random_id_data( - (num_rows, num_multi_id), self.num_ids, self.vocab_list - ).astype(str) - data_multi = list(map(lambda x: self.multival_sep.join(x), data_multi)) - random.shuffle(data_multi) - return pa.array(data_multi) - - class SeqIdMockInput(MockInput): """Mock sparse id sequence input data class.""" @@ -694,7 +664,6 @@ def build_mock_input_with_fg( features: List[BaseFeature], user_id: str = "", item_id: str = "", - is_hstu: bool = False, ) -> Dict[str, MockInput]: """Build mock input instance list with fg from features.""" inputs = defaultdict(dict) @@ -818,23 +787,14 @@ def build_mock_input_with_fg( if isinstance(inputs[side][sub_name], IdMockInput): inputs[side][sub_name].is_multi = False else: - if is_hstu: - # hstu require number of sequence item is over 2 - inputs[side][input_name] = HSTUIdMockInput( - input_name, - is_multi=True, - num_ids=feature.num_embeddings, - multival_sep=feature.sequence_delim, - ) - else: - inputs[side][input_name] = IdMockInput( - input_name, - is_multi=True, - num_ids=10 - if isinstance(feature, CustomFeature) - else feature.num_embeddings, - multival_sep=feature.sequence_delim, - ) + inputs[side][input_name] = IdMockInput( + input_name, + is_multi=True, + num_ids=10 + if isinstance(feature, CustomFeature) + else feature.num_embeddings, + multival_sep=feature.sequence_delim, + ) return inputs["user"], inputs["item"] @@ -844,7 +804,6 @@ def load_config_for_test( user_id: str = "", item_id: str = "", cate_id: str = "", - is_hstu: bool = False, num_rows: Optional[int] = None, ) -> EasyRecConfig: """Modify pipeline config for integration tests.""" @@ -881,9 +840,7 @@ def load_config_for_test( num_parts=num_parts, ) else: - user_inputs, item_inputs = build_mock_input_with_fg( - features, user_id, item_id, is_hstu - ) + user_inputs, item_inputs = build_mock_input_with_fg(features, user_id, item_id) _, item_t = create_mock_data( os.path.join(test_dir, "item_data"), item_inputs, @@ -979,7 +936,7 @@ def load_config_for_test( item_id, # hstu only uses item_id as negative sample, \ # as sampler_config.attr_fields is sequence - neg_fields=[item_id] if is_hstu else list(sampler_config.attr_fields), + neg_fields=list(sampler_config.attr_fields), attr_delimiter=sampler_config.attr_delimiter, num_rows=data_config.batch_size * num_parts * 4, ) @@ -992,7 +949,7 @@ def load_config_for_test( os.path.join(test_dir, "item_gl"), item_inputs, item_id, - neg_fields=[item_id] if is_hstu else list(sampler_config.attr_fields), + neg_fields=list(sampler_config.attr_fields), attr_delimiter=sampler_config.attr_delimiter, num_rows=data_config.batch_size * num_parts * 4, ) @@ -1032,7 +989,6 @@ def test_train_eval( user_id: str = "", item_id: str = "", cate_id: str = "", - is_hstu: bool = False, env_str: str = "", num_rows: Optional[int] = None, ) -> bool: @@ -1043,7 +999,6 @@ def test_train_eval( user_id, item_id, cate_id, - is_hstu, num_rows=num_rows, ) From 24c8d5f2e46c99f533de43604608536dd7316500 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Fri, 3 Apr 2026 10:40:40 +0800 Subject: [PATCH 05/12] [bugfix] fix CI test failures for HSTUMatch - Fix CPU CI: remove @parameterized.expand from test_hstu_match_export (single-case test) to fix skipIf+parameterized decorator interaction that prevented GPU unavailable skip - Fix GPU CI: change candidate group from JAGGED_SEQUENCE to DEEP with id_feature to fix type mismatch (string vs int64) in mock data join during integration test. Negative sampling with standard row-append works correctly with DEEP candidate group. - Update HSTUMatchItemTower to read from DEEP group key (not .sequence) - Update _build_batch to use NEG_DATA_GROUP for candidate items Co-Authored-By: Claude Opus 4.6 (1M context) --- tzrec/models/hstu.py | 6 ++-- tzrec/models/hstu_test.py | 40 ++++++++++++------------- tzrec/models/match_model.py | 12 ++++---- tzrec/tests/configs/hstu_fg_mock.config | 6 ++-- 4 files changed, 32 insertions(+), 32 deletions(-) diff --git a/tzrec/models/hstu.py b/tzrec/models/hstu.py index 5b6ea364..c425a974 100644 --- a/tzrec/models/hstu.py +++ b/tzrec/models/hstu.py @@ -75,6 +75,7 @@ def __init__( contextual_group_name: str = "contextual", ) -> None: super().__init__(tower_config, output_dim, similarity, feature_group, features) + self._pass_grouped_features = True hstu_cfg = tower_config.hstu uih_dim = sum(feature_group_dims) stu_dim = hstu_cfg.stu.embedding_dim @@ -192,6 +193,7 @@ def __init__( # Override _group_name: parent sets it from tower_config.input ("uih"), # but item tower needs to read from the candidate feature group. self._group_name = feature_group.group_name + self._pass_grouped_features = True cand_dim = sum(feature_group_dims) self._item_projection: torch.nn.Module = torch.nn.Sequential( torch.nn.Linear(cand_dim, output_dim), @@ -207,7 +209,7 @@ def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor: Returns: L2-normalized item embeddings of shape (sum_candidates, D). """ - cand_emb = grouped_features[f"{self._group_name}.sequence"] + cand_emb = grouped_features[self._group_name] item_emb = self._item_projection(cand_emb) return F.normalize(item_emb, p=2.0, dim=-1, eps=1e-6) @@ -268,7 +270,7 @@ def __init__( cand_features = self.get_features_in_feature_groups([cand_fg]) uih_dims = self.embedding_group.group_dims(tower_cfg.input + ".sequence") - cand_dims = self.embedding_group.group_dims("candidate.sequence") + cand_dims = self.embedding_group.group_dims("candidate") # Optional contextual features contextual_feature_dim = 0 diff --git a/tzrec/models/hstu_test.py b/tzrec/models/hstu_test.py index f5e48a80..840801de 100644 --- a/tzrec/models/hstu_test.py +++ b/tzrec/models/hstu_test.py @@ -12,7 +12,6 @@ import unittest import torch -from parameterized import parameterized from torchrec import KeyedJaggedTensor from tzrec.datasets.utils import BASE_DATA_GROUP, NEG_DATA_GROUP, Batch @@ -43,8 +42,8 @@ def _build_model_config(): ), model_pb2.FeatureGroupConfig( group_name="candidate", - feature_names=["candidate_ids"], - group_type=model_pb2.FeatureGroupType.JAGGED_SEQUENCE, + feature_names=["item_id"], + group_type=model_pb2.FeatureGroupType.DEEP, ), ] return model_pb2.ModelConfig( @@ -94,16 +93,14 @@ def _build_features(): ) ), feature_pb2.FeatureConfig( - sequence_id_feature=feature_pb2.IdFeature( - feature_name="candidate_ids", - sequence_length=10, + id_feature=feature_pb2.IdFeature( + feature_name="item_id", embedding_dim=48, - num_buckets=3953, - embedding_name="historical_ids", + num_buckets=1000, ) ), ] - return create_features(feature_cfgs) + return create_features(feature_cfgs, neg_fields=["item_id"]) def _build_model(device): @@ -126,17 +123,19 @@ def _build_batch(device): """Build test batch with 2 users. UIH: user1 has 3 items, user2 has 4 items. - Candidates: 2 positive (1 per user) + 2 negative items. + Candidates: 2 pos (1 per user) + 2 neg items. """ + # BASE: UIH sequences + positive items sparse_feature = KeyedJaggedTensor.from_lengths_sync( - keys=["historical_ids", "candidate_ids"], - values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13]), - lengths=torch.tensor([3, 4, 2, 2]), # uih: [3,4], candidate: [2,2] + keys=["historical_ids"], + values=torch.tensor([1, 2, 3, 4, 5, 6, 7]), + lengths=torch.tensor([3, 4]), ) + # NEG: positive items (first batch_size) + negative items neg_sparse_feature = KeyedJaggedTensor.from_lengths_sync( - keys=["candidate_ids"], - values=torch.tensor([20, 21, 22, 23]), - lengths=torch.tensor([2, 2]), + keys=["item_id"], + values=torch.tensor([10, 11, 20, 21]), + lengths=torch.tensor([1, 1, 1, 1]), # 2 pos + 2 neg, each 1 item ) return Batch( sparse_features={ @@ -180,21 +179,20 @@ def test_hstu_match_eval(self) -> None: self.assertIn("recall@1", metric_result) @unittest.skipIf(*gpu_unavailable) - @parameterized.expand([[TestGraphType.FX_TRACE]]) - def test_hstu_match_export(self, graph_type) -> None: + def test_hstu_match_export(self) -> None: """Test HSTUMatch export: FX trace for serving.""" device = torch.device("cuda") hstu = _build_model(device) batch = _build_batch(device) hstu.eval() - hstu = create_test_model(hstu, graph_type) + hstu = create_test_model(hstu, TestGraphType.FX_TRACE) predictions = hstu(batch) self.assertIn("similarity", predictions) sim = predictions["similarity"] self.assertEqual(sim.dim(), 2) - self.assertEqual(sim.size(0), 2) # batch_size + self.assertEqual(sim.size(0), 2) @unittest.skipIf(*gpu_unavailable) def test_hstu_match_predict(self) -> None: @@ -210,7 +208,7 @@ def test_hstu_match_predict(self) -> None: self.assertIn("similarity", predictions) sim = predictions["similarity"] self.assertEqual(sim.dim(), 2) - self.assertEqual(sim.size(0), 2) # batch_size + self.assertEqual(sim.size(0), 2) self.assertFalse(torch.isnan(sim).any()) diff --git a/tzrec/models/match_model.py b/tzrec/models/match_model.py index 22b86dc9..73130fef 100644 --- a/tzrec/models/match_model.py +++ b/tzrec/models/match_model.py @@ -220,6 +220,7 @@ def __init__( self._similarity = similarity self._feature_group = feature_group self._features = features + self._pass_grouped_features = False class MatchModel(BaseModel): @@ -492,8 +493,9 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: embedding (dict): tower output embedding. """ grouped_features = self.embedding_group(batch) - return { - f"{self._tower_name}_emb": getattr(self, self._tower_name)( - grouped_features[self._group_name] - ) - } + tower = getattr(self, self._tower_name) + if tower._pass_grouped_features: + tower_input = grouped_features + else: + tower_input = grouped_features[self._group_name] + return {f"{self._tower_name}_emb": tower(tower_input)} diff --git a/tzrec/tests/configs/hstu_fg_mock.config b/tzrec/tests/configs/hstu_fg_mock.config index a755f673..aa09d921 100644 --- a/tzrec/tests/configs/hstu_fg_mock.config +++ b/tzrec/tests/configs/hstu_fg_mock.config @@ -59,11 +59,9 @@ feature_configs { } } feature_configs { - sequence_id_feature { + id_feature { feature_name: "item_id" expression: "item:item_id" - sequence_length: 10 - sequence_delim: ";" num_buckets: 1000 embedding_dim: 48 } @@ -78,7 +76,7 @@ model_config { feature_groups { group_name: "candidate" feature_names: "item_id" - group_type: JAGGED_SEQUENCE + group_type: DEEP } hstu_match { hstu_tower { From bed429bbaddb8ab65c392fac547012012735b81c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Fri, 3 Apr 2026 11:32:34 +0800 Subject: [PATCH 06/12] [refactor] rename SequencePreprocessor to UIHPreprocessor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rename to better reflect the preprocessor's purpose: it processes UIH (User Interaction History) sequences only, without candidate concat. This parallels the codebase convention where "uih" is the standard group name for user history features. - SequencePreprocessor → UIHPreprocessor - GRSequencePreprocessor → GRUIHPreprocessor - sequence_preprocessor → uih_preprocessor (proto oneof field) Co-Authored-By: Claude Opus 4.6 (1M context) --- tzrec/models/hstu.py | 4 ++-- tzrec/models/hstu_test.py | 2 +- tzrec/modules/gr/preprocessors.py | 6 +++--- tzrec/protos/module.proto | 4 ++-- tzrec/tests/configs/hstu_fg_mock.config | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tzrec/models/hstu.py b/tzrec/models/hstu.py index c425a974..64de54b7 100644 --- a/tzrec/models/hstu.py +++ b/tzrec/models/hstu.py @@ -43,7 +43,7 @@ class HSTUMatchUserTower(MatchTowerWoEG): """HSTU Match model user tower using modern STU module. - Processes UIH (User Interaction History) sequences through SequencePreprocessor, + Processes UIH (User Interaction History) sequences through UIHPreprocessor, HSTUPositionalEncoder, and STUStack to produce user embeddings. During training, returns one embedding per UIH position (autoregressive). During inference, returns the last position embedding per user for ANN retrieval. @@ -218,7 +218,7 @@ class HSTUMatch(MatchModel): """HSTU Match model for two-tower retrieval. Uses modern STUStack for user sequence encoding with native jagged sequences. - User tower processes UIH through SequencePreprocessor + STU. Item tower + User tower processes UIH through UIHPreprocessor + STU. Item tower projects and normalizes candidate embeddings. Similarity via dot product. Feature groups: diff --git a/tzrec/models/hstu_test.py b/tzrec/models/hstu_test.py index 840801de..6a88cf01 100644 --- a/tzrec/models/hstu_test.py +++ b/tzrec/models/hstu_test.py @@ -64,7 +64,7 @@ def _build_model_config(): num_position_buckets=512, ), input_preprocessor=module_pb2.GRInputPreprocessor( - sequence_preprocessor=module_pb2.GRSequencePreprocessor(), + uih_preprocessor=module_pb2.GRUIHPreprocessor(), ), output_postprocessor=module_pb2.GROutputPostprocessor( l2norm_postprocessor=module_pb2.GRL2NormPostprocessor(), diff --git a/tzrec/modules/gr/preprocessors.py b/tzrec/modules/gr/preprocessors.py index aaa2928e..d4992489 100644 --- a/tzrec/modules/gr/preprocessors.py +++ b/tzrec/modules/gr/preprocessors.py @@ -486,7 +486,7 @@ def contextual_seq_len(self) -> int: return self._max_contextual_seq_len -class SequencePreprocessor(InputPreprocessor): +class UIHPreprocessor(InputPreprocessor): """Preprocessor for sequence-only models without candidate concatenation. Processes UIH (User Interaction History) sequences with optional contextual @@ -729,7 +729,7 @@ def create_input_preprocessor( ) elif preprocessor_type == "contextual_interleave_preprocessor": return ContextualInterleavePreprocessor(**config_dict) - elif preprocessor_type == "sequence_preprocessor": - return SequencePreprocessor(**config_dict) + elif preprocessor_type == "uih_preprocessor": + return UIHPreprocessor(**config_dict) else: raise RuntimeError(f"Unknown preprocessor type: {preprocessor_type}") diff --git a/tzrec/protos/module.proto b/tzrec/protos/module.proto index 08601141..6a3ca89e 100644 --- a/tzrec/protos/module.proto +++ b/tzrec/protos/module.proto @@ -195,7 +195,7 @@ message GRContextualInterleavePreprocessor { required GRContextualizedMLP content_mlp = 7; } -message GRSequencePreprocessor { +message GRUIHPreprocessor { // action encoder config (optional - for models with action info) optional GRActionEncoder action_encoder = 1; // action embedding mlp config (required if action_encoder is set) @@ -209,7 +209,7 @@ message GRInputPreprocessor { // input preprocessor with interleave targets GRContextualInterleavePreprocessor contextual_interleave_preprocessor = 21; // input preprocessor for sequence-only models (no candidate concat) - GRSequencePreprocessor sequence_preprocessor = 22; + GRUIHPreprocessor uih_preprocessor = 22; } } diff --git a/tzrec/tests/configs/hstu_fg_mock.config b/tzrec/tests/configs/hstu_fg_mock.config index aa09d921..4a60ceb3 100644 --- a/tzrec/tests/configs/hstu_fg_mock.config +++ b/tzrec/tests/configs/hstu_fg_mock.config @@ -94,7 +94,7 @@ model_config { num_position_buckets: 512 } input_preprocessor { - sequence_preprocessor { + uih_preprocessor { } } output_postprocessor { From 2c8b93b7f1cae0c326f592fb83424afd4cc02166 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Fri, 3 Apr 2026 11:33:46 +0800 Subject: [PATCH 07/12] [fix] only normalize item embeddings for COSINE similarity HSTUMatchItemTower should only L2-normalize when similarity method is COSINE, matching the pattern in DSSMTower. INNER_PRODUCT similarity does not require normalization. Co-Authored-By: Claude Opus 4.6 (1M context) --- tzrec/models/hstu.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tzrec/models/hstu.py b/tzrec/models/hstu.py index 64de54b7..7064e407 100644 --- a/tzrec/models/hstu.py +++ b/tzrec/models/hstu.py @@ -207,11 +207,13 @@ def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor: grouped_features: dictionary of embedded features from EmbeddingGroup. Returns: - L2-normalized item embeddings of shape (sum_candidates, D). + item embeddings of shape (sum_candidates, D). """ cand_emb = grouped_features[self._group_name] item_emb = self._item_projection(cand_emb) - return F.normalize(item_emb, p=2.0, dim=-1, eps=1e-6) + if self._similarity == simi_pb2.Similarity.COSINE: + item_emb = F.normalize(item_emb, p=2.0, dim=-1, eps=1e-6) + return item_emb class HSTUMatch(MatchModel): From e7fb0c5614edc95fe908a0e68c56be4ba66b1558 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Fri, 3 Apr 2026 13:49:45 +0800 Subject: [PATCH 08/12] [refactor] address code review: fix abstract signature, add comments - Fix InputPreprocessor abstract forward() signature to match actual implementations (grouped_features dict, not individual params) - Add comment documenting non-empty sequence assumption at user_emb extraction - Add comment explaining target_offsets placeholder in UIHPreprocessor - Update HSTUMatchItemTower docstring for conditional normalization Co-Authored-By: Claude Opus 4.6 (1M context) --- tzrec/models/hstu.py | 4 +++- tzrec/modules/gr/preprocessors.py | 24 +++++------------------- 2 files changed, 8 insertions(+), 20 deletions(-) diff --git a/tzrec/models/hstu.py b/tzrec/models/hstu.py index 7064e407..93649713 100644 --- a/tzrec/models/hstu.py +++ b/tzrec/models/hstu.py @@ -161,6 +161,7 @@ def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor: ) # Extract last position embedding per user → (B, D) + # Assumes all sequences are non-empty (guaranteed by EmbeddingGroup). user_emb = user_emb[seq_offsets[1:] - 1] return user_emb @@ -169,7 +170,8 @@ def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor: class HSTUMatchItemTower(MatchTowerWoEG): """HSTU Match model item tower. - Projects candidate embeddings to STU embedding dimension and L2 normalizes. + Projects candidate embeddings to STU embedding dimension. Applies L2 + normalization only when similarity method is COSINE. Args: tower_config (HSTUMatchTower): tower config. diff --git a/tzrec/modules/gr/preprocessors.py b/tzrec/modules/gr/preprocessors.py index d4992489..caffdd58 100644 --- a/tzrec/modules/gr/preprocessors.py +++ b/tzrec/modules/gr/preprocessors.py @@ -53,15 +53,7 @@ class InputPreprocessor(BaseModule): @abc.abstractmethod def forward( self, - max_uih_len: int, - max_targets: int, - total_uih_len: int, - total_targets: int, - seq_lengths: torch.Tensor, - seq_timestamps: torch.Tensor, - seq_embeddings: torch.Tensor, - num_targets: torch.Tensor, - seq_payloads: Dict[str, torch.Tensor], + grouped_features: Dict[str, torch.Tensor], ) -> Tuple[ int, int, @@ -75,22 +67,14 @@ def forward( """Forward the module. Args: - max_uih_len (int): maximum user history sequence length. - max_targets (int): maximum candidates length. - total_uih_len (int): total user history sequence length. - total_targets (int): total candidates length. - seq_lengths (torch.Tensor): input sequence lengths. - seq_timestamps (torch.Tensor): input sequence timestamp tensor. - seq_embeddings (torch.Tensor): input sequence embedding tensor. - num_targets (torch.Tensor): number of targets. - seq_payloads (Dict[str, torch.Tensor]): sequence payload features. + grouped_features (Dict[str, torch.Tensor]): embedding group features. Returns: output_max_seq_len (int): output maximum sequence length. output_total_uih_len (int): output total user history sequence length. output_total_targets (int): output total candidates length. output_seq_lengths (torch.Tensor): output sequence lengths. - output_seq_offsets (torch.Tensor): output sequence lengths. + output_seq_offsets (torch.Tensor): output sequence offsets. output_seq_timestamps (torch.Tensor): output sequence timestamp tensor. output_seq_embeddings (torch.Tensor): output sequence embedding tensor. output_num_targets (torch.Tensor): output number of targets. @@ -627,6 +611,8 @@ def forward( # Optional: action embeddings if self._action_encoder_cfg is not None: + # target_offsets is unused when total_targets=0 (no candidates in + # UIH-only mode), so we pass uih_offsets as a placeholder. action_embeddings = self._action_encoder( seq_actions=grouped_features["uih_action.sequence"].to(torch.int64), max_uih_len=max_uih_len, From 377c6bda8cc76a9a0f927493ded6262d6fa9a389 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Tue, 7 Apr 2026 10:28:55 +0800 Subject: [PATCH 09/12] [refactor] candidate as JAGGED_SEQUENCE with sequence_id_feature Use ONE feature item_id (sequence_id_feature) for both negative sampler and candidate group (JAGGED_SEQUENCE). The sampler reads single-value item_id from raw data; combine_neg_as_candidate_sequence creates the candidate sequence "pos;neg1;neg2" at runtime. - hstu_fg_mock.config: item_id as sequence_id_feature, candidate group as JAGGED_SEQUENCE - HSTUMatchItemTower reads from "candidate.sequence" (jagged) - HSTUMatch uses _jagged_candidate_sim() to reshape (sum_cand, D) into (B, 1+num_neg, D) and compute per-user similarity - combine_neg_as_candidate_sequence: compute neg_per_pos from data length instead of using sampler._num_sample directly - IdMockInput: add as_string parameter for sequence features used as sampler item_id (single value but as string) - build_mock_input_with_fg: accept neg_fields, generate single-value string mock data for sequence features in neg_fields - create_mock_data: cast unique_id field to string when as_string=True for type-consistent joins Co-Authored-By: Claude Opus 4.6 (1M context) --- tzrec/datasets/dataset.py | 4 +- tzrec/models/hstu.py | 28 +++++++++-- tzrec/models/hstu_test.py | 31 +++++------- tzrec/tests/configs/hstu_fg_mock.config | 6 ++- tzrec/tests/utils.py | 63 ++++++++++++++++++++----- 5 files changed, 95 insertions(+), 37 deletions(-) diff --git a/tzrec/datasets/dataset.py b/tzrec/datasets/dataset.py index 604b89aa..eb001be9 100644 --- a/tzrec/datasets/dataset.py +++ b/tzrec/datasets/dataset.py @@ -379,10 +379,12 @@ def _build_batch(self, input_data: Dict[str, pa.Array]) -> Batch: if k in input_data: seq_delim = self._seq_field_delims.get(k) if seq_delim is not None: + # neg_per_pos = total_negatives // batch_size + neg_per_pos = len(v) // len(input_data[k]) input_data[k] = combine_neg_as_candidate_sequence( input_data[k], v, - self._sampler._num_sample, + neg_per_pos, seq_delim, ) else: diff --git a/tzrec/models/hstu.py b/tzrec/models/hstu.py index 93649713..1c4933ab 100644 --- a/tzrec/models/hstu.py +++ b/tzrec/models/hstu.py @@ -40,6 +40,28 @@ torch.fx.wrap(fx_int_item) +@torch.fx.wrap +def _jagged_candidate_sim( + user_emb: torch.Tensor, item_emb: torch.Tensor +) -> torch.Tensor: + """Compute per-user similarity for JAGGED_SEQUENCE candidates. + + Each user has the same number of candidates (1 pos + num_neg). The item + embeddings are organized as: [pos_1, neg_1_1, ..., neg_1_k, pos_2, ...]. + + Args: + user_emb: (B, D) user embeddings. + item_emb: (B * (1 + num_neg), D) candidate embeddings. + + Returns: + similarity (B, 1 + num_neg), first column is positive. + """ + batch_size = user_emb.size(0) + num_cand = item_emb.size(0) // batch_size + item_emb = item_emb.view(batch_size, num_cand, -1) + return torch.bmm(item_emb, user_emb.unsqueeze(-1)).squeeze(-1) + + class HSTUMatchUserTower(MatchTowerWoEG): """HSTU Match model user tower using modern STU module. @@ -211,7 +233,7 @@ def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor: Returns: item embeddings of shape (sum_candidates, D). """ - cand_emb = grouped_features[self._group_name] + cand_emb = grouped_features[f"{self._group_name}.sequence"] item_emb = self._item_projection(cand_emb) if self._similarity == simi_pb2.Similarity.COSINE: item_emb = F.normalize(item_emb, p=2.0, dim=-1, eps=1e-6) @@ -274,7 +296,7 @@ def __init__( cand_features = self.get_features_in_feature_groups([cand_fg]) uih_dims = self.embedding_group.group_dims(tower_cfg.input + ".sequence") - cand_dims = self.embedding_group.group_dims("candidate") + cand_dims = self.embedding_group.group_dims("candidate.sequence") # Optional contextual features contextual_feature_dim = 0 @@ -333,5 +355,5 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: user_emb = self.user_tower(grouped_features) item_emb = self.item_tower(grouped_features) - ui_sim = self.sim(user_emb, item_emb) / self._temperature + ui_sim = _jagged_candidate_sim(user_emb, item_emb) / self._temperature return {"similarity": ui_sim} diff --git a/tzrec/models/hstu_test.py b/tzrec/models/hstu_test.py index 6a88cf01..ef6de5a5 100644 --- a/tzrec/models/hstu_test.py +++ b/tzrec/models/hstu_test.py @@ -14,7 +14,7 @@ import torch from torchrec import KeyedJaggedTensor -from tzrec.datasets.utils import BASE_DATA_GROUP, NEG_DATA_GROUP, Batch +from tzrec.datasets.utils import BASE_DATA_GROUP, Batch from tzrec.features.feature import create_features from tzrec.models.hstu import HSTUMatch from tzrec.models.model import TrainWrapper @@ -43,7 +43,7 @@ def _build_model_config(): model_pb2.FeatureGroupConfig( group_name="candidate", feature_names=["item_id"], - group_type=model_pb2.FeatureGroupType.DEEP, + group_type=model_pb2.FeatureGroupType.JAGGED_SEQUENCE, ), ] return model_pb2.ModelConfig( @@ -93,14 +93,16 @@ def _build_features(): ) ), feature_pb2.FeatureConfig( - id_feature=feature_pb2.IdFeature( + sequence_id_feature=feature_pb2.IdFeature( feature_name="item_id", + sequence_length=10, + sequence_delim=";", embedding_dim=48, num_buckets=1000, ) ), ] - return create_features(feature_cfgs, neg_fields=["item_id"]) + return create_features(feature_cfgs) def _build_model(device): @@ -123,25 +125,16 @@ def _build_batch(device): """Build test batch with 2 users. UIH: user1 has 3 items, user2 has 4 items. - Candidates: 2 pos (1 per user) + 2 neg items. + Candidates: each user has a sequence of [pos, neg] (after combine_neg). """ - # BASE: UIH sequences + positive items + # BASE: UIH sequences + per-user candidate sequences (pos+neg) sparse_feature = KeyedJaggedTensor.from_lengths_sync( - keys=["historical_ids"], - values=torch.tensor([1, 2, 3, 4, 5, 6, 7]), - lengths=torch.tensor([3, 4]), - ) - # NEG: positive items (first batch_size) + negative items - neg_sparse_feature = KeyedJaggedTensor.from_lengths_sync( - keys=["item_id"], - values=torch.tensor([10, 11, 20, 21]), - lengths=torch.tensor([1, 1, 1, 1]), # 2 pos + 2 neg, each 1 item + keys=["historical_ids", "item_id"], + values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 100, 200, 101, 201]), + lengths=torch.tensor([3, 4, 2, 2]), ) return Batch( - sparse_features={ - BASE_DATA_GROUP: sparse_feature, - NEG_DATA_GROUP: neg_sparse_feature, - }, + sparse_features={BASE_DATA_GROUP: sparse_feature}, labels={"label": torch.tensor([1, 1])}, ).to(device) diff --git a/tzrec/tests/configs/hstu_fg_mock.config b/tzrec/tests/configs/hstu_fg_mock.config index 4a60ceb3..41e52a30 100644 --- a/tzrec/tests/configs/hstu_fg_mock.config +++ b/tzrec/tests/configs/hstu_fg_mock.config @@ -59,9 +59,11 @@ feature_configs { } } feature_configs { - id_feature { + sequence_id_feature { feature_name: "item_id" expression: "item:item_id" + sequence_length: 10 + sequence_delim: ";" num_buckets: 1000 embedding_dim: 48 } @@ -76,7 +78,7 @@ model_config { feature_groups { group_name: "candidate" feature_names: "item_id" - group_type: DEEP + group_type: JAGGED_SEQUENCE } hstu_match { hstu_tower { diff --git a/tzrec/tests/utils.py b/tzrec/tests/utils.py index 7b5cabd3..d31d8be8 100644 --- a/tzrec/tests/utils.py +++ b/tzrec/tests/utils.py @@ -14,7 +14,7 @@ import os import random from collections import OrderedDict, defaultdict -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Set, Tuple import numpy as np import numpy.typing as npt @@ -82,17 +82,19 @@ def __init__( num_ids: Optional[int] = None, vocab_list: Optional[List[str]] = None, multival_sep: str = chr(3), + as_string: bool = False, ) -> None: super().__init__(name) self.is_multi = is_multi self.num_ids = num_ids self.vocab_list = vocab_list self.multival_sep = multival_sep + self.as_string = as_string def create_data(self, num_rows: int, has_null: bool = True) -> pa.Array: """Create mock data.""" if not self.is_multi: - # int64 + # int64 (or string if as_string=True) num_valid_rows = ( random.randint(num_rows // 2, num_rows) if has_null else num_rows ) @@ -101,6 +103,8 @@ def create_data(self, num_rows: int, has_null: bool = True) -> pa.Array: ) data = data + [None] * (num_rows - num_valid_rows) random.shuffle(data) + if self.as_string: + data = [str(x) if x is not None else None for x in data] else: # string num_multi_rows = random.randint(num_rows // 3, 2 * num_rows // 3) @@ -483,7 +487,13 @@ def create_mock_data( input_data = {} for inp in inputs.values(): if inp.name == unique_id: - input_data[inp.name] = pa.array(list(range(num_rows))) + ids = pa.array(list(range(num_rows))) + # If the input is configured for string output (e.g., a + # sequence_id_feature used as sampler item_id), cast to string + # so joins on this key remain type-consistent. + if isinstance(inp, IdMockInput) and getattr(inp, "as_string", False): + ids = ids.cast(pa.string()) + input_data[inp.name] = ids elif isinstance(inp, SeqMockInput): input_data.update(inp.create_sequence_data(num_rows, join_t)) else: @@ -664,6 +674,7 @@ def build_mock_input_with_fg( features: List[BaseFeature], user_id: str = "", item_id: str = "", + neg_fields: Optional[Set[str]] = None, ) -> Dict[str, MockInput]: """Build mock input instance list with fg from features.""" inputs = defaultdict(dict) @@ -787,14 +798,31 @@ def build_mock_input_with_fg( if isinstance(inputs[side][sub_name], IdMockInput): inputs[side][sub_name].is_multi = False else: - inputs[side][input_name] = IdMockInput( - input_name, - is_multi=True, - num_ids=10 - if isinstance(feature, CustomFeature) - else feature.num_embeddings, - multival_sep=feature.sequence_delim, - ) + if neg_fields and input_name in neg_fields: + # Sampler-targeted sequence features must have + # single-value mock data because the sampler casts + # the field to int64. The candidate sequence is + # created at runtime by combine_neg_as_candidate_sequence. + # Generate as string so the sequence_id_feature parser + # can read it (with delimiter; single value parses as + # one-element sequence). + inputs[side][input_name] = IdMockInput( + input_name, + is_multi=False, + num_ids=10 + if isinstance(feature, CustomFeature) + else feature.num_embeddings, + as_string=True, + ) + else: + inputs[side][input_name] = IdMockInput( + input_name, + is_multi=True, + num_ids=10 + if isinstance(feature, CustomFeature) + else feature.num_embeddings, + multival_sep=feature.sequence_delim, + ) return inputs["user"], inputs["item"] @@ -840,7 +868,18 @@ def load_config_for_test( num_parts=num_parts, ) else: - user_inputs, item_inputs = build_mock_input_with_fg(features, user_id, item_id) + # Determine sampler attr_fields so mock generator can produce + # single-value data for sequence features that the sampler will + # use as item_id (sampler casts field values to int64). + sampler_neg_fields: Set[str] = set() + if data_config.HasField("sampler"): + sampler_type_name = data_config.WhichOneof("sampler") + sampler_cfg = getattr(data_config, sampler_type_name) + if hasattr(sampler_cfg, "attr_fields"): + sampler_neg_fields = set(sampler_cfg.attr_fields) + user_inputs, item_inputs = build_mock_input_with_fg( + features, user_id, item_id, neg_fields=sampler_neg_fields + ) _, item_t = create_mock_data( os.path.join(test_dir, "item_data"), item_inputs, From e09595cd3acd6e020a10e7ae2a2339f8f13ff8c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Tue, 7 Apr 2026 14:21:18 +0800 Subject: [PATCH 10/12] [fix] handle multi-value sequence in sampler item_id_field The negative sampler casts input_data[item_id_field] to int64 via _pa_ids_to_npy and pads to batch_size. It cannot process multi-value sequence strings like "1;2;3". Fix: in _build_batch, if item_id_field is a sequence field and the raw data is a string, extract the first item from each (possibly multi-value) row before calling the sampler. The original (possibly multi-value) data is preserved and used by combine_neg_as_candidate_sequence to build the final candidate sequence "pos_seq;neg1;neg2". - Save sampler item_id_field on dataset init for fast lookup - Pre-process input_data[item_id_field]: split by delim, take first item - Restore original after sampler.get() so combine sees full sequence - Add unit test for multi-value pos in combine_neg_as_candidate_sequence Co-Authored-By: Claude Opus 4.6 (1M context) --- tzrec/datasets/dataset.py | 27 +++++++++++++++++++++++++++ tzrec/datasets/utils_test.py | 14 ++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/tzrec/datasets/dataset.py b/tzrec/datasets/dataset.py index eb001be9..4a2fb63d 100644 --- a/tzrec/datasets/dataset.py +++ b/tzrec/datasets/dataset.py @@ -206,6 +206,7 @@ def __init__( self._selected_input_names |= set(data_config.sample_weight_fields) if data_config.HasField("sample_cost_field"): self._selected_input_names.add(data_config.sample_cost_field) + self._sampler_item_id_field: Optional[str] = None if self._data_config.HasField("sampler") and self._mode != Mode.PREDICT: sampler_type = self._data_config.WhichOneof("sampler") sampler_config = getattr(self._data_config, sampler_type) @@ -213,6 +214,7 @@ def __init__( "item_id_field" ): self._selected_input_names.add(sampler_config.item_id_field) + self._sampler_item_id_field = sampler_config.item_id_field if hasattr(sampler_config, "user_id_field") and sampler_config.HasField( "user_id_field" ): @@ -374,7 +376,32 @@ def _build_batch(self, input_data: Dict[str, pa.Array]) -> Batch: input_data, pos_sampled, neg_sampled, self._data_config ) else: + # If item_id_field is a sequence feature, the sampler can't + # process multi-value strings (it casts the whole array to + # int64). Extract a single representative value (the first + # item) for the sampler. The original (possibly multi-value) + # data is preserved for combining with sampled negatives. + saved_pos: Dict[str, pa.Array] = {} + if ( + self._sampler_item_id_field is not None + and self._sampler_item_id_field in self._seq_field_delims + ): + seq_delim = self._seq_field_delims[self._sampler_item_id_field] + raw = input_data[self._sampler_item_id_field] + if pa.types.is_string(raw.type): + saved_pos[self._sampler_item_id_field] = raw + # Take first item from each (possibly multi-value) row + split = pc.split_pattern(raw, seq_delim) + input_data[self._sampler_item_id_field] = pc.list_element( + split, 0 + ) + sampled = self._sampler.get(input_data) + + # Restore original (possibly multi-value) data for combine + for k, original in saved_pos.items(): + input_data[k] = original + for k, v in sampled.items(): if k in input_data: seq_delim = self._seq_field_delims.get(k) diff --git a/tzrec/datasets/utils_test.py b/tzrec/datasets/utils_test.py index 91125135..f552bbe4 100644 --- a/tzrec/datasets/utils_test.py +++ b/tzrec/datasets/utils_test.py @@ -232,6 +232,20 @@ def test_combine_neg_as_candidate_sequence_different_delim(self): expected = ["1|10|20", "2|30|40"] self.assertEqual(result.to_pylist(), expected) + def test_combine_neg_as_candidate_sequence_multivalue_pos(self): + """Test combine when pos_data already has multi-value sequences.""" + pos_data = pa.array(["1;2", "3;4;5"]) + neg_data = pa.array(["10", "20", "30", "40"]) + + result = combine_neg_as_candidate_sequence( + pos_data=pos_data, + neg_data=neg_data, + neg_sample_num=2, + seq_delim=";", + ) + expected = ["1;2;10;20", "3;4;5;30;40"] + self.assertEqual(result.to_pylist(), expected) + def test_normalize_type_str_basic_types(self): """Test normalizing basic types.""" self.assertEqual(_normalize_type_str("int32"), "INT32") From cbf1bb610e219ef714e82be3cb93fc10d45b541b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Tue, 7 Apr 2026 14:41:42 +0800 Subject: [PATCH 11/12] [fix] each multi-value positive gets its own negatives MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous fix took only the FIRST item from a multi-value sequence and discarded the rest. This is incorrect — each item in the candidate sequence should be treated as a positive that gets its own negatives. Fix: - dataset.py: flatten per-row positives into a single 1D array, pass the flat array to the sampler so each positive gets its own expand_factor negatives. Then restore the list-array form so the combine function knows the per-row grouping. - combine_neg_as_candidate_sequence: rewrite to handle both single and multi-value pos. For multi-value, interleave per-position negatives: "pos1;neg1_1;...;neg1_n;pos2;neg2_1;...;posK;negK_n". - Add assertion: total positives across rows must fit in sampler batch_size (chunking can be added later if needed). - Add unit tests for multi-value pos with 1 and 2 negatives per position. Co-Authored-By: Claude Opus 4.6 (1M context) --- tzrec/datasets/dataset.py | 51 ++++++++++++++++++----------- tzrec/datasets/utils.py | 62 ++++++++++++++++++++++-------------- tzrec/datasets/utils_test.py | 28 +++++++++++++--- 3 files changed, 95 insertions(+), 46 deletions(-) diff --git a/tzrec/datasets/dataset.py b/tzrec/datasets/dataset.py index 4a2fb63d..ef4d3547 100644 --- a/tzrec/datasets/dataset.py +++ b/tzrec/datasets/dataset.py @@ -376,12 +376,14 @@ def _build_batch(self, input_data: Dict[str, pa.Array]) -> Batch: input_data, pos_sampled, neg_sampled, self._data_config ) else: - # If item_id_field is a sequence feature, the sampler can't - # process multi-value strings (it casts the whole array to - # int64). Extract a single representative value (the first - # item) for the sampler. The original (possibly multi-value) - # data is preserved for combining with sampled negatives. - saved_pos: Dict[str, pa.Array] = {} + # If item_id_field is a sequence feature, each item in the + # row's sequence is a positive that should get its own + # negatives. Flatten per-row positives into a single 1D + # array, call the sampler so each positive is treated as a + # separate query, then combine_neg_as_candidate_sequence + # interleaves them per row: + # [pos1, neg1_1, ..., pos2, neg2_1, ..., posK, negK_n]. + multi_pos_lists: Optional[pa.Array] = None if ( self._sampler_item_id_field is not None and self._sampler_item_id_field in self._seq_field_delims @@ -389,28 +391,41 @@ def _build_batch(self, input_data: Dict[str, pa.Array]) -> Batch: seq_delim = self._seq_field_delims[self._sampler_item_id_field] raw = input_data[self._sampler_item_id_field] if pa.types.is_string(raw.type): - saved_pos[self._sampler_item_id_field] = raw - # Take first item from each (possibly multi-value) row - split = pc.split_pattern(raw, seq_delim) - input_data[self._sampler_item_id_field] = pc.list_element( - split, 0 + multi_pos_lists = pc.split_pattern(raw, seq_delim) + flat_pos = pc.list_flatten(multi_pos_lists) + sampler_bs = self._sampler._batch_size + assert len(flat_pos) <= sampler_bs, ( + f"Total positives {len(flat_pos)} across rows " + f"exceeds sampler batch_size {sampler_bs}. Reduce " + f"batch_size or per-row sequence length." ) + input_data[self._sampler_item_id_field] = flat_pos sampled = self._sampler.get(input_data) - # Restore original (possibly multi-value) data for combine - for k, original in saved_pos.items(): - input_data[k] = original + if multi_pos_lists is not None: + # Restore the list-array form so combine_neg sees the + # original per-row grouping of positives. + input_data[self._sampler_item_id_field] = multi_pos_lists for k, v in sampled.items(): if k in input_data: seq_delim = self._seq_field_delims.get(k) if seq_delim is not None: - # neg_per_pos = total_negatives // batch_size - neg_per_pos = len(v) // len(input_data[k]) + pos_for_combine = input_data[k] + if pa.types.is_list(pos_for_combine.type): + num_pos = len(pc.list_flatten(pos_for_combine)) + else: + num_pos = len(pos_for_combine) + # Sampler returns batch_size * expand_factor negs + # (with padding); slice to the valid prefix. + neg_per_pos = len(v) // self._sampler._batch_size + if neg_per_pos < 1: + neg_per_pos = 1 + valid_negs = v.slice(0, num_pos * neg_per_pos) input_data[k] = combine_neg_as_candidate_sequence( - input_data[k], - v, + pos_for_combine, + valid_negs, neg_per_pos, seq_delim, ) diff --git a/tzrec/datasets/utils.py b/tzrec/datasets/utils.py index 0ad5bff5..27cca11a 100644 --- a/tzrec/datasets/utils.py +++ b/tzrec/datasets/utils.py @@ -13,7 +13,6 @@ from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple -import numpy as np import numpy.typing as npt import pyarrow as pa import pyarrow.compute as pc @@ -511,39 +510,54 @@ def combine_neg_as_candidate_sequence( ) -> pa.Array: """Combine positive and negative items into candidate sequences. - For each sample, joins the positive item with its negative items using the - sequence delimiter. Used when candidate features are sequence_id_features - in a JAGGED_SEQUENCE group. + Each row's positive items are flattened and each positive gets its own + set of `neg_sample_num` negatives. The result interleaves per position: + "pos1;neg1_1;...;neg1_n;pos2;neg2_1;...;posK;negK_n". + + Used when candidate features are sequence_id_features in a JAGGED_SEQUENCE + group. Supports both single-value positives ("123") and multi-value + positives ("1;2;3") per row. Args: - pos_data: positive item IDs, one per sample. Shape: (B,). - neg_data: negative item IDs. Shape: (B * neg_sample_num,). + pos_data: positive items per row. Either a 1D array of strings (each + row may contain a delimited sequence) or a list array. + neg_data: flat array of negative items, ordered by positive. + Length = total_positives * neg_sample_num. neg_sample_num: number of negative samples per positive. seq_delim: delimiter for joining items into a sequence string. Returns: - pa.Array of strings, each containing "pos;neg1;neg2;..." per sample. + pa.Array of strings, each row's candidate sequence. Example: - pos_data = ["1", "2"] - neg_data = ["3", "4", "5", "6"] - neg_sample_num = 2 + pos_data = ["1", "2;3"] + neg_data = ["10", "20", "30"] # 1 neg per pos, total 3 positives + neg_sample_num = 1 seq_delim = ";" - result = ["1;3;4", "2;5;6"] + result = ["1;10", "2;20;3;30"] """ - neg_str = neg_data.cast(pa.string()) - neg_offsets = pa.array( - np.concatenate( - [ - np.array([0]), - np.arange(neg_sample_num, len(neg_str) + 1, neg_sample_num), - ] - ) - ) - neg_lists = pa.ListArray.from_arrays(neg_offsets, neg_str) - neg_joined = pc.binary_join(neg_lists, seq_delim) - pos_str = pos_data.cast(pa.string()) - return pc.binary_join_element_wise(pos_str, neg_joined, seq_delim) + # Normalize pos_data to a list array (each row → list of positives) + if pa.types.is_list(pos_data.type) or pa.types.is_large_list(pos_data.type): + pos_lists = pos_data + else: + pos_lists = pc.split_pattern(pos_data.cast(pa.string()), seq_delim) + + counts = pc.list_value_length(pos_lists).to_pylist() + pos_flat = pc.list_flatten(pos_lists).cast(pa.string()).to_pylist() + neg_str_list = neg_data.cast(pa.string()).to_pylist() + + result: List[str] = [] + pos_offset = 0 + neg_offset = 0 + for k_i in counts: + parts: List[str] = [] + for _ in range(k_i): + parts.append(pos_flat[pos_offset]) + parts.extend(neg_str_list[neg_offset : neg_offset + neg_sample_num]) + pos_offset += 1 + neg_offset += neg_sample_num + result.append(seq_delim.join(parts)) + return pa.array(result) def calc_slice_position( diff --git a/tzrec/datasets/utils_test.py b/tzrec/datasets/utils_test.py index f552bbe4..1fb56180 100644 --- a/tzrec/datasets/utils_test.py +++ b/tzrec/datasets/utils_test.py @@ -233,9 +233,27 @@ def test_combine_neg_as_candidate_sequence_different_delim(self): self.assertEqual(result.to_pylist(), expected) def test_combine_neg_as_candidate_sequence_multivalue_pos(self): - """Test combine when pos_data already has multi-value sequences.""" - pos_data = pa.array(["1;2", "3;4;5"]) - neg_data = pa.array(["10", "20", "30", "40"]) + """Each item in multi-value pos sequence gets its own negatives.""" + pos_data = pa.array(["1;2", "3;4;5"]) # 5 positives total + neg_data = pa.array(["10", "20", "30", "40", "50"]) # 1 neg per pos + + result = combine_neg_as_candidate_sequence( + pos_data=pos_data, + neg_data=neg_data, + neg_sample_num=1, + seq_delim=";", + ) + # row1: pos=1 → "1;10", pos=2 → "2;20" → "1;10;2;20" + # row2: pos=3 → "3;30", pos=4 → "4;40", pos=5 → "5;50" + # → "3;30;4;40;5;50" + expected = ["1;10;2;20", "3;30;4;40;5;50"] + self.assertEqual(result.to_pylist(), expected) + + def test_combine_neg_as_candidate_sequence_multivalue_pos_multi_neg(self): + """Multi-value pos with multiple negatives per positive.""" + pos_data = pa.array(["1;2", "3"]) # 3 positives total + # neg_sample_num=2 → need 6 negatives + neg_data = pa.array(["10", "20", "30", "40", "50", "60"]) result = combine_neg_as_candidate_sequence( pos_data=pos_data, @@ -243,7 +261,9 @@ def test_combine_neg_as_candidate_sequence_multivalue_pos(self): neg_sample_num=2, seq_delim=";", ) - expected = ["1;2;10;20", "3;4;5;30;40"] + # row1: pos=1 → "1;10;20", pos=2 → "2;30;40" → "1;10;20;2;30;40" + # row2: pos=3 → "3;50;60" + expected = ["1;10;20;2;30;40", "3;50;60"] self.assertEqual(result.to_pylist(), expected) def test_normalize_type_str_basic_types(self): From 7c9d4f73b6db56ef5f526a6b144565e56312bdbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 8 Apr 2026 20:57:24 +0800 Subject: [PATCH 12/12] [refactor] dynamic expand_factor in sampler, expand user_id for V2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previous approach took the first item from multi-value item_id for sampling, which was incorrect — each positive in the candidate sequence should get its own negatives. Changes: 1. Sampler: move expand_factor computation from init() to get(). Compute dynamically from len(ids) and num_sample. Cache samplers by expand_factor. Remove np.pad (sampler supports any input size now). Applies to NegativeSampler, NegativeSamplerV2, HardNegativeSampler, HardNegativeSamplerV2. 2. Dataset: flatten multi-value item_id before sampler.get(). For samplers that use user_id_field (V2, HardNeg variants), also expand user_id to match the flattened length via pc.take with per-row repeat indices. Restore multi-value form after sampling. 3. Remove the len(flat_pos) <= sampler_bs assertion — no longer needed since the sampler handles any size. Co-Authored-By: Claude Opus 4.6 (1M context) --- tzrec/datasets/dataset.py | 47 ++++++++++-------- tzrec/datasets/sampler.py | 100 ++++++++++++++++++++++++-------------- 2 files changed, 91 insertions(+), 56 deletions(-) diff --git a/tzrec/datasets/dataset.py b/tzrec/datasets/dataset.py index ef4d3547..a6e9d4e4 100644 --- a/tzrec/datasets/dataset.py +++ b/tzrec/datasets/dataset.py @@ -207,6 +207,7 @@ def __init__( if data_config.HasField("sample_cost_field"): self._selected_input_names.add(data_config.sample_cost_field) self._sampler_item_id_field: Optional[str] = None + self._sampler_user_id_field: Optional[str] = None if self._data_config.HasField("sampler") and self._mode != Mode.PREDICT: sampler_type = self._data_config.WhichOneof("sampler") sampler_config = getattr(self._data_config, sampler_type) @@ -219,6 +220,7 @@ def __init__( "user_id_field" ): self._selected_input_names.add(sampler_config.user_id_field) + self._sampler_user_id_field = sampler_config.user_id_field # if set selected_input_names to None, # all columns will be reserved. if ( @@ -376,14 +378,14 @@ def _build_batch(self, input_data: Dict[str, pa.Array]) -> Batch: input_data, pos_sampled, neg_sampled, self._data_config ) else: - # If item_id_field is a sequence feature, each item in the - # row's sequence is a positive that should get its own - # negatives. Flatten per-row positives into a single 1D - # array, call the sampler so each positive is treated as a - # separate query, then combine_neg_as_candidate_sequence - # interleaves them per row: - # [pos1, neg1_1, ..., pos2, neg2_1, ..., posK, negK_n]. + # If item_id_field is a sequence feature, flatten per-row + # positives into a 1D array so the sampler treats each + # positive as a separate query. Expand user_id in the same + # way for samplers that rely on it (V2, HardNeg, HardNegV2). + # The sampler computes expand_factor dynamically from the + # actual input length, so no padding or batch_size limit. multi_pos_lists: Optional[pa.Array] = None + saved_user_ids: Optional[pa.Array] = None if ( self._sampler_item_id_field is not None and self._sampler_item_id_field in self._seq_field_delims @@ -393,20 +395,28 @@ def _build_batch(self, input_data: Dict[str, pa.Array]) -> Batch: if pa.types.is_string(raw.type): multi_pos_lists = pc.split_pattern(raw, seq_delim) flat_pos = pc.list_flatten(multi_pos_lists) - sampler_bs = self._sampler._batch_size - assert len(flat_pos) <= sampler_bs, ( - f"Total positives {len(flat_pos)} across rows " - f"exceeds sampler batch_size {sampler_bs}. Reduce " - f"batch_size or per-row sequence length." - ) input_data[self._sampler_item_id_field] = flat_pos + # Expand user_id to match flat positives (step 4). + if self._sampler_user_id_field is not None and ( + self._sampler_user_id_field in input_data + ): + counts_np = pc.list_value_length(multi_pos_lists).to_numpy() + row_indices = pa.array( + np.repeat(np.arange(len(counts_np)), counts_np) + ) + saved_user_ids = input_data[self._sampler_user_id_field] + input_data[self._sampler_user_id_field] = pc.take( + saved_user_ids, row_indices + ) + sampled = self._sampler.get(input_data) + # Restore multi-value form so combine_neg sees original grouping. if multi_pos_lists is not None: - # Restore the list-array form so combine_neg sees the - # original per-row grouping of positives. input_data[self._sampler_item_id_field] = multi_pos_lists + if saved_user_ids is not None: + input_data[self._sampler_user_id_field] = saved_user_ids for k, v in sampled.items(): if k in input_data: @@ -417,11 +427,8 @@ def _build_batch(self, input_data: Dict[str, pa.Array]) -> Batch: num_pos = len(pc.list_flatten(pos_for_combine)) else: num_pos = len(pos_for_combine) - # Sampler returns batch_size * expand_factor negs - # (with padding); slice to the valid prefix. - neg_per_pos = len(v) // self._sampler._batch_size - if neg_per_pos < 1: - neg_per_pos = 1 + # Sampler returns num_pos * expand_factor negs. + neg_per_pos = max(1, len(v) // max(1, num_pos)) valid_negs = v.slice(0, num_pos * neg_per_pos) input_data[k] = combine_neg_as_candidate_sequence( pos_for_combine, diff --git a/tzrec/datasets/sampler.py b/tzrec/datasets/sampler.py index 30ae5862..02451c84 100644 --- a/tzrec/datasets/sampler.py +++ b/tzrec/datasets/sampler.py @@ -14,7 +14,7 @@ import random import socket import time -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import graphlearn as gl import numpy as np @@ -428,20 +428,28 @@ def __init__( ), ) self._item_id_field = config.item_id_field - self._sampler = None + self._sampler_cache: Dict[int, Any] = {} self.item_id_delim = config.item_id_delim def init(self, client_id: int = -1) -> None: - """Init sampler client and samplers.""" + """Init sampler client; samplers are created lazily in get().""" super().init(client_id) - expand_factor = int(math.ceil(self._num_sample / self._batch_size)) - self._sampler = self._g.negative_sampler( - "item", expand_factor, strategy="node_weight" - ) + self._sampler_cache = {} + + def _get_sampler(self, expand_factor: int) -> Any: + if expand_factor not in self._sampler_cache: + self._sampler_cache[expand_factor] = self._g.negative_sampler( + "item", expand_factor, strategy="node_weight" + ) + return self._sampler_cache[expand_factor] def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]: """Sampling method. + Computes expand_factor dynamically from the actual input length so + that multi-value sequence positives (flattened upstream) are + supported without padding. + Args: input_data (dict): input data with item_id. @@ -449,8 +457,10 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]: Negative sampled feature dict. """ ids = _pa_ids_to_npy(input_data[self._item_id_field]) - ids = np.pad(ids, (0, self._batch_size - len(ids)), "edge") - nodes = self._sampler.get(ids) + num_pos = max(1, len(ids)) + expand_factor = int(math.ceil(self._num_sample / num_pos)) + sampler = self._get_sampler(expand_factor) + nodes = sampler.get(ids) features = self._parse_nodes(nodes) result_dict = dict(zip(self._valid_attr_names, features)) return result_dict @@ -509,15 +519,12 @@ def __init__( ) self._item_id_field = config.item_id_field self._user_id_field = config.user_id_field - self._sampler = None + self._sampler_cache: Dict[int, Any] = {} def init(self, client_id: int = -1) -> None: - """Init sampler client and samplers.""" + """Init sampler client; samplers are created lazily in get().""" super().init(client_id) - expand_factor = int(math.ceil(self._num_sample / self._batch_size)) - self._sampler = self._g.negative_sampler( - "edge", expand_factor, strategy="random", conditional=True - ) + self._sampler_cache = {} # prevent gl timeout worker_info = get_worker_info() @@ -528,20 +535,30 @@ def init(self, client_id: int = -1) -> None: {self._user_id_field: pa.array([0]), self._item_id_field: pa.array([0])} ) + def _get_sampler(self, expand_factor: int) -> Any: + if expand_factor not in self._sampler_cache: + self._sampler_cache[expand_factor] = self._g.negative_sampler( + "edge", expand_factor, strategy="random", conditional=True + ) + return self._sampler_cache[expand_factor] + def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]: """Sampling method. Args: - input_data (dict): input data with user_id and item_id. + input_data (dict): input data with user_id and item_id. When + item_id is a flattened multi-value sequence, user_id must + be expanded to match the flattened length. Returns: Negative sampled feature dict. """ src_ids = _pa_ids_to_npy(input_data[self._user_id_field]) dst_ids = _pa_ids_to_npy(input_data[self._item_id_field]) - src_ids = np.pad(src_ids, (0, self._batch_size - len(src_ids)), "edge") - dst_ids = np.pad(dst_ids, (0, self._batch_size - len(dst_ids)), "edge") - nodes = self._sampler.get(src_ids, dst_ids) + num_pos = max(1, len(dst_ids)) + expand_factor = int(math.ceil(self._num_sample / num_pos)) + sampler = self._get_sampler(expand_factor) + nodes = sampler.get(src_ids, dst_ids) features = self._parse_nodes(nodes) result_dict = dict(zip(self._valid_attr_names, features)) return result_dict @@ -602,20 +619,24 @@ def __init__( ) self._item_id_field = config.item_id_field self._user_id_field = config.user_id_field - self._neg_sampler = None + self._neg_sampler_cache: Dict[int, Any] = {} self._hard_neg_sampler = None def init(self, client_id: int = -1) -> None: - """Init sampler client and samplers.""" + """Init sampler client; neg_sampler created lazily in get().""" super().init(client_id) - expand_factor = int(math.ceil(self._num_sample / self._batch_size)) - self._neg_sampler = self._g.negative_sampler( - "item", expand_factor, strategy="node_weight" - ) + self._neg_sampler_cache = {} self._hard_neg_sampler = self._g.neighbor_sampler( ["hard_neg_edge"], self._num_hard_sample, strategy="full" ) + def _get_neg_sampler(self, expand_factor: int) -> Any: + if expand_factor not in self._neg_sampler_cache: + self._neg_sampler_cache[expand_factor] = self._g.negative_sampler( + "item", expand_factor, strategy="node_weight" + ) + return self._neg_sampler_cache[expand_factor] + def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]: """Sampling method. @@ -628,8 +649,10 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]: """ src_ids = _pa_ids_to_npy(input_data[self._user_id_field]) dst_ids = _pa_ids_to_npy(input_data[self._item_id_field]) - dst_ids = np.pad(dst_ids, (0, self._batch_size - len(dst_ids)), "edge") - nodes = self._neg_sampler.get(dst_ids) + num_pos = max(1, len(dst_ids)) + expand_factor = int(math.ceil(self._num_sample / num_pos)) + neg_sampler = self._get_neg_sampler(expand_factor) + nodes = neg_sampler.get(dst_ids) neg_features = self._parse_nodes(nodes) sparse_nodes = self._hard_neg_sampler.get(src_ids).layer_nodes(1) hard_neg_features, hard_neg_indices = self._parse_sparse_nodes(sparse_nodes) @@ -703,20 +726,24 @@ def __init__( ) self._item_id_field = config.item_id_field self._user_id_field = config.user_id_field - self._neg_sampler = None + self._neg_sampler_cache: Dict[int, Any] = {} self._hard_neg_sampler = None def init(self, client_id: int = -1) -> None: - """Init sampler client and samplers.""" + """Init sampler client; neg_sampler created lazily in get().""" super().init(client_id) - expand_factor = int(math.ceil(self._num_sample / self._batch_size)) - self._neg_sampler = self._g.negative_sampler( - "edge", expand_factor, strategy="random", conditional=True - ) + self._neg_sampler_cache = {} self._hard_neg_sampler = self._g.neighbor_sampler( ["hard_neg_edge"], self._num_hard_sample, strategy="full" ) + def _get_neg_sampler(self, expand_factor: int) -> Any: + if expand_factor not in self._neg_sampler_cache: + self._neg_sampler_cache[expand_factor] = self._g.negative_sampler( + "edge", expand_factor, strategy="random", conditional=True + ) + return self._neg_sampler_cache[expand_factor] + def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]: """Sampling method. @@ -729,9 +756,10 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]: """ src_ids = _pa_ids_to_npy(input_data[self._user_id_field]) dst_ids = _pa_ids_to_npy(input_data[self._item_id_field]) - padded_src_ids = np.pad(src_ids, (0, self._batch_size - len(src_ids)), "edge") - dst_ids = np.pad(dst_ids, (0, self._batch_size - len(dst_ids)), "edge") - nodes = self._neg_sampler.get(padded_src_ids, dst_ids) + num_pos = max(1, len(dst_ids)) + expand_factor = int(math.ceil(self._num_sample / num_pos)) + neg_sampler = self._get_neg_sampler(expand_factor) + nodes = neg_sampler.get(src_ids, dst_ids) neg_features = self._parse_nodes(nodes) sparse_nodes = self._hard_neg_sampler.get(src_ids).layer_nodes(1) hard_neg_features, hard_neg_indices = self._parse_sparse_nodes(sparse_nodes)