# Owner(s): ["module: dynamo"] import sys import unittest from typing import Dict, List import torch import torch._dynamo.config import torch._dynamo.test_case from torch import nn from torch._dynamo.test_case import TestCase from torch._dynamo.testing import CompileCounter from torch.testing._internal.common_utils import NoTest try: from torchrec.datasets.random import RandomRecDataset from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor HAS_TORCHREC = True except ImportError: HAS_TORCHREC = False @torch._dynamo.config.patch(force_unspec_int_unbacked_size_like_on_torchrec_kjt=True) class BucketizeMod(torch.nn.Module): def __init__(self, feature_boundaries: Dict[str, List[float]]): super().__init__() self.bucket_w = torch.nn.ParameterDict() self.boundaries_dict = {} for key, boundaries in feature_boundaries.items(): self.bucket_w[key] = torch.nn.Parameter( torch.empty([len(boundaries) + 1]).fill_(1.0), requires_grad=True, ) buf = torch.tensor(boundaries, requires_grad=False) self.register_buffer( f"{key}_boundaries", buf, persistent=False, ) self.boundaries_dict[key] = buf def forward(self, features: "KeyedJaggedTensor") -> "KeyedJaggedTensor": weights_list = [] for key, boundaries in self.boundaries_dict.items(): jt = features[key] bucketized = torch.bucketize(jt.weights(), boundaries) # doesn't super matter I guess # hashed = torch.ops.fb.index_hash(bucketized, seed=0, modulo=len(boundaries)) hashed = bucketized weights = torch.gather(self.bucket_w[key], dim=0, index=hashed) weights_list.append(weights) return KeyedJaggedTensor( keys=features.keys(), values=features.values(), weights=torch.cat(weights_list), lengths=features.lengths(), offsets=features.offsets(), stride=features.stride(), length_per_key=features.length_per_key(), ) if not HAS_TORCHREC: print("torchrec not available, skipping tests", file=sys.stderr) TestCase = NoTest # noqa: F811 @unittest.skipIf(not HAS_TORCHREC, "these tests require torchrec") class TorchRecTests(TestCase): def test_pooled(self): tables = [ (nn.EmbeddingBag(2000, 8), ["a0", "b0"]), (nn.EmbeddingBag(2000, 8), ["a1", "b1"]), (nn.EmbeddingBag(2000, 8), ["b2"]), ] embedding_groups = { "a": ["a0", "a1"], "b": ["b0", "b1", "b2"], } counter = CompileCounter() @torch.compile(backend=counter, fullgraph=True, dynamic=True) def f(id_list_features: KeyedJaggedTensor): id_list_jt_dict: Dict[str, JaggedTensor] = id_list_features.to_dict() pooled_embeddings = {} # TODO: run feature processor for emb_module, feature_names in tables: features_dict = id_list_jt_dict for feature_name in feature_names: f = features_dict[feature_name] pooled_embeddings[feature_name] = emb_module( f.values(), f.offsets() ) pooled_embeddings_by_group = {} for group_name, group_embedding_names in embedding_groups.items(): group_embeddings = [ pooled_embeddings[name] for name in group_embedding_names ] pooled_embeddings_by_group[group_name] = torch.cat( group_embeddings, dim=1 ) return pooled_embeddings_by_group dataset = RandomRecDataset( keys=["a0", "a1", "b0", "b1", "b2"], batch_size=4, hash_size=2000, ids_per_feature=3, num_dense=0, ) di = iter(dataset) # unsync should work d1 = next(di).sparse_features.unsync() d2 = next(di).sparse_features.unsync() d3 = next(di).sparse_features.unsync() r1 = f(d1) r2 = f(d2) r3 = f(d3) self.assertEqual(counter.frame_count, 1) counter.frame_count = 0 # sync should work too d1 = next(di).sparse_features.sync() d2 = next(di).sparse_features.sync() d3 = next(di).sparse_features.sync() r1 = f(d1) r2 = f(d2) r3 = f(d3) self.assertEqual(counter.frame_count, 1) # export only works with unsync gm = torch._dynamo.export(f)(next(di).sparse_features.unsync()).graph_module gm.print_readable() self.assertEqual(gm(d1), r1) self.assertEqual(gm(d2), r2) self.assertEqual(gm(d3), r3) def test_bucketize(self): mod = BucketizeMod({"f1": [0.0, 0.5, 1.0]}) features = KeyedJaggedTensor.from_lengths_sync( keys=["f1"], values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), lengths=torch.tensor([2, 0, 1, 1, 1, 3]), weights=torch.tensor([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]), ).unsync() def f(x): # This is a trick to populate the computed cache and instruct # ShapeEnv that they're all sizey x.to_dict() return mod(x) torch._dynamo.export(f, aten_graph=True)(features).graph_module.print_readable() @unittest.expectedFailure def test_simple(self): jag_tensor1 = KeyedJaggedTensor( values=torch.tensor([3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), keys=["index_0", "index_1"], lengths=torch.tensor([0, 0, 1, 1, 1, 3]), ).sync() # ordinarily, this would trigger one specialization self.assertEqual(jag_tensor1.length_per_key(), [1, 5]) counter = CompileCounter() @torch._dynamo.optimize(counter, nopython=True) def f(jag_tensor): # The indexing here requires more symbolic reasoning # and doesn't work right now return jag_tensor["index_0"].values().sum() f(jag_tensor1) self.assertEqual(counter.frame_count, 1) jag_tensor2 = KeyedJaggedTensor( values=torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), keys=["index_0", "index_1"], lengths=torch.tensor([2, 0, 1, 1, 1, 3]), ).sync() f(jag_tensor2) self.assertEqual(counter.frame_count, 1) if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()