Skip to content

Commit 7ecee99

Browse files
chaodenguscmeta-codesync[bot]
authored andcommitted
Remove the hard-coded value for bucket buffer (#3581)
Summary: Pull Request resolved: #3581 As title. Differential Revision: D87966872 fbshipit-source-id: fb8e26b3aa59d57b3015be9ff788882dc628576a
1 parent ca2f687 commit 7ecee99

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

torchrec/modules/hash_mc_modules.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ class HashZchManagedCollisionModule(ManagedCollisionModule):
185185

186186
IDENTITY_BUFFER: str = "_hash_zch_identities"
187187
METADATA_BUFFER: str = "_hash_zch_metadata"
188+
BUCKET_BUFFER: str = "_hash_zch_bucket"
188189

189190
table_name_on_device_remapped_ids_dict: Dict[
190191
str, torch.Tensor
@@ -308,7 +309,10 @@ def __init__(
308309

309310
self._max_probe = max_probe
310311
self._buckets = total_num_buckets
311-
self.register_buffer("_hash_zch_bucket", torch.tensor([[total_num_buckets]]))
312+
self.register_buffer(
313+
HashZchManagedCollisionModule.BUCKET_BUFFER,
314+
torch.tensor([[total_num_buckets]]),
315+
)
312316
# Do not need to store in buffer since this is created and consumed
313317
# at each step https://fburl.com/code/axzimmbx
314318
self._evicted_indices = []

0 commit comments

Comments
 (0)