Gradient backpropagation with torch.distributed.all_gather

Jianfeng Wang
4 min readFeb 7, 2021

--

This blog presents some tips on how to use torch.distributed.all_gather and make sure the gradient is correctly calculated for deep learning.

No gradient back propagated with torch.distributed.all_gather

First of all, the function of torch.distributed.all_gather itself does not propagate back the gradient. To test it out, we can run the following code.

batch_size = 16
rank = int(os.environ.get('OMPI_COMM_WORLD_RANK', '0'))
world_size = int(os.environ.get('OMPI_COMM_WORLD_SIZE', '1'))
bs_each = batch_size // world_size
device_id = int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK', '0'))
torch.cuda.set_device(device_id)
torch.distributed.init_process_group(
backend='nccl',
init_method='tcp://localhost:12345',
rank=rank,
world_size=world_size,
)
from torch import nn
model = nn.Linear(1, 1, bias=False)
model.weight.data[:] = 1.
model = model.cuda()
x = torch.ones((bs_each, 1), requires_grad=True).cuda()
y = model(x)
ys = [torch.zeros_like(y) for i in range(get_mpi_size())]
torch.distributed.all_gather(ys, y)
print(y.grad_fn)
#<MmBackward object at 0x7f2073fc3ba8>
for sub_y in ys:
print(sub_y.grad_fn)
#None

Run the code by python a.py . First, it will print out the real grad function for the code path without all_gather. But after we call all_gather, the output of ys contains no grad_fn, which we can understand that there is no gradient back propagated.

Thus, in practice, it is recommended to use torch.no_grad() to wrap the function of all_gather to make it explicitly clear that no gradient is propagated back.

Do we need the propagation for all_gather?

Since there is no gradient propagated back for all_gather . Do we need it? If we need it, how can we deal with it since it is not there. A typical setting is that each GPU computes some output, and the loss is calculated based on the outputs from all GPUs rather than from each individual GPU itself. In this setting, we can do the following to make sure 1) we don’t need the gradient back through all_gather 2) the loss can be calculated very easily.

Specifically, let’s say, the output from the i-th GPU is x_i, and the loss depends on all x_i. That is, we can write the loss as

For the traditional separable case, we have the following.

So, each GPU can calculate the loss of g(x) and then the auto grad will do the job to calculate the gradient for all parameters. Normally, this is paired with DistributedDataParallel , which will do the averaging automatically. In this case, we don’t need to gather other GPU’s output. But, what if the loss is not separable?

In this case, for each GPU, we can gather all features. As the gathered output has no grad_fn, we can replace the current one with the current network output. That is,

with torch.no_grad():
all_x = [torch.zeros_like(x) for _ in range(world_size)]
torch.distributed.all_gather(all_x, x)
all_x[rank] = x

So, all_x contains all x from all GPUs. All x there has no grad_fn except the one in the current GPU because of the last all_x[rank]=x.

After we have all_x , we can simply calculate the loss based on f . With the auto-grad, what it calculates would be

Here, θ is the network parameters. Note, there is only gradient through x_i and there is no gradient from other x because there is no gradient through all_gather . With DistributedDataParallel , the gradient for each parameter will be

Here, N is the world size. Is this what we want? No. Our objective is f , but here, there is an extra divisor N . To fix this problem, we can simply calculate the loss for each GPU as Nf rather than f .

That’s it. In conclusion, there is no gradient through all_gather , but we have effectively calculated the true gradient. Let’s summarize the steps.

  1. gather all network outputs through all_gather and then replace the current output so that the current output has gradients.
  2. calculate your loss function and then multiply it by the world size.

--

--

Jianfeng Wang
Jianfeng Wang

Written by Jianfeng Wang

Research Engineer at Microsoft for Computer Vision

Responses (2)