-
Notifications
You must be signed in to change notification settings - Fork 63
[feat] integrate dynamicemb table fusion (wheel 20260407.97b80bf) #466
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| dynamicemb @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/dynamicemb/cu129/dynamicemb-0.0.1%2B20260331.bea6b4b.cu129-cp310-cp310-linux_x86_64.whl ; python_version=="3.10" | ||
| dynamicemb @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/dynamicemb/cu129/dynamicemb-0.0.1%2B20260331.bea6b4b.cu129-cp311-cp311-linux_x86_64.whl ; python_version=="3.11" | ||
| dynamicemb @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/dynamicemb/cu129/dynamicemb-0.0.1%2B20260331.bea6b4b.cu129-cp312-cp312-linux_x86_64.whl ; python_version=="3.12" | ||
| dynamicemb @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/dynamicemb/cu129/dynamicemb-0.0.1%2B20260407.97b80bf.cu129-cp310-cp310-linux_x86_64.whl ; python_version=="3.10" | ||
| dynamicemb @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/dynamicemb/cu129/dynamicemb-0.0.1%2B20260407.97b80bf.cu129-cp311-cp311-linux_x86_64.whl ; python_version=="3.11" | ||
| dynamicemb @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/dynamicemb/cu129/dynamicemb-0.0.1%2B20260407.97b80bf.cu129-cp312-cp312-linux_x86_64.whl ; python_version=="3.12" | ||
| torch_fx_tool @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/rtp/torch_fx_tool-0.0.1%2B20251201.8c109c4-py3-none-any.whl |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -20,7 +20,6 @@ | |||||||||||||||||
| from torchrec.distributed.embedding_types import ( | ||||||||||||||||||
| EmbeddingComputeKernel, | ||||||||||||||||||
| GroupedEmbeddingConfig, | ||||||||||||||||||
| ShardedEmbeddingTable, | ||||||||||||||||||
| ) | ||||||||||||||||||
| from torchrec.distributed.planner import ( | ||||||||||||||||||
| constants, | ||||||||||||||||||
|
|
@@ -47,7 +46,7 @@ | |||||||||||||||||
| ShardingType, | ||||||||||||||||||
| ShardMetadata, | ||||||||||||||||||
| ) | ||||||||||||||||||
| from torchrec.modules.embedding_configs import BaseEmbeddingConfig, DataType | ||||||||||||||||||
| from torchrec.modules.embedding_configs import BaseEmbeddingConfig | ||||||||||||||||||
|
|
||||||||||||||||||
| from tzrec.protos import feature_pb2 | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -61,7 +60,6 @@ | |||||||||||||||||
| FrequencyAdmissionStrategy, | ||||||||||||||||||
| KVCounter, | ||||||||||||||||||
| align_to_table_size, | ||||||||||||||||||
| batched_dynamicemb_compute_kernel, | ||||||||||||||||||
| ) | ||||||||||||||||||
| from dynamicemb.batched_dynamicemb_compute_kernel import ( | ||||||||||||||||||
| BatchedDynamicEmbedding, | ||||||||||||||||||
|
|
@@ -191,6 +189,10 @@ def build_dynamicemb_constraints( | |||||||||||||||||
| else: | ||||||||||||||||||
| raise ValueError(f"Unknown AdmissionStrategy: {admission_strategy_type}") | ||||||||||||||||||
|
|
||||||||||||||||||
| demb_opt_kwargs = {} | ||||||||||||||||||
| if dynamicemb_cfg.HasField("bucket_capacity"): | ||||||||||||||||||
| demb_opt_kwargs["bucket_capacity"] = dynamicemb_cfg.bucket_capacity | ||||||||||||||||||
|
Comment on lines
+192
to
+194
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: validate A user could set
Suggested change
|
||||||||||||||||||
|
|
||||||||||||||||||
| dynamicemb_options = dynamicemb.DynamicEmbTableOptions( | ||||||||||||||||||
| max_capacity=dynamicemb_cfg.max_capacity, | ||||||||||||||||||
| init_capacity=init_capacity, | ||||||||||||||||||
|
|
@@ -207,6 +209,7 @@ def build_dynamicemb_constraints( | |||||||||||||||||
| score_strategy=score_strategy, | ||||||||||||||||||
| admit_strategy=admit_strategy, | ||||||||||||||||||
| admission_counter=admission_counter, | ||||||||||||||||||
| **demb_opt_kwargs, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| constraints_kwargs = {} | ||||||||||||||||||
|
|
@@ -352,18 +355,18 @@ def _to_sharding_plan( | |||||||||||||||||
| bucket_capacity=dynamicemb_options.bucket_capacity, | ||||||||||||||||||
| ) | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # align to DEMB_TABLE_ALIGN_SIZE | ||||||||||||||||||
| num_aligned_embedding_per_rank = align_to_table_size(shards[0].size[0]) | ||||||||||||||||||
| num_embeddings_per_shard = shards[0].size[0] | ||||||||||||||||||
| if num_aligned_embedding_per_rank < dynamicemb_options.bucket_capacity: | ||||||||||||||||||
| num_aligned_embedding_per_rank = align_to_table_size( | ||||||||||||||||||
| dynamicemb_options.bucket_capacity | ||||||||||||||||||
| ) | ||||||||||||||||||
| if num_embeddings_per_shard != num_aligned_embedding_per_rank: | ||||||||||||||||||
| dynamicemb_options.num_aligned_embedding_per_rank = ( | ||||||||||||||||||
| num_aligned_embedding_per_rank | ||||||||||||||||||
| ) | ||||||||||||||||||
| # Fill in per-shard fields that used to be populated by | ||||||||||||||||||
| # dynamicemb's internal ``_get_dynamicemb_options_per_table``. | ||||||||||||||||||
| # After the fused-storage refactor (NVIDIA recsys-examples | ||||||||||||||||||
| # PR #343) that upstream function became a pass-through | ||||||||||||||||||
| # validator, so the caller must set ``dim``, ``max_capacity`` | ||||||||||||||||||
| # (per-shard row count) and ``embedding_dtype`` directly. | ||||||||||||||||||
| dynamicemb_options.dim = shards[0].size[1] | ||||||||||||||||||
| dynamicemb_options.max_capacity = shards[0].size[0] | ||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question: should The storage estimator ( Should this be |
||||||||||||||||||
| if dynamicemb_options.embedding_dtype is None: | ||||||||||||||||||
| dynamicemb_options.embedding_dtype = tensor.dtype | ||||||||||||||||||
| if dynamicemb_options.index_type is None: | ||||||||||||||||||
| dynamicemb_options.index_type = torch.int64 | ||||||||||||||||||
|
|
||||||||||||||||||
| module_plan[sharding_option.name] = DynamicEmbParameterSharding( | ||||||||||||||||||
| sharding_spec=sharding_spec, | ||||||||||||||||||
|
|
@@ -614,42 +617,6 @@ def dynamicemb_calculate_shard_storages( | |||||||||||||||||
| for hbm_size, ddr_size in zip(hbm_sizes, ddr_sizes) | ||||||||||||||||||
| ] | ||||||||||||||||||
|
|
||||||||||||||||||
| _dynamicemb_get_dynamicemb_options_per_table = ( | ||||||||||||||||||
| batched_dynamicemb_compute_kernel._get_dynamicemb_options_per_table | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| def _get_dynamicemb_options_per_table( | ||||||||||||||||||
| local_row: int, | ||||||||||||||||||
| local_col: int, | ||||||||||||||||||
| data_type: DataType, | ||||||||||||||||||
| optimizer: dynamicemb.EmbOptimType, | ||||||||||||||||||
| table: ShardedEmbeddingTable, | ||||||||||||||||||
| ) -> dynamicemb.DynamicEmbTableOptions: | ||||||||||||||||||
| # pyre-ignore [16] | ||||||||||||||||||
| dynamicemb_options = table.fused_params["dynamicemb_options"] | ||||||||||||||||||
| bak_local_hbm_for_values = None | ||||||||||||||||||
| if dynamicemb_options.num_aligned_embedding_per_rank is not None: | ||||||||||||||||||
| bak_local_hbm_for_values = dynamicemb_options.local_hbm_for_values | ||||||||||||||||||
|
|
||||||||||||||||||
| dynamicemb_options = _dynamicemb_get_dynamicemb_options_per_table( | ||||||||||||||||||
| local_row=local_row, | ||||||||||||||||||
| local_col=local_col, | ||||||||||||||||||
| data_type=data_type, | ||||||||||||||||||
| optimizer=optimizer, | ||||||||||||||||||
| table=table, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # do not improve the HBM budget, already aligned in planner. | ||||||||||||||||||
| if bak_local_hbm_for_values is not None: | ||||||||||||||||||
| dynamicemb_options.local_hbm_for_values = bak_local_hbm_for_values | ||||||||||||||||||
|
|
||||||||||||||||||
| return dynamicemb_options | ||||||||||||||||||
|
|
||||||||||||||||||
| # pyre-ignore [9] | ||||||||||||||||||
| batched_dynamicemb_compute_kernel._get_dynamicemb_options_per_table = ( | ||||||||||||||||||
| _get_dynamicemb_options_per_table | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Monkey-patch for torchrec 1.5.0 compatibility | ||||||||||||||||||
| # The base class now passes 'env' parameter to _create_embedding_kernel | ||||||||||||||||||
| def _grouped_embeddings_lookup_create_embedding_kernel( | ||||||||||||||||||
|
|
||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: The comment lists
CUSTOMIZEDas a validscore_strategy, but the user-facing docs intentionally omit it (it's an internal dynamicemb enum value, not meant for end users). Consider removingCUSTOMIZEDfrom this comment to avoid confusion for anyone reading the proto directly.