diff --git a/tzrec/datasets/dataset.py b/tzrec/datasets/dataset.py index dd3cd8aa..a6e9d4e4 100644 --- a/tzrec/datasets/dataset.py +++ b/tzrec/datasets/dataset.py @@ -30,8 +30,7 @@ CKPT_SOURCE_ID, Batch, RecordBatchTensor, - process_hstu_neg_sample, - process_hstu_seq_data, + combine_neg_as_candidate_sequence, remove_nullable, ) from tzrec.features.feature import BaseFeature @@ -177,7 +176,6 @@ def __init__( self._reserved_columns = reserved_columns or [] self._mode = mode self._debug_level = debug_level - self._enable_hstu = data_config.enable_hstu self.sampler_type = ( self._data_config.WhichOneof("sampler") if self._data_config.HasField("sampler") @@ -208,6 +206,8 @@ def __init__( self._selected_input_names |= set(data_config.sample_weight_fields) if data_config.HasField("sample_cost_field"): self._selected_input_names.add(data_config.sample_cost_field) + self._sampler_item_id_field: Optional[str] = None + self._sampler_user_id_field: Optional[str] = None if self._data_config.HasField("sampler") and self._mode != Mode.PREDICT: sampler_type = self._data_config.WhichOneof("sampler") sampler_config = getattr(self._data_config, sampler_type) @@ -215,10 +215,12 @@ def __init__( "item_id_field" ): self._selected_input_names.add(sampler_config.item_id_field) + self._sampler_item_id_field = sampler_config.item_id_field if hasattr(sampler_config, "user_id_field") and sampler_config.HasField( "user_id_field" ): self._selected_input_names.add(sampler_config.user_id_field) + self._sampler_user_id_field = sampler_config.user_id_field # if set selected_input_names to None, # all columns will be reserved. if ( @@ -238,6 +240,14 @@ def __init__( self._sampler = None self._sampler_inited = False + # Build mapping of field_name → sequence_delim for candidate sequence + # auto-detection during negative sampling. + self._seq_field_delims: Dict[str, str] = {} + for feature in features: + if hasattr(feature, "sequence_delim") and feature.sequence_delim: + for input_name in feature.inputs: + self._seq_field_delims[input_name] = feature.sequence_delim + self._reader = None def launch_sampler_cluster( @@ -367,49 +377,67 @@ def _build_batch(self, input_data: Dict[str, pa.Array]) -> Batch: input_data = _expand_tdm_sample( input_data, pos_sampled, neg_sampled, self._data_config ) - elif self._enable_hstu: - seq_attr = self._sampler._item_id_field - - ( - input_data_k_split, - input_data_k_split_slice, - pre_seq_filter_reshaped_joined, - ) = process_hstu_seq_data( - input_data=input_data, - seq_attr=seq_attr, - seq_str_delim=self._sampler.item_id_delim, - ) - if self._mode == Mode.TRAIN: - # Training using all possible target items - input_data[seq_attr] = input_data_k_split_slice - elif self._mode == Mode.EVAL: - # Evaluation using the last item for previous sequence - input_data[seq_attr] = input_data_k_split.values.take( - pa.array(input_data_k_split.offsets.to_numpy()[1:] - 1) - ) - sampled = self._sampler.get(input_data) - # To keep consistent with other process, use two functions - for k, v in sampled.items(): - if k in input_data: - combined = process_hstu_neg_sample( - input_data, - v, - self._sampler._num_sample, - self._sampler.item_id_delim, - seq_attr, - ) - # Combine here to make embddings of both user sequence - # and target item are the same - input_data[k] = pa.concat_arrays( - [pre_seq_filter_reshaped_joined, combined] - ) - else: - input_data[k] = v else: + # If item_id_field is a sequence feature, flatten per-row + # positives into a 1D array so the sampler treats each + # positive as a separate query. Expand user_id in the same + # way for samplers that rely on it (V2, HardNeg, HardNegV2). + # The sampler computes expand_factor dynamically from the + # actual input length, so no padding or batch_size limit. + multi_pos_lists: Optional[pa.Array] = None + saved_user_ids: Optional[pa.Array] = None + if ( + self._sampler_item_id_field is not None + and self._sampler_item_id_field in self._seq_field_delims + ): + seq_delim = self._seq_field_delims[self._sampler_item_id_field] + raw = input_data[self._sampler_item_id_field] + if pa.types.is_string(raw.type): + multi_pos_lists = pc.split_pattern(raw, seq_delim) + flat_pos = pc.list_flatten(multi_pos_lists) + input_data[self._sampler_item_id_field] = flat_pos + + # Expand user_id to match flat positives (step 4). + if self._sampler_user_id_field is not None and ( + self._sampler_user_id_field in input_data + ): + counts_np = pc.list_value_length(multi_pos_lists).to_numpy() + row_indices = pa.array( + np.repeat(np.arange(len(counts_np)), counts_np) + ) + saved_user_ids = input_data[self._sampler_user_id_field] + input_data[self._sampler_user_id_field] = pc.take( + saved_user_ids, row_indices + ) + sampled = self._sampler.get(input_data) + + # Restore multi-value form so combine_neg sees original grouping. + if multi_pos_lists is not None: + input_data[self._sampler_item_id_field] = multi_pos_lists + if saved_user_ids is not None: + input_data[self._sampler_user_id_field] = saved_user_ids + for k, v in sampled.items(): if k in input_data: - input_data[k] = pa.concat_arrays([input_data[k], v]) + seq_delim = self._seq_field_delims.get(k) + if seq_delim is not None: + pos_for_combine = input_data[k] + if pa.types.is_list(pos_for_combine.type): + num_pos = len(pc.list_flatten(pos_for_combine)) + else: + num_pos = len(pos_for_combine) + # Sampler returns num_pos * expand_factor negs. + neg_per_pos = max(1, len(v) // max(1, num_pos)) + valid_negs = v.slice(0, num_pos * neg_per_pos) + input_data[k] = combine_neg_as_candidate_sequence( + pos_for_combine, + valid_negs, + neg_per_pos, + seq_delim, + ) + else: + input_data[k] = pa.concat_arrays([input_data[k], v]) else: input_data[k] = v diff --git a/tzrec/datasets/sampler.py b/tzrec/datasets/sampler.py index 30ae5862..02451c84 100644 --- a/tzrec/datasets/sampler.py +++ b/tzrec/datasets/sampler.py @@ -14,7 +14,7 @@ import random import socket import time -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import graphlearn as gl import numpy as np @@ -428,20 +428,28 @@ def __init__( ), ) self._item_id_field = config.item_id_field - self._sampler = None + self._sampler_cache: Dict[int, Any] = {} self.item_id_delim = config.item_id_delim def init(self, client_id: int = -1) -> None: - """Init sampler client and samplers.""" + """Init sampler client; samplers are created lazily in get().""" super().init(client_id) - expand_factor = int(math.ceil(self._num_sample / self._batch_size)) - self._sampler = self._g.negative_sampler( - "item", expand_factor, strategy="node_weight" - ) + self._sampler_cache = {} + + def _get_sampler(self, expand_factor: int) -> Any: + if expand_factor not in self._sampler_cache: + self._sampler_cache[expand_factor] = self._g.negative_sampler( + "item", expand_factor, strategy="node_weight" + ) + return self._sampler_cache[expand_factor] def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]: """Sampling method. + Computes expand_factor dynamically from the actual input length so + that multi-value sequence positives (flattened upstream) are + supported without padding. + Args: input_data (dict): input data with item_id. @@ -449,8 +457,10 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]: Negative sampled feature dict. """ ids = _pa_ids_to_npy(input_data[self._item_id_field]) - ids = np.pad(ids, (0, self._batch_size - len(ids)), "edge") - nodes = self._sampler.get(ids) + num_pos = max(1, len(ids)) + expand_factor = int(math.ceil(self._num_sample / num_pos)) + sampler = self._get_sampler(expand_factor) + nodes = sampler.get(ids) features = self._parse_nodes(nodes) result_dict = dict(zip(self._valid_attr_names, features)) return result_dict @@ -509,15 +519,12 @@ def __init__( ) self._item_id_field = config.item_id_field self._user_id_field = config.user_id_field - self._sampler = None + self._sampler_cache: Dict[int, Any] = {} def init(self, client_id: int = -1) -> None: - """Init sampler client and samplers.""" + """Init sampler client; samplers are created lazily in get().""" super().init(client_id) - expand_factor = int(math.ceil(self._num_sample / self._batch_size)) - self._sampler = self._g.negative_sampler( - "edge", expand_factor, strategy="random", conditional=True - ) + self._sampler_cache = {} # prevent gl timeout worker_info = get_worker_info() @@ -528,20 +535,30 @@ def init(self, client_id: int = -1) -> None: {self._user_id_field: pa.array([0]), self._item_id_field: pa.array([0])} ) + def _get_sampler(self, expand_factor: int) -> Any: + if expand_factor not in self._sampler_cache: + self._sampler_cache[expand_factor] = self._g.negative_sampler( + "edge", expand_factor, strategy="random", conditional=True + ) + return self._sampler_cache[expand_factor] + def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]: """Sampling method. Args: - input_data (dict): input data with user_id and item_id. + input_data (dict): input data with user_id and item_id. When + item_id is a flattened multi-value sequence, user_id must + be expanded to match the flattened length. Returns: Negative sampled feature dict. """ src_ids = _pa_ids_to_npy(input_data[self._user_id_field]) dst_ids = _pa_ids_to_npy(input_data[self._item_id_field]) - src_ids = np.pad(src_ids, (0, self._batch_size - len(src_ids)), "edge") - dst_ids = np.pad(dst_ids, (0, self._batch_size - len(dst_ids)), "edge") - nodes = self._sampler.get(src_ids, dst_ids) + num_pos = max(1, len(dst_ids)) + expand_factor = int(math.ceil(self._num_sample / num_pos)) + sampler = self._get_sampler(expand_factor) + nodes = sampler.get(src_ids, dst_ids) features = self._parse_nodes(nodes) result_dict = dict(zip(self._valid_attr_names, features)) return result_dict @@ -602,20 +619,24 @@ def __init__( ) self._item_id_field = config.item_id_field self._user_id_field = config.user_id_field - self._neg_sampler = None + self._neg_sampler_cache: Dict[int, Any] = {} self._hard_neg_sampler = None def init(self, client_id: int = -1) -> None: - """Init sampler client and samplers.""" + """Init sampler client; neg_sampler created lazily in get().""" super().init(client_id) - expand_factor = int(math.ceil(self._num_sample / self._batch_size)) - self._neg_sampler = self._g.negative_sampler( - "item", expand_factor, strategy="node_weight" - ) + self._neg_sampler_cache = {} self._hard_neg_sampler = self._g.neighbor_sampler( ["hard_neg_edge"], self._num_hard_sample, strategy="full" ) + def _get_neg_sampler(self, expand_factor: int) -> Any: + if expand_factor not in self._neg_sampler_cache: + self._neg_sampler_cache[expand_factor] = self._g.negative_sampler( + "item", expand_factor, strategy="node_weight" + ) + return self._neg_sampler_cache[expand_factor] + def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]: """Sampling method. @@ -628,8 +649,10 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]: """ src_ids = _pa_ids_to_npy(input_data[self._user_id_field]) dst_ids = _pa_ids_to_npy(input_data[self._item_id_field]) - dst_ids = np.pad(dst_ids, (0, self._batch_size - len(dst_ids)), "edge") - nodes = self._neg_sampler.get(dst_ids) + num_pos = max(1, len(dst_ids)) + expand_factor = int(math.ceil(self._num_sample / num_pos)) + neg_sampler = self._get_neg_sampler(expand_factor) + nodes = neg_sampler.get(dst_ids) neg_features = self._parse_nodes(nodes) sparse_nodes = self._hard_neg_sampler.get(src_ids).layer_nodes(1) hard_neg_features, hard_neg_indices = self._parse_sparse_nodes(sparse_nodes) @@ -703,20 +726,24 @@ def __init__( ) self._item_id_field = config.item_id_field self._user_id_field = config.user_id_field - self._neg_sampler = None + self._neg_sampler_cache: Dict[int, Any] = {} self._hard_neg_sampler = None def init(self, client_id: int = -1) -> None: - """Init sampler client and samplers.""" + """Init sampler client; neg_sampler created lazily in get().""" super().init(client_id) - expand_factor = int(math.ceil(self._num_sample / self._batch_size)) - self._neg_sampler = self._g.negative_sampler( - "edge", expand_factor, strategy="random", conditional=True - ) + self._neg_sampler_cache = {} self._hard_neg_sampler = self._g.neighbor_sampler( ["hard_neg_edge"], self._num_hard_sample, strategy="full" ) + def _get_neg_sampler(self, expand_factor: int) -> Any: + if expand_factor not in self._neg_sampler_cache: + self._neg_sampler_cache[expand_factor] = self._g.negative_sampler( + "edge", expand_factor, strategy="random", conditional=True + ) + return self._neg_sampler_cache[expand_factor] + def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]: """Sampling method. @@ -729,9 +756,10 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]: """ src_ids = _pa_ids_to_npy(input_data[self._user_id_field]) dst_ids = _pa_ids_to_npy(input_data[self._item_id_field]) - padded_src_ids = np.pad(src_ids, (0, self._batch_size - len(src_ids)), "edge") - dst_ids = np.pad(dst_ids, (0, self._batch_size - len(dst_ids)), "edge") - nodes = self._neg_sampler.get(padded_src_ids, dst_ids) + num_pos = max(1, len(dst_ids)) + expand_factor = int(math.ceil(self._num_sample / num_pos)) + neg_sampler = self._get_neg_sampler(expand_factor) + nodes = neg_sampler.get(src_ids, dst_ids) neg_features = self._parse_nodes(nodes) sparse_nodes = self._hard_neg_sampler.get(src_ids).layer_nodes(1) hard_neg_features, hard_neg_indices = self._parse_sparse_nodes(sparse_nodes) diff --git a/tzrec/datasets/utils.py b/tzrec/datasets/utils.py index 9009c762..27cca11a 100644 --- a/tzrec/datasets/utils.py +++ b/tzrec/datasets/utils.py @@ -13,7 +13,6 @@ from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple -import numpy as np import numpy.typing as npt import pyarrow as pa import pyarrow.compute as pc @@ -503,129 +502,62 @@ def to_dict( return tensor_dict -def process_hstu_seq_data( - input_data: Dict[str, pa.Array], - seq_attr: str, - seq_str_delim: str, -) -> Tuple[pa.Array, pa.Array, pa.Array]: - """Process sequence data for HSTU match model. - - Args: - input_data: Dictionary containing input arrays - seq_attr: Name of the sequence attribute field - seq_str_delim: Delimiter used to separate sequence items - - Returns: - Tuple containing: - - input_data_k_split: pa.Array, Original sequence items - - input_data_k_split_slice: pa.Array, Target items for autoregressive training - - pre_seq_filter_reshaped_joined: pa.Array, - Training sequence for autoregressive training - """ - # default sequence data is string - if pa.types.is_string(input_data[seq_attr].type): - input_data_k_split = pc.split_pattern(input_data[seq_attr], seq_str_delim) - # Get target items for training for autoregressive training - # Example: [1,2,3,4,5] -> [2,3,4,5] - input_data_k_split_slice = pc.list_flatten( - pc.list_slice(input_data_k_split, start=1) - ) - - # Directly extract the training sequence for autoregressive training - # Operation target example: [1,2,3,4,5] -> [1,2,3,4] - # (corresponding target: [2,3,4,5]) - # As this can not be achieved by pyarrow.compute, and for loop is costly - # we need to do this using pa.ListArray.from_arrays using offsets - # 1. transfer to numpy and filter out the last item - pre_seq = pc.list_flatten(input_data_k_split).to_numpy(zero_copy_only=False) - # Mark last items of each seq with '-1' - pre_seq[input_data_k_split.offsets.to_numpy()[1:] - 1] = "-1" - # Filter out -1 marker elements - mask = pre_seq != "-1" - pre_seq_filter = pre_seq[mask] - # 2. create offsets for reshaping filtered sequence - # The offsets should be created extract the training sequence - # Example: if the original offsets are [0,2,5,9], after filter, - # for the offsets should be [0, 1, 3, 6] - # that is, [0] + [2-1, 5-2, 9-3] - pre_seq_filter_offsets = pa.array( - np.concatenate( - [ - np.array([0]), - input_data_k_split.offsets[1:].to_numpy(zero_copy_only=False) - - np.arange( - 1, - len( - input_data_k_split.offsets[1:].to_numpy( - zero_copy_only=False - ) - ) - + 1, - ), - ] - ) - ) - pre_seq_filter_reshaped = pa.ListArray.from_arrays( - pre_seq_filter_offsets, pre_seq_filter - ) - # Join filtered sequence with delimiter - pre_seq_filter_reshaped_joined = pc.binary_join( - pre_seq_filter_reshaped, seq_str_delim - ) - - return ( - input_data_k_split, - input_data_k_split_slice, - pre_seq_filter_reshaped_joined, - ) - - -def process_hstu_neg_sample( - input_data: Dict[str, pa.Array], - v: pa.Array, +def combine_neg_as_candidate_sequence( + pos_data: pa.Array, + neg_data: pa.Array, neg_sample_num: int, - seq_str_delim: str, - seq_attr: str, + seq_delim: str, ) -> pa.Array: - """Process negative samples for HSTU match model. + """Combine positive and negative items into candidate sequences. + + Each row's positive items are flattened and each positive gets its own + set of `neg_sample_num` negatives. The result interleaves per position: + "pos1;neg1_1;...;neg1_n;pos2;neg2_1;...;posK;negK_n". + + Used when candidate features are sequence_id_features in a JAGGED_SEQUENCE + group. Supports both single-value positives ("123") and multi-value + positives ("1;2;3") per row. Args: - input_data: Dict[str, pa.Array], Dictionary containing input arrays - v: pa.Array, negative samples. - neg_sample_num: int, number of negative samples. - seq_str_delim: str, delimiter for sequence string. - seq_attr: str, attribute name of sequence. + pos_data: positive items per row. Either a 1D array of strings (each + row may contain a delimited sequence) or a list array. + neg_data: flat array of negative items, ordered by positive. + Length = total_positives * neg_sample_num. + neg_sample_num: number of negative samples per positive. + seq_delim: delimiter for joining items into a sequence string. Returns: - pa.Array: Processed negative samples + pa.Array of strings, each row's candidate sequence. + + Example: + pos_data = ["1", "2;3"] + neg_data = ["10", "20", "30"] # 1 neg per pos, total 3 positives + neg_sample_num = 1 + seq_delim = ";" + result = ["1;10", "2;20;3;30"] """ - # The goal is to make neg samples concat to the training sequence - # Example: - # input_data[seq_attr] = ["1;2;3"] - # neg_sample_num = 2 - # v = [4,5,6,7,8,9] - # then the output should be [[1,4,5], [2,6,7], [3,8,9]] - v_str = v.cast(pa.string()) - filtered_v_offsets = pa.array( - np.concatenate( - [ - np.array([0]), - np.arange(neg_sample_num, len(v_str) + 1, neg_sample_num), - ] - ) - ) - # Reshape v for each input_data[seq_attr] - # Example:[4,5,6,7,8,9] -> [[4,5], [6,7], [8,9]] - filtered_v_palist = pa.ListArray.from_arrays(filtered_v_offsets, v_str) - # Using string for join, as not found operation for ListArray achieving this - # Example: [[4,5], [6,7], [8,9]] -> ["4;5", "6;7", "8;9"] - sampled_joined = pc.binary_join(filtered_v_palist, seq_str_delim) - # Combine training sequence and target items - # Example: ["1;2;3"] + ["4;5", "6;7", "8;9"] - # -> ["1;4;5", "2;6;7", "3;8;9"] - return pc.binary_join_element_wise( - input_data[seq_attr], sampled_joined, seq_str_delim - ) + # Normalize pos_data to a list array (each row → list of positives) + if pa.types.is_list(pos_data.type) or pa.types.is_large_list(pos_data.type): + pos_lists = pos_data + else: + pos_lists = pc.split_pattern(pos_data.cast(pa.string()), seq_delim) + + counts = pc.list_value_length(pos_lists).to_pylist() + pos_flat = pc.list_flatten(pos_lists).cast(pa.string()).to_pylist() + neg_str_list = neg_data.cast(pa.string()).to_pylist() + + result: List[str] = [] + pos_offset = 0 + neg_offset = 0 + for k_i in counts: + parts: List[str] = [] + for _ in range(k_i): + parts.append(pos_flat[pos_offset]) + parts.extend(neg_str_list[neg_offset : neg_offset + neg_sample_num]) + pos_offset += 1 + neg_offset += neg_sample_num + result.append(seq_delim.join(parts)) + return pa.array(result) def calc_slice_position( diff --git a/tzrec/datasets/utils_test.py b/tzrec/datasets/utils_test.py index 46ddb8b0..1fb56180 100644 --- a/tzrec/datasets/utils_test.py +++ b/tzrec/datasets/utils_test.py @@ -14,16 +14,14 @@ import numpy as np import pyarrow as pa -import pyarrow.compute as pc from tzrec.datasets.utils import ( _normalize_type_str, calc_remaining_intervals, calc_slice_intervals, calc_slice_position, + combine_neg_as_candidate_sequence, get_input_fields_proto, - process_hstu_neg_sample, - process_hstu_seq_data, ) from tzrec.protos import data_pb2 from tzrec.protos.data_pb2 import FieldType @@ -206,80 +204,67 @@ def test_calc_slice_intervals_topology_change(self): self.assertEqual(total_rows, 400) # All remaining rows accounted for - def test_process_hstu_seq_data(self): - """Test processing sequence data for HSTU match model.""" - input_data = {"sequence": pa.array(["1;2;3;4", "5;6;7;8", "9;10;11;12"])} + def test_combine_neg_as_candidate_sequence(self): + """Test combining positive and negative items into candidate sequences.""" + pos_data = pa.array(["1", "2", "3"]) + neg_data = pa.array(["101", "102", "103", "104", "105", "106"]) - split, slice_result, training_seq = process_hstu_seq_data( - input_data=input_data, seq_attr="sequence", seq_str_delim=";" - ) - - # Verify results - # Test original split sequences - expected_split_values = [ - "1", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "9", - "10", - "11", - "12", - ] - self.assertEqual(pc.list_flatten(split).to_pylist(), expected_split_values) - - # Test sliced sequences (target items) - expected_slice_values = ["2", "3", "4", "6", "7", "8", "10", "11", "12"] - self.assertEqual(slice_result.to_pylist(), expected_slice_values) - - # Test training sequences - expected_training_seqs = ["1;2;3", "5;6;7", "9;10;11"] - self.assertEqual(training_seq.to_pylist(), expected_training_seqs) - - def test_process_hstu_neg_sample(self): - """Test processing negative samples for HSTU match model.""" - input_data = {"sequence": pa.array(["1", "2", "3"])} - neg_samples = pa.array(["101", "102", "103", "104", "105", "106"]) - - result = process_hstu_neg_sample( - input_data=input_data, - v=neg_samples, + result = combine_neg_as_candidate_sequence( + pos_data=pos_data, + neg_data=neg_data, neg_sample_num=2, - seq_str_delim=";", - seq_attr="sequence", + seq_delim=";", ) + expected = ["1;101;102", "2;103;104", "3;105;106"] + self.assertEqual(result.to_pylist(), expected) - expected_results = [ - "1;101;102", - "2;103;104", - "3;105;106", - ] - self.assertEqual(result.to_pylist(), expected_results) - - def test_process_hstu_neg_sample_with_different_delim(self): - """Test negative sampling with different delimiter.""" - input_data = {"sequence": pa.array(["1", "2", "3"])} - - neg_samples = pa.array(["101", "102", "103", "104", "105", "106"]) + def test_combine_neg_as_candidate_sequence_different_delim(self): + """Test candidate sequence combination with different delimiter.""" + pos_data = pa.array(["1", "2"]) + neg_data = pa.array(["10", "20", "30", "40"]) - result = process_hstu_neg_sample( - input_data=input_data, - v=neg_samples, + result = combine_neg_as_candidate_sequence( + pos_data=pos_data, + neg_data=neg_data, neg_sample_num=2, - seq_str_delim="|", - seq_attr="sequence", + seq_delim="|", ) - - expected_results = [ - "1|101|102", - "2|103|104", - "3|105|106", - ] - self.assertEqual(result.to_pylist(), expected_results) + expected = ["1|10|20", "2|30|40"] + self.assertEqual(result.to_pylist(), expected) + + def test_combine_neg_as_candidate_sequence_multivalue_pos(self): + """Each item in multi-value pos sequence gets its own negatives.""" + pos_data = pa.array(["1;2", "3;4;5"]) # 5 positives total + neg_data = pa.array(["10", "20", "30", "40", "50"]) # 1 neg per pos + + result = combine_neg_as_candidate_sequence( + pos_data=pos_data, + neg_data=neg_data, + neg_sample_num=1, + seq_delim=";", + ) + # row1: pos=1 → "1;10", pos=2 → "2;20" → "1;10;2;20" + # row2: pos=3 → "3;30", pos=4 → "4;40", pos=5 → "5;50" + # → "3;30;4;40;5;50" + expected = ["1;10;2;20", "3;30;4;40;5;50"] + self.assertEqual(result.to_pylist(), expected) + + def test_combine_neg_as_candidate_sequence_multivalue_pos_multi_neg(self): + """Multi-value pos with multiple negatives per positive.""" + pos_data = pa.array(["1;2", "3"]) # 3 positives total + # neg_sample_num=2 → need 6 negatives + neg_data = pa.array(["10", "20", "30", "40", "50", "60"]) + + result = combine_neg_as_candidate_sequence( + pos_data=pos_data, + neg_data=neg_data, + neg_sample_num=2, + seq_delim=";", + ) + # row1: pos=1 → "1;10;20", pos=2 → "2;30;40" → "1;10;20;2;30;40" + # row2: pos=3 → "3;50;60" + expected = ["1;10;20;2;30;40", "3;50;60"] + self.assertEqual(result.to_pylist(), expected) def test_normalize_type_str_basic_types(self): """Test normalizing basic types.""" diff --git a/tzrec/models/hstu.py b/tzrec/models/hstu.py index 1ef8877a..1c4933ab 100644 --- a/tzrec/models/hstu.py +++ b/tzrec/models/hstu.py @@ -13,38 +13,74 @@ import torch import torch.nn.functional as F -from torch._tensor import Tensor -from tzrec.datasets.utils import NEG_DATA_GROUP, Batch +from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature from tzrec.models.match_model import MatchModel, MatchTowerWoEG from tzrec.modules.embedding import EmbeddingGroup -from tzrec.modules.sequence import HSTUEncoder +from tzrec.modules.gr.positional_encoder import HSTUPositionalEncoder +from tzrec.modules.gr.postprocessors import ( + OutputPostprocessor, + create_output_postprocessor, +) +from tzrec.modules.gr.preprocessors import ( + InputPreprocessor, + create_input_preprocessor, +) +from tzrec.modules.gr.stu import STU, STULayer, STUStack +from tzrec.modules.norm import LayerNorm +from tzrec.modules.utils import init_linear_xavier_weights_zero_bias +from tzrec.ops.utils import set_static_max_seq_lens from tzrec.protos import model_pb2, simi_pb2, tower_pb2 -from tzrec.utils import config_util +from tzrec.protos.model_pb2 import ModelConfig +from tzrec.protos.models import match_model_pb2 +from tzrec.utils.config_util import config_to_kwargs +from tzrec.utils.fx_util import fx_int_item + +torch.fx.wrap(fx_int_item) @torch.fx.wrap -def _update_dict_tensor( - tensor_dict: Dict[str, torch.Tensor], - new_tensor_dict: Optional[Dict[str, Optional[torch.Tensor]]], -) -> None: - if new_tensor_dict: - for k, v in new_tensor_dict.items(): - if v is not None: - tensor_dict[k] = v +def _jagged_candidate_sim( + user_emb: torch.Tensor, item_emb: torch.Tensor +) -> torch.Tensor: + """Compute per-user similarity for JAGGED_SEQUENCE candidates. + + Each user has the same number of candidates (1 pos + num_neg). The item + embeddings are organized as: [pos_1, neg_1_1, ..., neg_1_k, pos_2, ...]. + + Args: + user_emb: (B, D) user embeddings. + item_emb: (B * (1 + num_neg), D) candidate embeddings. + + Returns: + similarity (B, 1 + num_neg), first column is positive. + """ + batch_size = user_emb.size(0) + num_cand = item_emb.size(0) // batch_size + item_emb = item_emb.view(batch_size, num_cand, -1) + return torch.bmm(item_emb, user_emb.unsqueeze(-1)).squeeze(-1) class HSTUMatchUserTower(MatchTowerWoEG): - """HSTU Match model user tower. + """HSTU Match model user tower using modern STU module. + + Processes UIH (User Interaction History) sequences through UIHPreprocessor, + HSTUPositionalEncoder, and STUStack to produce user embeddings. During training, + returns one embedding per UIH position (autoregressive). During inference, returns + the last position embedding per user for ANN retrieval. Args: - tower_config (Tower): user tower config. - output_dim (int): user output embedding dimension. - similarity (Similarity): when use COSINE similarity, - will norm the output embedding. - feature_group (FeatureGroupConfig): feature group config. + tower_config (HSTUMatchTower): user tower config with HSTU settings. + output_dim (int): user output embedding dimension (stu.embedding_dim). + similarity (Similarity): similarity method config. + feature_group (FeatureGroupConfig): uih feature group config. + feature_group_dims (list): per-feature embedding dims in the group. features (list): list of features. + model_config (ModelConfig): full model config. + contextual_feature_dim (int): dimension of each contextual feature. + max_contextual_seq_len (int): number of contextual features. + contextual_group_name (str): contextual group name in grouped features. """ def __init__( @@ -55,177 +91,269 @@ def __init__( feature_group: model_pb2.FeatureGroupConfig, feature_group_dims: List[int], features: List[BaseFeature], - model_config: model_pb2.ModelConfig, + model_config: ModelConfig, + contextual_feature_dim: int = 0, + max_contextual_seq_len: int = 0, + contextual_group_name: str = "contextual", ) -> None: super().__init__(tower_config, output_dim, similarity, feature_group, features) - self.tower_config = tower_config - encoder_config = tower_config.hstu_encoder - seq_config_dict = config_util.config_to_kwargs(encoder_config) - sequence_dim = sum(feature_group_dims) - seq_config_dict["sequence_dim"] = sequence_dim - self.seq_encoder = HSTUEncoder(**seq_config_dict) + self._pass_grouped_features = True + hstu_cfg = tower_config.hstu + uih_dim = sum(feature_group_dims) + stu_dim = hstu_cfg.stu.embedding_dim + + # Preprocessor: projects UIH, handles optional contextual/actions + self._input_preprocessor: InputPreprocessor = create_input_preprocessor( + hstu_cfg.input_preprocessor, + uih_embedding_dim=uih_dim, + output_embedding_dim=stu_dim, + contextual_feature_dim=contextual_feature_dim, + max_contextual_seq_len=max_contextual_seq_len, + contextual_group_name=contextual_group_name, + ) + + # Positional encoder + pos_kwargs = config_to_kwargs(hstu_cfg.positional_encoder) + self._positional_encoder: HSTUPositionalEncoder = HSTUPositionalEncoder( + embedding_dim=stu_dim, + contextual_seq_len=self._input_preprocessor.contextual_seq_len(), + **pos_kwargs, + ) + + # STU stack + stu_kwargs = config_to_kwargs(hstu_cfg.stu) + self._stu_module: STU = STUStack( + stu_list=[STULayer(**stu_kwargs) for _ in range(hstu_cfg.attn_num_layers)], + ) + + # Output postprocessor (L2 norm or layer norm) + self._output_postprocessor: OutputPostprocessor = create_output_postprocessor( + hstu_cfg.output_postprocessor, + embedding_dim=stu_dim, + ) + + self._input_dropout_ratio: float = hstu_cfg.input_dropout_ratio def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor: - """Forward the tower. + """Forward the user tower. Args: - grouped_features: Dictionary containing grouped feature tensors + grouped_features: dictionary of embedded features from EmbeddingGroup. Returns: - torch.Tensor: The output tensor from the tower + user embeddings of shape (B, D), last position embedding per user. """ - output = self.seq_encoder(grouped_features) + # 1. Preprocess: project UIH + optional contextual/actions + ( + max_seq_len, + total_uih_len, + _, + seq_lengths, + seq_offsets, + seq_timestamps, + seq_embeddings, + num_targets, + ) = self._input_preprocessor(grouped_features) + + # 2. Positional encoding + seq_embeddings = self._positional_encoder( + max_seq_len=max_seq_len, + seq_lengths=seq_lengths, + seq_offsets=seq_offsets, + seq_timestamps=seq_timestamps, + seq_embeddings=seq_embeddings, + num_targets=num_targets, + ) - return output + # 3. Input dropout + STU + seq_embeddings = F.dropout( + seq_embeddings, p=self._input_dropout_ratio, training=self.training + ) + seq_embeddings = self._stu_module( + x=seq_embeddings, + x_offsets=seq_offsets, + max_seq_len=max_seq_len, + num_targets=num_targets, + ) + + # 4. Output postprocessor + user_emb = self._output_postprocessor( + seq_embeddings=seq_embeddings, + seq_timestamps=seq_timestamps, + ) + + # Extract last position embedding per user → (B, D) + # Assumes all sequences are non-empty (guaranteed by EmbeddingGroup). + user_emb = user_emb[seq_offsets[1:] - 1] + + return user_emb class HSTUMatchItemTower(MatchTowerWoEG): """HSTU Match model item tower. + Projects candidate embeddings to STU embedding dimension. Applies L2 + normalization only when similarity method is COSINE. + Args: - tower_config (Tower): item tower config. - output_dim (int): item output embedding dimension. - similarity (Similarity): when use COSINE similarity, - will norm the output embedding. - feature_group (FeatureGroupConfig): feature group config. + tower_config (HSTUMatchTower): tower config. + output_dim (int): item output embedding dimension (stu.embedding_dim). + similarity (Similarity): similarity method config. + feature_group (FeatureGroupConfig): candidate feature group config. + feature_group_dims (list): per-feature embedding dims in the group. features (list): list of features. """ def __init__( self, - tower_config: tower_pb2.Tower, + tower_config: tower_pb2.HSTUMatchTower, output_dim: int, similarity: simi_pb2.Similarity, feature_group: model_pb2.FeatureGroupConfig, + feature_group_dims: List[int], features: List[BaseFeature], ) -> None: super().__init__(tower_config, output_dim, similarity, feature_group, features) - self.tower_config = tower_config + # Override _group_name: parent sets it from tower_config.input ("uih"), + # but item tower needs to read from the candidate feature group. + self._group_name = feature_group.group_name + self._pass_grouped_features = True + cand_dim = sum(feature_group_dims) + self._item_projection: torch.nn.Module = torch.nn.Sequential( + torch.nn.Linear(cand_dim, output_dim), + LayerNorm(output_dim), + ).apply(init_linear_xavier_weights_zero_bias) def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor: - """Forward the tower. + """Forward the item tower. Args: - grouped_features: Dictionary containing grouped feature tensors + grouped_features: dictionary of embedded features from EmbeddingGroup. Returns: - torch.Tensor: The output tensor from the tower + item embeddings of shape (sum_candidates, D). """ - output = grouped_features[f"{self._group_name}.sequence"] - output = F.normalize(output, p=2.0, dim=1, eps=1e-6) - - return output + cand_emb = grouped_features[f"{self._group_name}.sequence"] + item_emb = self._item_projection(cand_emb) + if self._similarity == simi_pb2.Similarity.COSINE: + item_emb = F.normalize(item_emb, p=2.0, dim=-1, eps=1e-6) + return item_emb class HSTUMatch(MatchModel): - """HSTU Match model. + """HSTU Match model for two-tower retrieval. + + Uses modern STUStack for user sequence encoding with native jagged sequences. + User tower processes UIH through UIHPreprocessor + STU. Item tower + projects and normalizes candidate embeddings. Similarity via dot product. + + Feature groups: + - "uih" (JAGGED_SEQUENCE): user interaction history + - "candidate" (JAGGED_SEQUENCE): candidate items (pos + neg) + - "contextual" (optional, DEEP/SEQUENCE): user contextual features + + Training: autoregressive — each UIH position produces a user embedding, + compared against candidates via dot product similarity. + Inference: last UIH position → user embedding for ANN retrieval. Args: model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( self, - model_config: model_pb2.ModelConfig, + model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: Optional[List[str]] = None, **kwargs: Any, ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) - assert len(model_config.feature_groups) == 1 + assert model_config.WhichOneof("model") == "hstu_match", ( + "invalid model config: %s" % model_config.WhichOneof("model") + ) + assert isinstance(self._model_config, match_model_pb2.HSTUMatch) + + tower_cfg = self._model_config.hstu_tower + set_static_max_seq_lens([tower_cfg.max_seq_len]) + self.embedding_group = EmbeddingGroup( features, list(model_config.feature_groups) ) - name_to_feature_group = {x.group_name: x for x in model_config.feature_groups} - feature_group = name_to_feature_group[self._model_config.hstu_tower.input] - - used_features = self.get_features_in_feature_groups([feature_group]) + stu_dim = tower_cfg.hstu.stu.embedding_dim + + # Resolve feature groups + name_to_fg = {x.group_name: x for x in model_config.feature_groups} + uih_fg = name_to_fg[tower_cfg.input] + cand_fg = name_to_fg.get("candidate") + assert cand_fg is not None, "HSTUMatch requires a 'candidate' feature group." + + uih_features = self.get_features_in_feature_groups([uih_fg]) + cand_features = self.get_features_in_feature_groups([cand_fg]) + + uih_dims = self.embedding_group.group_dims(tower_cfg.input + ".sequence") + cand_dims = self.embedding_group.group_dims("candidate.sequence") + + # Optional contextual features + contextual_feature_dim = 0 + max_contextual_seq_len = 0 + contextual_group_name = "contextual" + if "contextual" in name_to_fg: + ctx_group_type = self.embedding_group.group_type("contextual") + if ctx_group_type == model_pb2.SEQUENCE: + contextual_group_name = "contextual.query" + elif ctx_group_type == model_pb2.DEEP: + contextual_group_name = "contextual" + ctx_dims = self.embedding_group.group_dims(contextual_group_name) + if len(set(ctx_dims)) > 1: + raise ValueError( + "output_dim of features in contextual features_group " + f"must be same, but now {set(ctx_dims)}." + ) + contextual_feature_dim = ctx_dims[0] + max_contextual_seq_len = len(ctx_dims) self.user_tower = HSTUMatchUserTower( - self._model_config.hstu_tower, - self._model_config.output_dim, - self._model_config.similarity, - feature_group, - self.embedding_group.group_dims( - self._model_config.hstu_tower.input + ".sequence" - ), - used_features, - model_config, + tower_config=tower_cfg, + output_dim=stu_dim, + similarity=self._model_config.similarity, + feature_group=uih_fg, + feature_group_dims=uih_dims, + features=uih_features, + model_config=model_config, + contextual_feature_dim=contextual_feature_dim, + max_contextual_seq_len=max_contextual_seq_len, + contextual_group_name=contextual_group_name, ) self.item_tower = HSTUMatchItemTower( - self._model_config.hstu_tower, - self._model_config.output_dim, - self._model_config.similarity, - feature_group, - used_features, + tower_config=tower_cfg, + output_dim=stu_dim, + similarity=self._model_config.similarity, + feature_group=cand_fg, + feature_group_dims=cand_dims, + features=cand_features, ) - self.seq_tower_input = self._model_config.hstu_tower.input + self._temperature = self._model_config.temperature - def predict(self, batch: Batch) -> Dict[str, Tensor]: + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Forward the model. Args: batch (Batch): input batch data. Return: - predictions (dict): a dict of predicted result. + predictions (dict): a dict of predicted result with 'similarity' key. """ - batch_sparse_features = batch.sparse_features[NEG_DATA_GROUP] - # Get batch_size and neg_sample_size from batch_sparse_features - batch_size = batch.labels[self._label_name].shape[0] - neg_sample_size = batch_sparse_features.lengths()[batch_size] - 1 grouped_features = self.embedding_group(batch) - item_group_features = { - self.seq_tower_input + ".sequence": grouped_features[ - self.seq_tower_input + ".sequence" - ][batch_size:, : neg_sample_size + 1], - } - item_tower_emb = self.item_tower(item_group_features) - user_group_features = { - self.seq_tower_input + ".sequence": grouped_features[ - self.seq_tower_input + ".sequence" - ][:batch_size], - self.seq_tower_input + ".sequence_length": grouped_features[ - self.seq_tower_input + ".sequence_length" - ][:batch_size], - } - user_tower_emb = self.user_tower(user_group_features) - ui_sim = ( - self.simi(user_tower_emb, item_tower_emb, neg_for_each_sample=True) - / self._model_config.temperature - ) - return {"similarity": ui_sim} + user_emb = self.user_tower(grouped_features) + item_emb = self.item_tower(grouped_features) - def simi( - self, - user_emb: torch.Tensor, - item_emb: torch.Tensor, - neg_for_each_sample: bool = False, - ) -> torch.Tensor: - """Override the sim method in MatchModel to calculate similarity.""" - if self._in_batch_negative: - return torch.mm(user_emb, item_emb.T) - else: - batch_size = user_emb.size(0) - pos_item_emb = item_emb[:, 0] - neg_item_emb = item_emb[:, 1:].reshape(-1, item_emb.shape[-1]) - pos_ui_sim = torch.sum( - torch.multiply(user_emb, pos_item_emb), dim=-1, keepdim=True - ) - neg_ui_sim = None - if not neg_for_each_sample: - neg_ui_sim = torch.matmul(user_emb, neg_item_emb.transpose(0, 1)) - else: - num_neg_per_user = neg_item_emb.size(0) // batch_size - neg_size = batch_size * num_neg_per_user - neg_item_emb = neg_item_emb[:neg_size] - neg_item_emb = neg_item_emb.view(batch_size, num_neg_per_user, -1) - neg_ui_sim = torch.sum(user_emb.unsqueeze(1) * neg_item_emb, dim=-1) - return torch.cat([pos_ui_sim, neg_ui_sim], dim=-1) + ui_sim = _jagged_candidate_sim(user_emb, item_emb) / self._temperature + return {"similarity": ui_sim} diff --git a/tzrec/models/hstu_test.py b/tzrec/models/hstu_test.py index d74a1865..ef6de5a5 100644 --- a/tzrec/models/hstu_test.py +++ b/tzrec/models/hstu_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Alibaba Group; +# Copyright (c) 2024-2025, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,107 +12,197 @@ import unittest import torch -from parameterized import parameterized from torchrec import KeyedJaggedTensor -from tzrec.datasets.utils import BASE_DATA_GROUP, NEG_DATA_GROUP, Batch +from tzrec.datasets.utils import BASE_DATA_GROUP, Batch from tzrec.features.feature import create_features from tzrec.models.hstu import HSTUMatch +from tzrec.models.model import TrainWrapper +from tzrec.modules.utils import Kernel from tzrec.protos import ( feature_pb2, loss_pb2, + metric_pb2, model_pb2, - seq_encoder_pb2, + module_pb2, tower_pb2, ) from tzrec.protos.models import match_model_pb2 from tzrec.utils.state_dict_util import init_parameters -from tzrec.utils.test_util import TestGraphType, create_test_model - - -class HSTUTest(unittest.TestCase): - @parameterized.expand([[TestGraphType.NORMAL], [TestGraphType.FX_TRACE]]) - def test_hstu(self, graph_type) -> None: - feature_cfgs = [ - feature_pb2.FeatureConfig( - sequence_id_feature=feature_pb2.IdFeature( - feature_name="historical_ids", - sequence_length=210, - embedding_dim=48, - num_buckets=3953, - ) - ), - feature_pb2.FeatureConfig( - id_feature=feature_pb2.IdFeature( - feature_name="item_id", - embedding_dim=48, - num_buckets=1000, - embedding_name="item_id", - ) - ), - ] - features = create_features(feature_cfgs) - feature_groups = [ - model_pb2.FeatureGroupConfig( - group_name="sequence", - feature_names=["historical_ids"], - group_type=model_pb2.FeatureGroupType.SEQUENCE, - ), - ] - model_config = model_pb2.ModelConfig( - feature_groups=feature_groups, - hstu_match=match_model_pb2.HSTUMatch( - hstu_tower=tower_pb2.HSTUMatchTower( - input="sequence", - hstu_encoder=seq_encoder_pb2.HSTUEncoder( - sequence_dim=48, - attn_dim=48, - linear_dim=48, - input="sequence", - max_seq_length=210, - num_blocks=2, +from tzrec.utils.test_util import TestGraphType, create_test_model, gpu_unavailable + + +def _build_model_config(): + """Build HSTUMatch model config for tests.""" + feature_groups = [ + model_pb2.FeatureGroupConfig( + group_name="uih", + feature_names=["historical_ids"], + group_type=model_pb2.FeatureGroupType.JAGGED_SEQUENCE, + ), + model_pb2.FeatureGroupConfig( + group_name="candidate", + feature_names=["item_id"], + group_type=model_pb2.FeatureGroupType.JAGGED_SEQUENCE, + ), + ] + return model_pb2.ModelConfig( + feature_groups=feature_groups, + hstu_match=match_model_pb2.HSTUMatch( + hstu_tower=tower_pb2.HSTUMatchTower( + input="uih", + hstu=module_pb2.HSTU( + stu=module_pb2.STU( + embedding_dim=48, num_heads=1, - linear_activation="silu", - linear_config="uvqk", - max_output_len=0, + hidden_dim=48, + attention_dim=48, + output_dropout_ratio=0.2, + ), + attn_num_layers=2, + positional_encoder=module_pb2.GRPositionalEncoder( + num_position_buckets=512, + ), + input_preprocessor=module_pb2.GRInputPreprocessor( + uih_preprocessor=module_pb2.GRUIHPreprocessor(), + ), + output_postprocessor=module_pb2.GROutputPostprocessor( + l2norm_postprocessor=module_pb2.GRL2NormPostprocessor(), ), ), - temperature=0.05, + max_seq_len=210, ), - losses=[ - loss_pb2.LossConfig( - softmax_cross_entropy=loss_pb2.SoftmaxCrossEntropy() - ) - ], - ) - hstu = HSTUMatch( - model_config=model_config, - features=features, - labels=["label"], - sampler_type="negative_sampler", - ) - init_parameters(hstu, device=torch.device("cpu")) - hstu = create_test_model(hstu, graph_type) - - # Create test sequence data - sparse_feature = KeyedJaggedTensor.from_lengths_sync( - keys=["historical_ids"], - values=torch.tensor([1, 2, 3, 4, 5, 2, 7, 8, 9, 4, 11, 5, 13, 14, 15]), - # sequence length is - # 2, 3, 2 (neg_seq, first is pos), 2 (neg_seq, first is pos)... - lengths=torch.tensor([2, 3, 2, 2, 2, 2, 2]), - ) - - batch = Batch( - sparse_features={ - NEG_DATA_GROUP: sparse_feature, - BASE_DATA_GROUP: sparse_feature, - }, - labels={"label": torch.tensor([1, 1])}, - ) + temperature=0.05, + ), + losses=[ + loss_pb2.LossConfig(softmax_cross_entropy=loss_pb2.SoftmaxCrossEntropy()) + ], + metrics=[metric_pb2.MetricConfig(recall_at_k=metric_pb2.RecallAtK(top_k=1))], + ) + + +def _build_features(): + """Build features for HSTUMatch tests.""" + feature_cfgs = [ + feature_pb2.FeatureConfig( + sequence_id_feature=feature_pb2.IdFeature( + feature_name="historical_ids", + sequence_length=210, + embedding_dim=48, + num_buckets=3953, + ) + ), + feature_pb2.FeatureConfig( + sequence_id_feature=feature_pb2.IdFeature( + feature_name="item_id", + sequence_length=10, + sequence_delim=";", + embedding_dim=48, + num_buckets=1000, + ) + ), + ] + return create_features(feature_cfgs) + + +def _build_model(device): + """Build HSTUMatch model on device.""" + model_config = _build_model_config() + features = _build_features() + hstu = HSTUMatch( + model_config=model_config, + features=features, + labels=["label"], + sampler_type="negative_sampler", + ) + init_parameters(hstu, device=device) + hstu.to(device) + hstu.set_kernel(Kernel.PYTORCH) + return hstu + + +def _build_batch(device): + """Build test batch with 2 users. + UIH: user1 has 3 items, user2 has 4 items. + Candidates: each user has a sequence of [pos, neg] (after combine_neg). + """ + # BASE: UIH sequences + per-user candidate sequences (pos+neg) + sparse_feature = KeyedJaggedTensor.from_lengths_sync( + keys=["historical_ids", "item_id"], + values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 100, 200, 101, 201]), + lengths=torch.tensor([3, 4, 2, 2]), + ) + return Batch( + sparse_features={BASE_DATA_GROUP: sparse_feature}, + labels={"label": torch.tensor([1, 1])}, + ).to(device) + + +class HSTUMatchTest(unittest.TestCase): + """Tests for HSTUMatch model with STU and jagged sequences.""" + + @unittest.skipIf(*gpu_unavailable) + def test_hstu_match_train(self) -> None: + """Test HSTUMatch training: forward + loss + backward.""" + device = torch.device("cuda") + hstu = _build_model(device) + batch = _build_batch(device) + + train_model = TrainWrapper(hstu, device=device).to(device) + total_loss, (losses, predictions, batch) = train_model(batch) + + self.assertIn("similarity", predictions) + self.assertIn("softmax_cross_entropy", losses) + self.assertTrue(total_loss.requires_grad) + self.assertFalse(torch.isnan(total_loss)) + + @unittest.skipIf(*gpu_unavailable) + def test_hstu_match_eval(self) -> None: + """Test HSTUMatch evaluation: forward + metrics.""" + device = torch.device("cuda") + hstu = _build_model(device) + batch = _build_batch(device) + + train_model = TrainWrapper(hstu, device=device).to(device) + _, (_, predictions, batch) = train_model(batch) + + hstu.update_metric(predictions, batch) + metric_result = hstu.compute_metric() + self.assertIn("recall@1", metric_result) + + @unittest.skipIf(*gpu_unavailable) + def test_hstu_match_export(self) -> None: + """Test HSTUMatch export: FX trace for serving.""" + device = torch.device("cuda") + hstu = _build_model(device) + batch = _build_batch(device) + + hstu.eval() + hstu = create_test_model(hstu, TestGraphType.FX_TRACE) predictions = hstu(batch) - self.assertEqual(predictions["similarity"].size(), (5, 2)) + + self.assertIn("similarity", predictions) + sim = predictions["similarity"] + self.assertEqual(sim.dim(), 2) + self.assertEqual(sim.size(0), 2) + + @unittest.skipIf(*gpu_unavailable) + def test_hstu_match_predict(self) -> None: + """Test HSTUMatch predict: inference mode forward pass.""" + device = torch.device("cuda") + hstu = _build_model(device) + batch = _build_batch(device) + + hstu.eval() + with torch.no_grad(): + predictions = hstu.predict(batch) + + self.assertIn("similarity", predictions) + sim = predictions["similarity"] + self.assertEqual(sim.dim(), 2) + self.assertEqual(sim.size(0), 2) + self.assertFalse(torch.isnan(sim).any()) if __name__ == "__main__": diff --git a/tzrec/models/match_model.py b/tzrec/models/match_model.py index 22b86dc9..73130fef 100644 --- a/tzrec/models/match_model.py +++ b/tzrec/models/match_model.py @@ -220,6 +220,7 @@ def __init__( self._similarity = similarity self._feature_group = feature_group self._features = features + self._pass_grouped_features = False class MatchModel(BaseModel): @@ -492,8 +493,9 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: embedding (dict): tower output embedding. """ grouped_features = self.embedding_group(batch) - return { - f"{self._tower_name}_emb": getattr(self, self._tower_name)( - grouped_features[self._group_name] - ) - } + tower = getattr(self, self._tower_name) + if tower._pass_grouped_features: + tower_input = grouped_features + else: + tower_input = grouped_features[self._group_name] + return {f"{self._tower_name}_emb": tower(tower_input)} diff --git a/tzrec/modules/gr/postprocessors.py b/tzrec/modules/gr/postprocessors.py index 9c7891bf..5d6c6bca 100644 --- a/tzrec/modules/gr/postprocessors.py +++ b/tzrec/modules/gr/postprocessors.py @@ -229,7 +229,7 @@ def create_output_postprocessor( ) -> OutputPostprocessor: """Create OutputPostprocessor.""" if isinstance(postprocessor_cfg, module_pb2.GROutputPostprocessor): - postprocessor_type = postprocessor_cfg.WhichOneof("input_preprocessor") + postprocessor_type = postprocessor_cfg.WhichOneof("output_postprocessor") config_dict = config_to_kwargs(getattr(postprocessor_cfg, postprocessor_type)) else: assert len(postprocessor_cfg) == 1, ( diff --git a/tzrec/modules/gr/preprocessors.py b/tzrec/modules/gr/preprocessors.py index d00af706..caffdd58 100644 --- a/tzrec/modules/gr/preprocessors.py +++ b/tzrec/modules/gr/preprocessors.py @@ -53,15 +53,7 @@ class InputPreprocessor(BaseModule): @abc.abstractmethod def forward( self, - max_uih_len: int, - max_targets: int, - total_uih_len: int, - total_targets: int, - seq_lengths: torch.Tensor, - seq_timestamps: torch.Tensor, - seq_embeddings: torch.Tensor, - num_targets: torch.Tensor, - seq_payloads: Dict[str, torch.Tensor], + grouped_features: Dict[str, torch.Tensor], ) -> Tuple[ int, int, @@ -75,22 +67,14 @@ def forward( """Forward the module. Args: - max_uih_len (int): maximum user history sequence length. - max_targets (int): maximum candidates length. - total_uih_len (int): total user history sequence length. - total_targets (int): total candidates length. - seq_lengths (torch.Tensor): input sequence lengths. - seq_timestamps (torch.Tensor): input sequence timestamp tensor. - seq_embeddings (torch.Tensor): input sequence embedding tensor. - num_targets (torch.Tensor): number of targets. - seq_payloads (Dict[str, torch.Tensor]): sequence payload features. + grouped_features (Dict[str, torch.Tensor]): embedding group features. Returns: output_max_seq_len (int): output maximum sequence length. output_total_uih_len (int): output total user history sequence length. output_total_targets (int): output total candidates length. output_seq_lengths (torch.Tensor): output sequence lengths. - output_seq_offsets (torch.Tensor): output sequence lengths. + output_seq_offsets (torch.Tensor): output sequence offsets. output_seq_timestamps (torch.Tensor): output sequence timestamp tensor. output_seq_embeddings (torch.Tensor): output sequence embedding tensor. output_num_targets (torch.Tensor): output number of targets. @@ -486,6 +470,229 @@ def contextual_seq_len(self) -> int: return self._max_contextual_seq_len +class UIHPreprocessor(InputPreprocessor): + """Preprocessor for sequence-only models without candidate concatenation. + + Processes UIH (User Interaction History) sequences with optional contextual + features, timestamps, and actions. Projects UIH embeddings to the STU + embedding dimension. Suitable for two-tower match/retrieval models where + user and item towers are independent. + + Args: + uih_embedding_dim (int): dimension of UIH input embeddings. + output_embedding_dim (int): dimension of output embeddings (STU dim). + contextual_feature_dim (int): dimension of each contextual feature. + Inferred from embedding_group at model init time. + max_contextual_seq_len (int): number of contextual features. + Inferred from embedding_group at model init time. + contextual_group_name (str): name of contextual group in grouped features. + action_encoder (Dict[str, Any]): optional ActionEncoder config. + action_mlp (Dict[str, Any]): optional action MLP config. + is_inference (bool): whether to run in inference mode. + """ + + def __init__( + self, + uih_embedding_dim: int, + output_embedding_dim: int, + contextual_feature_dim: int = 0, + max_contextual_seq_len: int = 0, + contextual_group_name: str = "contextual", + action_encoder: Optional[Dict[str, Any]] = None, + action_mlp: Optional[Dict[str, Any]] = None, + is_inference: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(is_inference=is_inference) + self._uih_embedding_dim: int = uih_embedding_dim + self._output_embedding_dim: int = output_embedding_dim + self._contextual_feature_dim: int = contextual_feature_dim + self._max_contextual_seq_len: int = max_contextual_seq_len + self._contextual_group_name: str = contextual_group_name + + # Project UIH to STU embedding dim + self._input_projection: torch.nn.Module = torch.nn.Linear( + uih_embedding_dim, output_embedding_dim + ) + + # Optional contextual feature projection + if max_contextual_seq_len > 0: + std = 1.0 * sqrt(2.0 / float(contextual_feature_dim + output_embedding_dim)) + self._batched_contextual_linear_weights = torch.nn.Parameter( + torch.empty( + ( + max_contextual_seq_len, + contextual_feature_dim, + output_embedding_dim, + ) + ).normal_(0.0, std) + ) + self._batched_contextual_linear_bias = torch.nn.Parameter( + torch.empty((max_contextual_seq_len, 1, output_embedding_dim)).fill_( + 0.0 + ) + ) + + # Optional action encoder + self._action_encoder_cfg = action_encoder + if action_encoder is not None: + contextual_embedding_dim: int = ( + max_contextual_seq_len * contextual_feature_dim + ) + self._action_encoder: ActionEncoder = create_action_encoder( + action_encoder, + is_inference=is_inference, + ) + assert action_mlp is not None, ( + "action_mlp must be set when action_encoder is set." + ) + self._action_embedding_mlp: ContextualizedMLP = create_contextualized_mlp( + action_mlp, + contextual_embedding_dim=contextual_embedding_dim, + sequential_input_dim=self._action_encoder.output_dim, + sequential_output_dim=output_embedding_dim, + is_inference=is_inference, + ) + + def forward( + self, grouped_features: Dict[str, torch.Tensor] + ) -> Tuple[ + int, + int, + int, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + """Forward the module. + + Args: + grouped_features (Dict[str, torch.Tensor]): embedding group features. + + Returns: + output_max_seq_len (int): output maximum sequence length. + output_total_uih_len (int): output total user history sequence length. + output_total_targets (int): always 0 (no candidates). + output_seq_lengths (torch.Tensor): output sequence lengths. + output_seq_offsets (torch.Tensor): output sequence offsets. + output_seq_timestamps (torch.Tensor): output sequence timestamps. + output_seq_embeddings (torch.Tensor): output sequence embeddings. + output_num_targets (torch.Tensor): always zeros (no candidates). + """ + uih_embeddings = grouped_features["uih.sequence"] + uih_seq_lengths = grouped_features["uih.sequence_length"] + max_uih_len = fx_int_item(uih_seq_lengths.max()) + total_uih_len = fx_int_item(uih_seq_lengths.sum()) + uih_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(uih_seq_lengths) + + # Project UIH to output dim + seq_embeddings = self._input_projection(uih_embeddings) + + # Optional: contextual features + contextual_input_embeddings: Optional[torch.Tensor] = None + contextual_embeddings: Optional[torch.Tensor] = None + if self._max_contextual_seq_len > 0: + contextual_input_embeddings = grouped_features[self._contextual_group_name] + contextual_embeddings = torch.baddbmm( + self._batched_contextual_linear_bias.to( + contextual_input_embeddings.dtype + ), + contextual_input_embeddings.view( + -1, + self._max_contextual_seq_len, + self._contextual_feature_dim, + ).transpose(0, 1), + self._batched_contextual_linear_weights.to( + contextual_input_embeddings.dtype + ), + ).transpose(0, 1) + + # Optional: action embeddings + if self._action_encoder_cfg is not None: + # target_offsets is unused when total_targets=0 (no candidates in + # UIH-only mode), so we pass uih_offsets as a placeholder. + action_embeddings = self._action_encoder( + seq_actions=grouped_features["uih_action.sequence"].to(torch.int64), + max_uih_len=max_uih_len, + max_targets=0, + uih_offsets=uih_offsets, + target_offsets=uih_offsets, + total_uih_len=total_uih_len, + total_targets=0, + ).to(uih_embeddings.dtype) + action_embeddings = self._action_embedding_mlp( + seq_embeddings=action_embeddings, + seq_offsets=uih_offsets, + max_seq_len=max_uih_len, + contextual_embeddings=contextual_input_embeddings, + ) + seq_embeddings = seq_embeddings + action_embeddings + + # Timestamps: always use zeros (timestamps are optional for match models) + seq_timestamps = torch.zeros( + total_uih_len, dtype=torch.int64, device=uih_embeddings.device + ) + + output_max_seq_len = max_uih_len + output_seq_lengths = uih_seq_lengths + output_seq_offsets = uih_offsets + output_total_uih_len = total_uih_len + + # Concat contextual embeddings if present + if self._max_contextual_seq_len > 0: + seq_embeddings = concat_2D_jagged( + values_left=fx_unwrap_optional_tensor(contextual_embeddings).reshape( + -1, self._output_embedding_dim + ), + values_right=seq_embeddings, + max_len_left=self._max_contextual_seq_len, + max_len_right=max_uih_len, + offsets_left=None, + offsets_right=uih_offsets, + kernel=self.kernel(), + ) + seq_timestamps = concat_2D_jagged( + values_left=_fx_timestamp_contextual_zeros( + seq_timestamps, + uih_seq_lengths, + self._max_contextual_seq_len, + ), + values_right=seq_timestamps.unsqueeze(-1), + max_len_left=self._max_contextual_seq_len, + max_len_right=max_uih_len, + offsets_left=None, + offsets_right=uih_offsets, + kernel=self.kernel(), + ).squeeze(-1) + output_max_seq_len = max_uih_len + self._max_contextual_seq_len + output_total_uih_len = ( + total_uih_len + self._max_contextual_seq_len * uih_seq_lengths.size(0) + ) + output_seq_lengths = uih_seq_lengths + self._max_contextual_seq_len + output_seq_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + output_seq_lengths + ) + + num_targets = torch.zeros_like(output_seq_lengths) + + return ( + output_max_seq_len, + output_total_uih_len, + 0, + output_seq_lengths, + output_seq_offsets, + seq_timestamps, + seq_embeddings, + num_targets, + ) + + def contextual_seq_len(self) -> int: + """Contextual feature sequence length.""" + return self._max_contextual_seq_len + + def create_input_preprocessor( preprocessor_cfg: Union[module_pb2.GRInputPreprocessor, Dict[str, Any]], **kwargs: Any, @@ -508,5 +715,7 @@ def create_input_preprocessor( ) elif preprocessor_type == "contextual_interleave_preprocessor": return ContextualInterleavePreprocessor(**config_dict) + elif preprocessor_type == "uih_preprocessor": + return UIHPreprocessor(**config_dict) else: raise RuntimeError(f"Unknown preprocessor type: {preprocessor_type}") diff --git a/tzrec/modules/sequence.py b/tzrec/modules/sequence.py index 880402fb..fc316521 100644 --- a/tzrec/modules/sequence.py +++ b/tzrec/modules/sequence.py @@ -10,18 +10,13 @@ # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional import numpy as np import torch from torch import nn from torch.nn import functional as F -from tzrec.modules.hstu import ( - HSTUCacheState, - RelativeBucketedTimeAndPositionBasedBias, - SequentialTransductionUnitJagged, -) from tzrec.modules.mlp import MLP from tzrec.protos.seq_encoder_pb2 import SeqEncoderConfig from tzrec.utils import config_util @@ -372,211 +367,6 @@ def forward(self, sequence_embedded: Dict[str, torch.Tensor]) -> torch.Tensor: ) # [B, (L+1)*C] -class HSTUEncoder(SequenceEncoder): - """HSTU sequence encoder. - - Args: - sequence_dim (int): sequence tensor channel dimension. - query_dim (int): query tensor channel dimension. - input(str): input feature group name. - attn_mlp (dict): target attention MLP module parameters. - """ - - def __init__( - self, - sequence_dim: int, - attn_dim: int, - linear_dim: int, - input: str, - max_seq_length: int, - pos_dropout_rate: float = 0.2, - linear_dropout_rate: float = 0.2, - attn_dropout_rate: float = 0.0, - normalization: str = "rel_bias", - linear_activation: str = "silu", - linear_config: str = "uvqk", - num_heads: int = 1, - num_blocks: int = 2, - max_output_len: int = 10, - time_bucket_size: int = 128, - **kwargs: Optional[Dict[str, Any]], - ) -> None: - super().__init__(input) - self._sequence_dim = sequence_dim - self._attn_dim = attn_dim - self._linear_dim = linear_dim - self._max_seq_length = max_seq_length - self._query_name = f"{input}.query" - self._sequence_name = f"{input}.sequence" - self._sequence_length_name = f"{input}.sequence_length" - max_output_len = max_output_len + 1 # for target - self.position_embed = nn.Embedding( - self._max_seq_length + max_output_len, self._sequence_dim - ) - self.dropout_rate = pos_dropout_rate - self.enable_relative_attention_bias = True - self.autocast_dtype = None - self._attention_layers: nn.ModuleList = nn.ModuleList( - modules=[ - SequentialTransductionUnitJagged( - embedding_dim=self._sequence_dim, - linear_hidden_dim=self._linear_dim, - attention_dim=self._attn_dim, - normalization=normalization, - linear_config=linear_config, - linear_activation=linear_activation, - num_heads=num_heads, - relative_attention_bias_module=( - RelativeBucketedTimeAndPositionBasedBias( - max_seq_len=max_seq_length + max_output_len, - num_buckets=time_bucket_size, - bucketization_fn=lambda x: ( - torch.log(torch.abs(x).clamp(min=1)) / 0.301 - ).long(), - ) - if self.enable_relative_attention_bias - else None - ), - dropout_ratio=linear_dropout_rate, - attn_dropout_ratio=attn_dropout_rate, - concat_ua=False, - ) - for _ in range(num_blocks) - ] - ) - self.register_buffer( - "_attn_mask", - torch.triu( - torch.ones( - ( - self._max_seq_length + max_output_len, - self._max_seq_length + max_output_len, - ), - dtype=torch.bool, - ), - diagonal=1, - ), - ) - self._autocast_dtype = None - - def output_dim(self) -> int: - """Output dimension of the module.""" - return self._sequence_dim - - def forward(self, sequence_embedded: Dict[str, torch.Tensor]) -> torch.Tensor: - """Forward the module.""" - sequence = sequence_embedded[self._sequence_name] # B, N, E - sequence_length = sequence_embedded[self._sequence_length_name] # N - # max_seq_length = sequence.size(1) - float_dtype = sequence.dtype - - # Add positional embeddings and apply dropout - positions = ( - fx_arange(sequence.size(1), device=sequence.device) - .unsqueeze(0) - .expand(sequence.size(0), -1) - ) - sequence = sequence * (self._sequence_dim**0.5) + self.position_embed(positions) - sequence = F.dropout(sequence, p=self.dropout_rate, training=self.training) - sequence_mask = fx_arange( - sequence.size(1), device=sequence_length.device - ).unsqueeze(0) < sequence_length.unsqueeze(1) - sequence = sequence * sequence_mask.unsqueeze(-1).to(float_dtype) - - invalid_attn_mask = 1.0 - self._attn_mask.to(float_dtype) - sequence_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( - sequence_length - ) - sequence = torch.ops.fbgemm.dense_to_jagged(sequence, [sequence_offsets])[0] - - all_timestamps = None - jagged_x, cache_states = self.jagged_forward( - x=sequence, - x_offsets=sequence_offsets, - all_timestamps=all_timestamps, - invalid_attn_mask=invalid_attn_mask, - delta_x_offsets=None, - cache=None, - return_cache_states=False, - ) - # post processing: L2 Normalization - output_embeddings = jagged_x - output_embeddings = output_embeddings[..., : self._sequence_dim] - output_embeddings = output_embeddings / torch.clamp( - torch.linalg.norm(output_embeddings, ord=None, dim=-1, keepdim=True), - min=1e-6, - ) - if not self.training: - output_embeddings = self.get_current_embeddings( - sequence_length, output_embeddings - ) - return output_embeddings - - def jagged_forward( - self, - x: torch.Tensor, - x_offsets: torch.Tensor, - all_timestamps: Optional[torch.Tensor], - invalid_attn_mask: torch.Tensor, - delta_x_offsets: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - cache: Optional[List[HSTUCacheState]] = None, - return_cache_states: bool = False, - ) -> Tuple[torch.Tensor, List[HSTUCacheState]]: - r"""Jagged forward. - - Args: - x: (\sum_i N_i, D) x float - x_offsets: (B + 1) x int32 - all_timestamps: (B, 1 + N) x int64 - invalid_attn_mask: (B, N, N) x float, each element in {0, 1} - delta_x_offsets: offsets for x - cache: cache contents - return_cache_states: bool. True if we should return cache states. - - Returns: - x' = f(x), (\sum_i N_i, D) x float - """ - cache_states: List[HSTUCacheState] = [] - - with torch.autocast( - "cuda", - enabled=self._autocast_dtype is not None, - dtype=self._autocast_dtype or torch.float16, - ): - for i, layer in enumerate(self._attention_layers): - x, cache_states_i = layer( - x=x, - x_offsets=x_offsets, - all_timestamps=all_timestamps, - invalid_attn_mask=invalid_attn_mask, - delta_x_offsets=delta_x_offsets, - cache=cache[i] if cache is not None else None, - return_cache_states=return_cache_states, - ) - if return_cache_states: - cache_states.append(cache_states_i) - - return x, cache_states - - def get_current_embeddings( - self, - lengths: torch.Tensor, - encoded_embeddings: torch.Tensor, - ) -> torch.Tensor: - """Get the embeddings of the last past_id as the current embeds. - - Args: - lengths: (B,) x int - encoded_embeddings: (B, N, D,) x float - - Returns: - (B, D,) x float, where [i, :] == encoded_embeddings[i, lengths[i] - 1, :] - """ - offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) - indices = offsets[1:] - 1 - return encoded_embeddings[indices] - - def create_seq_encoder( seq_encoder_config: SeqEncoderConfig, group_total_dim: Dict[str, int] ) -> SequenceEncoder: diff --git a/tzrec/modules/sequence_test.py b/tzrec/modules/sequence_test.py index 96559bc1..767e145e 100644 --- a/tzrec/modules/sequence_test.py +++ b/tzrec/modules/sequence_test.py @@ -16,7 +16,6 @@ from tzrec.modules.sequence import ( DINEncoder, - HSTUEncoder, MultiWindowDINEncoder, PoolingEncoder, SelfAttentionEncoder, @@ -102,54 +101,6 @@ def test_din_encoder_padding(self, graph_type) -> None: self.assertEqual(result.size(), (4, 16)) -class HSTUEncoderTest(unittest.TestCase): - @parameterized.expand( - [[TestGraphType.NORMAL], [TestGraphType.FX_TRACE], [TestGraphType.JIT_SCRIPT]] - ) - def test_hstu_encoder(self, graph_type) -> None: - hstu = HSTUEncoder( - sequence_dim=16, - input="click_seq", - max_seq_length=10, - attn_dim=16, - linear_dim=16, - ) - self.assertEqual(hstu.output_dim(), 16) - hstu = create_test_module(hstu, graph_type) - embedded = { - "click_seq.sequence": torch.randn(4, 10, 16), - "click_seq.sequence_length": torch.tensor([2, 3, 4, 5]), - } - result = hstu(embedded) - if hstu.training: - self.assertEqual(result.size(), (14, 16)) - else: - self.assertEqual(result.size(), (4, 16)) - - @parameterized.expand( - [[TestGraphType.NORMAL], [TestGraphType.FX_TRACE], [TestGraphType.JIT_SCRIPT]] - ) - def test_hstu_encoder_padding(self, graph_type) -> None: - hstu = HSTUEncoder( - sequence_dim=16, - input="click_seq", - max_seq_length=10, - attn_dim=16, - linear_dim=16, - ) - self.assertEqual(hstu.output_dim(), 16) - hstu = create_test_module(hstu, graph_type) - embedded = { - "click_seq.sequence": torch.randn(4, 10, 16), - "click_seq.sequence_length": torch.tensor([2, 3, 4, 5]), - } - result = hstu(embedded) - if hstu.training: - self.assertEqual(result.size(), (14, 16)) - else: - self.assertEqual(result.size(), (4, 16)) - - class SimpleAttentionTest(unittest.TestCase): @parameterized.expand( [ diff --git a/tzrec/protos/data.proto b/tzrec/protos/data.proto index 6012e5c2..1db03d62 100644 --- a/tzrec/protos/data.proto +++ b/tzrec/protos/data.proto @@ -132,9 +132,6 @@ message DataConfig { // fg run mode. optional FgMode fg_mode = 20 [default = FG_NONE]; - // hstu enable - optional bool enable_hstu = 21 [default = false]; - // whether to shuffle data optional bool shuffle = 22 [default = false]; diff --git a/tzrec/protos/models/match_model.proto b/tzrec/protos/models/match_model.proto index 11024318..95846c1f 100644 --- a/tzrec/protos/models/match_model.proto +++ b/tzrec/protos/models/match_model.proto @@ -21,8 +21,6 @@ message DSSM { message HSTUMatch { required HSTUMatchTower hstu_tower = 1; - // user and item tower output dimension - required int32 output_dim = 2; // similarity method optional Similarity similarity = 3 [default=INNER_PRODUCT]; // similarity scaling factor diff --git a/tzrec/protos/module.proto b/tzrec/protos/module.proto index 7a1a5982..6a3ca89e 100644 --- a/tzrec/protos/module.proto +++ b/tzrec/protos/module.proto @@ -195,12 +195,21 @@ message GRContextualInterleavePreprocessor { required GRContextualizedMLP content_mlp = 7; } +message GRUIHPreprocessor { + // action encoder config (optional - for models with action info) + optional GRActionEncoder action_encoder = 1; + // action embedding mlp config (required if action_encoder is set) + optional GRContextualizedMLP action_mlp = 2; +} + message GRInputPreprocessor { oneof input_preprocessor { // input preprocessor with contextual features GRContextualPreprocessor contextual_preprocessor = 20; // input preprocessor with interleave targets GRContextualInterleavePreprocessor contextual_interleave_preprocessor = 21; + // input preprocessor for sequence-only models (no candidate concat) + GRUIHPreprocessor uih_preprocessor = 22; } } diff --git a/tzrec/protos/tower.proto b/tzrec/protos/tower.proto index 024cbb71..07a668a0 100644 --- a/tzrec/protos/tower.proto +++ b/tzrec/protos/tower.proto @@ -4,7 +4,6 @@ package tzrec.protos; import "tzrec/protos/module.proto"; import "tzrec/protos/loss.proto"; import "tzrec/protos/metric.proto"; -import "tzrec/protos/seq_encoder.proto"; message Tower { // input feature group name required string input = 1; @@ -13,10 +12,12 @@ message Tower { }; message HSTUMatchTower { - // input feature group name + // input feature group name (uih group) required string input = 1; - // hstu_encoder config - required HSTUEncoder hstu_encoder = 2; + // HSTU config (STU, positional_encoder, input_preprocessor, output_postprocessor) + required HSTU hstu = 2; + // max sequence length + optional uint32 max_seq_len = 3 [default = 100]; } message DINTower { diff --git a/tzrec/tests/configs/hstu_fg_mock.config b/tzrec/tests/configs/hstu_fg_mock.config index 3eb55d18..41e52a30 100644 --- a/tzrec/tests/configs/hstu_fg_mock.config +++ b/tzrec/tests/configs/hstu_fg_mock.config @@ -31,15 +31,13 @@ data_config { dataset_type: ParquetDataset fg_mode: FG_DAG label_fields: "clk" - enable_hstu: true num_workers: 8 negative_sampler { input_path: "odps://{PROJECT}/tables/taobao_ad_feature_gl_bucketized_v1" num_sample: 128 - attr_fields: "historical_ids" - item_id_field: "historical_ids" + attr_fields: "item_id" + item_id_field: "item_id" attr_delimiter: "\t" - item_id_delim: ';' } } feature_configs { @@ -61,36 +59,52 @@ feature_configs { } } feature_configs { - id_feature { + sequence_id_feature { feature_name: "item_id" expression: "item:item_id" + sequence_length: 10 + sequence_delim: ";" num_buckets: 1000 embedding_dim: 48 - embedding_name: "item_id" } } model_config { feature_groups { - group_name: "sequence" + group_name: "uih" feature_names: "historical_ids" - group_type: SEQUENCE + group_type: JAGGED_SEQUENCE + } + feature_groups { + group_name: "candidate" + feature_names: "item_id" + group_type: JAGGED_SEQUENCE } hstu_match { hstu_tower { - input: 'sequence' - hstu_encoder { - sequence_dim: 48 - attn_dim: 48 - linear_dim: 48 - input: "sequence" - max_seq_length: 210 - num_blocks: 2 - num_heads: 1 - linear_activation: "silu" - linear_config: "uvqk" - max_output_len: 0 + input: "uih" + hstu { + stu { + embedding_dim: 48 + num_heads: 1 + hidden_dim: 48 + attention_dim: 48 + output_dropout_ratio: 0.2 + } + attn_num_layers: 2 + positional_encoder { + num_position_buckets: 512 + } + input_preprocessor { + uih_preprocessor { + } + } + output_postprocessor { + l2norm_postprocessor { + } + } } + max_seq_len: 210 } temperature: 0.05 } diff --git a/tzrec/tests/match_integration_test.py b/tzrec/tests/match_integration_test.py index 9333ee41..31138e2c 100644 --- a/tzrec/tests/match_integration_test.py +++ b/tzrec/tests/match_integration_test.py @@ -15,6 +15,7 @@ import unittest from tzrec.tests import utils +from tzrec.utils.test_util import gpu_unavailable class MatchIntegrationTest(unittest.TestCase): @@ -431,14 +432,13 @@ def test_mind_train_eval_export(self): os.path.exists(os.path.join(self.test_dir, "export/item/scripted_model.pt")) ) - @unittest.skip("skip hstu match test") + @unittest.skipIf(*gpu_unavailable) def test_hstu_with_fg_train_eval_export(self): self.success = utils.test_train_eval( "tzrec/tests/configs/hstu_fg_mock.config", self.test_dir, user_id="user_id", item_id="item_id", - is_hstu=True, ) if self.success: self.success = utils.test_eval( @@ -448,6 +448,15 @@ def test_hstu_with_fg_train_eval_export(self): self.success = utils.test_export( os.path.join(self.test_dir, "pipeline.config"), self.test_dir ) + if self.success: + self.success = utils.test_predict( + scripted_model_path=os.path.join(self.test_dir, "export/item"), + predict_input_path=os.path.join(self.test_dir, r"eval_data/\*.parquet"), + predict_output_path=os.path.join(self.test_dir, "predict_result"), + reserved_columns="item_id", + output_columns="item_tower_emb", + test_dir=self.test_dir, + ) self.assertTrue(self.success) self.assertTrue( os.path.exists(os.path.join(self.test_dir, "export/user/scripted_model.pt")) diff --git a/tzrec/tests/utils.py b/tzrec/tests/utils.py index 77022083..d31d8be8 100644 --- a/tzrec/tests/utils.py +++ b/tzrec/tests/utils.py @@ -14,7 +14,7 @@ import os import random from collections import OrderedDict, defaultdict -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Set, Tuple import numpy as np import numpy.typing as npt @@ -82,17 +82,19 @@ def __init__( num_ids: Optional[int] = None, vocab_list: Optional[List[str]] = None, multival_sep: str = chr(3), + as_string: bool = False, ) -> None: super().__init__(name) self.is_multi = is_multi self.num_ids = num_ids self.vocab_list = vocab_list self.multival_sep = multival_sep + self.as_string = as_string def create_data(self, num_rows: int, has_null: bool = True) -> pa.Array: """Create mock data.""" if not self.is_multi: - # int64 + # int64 (or string if as_string=True) num_valid_rows = ( random.randint(num_rows // 2, num_rows) if has_null else num_rows ) @@ -101,6 +103,8 @@ def create_data(self, num_rows: int, has_null: bool = True) -> pa.Array: ) data = data + [None] * (num_rows - num_valid_rows) random.shuffle(data) + if self.as_string: + data = [str(x) if x is not None else None for x in data] else: # string num_multi_rows = random.randint(num_rows // 3, 2 * num_rows // 3) @@ -128,36 +132,6 @@ def create_data(self, num_rows: int, has_null: bool = True) -> pa.Array: return pa.array(data) -class HSTUIdMockInput(MockInput): - """Mock sparse id input data class.""" - - def __init__( - self, - name: str, - is_multi: bool = False, - num_ids: Optional[int] = None, - vocab_list: Optional[List[str]] = None, - multival_sep: str = chr(3), - ) -> None: - super().__init__(name) - self.is_multi = is_multi - self.num_ids = num_ids - self.vocab_list = vocab_list - self.multival_sep = multival_sep - - def create_data(self, num_rows: int, has_null: bool = True) -> pa.Array: - """Create mock data.""" - # string - # num_multi_rows = random.randint(num_rows // 3, 2 * num_rows // 3) - num_multi_id = 3 - data_multi = _create_random_id_data( - (num_rows, num_multi_id), self.num_ids, self.vocab_list - ).astype(str) - data_multi = list(map(lambda x: self.multival_sep.join(x), data_multi)) - random.shuffle(data_multi) - return pa.array(data_multi) - - class SeqIdMockInput(MockInput): """Mock sparse id sequence input data class.""" @@ -513,7 +487,13 @@ def create_mock_data( input_data = {} for inp in inputs.values(): if inp.name == unique_id: - input_data[inp.name] = pa.array(list(range(num_rows))) + ids = pa.array(list(range(num_rows))) + # If the input is configured for string output (e.g., a + # sequence_id_feature used as sampler item_id), cast to string + # so joins on this key remain type-consistent. + if isinstance(inp, IdMockInput) and getattr(inp, "as_string", False): + ids = ids.cast(pa.string()) + input_data[inp.name] = ids elif isinstance(inp, SeqMockInput): input_data.update(inp.create_sequence_data(num_rows, join_t)) else: @@ -694,7 +674,7 @@ def build_mock_input_with_fg( features: List[BaseFeature], user_id: str = "", item_id: str = "", - is_hstu: bool = False, + neg_fields: Optional[Set[str]] = None, ) -> Dict[str, MockInput]: """Build mock input instance list with fg from features.""" inputs = defaultdict(dict) @@ -818,13 +798,21 @@ def build_mock_input_with_fg( if isinstance(inputs[side][sub_name], IdMockInput): inputs[side][sub_name].is_multi = False else: - if is_hstu: - # hstu require number of sequence item is over 2 - inputs[side][input_name] = HSTUIdMockInput( + if neg_fields and input_name in neg_fields: + # Sampler-targeted sequence features must have + # single-value mock data because the sampler casts + # the field to int64. The candidate sequence is + # created at runtime by combine_neg_as_candidate_sequence. + # Generate as string so the sequence_id_feature parser + # can read it (with delimiter; single value parses as + # one-element sequence). + inputs[side][input_name] = IdMockInput( input_name, - is_multi=True, - num_ids=feature.num_embeddings, - multival_sep=feature.sequence_delim, + is_multi=False, + num_ids=10 + if isinstance(feature, CustomFeature) + else feature.num_embeddings, + as_string=True, ) else: inputs[side][input_name] = IdMockInput( @@ -844,7 +832,6 @@ def load_config_for_test( user_id: str = "", item_id: str = "", cate_id: str = "", - is_hstu: bool = False, num_rows: Optional[int] = None, ) -> EasyRecConfig: """Modify pipeline config for integration tests.""" @@ -881,8 +868,17 @@ def load_config_for_test( num_parts=num_parts, ) else: + # Determine sampler attr_fields so mock generator can produce + # single-value data for sequence features that the sampler will + # use as item_id (sampler casts field values to int64). + sampler_neg_fields: Set[str] = set() + if data_config.HasField("sampler"): + sampler_type_name = data_config.WhichOneof("sampler") + sampler_cfg = getattr(data_config, sampler_type_name) + if hasattr(sampler_cfg, "attr_fields"): + sampler_neg_fields = set(sampler_cfg.attr_fields) user_inputs, item_inputs = build_mock_input_with_fg( - features, user_id, item_id, is_hstu + features, user_id, item_id, neg_fields=sampler_neg_fields ) _, item_t = create_mock_data( os.path.join(test_dir, "item_data"), @@ -979,7 +975,7 @@ def load_config_for_test( item_id, # hstu only uses item_id as negative sample, \ # as sampler_config.attr_fields is sequence - neg_fields=[item_id] if is_hstu else list(sampler_config.attr_fields), + neg_fields=list(sampler_config.attr_fields), attr_delimiter=sampler_config.attr_delimiter, num_rows=data_config.batch_size * num_parts * 4, ) @@ -992,7 +988,7 @@ def load_config_for_test( os.path.join(test_dir, "item_gl"), item_inputs, item_id, - neg_fields=[item_id] if is_hstu else list(sampler_config.attr_fields), + neg_fields=list(sampler_config.attr_fields), attr_delimiter=sampler_config.attr_delimiter, num_rows=data_config.batch_size * num_parts * 4, ) @@ -1032,7 +1028,6 @@ def test_train_eval( user_id: str = "", item_id: str = "", cate_id: str = "", - is_hstu: bool = False, env_str: str = "", num_rows: Optional[int] = None, ) -> bool: @@ -1043,7 +1038,6 @@ def test_train_eval( user_id, item_id, cate_id, - is_hstu, num_rows=num_rows, )