xref: /aosp_15_r20/external/pytorch/docs/source/distributed.tensor.parallel.rst (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker.. role:: hidden
2*da0073e9SAndroid Build Coastguard Worker    :class: hidden-section
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard WorkerTensor Parallelism - torch.distributed.tensor.parallel
5*da0073e9SAndroid Build Coastguard Worker======================================================
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard WorkerTensor Parallelism(TP) is built on top of the PyTorch DistributedTensor
8*da0073e9SAndroid Build Coastguard Worker(`DTensor <https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/README.md>`__)
9*da0073e9SAndroid Build Coastguard Workerand provides different parallelism styles: Colwise, Rowwise, and Sequence Parallelism.
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker.. warning ::
12*da0073e9SAndroid Build Coastguard Worker    Tensor Parallelism APIs are experimental and subject to change.
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard WorkerThe entrypoint to parallelize your ``nn.Module`` using Tensor Parallelism is:
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker.. automodule:: torch.distributed.tensor.parallel
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker.. currentmodule:: torch.distributed.tensor.parallel
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker.. autofunction::  parallelize_module
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard WorkerTensor Parallelism supports the following parallel styles:
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker.. autoclass:: torch.distributed.tensor.parallel.ColwiseParallel
25*da0073e9SAndroid Build Coastguard Worker  :members:
26*da0073e9SAndroid Build Coastguard Worker  :undoc-members:
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker.. autoclass:: torch.distributed.tensor.parallel.RowwiseParallel
29*da0073e9SAndroid Build Coastguard Worker  :members:
30*da0073e9SAndroid Build Coastguard Worker  :undoc-members:
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker.. autoclass:: torch.distributed.tensor.parallel.SequenceParallel
33*da0073e9SAndroid Build Coastguard Worker  :members:
34*da0073e9SAndroid Build Coastguard Worker  :undoc-members:
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard WorkerTo simply configure the nn.Module's inputs and outputs with DTensor layouts
37*da0073e9SAndroid Build Coastguard Workerand perform necessary layout redistributions, without distribute the module
38*da0073e9SAndroid Build Coastguard Workerparameters to DTensors, the following ``ParallelStyle`` s can be used in
39*da0073e9SAndroid Build Coastguard Workerthe ``parallelize_plan`` when calling ``parallelize_module``:
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker.. autoclass:: torch.distributed.tensor.parallel.PrepareModuleInput
42*da0073e9SAndroid Build Coastguard Worker  :members:
43*da0073e9SAndroid Build Coastguard Worker  :undoc-members:
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker.. autoclass:: torch.distributed.tensor.parallel.PrepareModuleOutput
46*da0073e9SAndroid Build Coastguard Worker  :members:
47*da0073e9SAndroid Build Coastguard Worker  :undoc-members:
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker.. note:: when using the ``Shard(dim)`` as the input/output layouts for the above
50*da0073e9SAndroid Build Coastguard Worker  ``ParallelStyle`` s, we assume the input/output activation tensors are evenly sharded on
51*da0073e9SAndroid Build Coastguard Worker  the tensor dimension ``dim`` on the ``DeviceMesh`` that TP operates on. For instance,
52*da0073e9SAndroid Build Coastguard Worker  since ``RowwiseParallel`` accepts input that is sharded on the last dimension, it assumes
53*da0073e9SAndroid Build Coastguard Worker  the input tensor has already been evenly sharded on the last dimension. For the case of uneven
54*da0073e9SAndroid Build Coastguard Worker  sharded activation tensors, one could pass in DTensor directly to the partitioned modules,
55*da0073e9SAndroid Build Coastguard Worker  and use ``use_local_output=False`` to return DTensor after each ``ParallelStyle``, where
56*da0073e9SAndroid Build Coastguard Worker  DTensor could track the uneven sharding information.
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard WorkerFor models like Transformer, we recommend users to use ``ColwiseParallel``
59*da0073e9SAndroid Build Coastguard Workerand ``RowwiseParallel`` together in the parallelize_plan for achieve the desired
60*da0073e9SAndroid Build Coastguard Workersharding for the entire model (i.e. Attention and MLP).
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard WorkerParallelized cross-entropy loss computation (loss parallelism), is supported via the following context manager:
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker.. autofunction:: torch.distributed.tensor.parallel.loss_parallel
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker.. warning ::
67*da0073e9SAndroid Build Coastguard Worker    The loss_parallel API is experimental and subject to change.
68