1# Owner(s): ["oncall: distributed"] 2 3import sys 4 5from torch import distributed as dist 6from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 7from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 8from torch.testing._internal.common_fsdp import ( 9 CUDAInitMode, 10 FSDPInitMode, 11 FSDPTest, 12 NestedWrappedModule, 13) 14from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN 15 16 17if not dist.is_available(): 18 print("Distributed not available, skipping tests", file=sys.stderr) 19 sys.exit(0) 20 21if TEST_WITH_DEV_DBG_ASAN: 22 print( 23 "Skip dev-asan as torch + multiprocessing spawn have known issues", 24 file=sys.stderr, 25 ) 26 sys.exit(0) 27 28 29class TestTraversal(FSDPTest): 30 @property 31 def world_size(self): 32 return 2 33 34 @skip_if_lt_x_gpu(2) 35 def test_fsdp_modules(self): 36 nested_wrapped_module = NestedWrappedModule.init( 37 self.process_group, 38 FSDPInitMode.RECURSIVE, 39 CUDAInitMode.CUDA_BEFORE, 40 ) 41 modules = FSDP.fsdp_modules(nested_wrapped_module) 42 self.assertEqual( 43 modules, 44 [ 45 nested_wrapped_module.module.get_submodule("1"), 46 nested_wrapped_module.module.get_submodule("1").get_submodule("0"), 47 nested_wrapped_module.module.get_submodule("2"), 48 ], 49 ) 50 modules = FSDP.fsdp_modules(nested_wrapped_module, root_only=True) 51 self.assertEqual( 52 modules, 53 [ 54 nested_wrapped_module.module.get_submodule("1"), 55 nested_wrapped_module.module.get_submodule("2"), 56 ], 57 ) 58 59 60if __name__ == "__main__": 61 run_tests() 62