# Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] from model_registry import ModelWithKwargs import torch from torch.distributed.pipelining import pipeline from torch.distributed.pipelining.microbatch import ( merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec, ) from torch.testing._internal.common_utils import run_tests, TestCase d_hid = 512 torch.manual_seed(0) class MicrobatchTests(TestCase): def test_split_and_merge(self): x0 = torch.randn(128, d_hid) x1 = torch.randn(256, d_hid) x2 = torch.randn(512, d_hid) args = (x0, x1, x2) kwargs = {"x0": x0, "x1": x1, "x2": x2} # Default chunking: dim 0 arg_chunks, kwarg_chunks = split_args_kwargs_into_chunks(args, kwargs, 2) assert len(arg_chunks) == 2 assert len(kwarg_chunks) == 2 assert arg_chunks[0][0].shape == torch.Size([64, d_hid]) assert arg_chunks[1][0].shape == torch.Size([64, d_hid]) assert arg_chunks[0][1].shape == torch.Size([128, d_hid]) assert arg_chunks[0][2].shape == torch.Size([256, d_hid]) assert kwarg_chunks[0]["x0"].shape == torch.Size([64, d_hid]) assert kwarg_chunks[0]["x1"].shape == torch.Size([128, d_hid]) assert kwarg_chunks[1]["x2"].shape == torch.Size([256, d_hid]) # Merge chunks back together merged_args = merge_chunks( arg_chunks, (TensorChunkSpec(0), TensorChunkSpec(0), TensorChunkSpec(0)), ) torch.testing.assert_close(merged_args, args) merged_kwargs = merge_chunks( kwarg_chunks, { "x0": TensorChunkSpec(0), "x1": TensorChunkSpec(0), "x2": TensorChunkSpec(0), }, ) torch.testing.assert_close(merged_kwargs, kwargs) print("Microbatch test passed") def test_chunk_spec(self): mod = ModelWithKwargs() batch_size = ModelWithKwargs.DEFAULT_BATCH_SIZE x = torch.randn(batch_size, d_hid) y = torch.randn(batch_size, d_hid) num_chunks = 4 args_chunk_spec = TensorChunkSpec.from_tuple((0,)) kwargs_chunk_spec = TensorChunkSpec.from_dict({"y": 0}) args_split, kwargs_split = split_args_kwargs_into_chunks( (x,), {"y": y}, num_chunks, args_chunk_spec, kwargs_chunk_spec, ) pipe = pipeline( mod, mb_args=args_split[0], mb_kwargs=kwargs_split[0], ) ref = mod(x, y) out = pipe(x, y)[0] torch.testing.assert_close(out, ref) print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}") if __name__ == "__main__": run_tests()