1*da0073e9SAndroid Build Coastguard Worker.. role:: hidden 2*da0073e9SAndroid Build Coastguard Worker :class: hidden-section 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard WorkerTensor Parallelism - torch.distributed.tensor.parallel 5*da0073e9SAndroid Build Coastguard Worker====================================================== 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard WorkerTensor Parallelism(TP) is built on top of the PyTorch DistributedTensor 8*da0073e9SAndroid Build Coastguard Worker(`DTensor <https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/README.md>`__) 9*da0073e9SAndroid Build Coastguard Workerand provides different parallelism styles: Colwise, Rowwise, and Sequence Parallelism. 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker.. warning :: 12*da0073e9SAndroid Build Coastguard Worker Tensor Parallelism APIs are experimental and subject to change. 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard WorkerThe entrypoint to parallelize your ``nn.Module`` using Tensor Parallelism is: 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker.. automodule:: torch.distributed.tensor.parallel 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker.. currentmodule:: torch.distributed.tensor.parallel 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker.. autofunction:: parallelize_module 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard WorkerTensor Parallelism supports the following parallel styles: 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker.. autoclass:: torch.distributed.tensor.parallel.ColwiseParallel 25*da0073e9SAndroid Build Coastguard Worker :members: 26*da0073e9SAndroid Build Coastguard Worker :undoc-members: 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker.. autoclass:: torch.distributed.tensor.parallel.RowwiseParallel 29*da0073e9SAndroid Build Coastguard Worker :members: 30*da0073e9SAndroid Build Coastguard Worker :undoc-members: 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker.. autoclass:: torch.distributed.tensor.parallel.SequenceParallel 33*da0073e9SAndroid Build Coastguard Worker :members: 34*da0073e9SAndroid Build Coastguard Worker :undoc-members: 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard WorkerTo simply configure the nn.Module's inputs and outputs with DTensor layouts 37*da0073e9SAndroid Build Coastguard Workerand perform necessary layout redistributions, without distribute the module 38*da0073e9SAndroid Build Coastguard Workerparameters to DTensors, the following ``ParallelStyle`` s can be used in 39*da0073e9SAndroid Build Coastguard Workerthe ``parallelize_plan`` when calling ``parallelize_module``: 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker.. autoclass:: torch.distributed.tensor.parallel.PrepareModuleInput 42*da0073e9SAndroid Build Coastguard Worker :members: 43*da0073e9SAndroid Build Coastguard Worker :undoc-members: 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker.. autoclass:: torch.distributed.tensor.parallel.PrepareModuleOutput 46*da0073e9SAndroid Build Coastguard Worker :members: 47*da0073e9SAndroid Build Coastguard Worker :undoc-members: 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker.. note:: when using the ``Shard(dim)`` as the input/output layouts for the above 50*da0073e9SAndroid Build Coastguard Worker ``ParallelStyle`` s, we assume the input/output activation tensors are evenly sharded on 51*da0073e9SAndroid Build Coastguard Worker the tensor dimension ``dim`` on the ``DeviceMesh`` that TP operates on. For instance, 52*da0073e9SAndroid Build Coastguard Worker since ``RowwiseParallel`` accepts input that is sharded on the last dimension, it assumes 53*da0073e9SAndroid Build Coastguard Worker the input tensor has already been evenly sharded on the last dimension. For the case of uneven 54*da0073e9SAndroid Build Coastguard Worker sharded activation tensors, one could pass in DTensor directly to the partitioned modules, 55*da0073e9SAndroid Build Coastguard Worker and use ``use_local_output=False`` to return DTensor after each ``ParallelStyle``, where 56*da0073e9SAndroid Build Coastguard Worker DTensor could track the uneven sharding information. 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard WorkerFor models like Transformer, we recommend users to use ``ColwiseParallel`` 59*da0073e9SAndroid Build Coastguard Workerand ``RowwiseParallel`` together in the parallelize_plan for achieve the desired 60*da0073e9SAndroid Build Coastguard Workersharding for the entire model (i.e. Attention and MLP). 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard WorkerParallelized cross-entropy loss computation (loss parallelism), is supported via the following context manager: 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker.. autofunction:: torch.distributed.tensor.parallel.loss_parallel 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker.. warning :: 67*da0073e9SAndroid Build Coastguard Worker The loss_parallel API is experimental and subject to change. 68