Transformers documentation

Tensor parallelism in transformers

You are viewing v4.52.1 version. A newer version v4.52.2 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Tensor parallelism in transformers

Tensor parallelism shards a model onto multiple GPUs and parallelizes computations such as matrix multiplication. It enables fitting larger model sizes into memory and is faster because each GPU can process a tensor slice. This document assumes that you are already familiar with the basics of tensor parallelism. If you are not, please refer to the Ultra-Scale Playbook section on tensor parallelism.

Tensor parallelism is very communication intensive, therefore it is reccomended to use it on a single machine with multiple GPUs, utilizing fast intra-node communication. For multi-node training, methods as pipeline or data parallelism are more efficient (depending on your use case).

Tensor parallelism requires slight changes to the model parameters, therefore in transformers, we support some of the popular models out of the box.

Expand the list below to see which models support tensor parallelism. Open a GitHub issue or pull request to add support for a model not currently below.

Supported models

Using 🤗 transformers

Transformers provides a simple interface to use for tensor parallelism. We provide multiple classes implementing different partitioning strategies and a simple entrypoint to parallelize nn.Module instance. You won’t have to interact with this interface directly, everything is done in PretrainedModel.from_pretrained method for you. This section will first talk about the partitioning strategies we support, then the user interface you will be interacting with, and finally it will teach you how to extend it with your own partitioning strategies.

Partitioning strategies

In transformers, partitioning strategies reside in a class ParallelInterface which works like a mapping from string to the strategy implementation.

class ParallelInterface(MutableMapping):
    """
    Dict-like object keeping track of allowed attention functions. You can easily add a new attention function
    with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`,
    it needs to declare a new instance of this class inside the `modeling_<model>.py`, and declare it on that instance.
    """
    _global_mapping = {
        "colwise": ColwiseParallel(),
        "rowwise": RowwiseParallel(),
        "colwise_rep": ColwiseParallel(output_layouts=Replicate()),
        "rowwise_rep": RowwiseParallel(input_layouts=Replicate()),
        "local_colwise": ColwiseParallel(use_dtensor=False),
        "local_rowwise": RowwiseParallel(use_dtensor=False),
        "local": IsolatedParallel(),
        "gather": GatherParallel(),
        "local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False),
        "sequence_parallel": SequenceParallel(),
        "replicate": ReplicateParallel(),
    }

We support the following strategies:

  • ColwiseParallel - A simple column-wise partitioning, being able to handle both weights and biases, does exactly what we’ve discussed before.
  • RowwiseParallel - Again, row-wise partitioning as dicussed before, supports weights and biases, on top of that it also supports nn.Embedding modules.
  • SequenceParallel - Sequence parallel implementation, for support of LayerNorm and Dropout layers. Also supports Python implementation of RMSNorm (see this)
  • PackedColwiseParallel - A variant of column-wise partitioning, however it works on packed weights (i.e. up_proj and gate_proj being packed together). For more details, see this comment
  • PackedRowwiseParallel - A variant of row-wise partitioning, works on packed weights, for more details check the comment linked above.
  • GatherParallel - A very simple class, that only makes the outputs of the module to be gathered across devices.
  • IsolatedParallel - This is a special case, where we want to isolate the module from the rest of the devices (world). This is used for Experts in MoE layers, basically creating Expert parallelism of sorts.
  • ReplicateParallel - Many torch.distributed APIs break if model is partially sharded, so this class is used to replicate the module across all devices.

Sharding a model

We provide two ways to shard a model, first one is to use auto tensor parallelism plan, which will automatically shard the model based on our predefined configuration. This requires the model to have predefined tensor parallel plan in transformers.

from transformers import AutoModelForCausalLM

# model_id = "meta-llama/Meta-Llama-3-8B-Instruct" # better for smaller number of GPUs
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" # better to visualize all the possible strategies

model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, tp_plan="auto")

print(model._tp_plan)

For a list of models that support tensor parallelism, see the Supported models section above.

The second way is to manually specify your own partitioning plan.

from transformers import AutoModelForCausalLM

tp_plan = {
    "model.layers.*.self_attn.q_proj": "colwise",
    "model.layers.*.self_attn.k_proj": "colwise",
    "model.layers.*.self_attn.v_proj": "colwise",
    "model.layers.*.self_attn.o_proj": "rowwise",
    ...
}

model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, tp_plan=tp_plan)

print(model._tp_plan)

You might have noticed that there are some special cases in the ParallelInterface mapping, let’s now talk about them. This will help you understand their purpose and help with extending to other strategies.

PackedRowwiseParallel

This class is a special case of RowwiseParallel, it’s used to shard packed weights. Weight packing is a common technique used in models. It’s a technique where we pack multiple linear layers into a single, bigger one.

For example in Llama4 model, we pack up_proj and gate_proj into a single gate_up_proj module.

class Llama4TextExperts(nn.Module):
    ...
    self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))

Then in forward, we can use batch matrix multiplication to compute the output of the gate_up_proj module.

def forward(self, hidden_states):
    ...
    gate_up = torch.bmm(hidden_states, self.gate_up_proj) # Compute the output of the gate_up_proj module
    gate, up = gate_up.chunk(2, dim=-1) # Split the output into gate and up

In this case, we need to use the PackedRowwiseParallel strategy to shard the gate_up_proj module, as using a simple RowwiseParallel will shard the layers wrongly.

If this is a bit difficult to wrap your head around, check out this comment for an amazing visual representation of why Packed* needs to be used.

local* strategies

You could have noticed that there are local* strategies, which use the same layers as * strategy, but don’t use DTensor at all. This is because DTensor is not supported for some of the operations: such as torch.chunk. Therefore, sometimes we need to use the local* strategies, which use vanilla torch.Tensor and do some of the distributed logic manually.

Manually specifying your own partitiong plan requires a good understanding of the model architecture and how the partitioning strategies interact together. If you are not sure about this, the resulting model can be very slow, even failing or incorrect. Again, refer to the Ultra-Scale Playbook which can teach you everything required.

Extending the interface with your own partitioning strategies

This is a very advanced topic, which requires a good understanding of distributed collectives and the model architecture. Your custom partitioning strategy should inherit from TensorParallelLayer defined in integrations/tensor_parallel.py and implement: partition_tensor, _prepare_input_fn and _prepare_output_fn. Then it should be registered in the ParallelInterface mapping, so our dispatching logic can find it when specified in the tp_plan.

Let’s go through this workflow step by step, on an already existing example: ColwiseParallel.

  1. Inherit from TensorParallelLayer and initialization
class ColwiseParallel(TensorParallelLayer):
    def __init__(
        self,
        *,
        input_layouts: Optional[Placement] = None, # The input layout coming from the previous layer
        output_layouts: Optional[Placement] = None, # The output layout we want to achieve
        use_local_output: bool = True, # Whether to use local output or not
        use_dtensor=True, # Whether to use DTensor or not
    ):
        self.input_layouts = (input_layouts or Replicate(),) # The input sharding coming from the previous layer
        self.output_layouts = (output_layouts or Shard(-1),) # Desired output sharding
        self.desired_input_layouts = (Replicate(),) # Desired input sharding, inputs should be replicated across GPUs
        self.use_local_output = use_local_output
        self.use_dtensor = use_dtensor

In the __init__ method, we define these attributes, where input_layouts and output_layouts describing, how the input and output tensors should be placed on the devices. desired_input_layouts is used to specify, how the input SHOULD be placed on the devices.

2a. Implement partition_tensor method

def partition_tensor(
    self,
    param, # Full tensor of the parameter
    empty_param, # Empty tensor of the parameter, will be filled with the partitioned tensor
    param_type, # Type of the parameter, `bias` or `weight`
    param_casting_dtype, # The type to cast the parameter to
    to_contiguous, # Whether to convert the tensor to a contiguous memory layout
    rank, # The rank of the current device
    device_mesh, # The device mesh
) -> nn.Parameter: # Return the partitioned parameter
    ...

This method is used to partition the tensor, and fill the empty_param with the partitioned tensor. We provide some utility functions to help you with this, such as get_tensor_shard which will get you the correct shard of the original parameter for this rank or get_packed_weights to help with packed weights.

2b. Implement _prepare_input_fn and _prepare_output_fn methods

These methods are used as pre-forward and forward hooks respectively. Their purpose is to re-distribute the inputs and outputs to the desired layout, passed in the __init__ method.

def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
    ...
    # Do some custom logic, cast to DTensor etc.
    ...
    return inputs.redistribute(placements=desired_input_layouts, device_mesh=device_mesh)

def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
    ...
    # Do some custom logic, cast to DTensor etc.
    ...
    return outputs.redistribute(placements=output_layouts, device_mesh=device_mesh)
  1. Register the strategy Congratulations! You’ve implemented your own partitioning strategy. Now, to use it with your own tp_plan, you need to register it in the ParallelInterface mapping.
from transformers.integrations.tensor_parallel import ParallelInterface

ParallelInterface.register_strategy("colwise_custom", ColwiseParallel)

And now you can use it in your tp_plan as such:

tp_plan = {
    "model.layers.*.self_attn.q_proj": "colwise_custom",
    ...
}

model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, tp_plan=tp_plan)

Full example

Let’s go through a full example of inference with tensor parallelism.

import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


# enable tensor parallelism
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B-Instruct",
    tp_plan="auto",
)

# prepare input tokens
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
prompt = "Can I help"
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)

# distributed run
outputs = model(inputs)

Launch the inference script above on torchrun with 4 processes per GPU.

torchrun --nproc-per-node 4 demo.py

You can benefit from considerable speed ups for inference, especially for inputs with large batch size or long sequences.

For a single forward pass on Llama with a sequence length of 512 and various batch sizes, you can expect the following speed ups.

Tensor parallelism in-depth

Our implementation of tensor parallelism is framework-agnostic in design, but the specific implementations we’ve developed rely on the torch.distributed package. We heavily utilize abstractions such as DeviceMesh or DTensor to provide a simple and extensible interface to the user.

DeviceMesh

Imagine DeviceMesh as a multi-dimensional grid of devices that communicate together. Different parallelization strategies require different types of communication patterns, therefore we can create a DeviceMesh with multiple submeshes:

from torch.distributed.device_mesh import init_device_mesh

# Create a 1D mesh of 4 GPUs
device_mesh = init_device_mesh("cuda", (4,), mesh_dim_names=["tp"])

Then, most of the torch.distributed defined parallelization strategies can be applied to a mesh itself, or its submesh, automatically handling the communication patterns.

DTensor

Abbreviation for Distributed Tensor, DTensor is a tensor subclass that handles the distributed logic on-top of the usual tensor operations. Most of the model weights in case of tensor parallelism are stored as DTensors (with some exceptions, more on that later). The most important part of DTensor, that is crucial to understand, is the placement attribute. It’s an attribute that tells PyTorch how is the tensor placed on the devices of the DeviceMesh.

It can have the following values:

  • Shard(dimension) - Annotates that this DTensor is sharded across a given dimension, over the DeviceMesh it was constructed under. For example, if we would like to shard weights for column-wise partitioning, we would do:
weight = ...
weight = DTensor.from_local(weight, device_mesh["tp"], placements=[Shard(0)]) # Shard across the 1st (column-wise) dimension
bias = ...
bias = DTensor.from_local(bias, device_mesh["tp"], placements=[Shard(-1)]) # Shard across the ONLY dimension

To give another example, for row-wise partitioning, we would do:

weight = ...
weight = DTensor.from_local(weight, device_mesh["tp"], placements=[Shard(1)]) # Shard across the 2nd (row-wise) dimension
bias = ...
bias = DTensor.from_local(bias, device_mesh["tp"], placements=[Replicate()]) # Replicate bias across all GPUs
  • Replicate() - Annotates that this DTensor is replicated across the DeviceMesh. Very straight-forward, only creates a full copy of the tensor on each device.
  • Partial() - This placement is mostly of no interest to us, it’s used to annotate that this tensor is pending a reduction operation.
< > Update on GitHub