Skip to content

[refactor] Refactor HSTUMatch to use STU module with jagged sequences#462

Open
tiankongdeguiji wants to merge 12 commits intoalibaba:masterfrom
tiankongdeguiji:refactor/hstu-match-stu-jagged
Open

[refactor] Refactor HSTUMatch to use STU module with jagged sequences#462
tiankongdeguiji wants to merge 12 commits intoalibaba:masterfrom
tiankongdeguiji:refactor/hstu-match-stu-jagged

Conversation

@tiankongdeguiji
Copy link
Copy Markdown
Collaborator

Summary

  • Replace old HSTUEncoder/SequentialTransductionUnitJagged with modern STUStack/STULayer from gr/stu.py in HSTUMatch model
  • Add new SequencePreprocessor for match models — handles UIH-only sequences with optional contextual features, timestamps, and actions (no candidate concatenation)
  • Switch to JAGGED_SEQUENCE feature groups for native jagged tensor support with separate UIH and candidate groups
  • Update proto definitions: HSTUMatchTower uses HSTU config, new GRSequencePreprocessor in GRInputPreprocessor oneof
  • Fix WhichOneof field name bug in create_output_postprocessor
  • Remove unused HSTUEncoder class and tests

Changes

  • tzrec/protos/module.proto — Add GRSequencePreprocessor, add to GRInputPreprocessor oneof
  • tzrec/protos/tower.protoHSTUMatchTower uses HSTU + max_seq_len instead of HSTUEncoder
  • tzrec/protos/models/match_model.proto — Remove output_dim from HSTUMatch
  • tzrec/modules/gr/preprocessors.py — Add SequencePreprocessor class + register in factory
  • tzrec/modules/gr/postprocessors.py — Fix WhichOneof("input_preprocessor")WhichOneof("output_postprocessor")
  • tzrec/models/hstu.py — Rewrite with STUStack, SequencePreprocessor, jagged two-tower architecture
  • tzrec/modules/sequence.py — Remove HSTUEncoder (replaced by SequencePreprocessor + STUStack)
  • tzrec/modules/sequence_test.py — Remove HSTUEncoderTest
  • tzrec/models/hstu_test.py — Updated test with JAGGED_SEQUENCE groups and new proto config

Test plan

  • python -m tzrec.models.hstu_test — NORMAL and FX_TRACE tests pass
  • python -m tzrec.models.dlrm_hstu_test — DlrmHSTU regression test passes
  • python -m tzrec.modules.sequence_test — Sequence module tests pass
  • pre-commit run -a — All linting/formatting checks pass

🤖 Generated with Claude Code

tiankongdeguiji and others added 5 commits April 1, 2026 16:41
…quences

- Replace old HSTUEncoder/SequentialTransductionUnitJagged with STUStack/STULayer
- Add SequencePreprocessor for match models (UIH-only, no candidate concat)
  with optional contextual features, timestamps, and actions
- Use JAGGED_SEQUENCE feature groups for native jagged tensor support
- Separate UIH and candidate feature groups (two-tower architecture)
- Update HSTUMatchTower proto to use HSTU config instead of HSTUEncoder
- Add GRSequencePreprocessor to GRInputPreprocessor oneof
- Fix WhichOneof field name bug in create_output_postprocessor
- Remove unused HSTUEncoder class and tests from sequence.py

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…quences

- Replace old HSTUEncoder/SequentialTransductionUnitJagged with STUStack/STULayer
- Add SequencePreprocessor for match models (UIH-only, no candidate concat)
  with optional contextual features, timestamps, and actions
- Use JAGGED_SEQUENCE feature groups for native jagged tensor support
- Separate UIH and candidate feature groups (two-tower architecture)
- Update HSTUMatchTower proto to use HSTU config instead of HSTUEncoder
- Add GRSequencePreprocessor to GRInputPreprocessor oneof
- Fix WhichOneof field name bug in create_output_postprocessor
- Fix item tower reading wrong feature group (_group_name override)
- Remove unused HSTUEncoder class and tests from sequence.py

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- User tower always returns (B, D) last-position embeddings (not jagged)
- Add unit tests: test_hstu_match_train (forward + loss + backward),
  test_hstu_match_eval (metrics), test_hstu_match_export (FX trace),
  test_hstu_match_predict (inference)
- Update hstu_fg_mock.config with new proto format (HSTU, STU,
  JAGGED_SEQUENCE, SequencePreprocessor)
- Enable integration test_hstu_with_fg_train_eval_export (was skipped)
- Add predict step to integration test

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Change candidate feature from id_feature to sequence_id_feature (list of
  items per user: pos + negs) in hstu_fg_mock.config
- Add combine_neg_as_candidate_sequence() to auto-detect sequence fields in
  the negative sampling path and combine pos+neg into per-user candidate
  sequence strings, replacing the old enable_hstu branch
- Remove enable_hstu field from data.proto
- Remove enable_hstu branch and process_hstu_seq_data/process_hstu_neg_sample
  from dataset.py and utils.py
- Remove is_hstu parameter and HSTUIdMockInput from tests/utils.py
- Add @unittest.skipIf(*gpu_unavailable) to HSTU integration test
- Add tests for combine_neg_as_candidate_sequence in utils_test.py

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Fix CPU CI: remove @parameterized.expand from test_hstu_match_export
  (single-case test) to fix skipIf+parameterized decorator interaction
  that prevented GPU unavailable skip
- Fix GPU CI: change candidate group from JAGGED_SEQUENCE to DEEP with
  id_feature to fix type mismatch (string vs int64) in mock data join
  during integration test. Negative sampling with standard row-append
  works correctly with DEEP candidate group.
- Update HSTUMatchItemTower to read from DEEP group key (not .sequence)
- Update _build_batch to use NEG_DATA_GROUP for candidate items

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@tiankongdeguiji tiankongdeguiji force-pushed the refactor/hstu-match-stu-jagged branch from 288e04e to 24c8d5f Compare April 3, 2026 03:09
tiankongdeguiji and others added 2 commits April 3, 2026 11:32
Rename to better reflect the preprocessor's purpose: it processes UIH
(User Interaction History) sequences only, without candidate concat.
This parallels the codebase convention where "uih" is the standard
group name for user history features.

- SequencePreprocessor → UIHPreprocessor
- GRSequencePreprocessor → GRUIHPreprocessor
- sequence_preprocessor → uih_preprocessor (proto oneof field)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
HSTUMatchItemTower should only L2-normalize when similarity method is
COSINE, matching the pattern in DSSMTower. INNER_PRODUCT similarity
does not require normalization.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@tiankongdeguiji tiankongdeguiji added the claude-review Let Claude Review label Apr 3, 2026
@github-actions github-actions bot removed the claude-review Let Claude Review label Apr 3, 2026
def process_hstu_neg_sample(
input_data: Dict[str, pa.Array],
v: pa.Array,
def combine_neg_as_candidate_sequence(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug risk: missing input validation. If len(neg_data) != len(pos_data) * neg_sample_num (e.g., sampler returns unexpected count), the offset array silently produces wrong results — corrupt training data with no error signal.

Suggested change
def combine_neg_as_candidate_sequence(
def combine_neg_as_candidate_sequence(
pos_data: pa.Array,
neg_data: pa.Array,
neg_sample_num: int,
seq_delim: str,
) -> pa.Array:
"""Combine positive and negative items into candidate sequences.
For each sample, joins the positive item with its negative items using the
sequence delimiter. Used when candidate features are sequence_id_features
in a JAGGED_SEQUENCE group.
Args:
pos_data: positive item IDs, one per sample. Shape: (B,).
neg_data: negative item IDs. Shape: (B * neg_sample_num,).
neg_sample_num: number of negative samples per positive.
seq_delim: delimiter for joining items into a sequence string.
Returns:
pa.Array of strings, each containing "pos;neg1;neg2;..." per sample.
Example:
pos_data = ["1", "2"]
neg_data = ["3", "4", "5", "6"]
neg_sample_num = 2
seq_delim = ";"
result = ["1;3;4", "2;5;6"]
"""
assert len(neg_data) == len(pos_data) * neg_sample_num, (
f"neg_data length ({len(neg_data)}) must equal "
f"pos_data length ({len(pos_data)}) * neg_sample_num ({neg_sample_num})"
)

Comment on lines +46 to +49
Processes UIH (User Interaction History) sequences through UIHPreprocessor,
HSTUPositionalEncoder, and STUStack to produce user embeddings. During training,
returns one embedding per UIH position (autoregressive). During inference, returns
the last position embedding per user for ANN retrieval.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inaccurate docstring. The class docstring says "During training, returns one embedding per UIH position (autoregressive)", but the forward method at line 164 always extracts the last-position embedding via user_emb[seq_offsets[1:] - 1] regardless of training mode. The return shape is always (B, D).

Consider simplifying to: "Produces user embeddings by extracting the last-position embedding per user."

class HSTUMatchItemTower(MatchTowerWoEG):
"""HSTU Match model item tower.

Projects candidate embeddings to STU embedding dimension and L2 normalizes.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this says "and L2 normalizes" unconditionally, but normalization only happens for COSINE similarity (line 215). Consider: "Projects candidate embeddings to STU embedding dimension. L2 normalizes when using COSINE similarity."

Comment on lines +497 to +501
if tower._pass_grouped_features:
tower_input = grouped_features
else:
tower_input = grouped_features[self._group_name]
return {f"{self._tower_name}_emb": tower(tower_input)}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Design note: This boolean flag toggles between two fundamentally different calling conventions (Dict[str, Tensor] vs single Tensor). It works but is type-unsafe — tower_input is implicitly Union[Dict, Tensor]. Consider whether overriding predict in subclasses or accepting group_name: Optional[str] would be cleaner long-term. Not blocking, but worth noting for maintainability.

seq_embeddings = seq_embeddings + action_embeddings

# Timestamps: always use zeros (timestamps are optional for match models)
seq_timestamps = torch.zeros(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance: This allocates a zeros tensor on GPU every forward call, and when max_contextual_seq_len > 0, a second zeros tensor is also allocated via _fx_timestamp_contextual_zeros. Since timestamps are always zero for this preprocessor, consider pre-allocating a buffer in __init__ (via register_buffer) and slicing it, or checking whether downstream consumers (e.g., L2NormPostprocessor) can skip timestamp processing entirely when unused.

@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 3, 2026

Code Review Summary

Well-structured refactor that cleanly replaces the old HSTUEncoder with modern STUStack/STULayer, adds UIHPreprocessor for match models, and simplifies negative sampling. The two-tower separation is clear, proto definitions are well-organized, and the WhichOneof bugfix is a good catch. Overall a solid improvement.

Issues to Address

Correctness (2 items — see inline comments):

  • combine_neg_as_candidate_sequence has no validation that len(neg_data) == len(pos_data) * neg_sample_num. Mismatches silently corrupt training data.
  • user_emb[seq_offsets[1:] - 1] produces index -1 (wrapping to last element) if any user has zero-length UIH sequence — silently returns wrong embedding.

Test Coverage (2 items):

  • UIHPreprocessor is ~220 lines of new code with conditional branches (contextual features, action encoder) but has no unit test. It's only exercised indirectly through integration tests. A dedicated test in preprocessors_test.py (following the existing hypothesis-based pattern) would be high-value.
  • All hstu_test.py tests are GPU-only (@unittest.skipIf(*gpu_unavailable)). CPU-only CI gets zero coverage for the entire HSTUMatch model. At minimum, a test verifying model instantiation and config validation without a forward pass could run on CPU.

Documentation (2 items — see inline comments):

  • HSTUMatchUserTower docstring claims different train/inference behavior, but forward always extracts last-position embedding.
  • HSTUMatchItemTower docstring says "L2 normalizes" unconditionally, but it's conditional on COSINE similarity.

Design Notes (non-blocking)

  • The _pass_grouped_features boolean dispatch in match_model.py toggles between Dict and Tensor calling conventions — pragmatic but type-unsafe. See inline comment.
  • HSTUMatchItemTower overrides _group_name immediately after the parent sets it — a "set then override" pattern that could be cleaner with an explicit parameter.
  • UIHPreprocessor and ContextualInterleavePreprocessor share substantial duplicated logic (contextual projection, concat_2D_jagged, action encoder setup). Extracting shared methods would reduce maintenance burden.
  • Hardcoded zero timestamps allocated every forward call in UIHPreprocessor — could be pre-allocated or skipped when downstream postprocessors don't use them.

tiankongdeguiji and others added 5 commits April 3, 2026 13:49
- Fix InputPreprocessor abstract forward() signature to match actual
  implementations (grouped_features dict, not individual params)
- Add comment documenting non-empty sequence assumption at user_emb extraction
- Add comment explaining target_offsets placeholder in UIHPreprocessor
- Update HSTUMatchItemTower docstring for conditional normalization

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Use ONE feature item_id (sequence_id_feature) for both negative sampler
and candidate group (JAGGED_SEQUENCE). The sampler reads single-value
item_id from raw data; combine_neg_as_candidate_sequence creates the
candidate sequence "pos;neg1;neg2" at runtime.

- hstu_fg_mock.config: item_id as sequence_id_feature, candidate group
  as JAGGED_SEQUENCE
- HSTUMatchItemTower reads from "candidate.sequence" (jagged)
- HSTUMatch uses _jagged_candidate_sim() to reshape (sum_cand, D) into
  (B, 1+num_neg, D) and compute per-user similarity
- combine_neg_as_candidate_sequence: compute neg_per_pos from data length
  instead of using sampler._num_sample directly
- IdMockInput: add as_string parameter for sequence features used as
  sampler item_id (single value but as string)
- build_mock_input_with_fg: accept neg_fields, generate single-value
  string mock data for sequence features in neg_fields
- create_mock_data: cast unique_id field to string when as_string=True
  for type-consistent joins

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The negative sampler casts input_data[item_id_field] to int64 via
_pa_ids_to_npy and pads to batch_size. It cannot process multi-value
sequence strings like "1;2;3".

Fix: in _build_batch, if item_id_field is a sequence field and the raw
data is a string, extract the first item from each (possibly multi-value)
row before calling the sampler. The original (possibly multi-value) data
is preserved and used by combine_neg_as_candidate_sequence to build the
final candidate sequence "pos_seq;neg1;neg2".

- Save sampler item_id_field on dataset init for fast lookup
- Pre-process input_data[item_id_field]: split by delim, take first item
- Restore original after sampler.get() so combine sees full sequence
- Add unit test for multi-value pos in combine_neg_as_candidate_sequence

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The previous fix took only the FIRST item from a multi-value sequence
and discarded the rest. This is incorrect — each item in the candidate
sequence should be treated as a positive that gets its own negatives.

Fix:
- dataset.py: flatten per-row positives into a single 1D array, pass
  the flat array to the sampler so each positive gets its own
  expand_factor negatives. Then restore the list-array form so the
  combine function knows the per-row grouping.
- combine_neg_as_candidate_sequence: rewrite to handle both single and
  multi-value pos. For multi-value, interleave per-position negatives:
  "pos1;neg1_1;...;neg1_n;pos2;neg2_1;...;posK;negK_n".
- Add assertion: total positives across rows must fit in sampler
  batch_size (chunking can be added later if needed).
- Add unit tests for multi-value pos with 1 and 2 negatives per position.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Previous approach took the first item from multi-value item_id for
sampling, which was incorrect — each positive in the candidate sequence
should get its own negatives.

Changes:
1. Sampler: move expand_factor computation from init() to get(). Compute
   dynamically from len(ids) and num_sample. Cache samplers by
   expand_factor. Remove np.pad (sampler supports any input size now).
   Applies to NegativeSampler, NegativeSamplerV2, HardNegativeSampler,
   HardNegativeSamplerV2.

2. Dataset: flatten multi-value item_id before sampler.get(). For
   samplers that use user_id_field (V2, HardNeg variants), also expand
   user_id to match the flattened length via pc.take with per-row
   repeat indices. Restore multi-value form after sampling.

3. Remove the len(flat_pos) <= sampler_bs assertion — no longer needed
   since the sampler handles any size.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant