Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 70 additions & 42 deletions tzrec/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
CKPT_SOURCE_ID,
Batch,
RecordBatchTensor,
process_hstu_neg_sample,
process_hstu_seq_data,
combine_neg_as_candidate_sequence,
remove_nullable,
)
from tzrec.features.feature import BaseFeature
Expand Down Expand Up @@ -177,7 +176,6 @@ def __init__(
self._reserved_columns = reserved_columns or []
self._mode = mode
self._debug_level = debug_level
self._enable_hstu = data_config.enable_hstu
self.sampler_type = (
self._data_config.WhichOneof("sampler")
if self._data_config.HasField("sampler")
Expand Down Expand Up @@ -208,17 +206,21 @@ def __init__(
self._selected_input_names |= set(data_config.sample_weight_fields)
if data_config.HasField("sample_cost_field"):
self._selected_input_names.add(data_config.sample_cost_field)
self._sampler_item_id_field: Optional[str] = None
self._sampler_user_id_field: Optional[str] = None
if self._data_config.HasField("sampler") and self._mode != Mode.PREDICT:
sampler_type = self._data_config.WhichOneof("sampler")
sampler_config = getattr(self._data_config, sampler_type)
if hasattr(sampler_config, "item_id_field") and sampler_config.HasField(
"item_id_field"
):
self._selected_input_names.add(sampler_config.item_id_field)
self._sampler_item_id_field = sampler_config.item_id_field
if hasattr(sampler_config, "user_id_field") and sampler_config.HasField(
"user_id_field"
):
self._selected_input_names.add(sampler_config.user_id_field)
self._sampler_user_id_field = sampler_config.user_id_field
# if set selected_input_names to None,
# all columns will be reserved.
if (
Expand All @@ -238,6 +240,14 @@ def __init__(
self._sampler = None
self._sampler_inited = False

# Build mapping of field_name → sequence_delim for candidate sequence
# auto-detection during negative sampling.
self._seq_field_delims: Dict[str, str] = {}
for feature in features:
if hasattr(feature, "sequence_delim") and feature.sequence_delim:
for input_name in feature.inputs:
self._seq_field_delims[input_name] = feature.sequence_delim

self._reader = None

def launch_sampler_cluster(
Expand Down Expand Up @@ -367,49 +377,67 @@ def _build_batch(self, input_data: Dict[str, pa.Array]) -> Batch:
input_data = _expand_tdm_sample(
input_data, pos_sampled, neg_sampled, self._data_config
)
elif self._enable_hstu:
seq_attr = self._sampler._item_id_field

(
input_data_k_split,
input_data_k_split_slice,
pre_seq_filter_reshaped_joined,
) = process_hstu_seq_data(
input_data=input_data,
seq_attr=seq_attr,
seq_str_delim=self._sampler.item_id_delim,
)
if self._mode == Mode.TRAIN:
# Training using all possible target items
input_data[seq_attr] = input_data_k_split_slice
elif self._mode == Mode.EVAL:
# Evaluation using the last item for previous sequence
input_data[seq_attr] = input_data_k_split.values.take(
pa.array(input_data_k_split.offsets.to_numpy()[1:] - 1)
)
sampled = self._sampler.get(input_data)
# To keep consistent with other process, use two functions
for k, v in sampled.items():
if k in input_data:
combined = process_hstu_neg_sample(
input_data,
v,
self._sampler._num_sample,
self._sampler.item_id_delim,
seq_attr,
)
# Combine here to make embddings of both user sequence
# and target item are the same
input_data[k] = pa.concat_arrays(
[pre_seq_filter_reshaped_joined, combined]
)
else:
input_data[k] = v
else:
# If item_id_field is a sequence feature, flatten per-row
# positives into a 1D array so the sampler treats each
# positive as a separate query. Expand user_id in the same
# way for samplers that rely on it (V2, HardNeg, HardNegV2).
# The sampler computes expand_factor dynamically from the
# actual input length, so no padding or batch_size limit.
multi_pos_lists: Optional[pa.Array] = None
saved_user_ids: Optional[pa.Array] = None
if (
self._sampler_item_id_field is not None
and self._sampler_item_id_field in self._seq_field_delims
):
seq_delim = self._seq_field_delims[self._sampler_item_id_field]
raw = input_data[self._sampler_item_id_field]
if pa.types.is_string(raw.type):
multi_pos_lists = pc.split_pattern(raw, seq_delim)
flat_pos = pc.list_flatten(multi_pos_lists)
input_data[self._sampler_item_id_field] = flat_pos

# Expand user_id to match flat positives (step 4).
if self._sampler_user_id_field is not None and (
self._sampler_user_id_field in input_data
):
counts_np = pc.list_value_length(multi_pos_lists).to_numpy()
row_indices = pa.array(
np.repeat(np.arange(len(counts_np)), counts_np)
)
saved_user_ids = input_data[self._sampler_user_id_field]
input_data[self._sampler_user_id_field] = pc.take(
saved_user_ids, row_indices
)

sampled = self._sampler.get(input_data)

# Restore multi-value form so combine_neg sees original grouping.
if multi_pos_lists is not None:
input_data[self._sampler_item_id_field] = multi_pos_lists
if saved_user_ids is not None:
input_data[self._sampler_user_id_field] = saved_user_ids

for k, v in sampled.items():
if k in input_data:
input_data[k] = pa.concat_arrays([input_data[k], v])
seq_delim = self._seq_field_delims.get(k)
if seq_delim is not None:
pos_for_combine = input_data[k]
if pa.types.is_list(pos_for_combine.type):
num_pos = len(pc.list_flatten(pos_for_combine))
else:
num_pos = len(pos_for_combine)
# Sampler returns num_pos * expand_factor negs.
neg_per_pos = max(1, len(v) // max(1, num_pos))
valid_negs = v.slice(0, num_pos * neg_per_pos)
input_data[k] = combine_neg_as_candidate_sequence(
pos_for_combine,
valid_negs,
neg_per_pos,
seq_delim,
)
else:
input_data[k] = pa.concat_arrays([input_data[k], v])
else:
input_data[k] = v

Expand Down
100 changes: 64 additions & 36 deletions tzrec/datasets/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import random
import socket
import time
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import graphlearn as gl
import numpy as np
Expand Down Expand Up @@ -428,29 +428,39 @@ def __init__(
),
)
self._item_id_field = config.item_id_field
self._sampler = None
self._sampler_cache: Dict[int, Any] = {}
self.item_id_delim = config.item_id_delim

def init(self, client_id: int = -1) -> None:
"""Init sampler client and samplers."""
"""Init sampler client; samplers are created lazily in get()."""
super().init(client_id)
expand_factor = int(math.ceil(self._num_sample / self._batch_size))
self._sampler = self._g.negative_sampler(
"item", expand_factor, strategy="node_weight"
)
self._sampler_cache = {}

def _get_sampler(self, expand_factor: int) -> Any:
if expand_factor not in self._sampler_cache:
self._sampler_cache[expand_factor] = self._g.negative_sampler(
"item", expand_factor, strategy="node_weight"
)
return self._sampler_cache[expand_factor]

def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
"""Sampling method.

Computes expand_factor dynamically from the actual input length so
that multi-value sequence positives (flattened upstream) are
supported without padding.

Args:
input_data (dict): input data with item_id.

Returns:
Negative sampled feature dict.
"""
ids = _pa_ids_to_npy(input_data[self._item_id_field])
ids = np.pad(ids, (0, self._batch_size - len(ids)), "edge")
nodes = self._sampler.get(ids)
num_pos = max(1, len(ids))
expand_factor = int(math.ceil(self._num_sample / num_pos))
sampler = self._get_sampler(expand_factor)
nodes = sampler.get(ids)
features = self._parse_nodes(nodes)
result_dict = dict(zip(self._valid_attr_names, features))
return result_dict
Expand Down Expand Up @@ -509,15 +519,12 @@ def __init__(
)
self._item_id_field = config.item_id_field
self._user_id_field = config.user_id_field
self._sampler = None
self._sampler_cache: Dict[int, Any] = {}

def init(self, client_id: int = -1) -> None:
"""Init sampler client and samplers."""
"""Init sampler client; samplers are created lazily in get()."""
super().init(client_id)
expand_factor = int(math.ceil(self._num_sample / self._batch_size))
self._sampler = self._g.negative_sampler(
"edge", expand_factor, strategy="random", conditional=True
)
self._sampler_cache = {}

# prevent gl timeout
worker_info = get_worker_info()
Expand All @@ -528,20 +535,30 @@ def init(self, client_id: int = -1) -> None:
{self._user_id_field: pa.array([0]), self._item_id_field: pa.array([0])}
)

def _get_sampler(self, expand_factor: int) -> Any:
if expand_factor not in self._sampler_cache:
self._sampler_cache[expand_factor] = self._g.negative_sampler(
"edge", expand_factor, strategy="random", conditional=True
)
return self._sampler_cache[expand_factor]

def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
"""Sampling method.

Args:
input_data (dict): input data with user_id and item_id.
input_data (dict): input data with user_id and item_id. When
item_id is a flattened multi-value sequence, user_id must
be expanded to match the flattened length.

Returns:
Negative sampled feature dict.
"""
src_ids = _pa_ids_to_npy(input_data[self._user_id_field])
dst_ids = _pa_ids_to_npy(input_data[self._item_id_field])
src_ids = np.pad(src_ids, (0, self._batch_size - len(src_ids)), "edge")
dst_ids = np.pad(dst_ids, (0, self._batch_size - len(dst_ids)), "edge")
nodes = self._sampler.get(src_ids, dst_ids)
num_pos = max(1, len(dst_ids))
expand_factor = int(math.ceil(self._num_sample / num_pos))
sampler = self._get_sampler(expand_factor)
nodes = sampler.get(src_ids, dst_ids)
features = self._parse_nodes(nodes)
result_dict = dict(zip(self._valid_attr_names, features))
return result_dict
Expand Down Expand Up @@ -602,20 +619,24 @@ def __init__(
)
self._item_id_field = config.item_id_field
self._user_id_field = config.user_id_field
self._neg_sampler = None
self._neg_sampler_cache: Dict[int, Any] = {}
self._hard_neg_sampler = None

def init(self, client_id: int = -1) -> None:
"""Init sampler client and samplers."""
"""Init sampler client; neg_sampler created lazily in get()."""
super().init(client_id)
expand_factor = int(math.ceil(self._num_sample / self._batch_size))
self._neg_sampler = self._g.negative_sampler(
"item", expand_factor, strategy="node_weight"
)
self._neg_sampler_cache = {}
self._hard_neg_sampler = self._g.neighbor_sampler(
["hard_neg_edge"], self._num_hard_sample, strategy="full"
)

def _get_neg_sampler(self, expand_factor: int) -> Any:
if expand_factor not in self._neg_sampler_cache:
self._neg_sampler_cache[expand_factor] = self._g.negative_sampler(
"item", expand_factor, strategy="node_weight"
)
return self._neg_sampler_cache[expand_factor]

def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
"""Sampling method.

Expand All @@ -628,8 +649,10 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
"""
src_ids = _pa_ids_to_npy(input_data[self._user_id_field])
dst_ids = _pa_ids_to_npy(input_data[self._item_id_field])
dst_ids = np.pad(dst_ids, (0, self._batch_size - len(dst_ids)), "edge")
nodes = self._neg_sampler.get(dst_ids)
num_pos = max(1, len(dst_ids))
expand_factor = int(math.ceil(self._num_sample / num_pos))
neg_sampler = self._get_neg_sampler(expand_factor)
nodes = neg_sampler.get(dst_ids)
neg_features = self._parse_nodes(nodes)
sparse_nodes = self._hard_neg_sampler.get(src_ids).layer_nodes(1)
hard_neg_features, hard_neg_indices = self._parse_sparse_nodes(sparse_nodes)
Expand Down Expand Up @@ -703,20 +726,24 @@ def __init__(
)
self._item_id_field = config.item_id_field
self._user_id_field = config.user_id_field
self._neg_sampler = None
self._neg_sampler_cache: Dict[int, Any] = {}
self._hard_neg_sampler = None

def init(self, client_id: int = -1) -> None:
"""Init sampler client and samplers."""
"""Init sampler client; neg_sampler created lazily in get()."""
super().init(client_id)
expand_factor = int(math.ceil(self._num_sample / self._batch_size))
self._neg_sampler = self._g.negative_sampler(
"edge", expand_factor, strategy="random", conditional=True
)
self._neg_sampler_cache = {}
self._hard_neg_sampler = self._g.neighbor_sampler(
["hard_neg_edge"], self._num_hard_sample, strategy="full"
)

def _get_neg_sampler(self, expand_factor: int) -> Any:
if expand_factor not in self._neg_sampler_cache:
self._neg_sampler_cache[expand_factor] = self._g.negative_sampler(
"edge", expand_factor, strategy="random", conditional=True
)
return self._neg_sampler_cache[expand_factor]

def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
"""Sampling method.

Expand All @@ -729,9 +756,10 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
"""
src_ids = _pa_ids_to_npy(input_data[self._user_id_field])
dst_ids = _pa_ids_to_npy(input_data[self._item_id_field])
padded_src_ids = np.pad(src_ids, (0, self._batch_size - len(src_ids)), "edge")
dst_ids = np.pad(dst_ids, (0, self._batch_size - len(dst_ids)), "edge")
nodes = self._neg_sampler.get(padded_src_ids, dst_ids)
num_pos = max(1, len(dst_ids))
expand_factor = int(math.ceil(self._num_sample / num_pos))
neg_sampler = self._get_neg_sampler(expand_factor)
nodes = neg_sampler.get(src_ids, dst_ids)
neg_features = self._parse_nodes(nodes)
sparse_nodes = self._hard_neg_sampler.get(src_ids).layer_nodes(1)
hard_neg_features, hard_neg_indices = self._parse_sparse_nodes(sparse_nodes)
Expand Down
Loading
Loading