Pytorch DDP / FSDP Overview

 BACKGROUND : What is stored in the GPU mem ? 

a) Parameter weights

b) Their gradients (first derivatives)

c) Optimizer state : E.g. Adam Optimizer




So what is stored in the GPU memory actually ?  

Assuming FP16 and x parameters:

  • Parameter memory : 2*x bytes (2 bytes per param)
  • Gradient memory : 2*x bytes (2 bytes per gradient)
  • Optimizer state : All stored as FP32s
    • Parameter copy : 4*x (4 bytes per param - optimizer always stores in full precision) 
    • Momentum copy : 4*x (m above) 
    • Variance copy : 4*x 
    • So total = 12*x, in general K*x

The below picture captures this as an equation. For  𝛙 = # of params and K = 12, total memory = 

     [ 2(weights) + 2(gradients) + K]*  𝛙 bytes. 




Above is also shown how the memory goes down as we shard different state components into N_d shards (e.g. above N_d = 64).
 

NCCL 

A word on NCCL - NVidia Collective Communication Library 
-- this is like a map-reduce for NVidia
-- supports primitives like all-gather, all-reduce, reduce, reduce-scatter

How does NCCL reduce/all-reduce work ? 

Assume running same program on different split of data at different nodes. Reduce allows us to aggregate / combine the values of the weights at a single node (root). 
                                                 

In contrast, an all-reduce will aggregate the values of the weights at all nodes (ensuring that all GPUs land up having the same weights).



There can be a variety of first order functions instead of sum - max, product, min, avg etc

These can be used to ensure that 

 How DDP works : 

We split data into mini-batch across GPUs but every GPU has entire model. 



So we use all-reduce to sync weights at end of mini-batch boundary, ensuring invariant that the beginning of each mini-batch, all weights are same. 

The same thing can be represented spatially as follows : 






What about FSDP ? 


First what FSDP is not! It is not 
- model parallelism : Split layers of model vertically across GPUs, data is replicated

Note : In feedforward manually move data from GPU1 to GPU2 and vice-versa for backprop phase. It is inefficient because when GPU1 is busy, GPU 2 is doing nothing (shown below) 



- tensor parallelism: Split layers horizontally across GPUs, data is replicated

  
Here we show the data (in white boxes) and the tensor state (in purple). The tensor state is sharded and same data is applied in parallel to extract 


- pipeline parallelism 

We know the problem with model parallelism

Pipeline parallelism tries to ameliorate this by using the waiting time to do some other useful work e.g. read the next mini-batch. 




Now we're good to get into FSDP

FSDP Phase 1 

FSDP works first by vertical splitting of model.
Unit of splitting could be layer, group of layers, stage - programmer can create these units.


 
               Here 3 units : 
                        unit 1 = layer 0, 3
                        unit 2 = layer 1. 2
                        unit 3 = layer 4, 5








FSDP Phase 2 


Next is the horizontal sharding phase. 

First store all the weights of an FSDP unit in a flatparameter list. 
Then shard the parameters across GPU nodes 


Here there are 12 weight parameters and 3 bias parameters. This is not a multiple of # of GPUS (16), so we pad and then shard. In this toy example, each GPU will be responsible for one element. 


FSDP Phase 3 


Now that we have sharded our model across GPUs, we are good to start training process.

A note on the all-gather function by NCCL 


This is how all-gather is used in FSDP for programming :

We start out with a FSDP flatparameter list of parameters, shard it as per Phase 2 and then 




In order to perform a feedforward on a mini-batch, it will need all weights in all shards of that unit. We call all-gather before each feedforward / backprop phase. An all-gather essentially assembles together a FSDP unit (a group of layers) to be able to execute forward/backward pass on it. 

IMPORTANT :  We are NOT gathering the entire model, the entire model is too big to fit in memory.  We are gathering a single FSDP unit. 

Once the operation is complete, each node discards shards it is not responsible for. 

The memory requirements of FSDP are proportional to size of sharded model + size of largest materialized FSDP unit.  

FSDP supports overlapping communication with computation by prefetching intelligently. Here we show in forward pass how we first prefetch AG0 and only after that do we execute FWD0. While FWD0 executes, we also prefetch AG1. 



FSDP Phase 4 


Another NCCL function we will need in backward pass is reduce-scatter. 



Here we want to sum across each position of list / array. But if we use an all-reduce, it will send all results to all nodes, which is not what we want. Instead, we want to scatter different results of reduces to different nodes. 

We show how FSDP uses reduce-scatter in backward pass : 

Here we have completed the forward pass and follow up with an all-gather step followed by computation of backward loss. Note that the gradients generated in the middle blocks are all different because the mini-batches each node operates on is different. Now we execute a reduce-scatter to aggregate gradients at each respective node. 




The corresponding computation / communication overlap is depicted after AG2 which corresponds to the All-Gather step.  : 


Note : FSDP unit 0 is not freed after forward pass. It is freed only at end of the backward pass unlike other FSDP units that are freed during forward pass also. 



Reasons to use FSDP

a) When model is too big to fit on single GPU

b) No free lunch - have to pay increased communication overhead.

c) Trade memory for time.  


Reasons to not use FSDP 

a) Small parameter models (< 100 M params) that are not memory constrained


Comments

Popular posts from this blog

GPipe - Combining Data and So Pipeline Paralllelism

Serving DNNS Like Clockwork - OSDI 2020