上节我们讲到all-reduce通信kernel在实际执行过程中会被拆解成多个send/recv步骤, 也就是下面这张图:

image.png

这里的Send, recvReduceSend等操作在NCCL里面就是基本的数据传输Op(这里列的只有一部分, 后面会讲到所有的), 这些op都是primitives这个类的成员函数, 而primitives这个泛型类的实例里保存着实际通信过程使用到的资源, 通信的行为; 所以我们今天主要讲解这个类.

这个类的代码位于src/collectives/device/ 目录下, 一共有三个, 分别是simple, ll(low lantency), ll128三种协议的实现, 这里我们以simple协议入手(其他两种大同小异), 看下prims_simple.h 这个文件

类成员变量

我们先看下代码, 参数及变量的含义已经标注了:

template<typename T, typename RedOp, typename Fan, int Direct,
         int SlicePerChunk, int StepPerSlice, int Unroll>
class Primitives<
    T,  // 通信的数据类型
    RedOp, // 规约操作(如Sum ,Max 等), 用于all-reduce这些需要计算的原语
    Fan, // 类似"树"结构中"度"的概念, 表示GPU的发送和接收的数量, 
		     //对于ring算法, 只有一对sender/recver, 所以Fan::MaxSend=Fan::MaxRecv = 1 , 而对于tree算法则为2
    Direct, // 表示数据写入缓冲区, 还是直接写到输入/输出;
    ProtoSimple<SlicePerChunk, // 初始化时确定, 表示每个chunk的slice数量
    StepPerSlice,  // 初始化确定, 每个slice的step数量
    Unroll>  // 是否循环展开
  > {
  static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend;  // 与上面Fan的概念类似
  static constexpr int Input=0, Output=1;  // 表示是否是有输入输出
  static constexpr int RoleInput = 0x01,  // 这些都是flag, 其中几个Role是个重要的概念, 后面会详细讲; 
                       RoleOutput = 0x02,
                       RoleWaitRecv = 0x04,
                       RoleWaitSend = 0x08,
                       RolePostSend = 0x10,
                       RolePostRecv = 0x20,
                       Aborted = 0x40,
                       PtrsFifoEnabled = 0x80,
                       SizesFifoEnabled = 0x100,
                       DirectWrite = 0x200,
                       DirectRead = 0x400, 
                       ThreadsSynced = 0x800;
  const int tid;  // 每个线程都会创建Primitives实例, 所以这里表示线程id
  int nthreads;  // 总线程数
  int nworkers;  // 工作线程数, 由于有同步线程的存在, 所以nworkers <= nthreads
  const int stepSize;  // 每个step的大小, 
  Fan fan;
  int index; // Peer index I'm responsible for
  int flags;
  int group; // 比较重要的概念, 后面和Role一起讲
  uint64_t step;
  union {
    void **connPtrsFifoPtr; // (flags & PtrsFifoEnabled)
    T *userBuff;            // (flags & (RoleInput|RoleOutput))
    T *connEltsFifo;        // !(flags & (PtrsFifoEnabled|RoleInput|RoleOutput))
  };  // 根据flag决定获取的数据地址是什么类型的
  union {
    int volatile *connSizesFifoPtr; //  (flags & SizesFifoEnabled)
    T *directBuff;                  // !(flags & SizesFifoEnabled)
  };
  uint64_t volatile *connStepPtr; // 指向peer的buffer的head, 所以类型是volatile
  uint64_t connStepCache; // Cache last seen value of (*connStepPtr)

上面的标注做一个简单的了解, 后面遇到也会结合上下文讲解, 更容易理解

我们了解了大概含义就可以看看all-reduce中是怎么用的了:

Primitives<T, RedOp, FanSymmetric<1>, 1, Proto> prims

这里有两个信息: 因为是ring算法, 所以Fan设置为1; 启用了Direct

接下来看下初始化

初始化

代码片段1:

__device__ Primitives(
      int tid, int nthreads, int const *recvPeers, int const *sendPeers,
      void const *inputBuf, void *outputBuf, uint64_t redOpArg, uint32_t group=0, struct ncclWorkElem* e = nullptr
    ):
    tid(tid),
    stepSize(ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T)) {

    // For send operations, we need an extra warp to overlap the threadfence and the copy
    this->nthreads = nthreads;
    this->nworkers = nthreads - (MaxSend > 0 && nthreads-WARP_SIZE >= 64 ? WARP_SIZE : 0);
    this->group = group & (uint16_t)0xFFFF;
    int connIndex = group >> 16;

    int nrecv=0, nsend=0;
    while (nrecv < MaxRecv && recvPeers[nrecv] != -1) nrecv++;
    while (nsend < MaxSend && sendPeers[nsend] != -1) nsend++;
    this->fan = Fan(nrecv, nsend);
  1. 首先看下函数声明里的变量, 这里的recvPeerssendPeers 表示当前rank的接收rank及发送rank的buffer的地址; 而inputBufoutputBuf 则表示用户给的输入输出, 是在当前rank上的地址.
  2. 这里的groupconnIndex不用管, all-reduce中=0, 暂时用不到
  3. 这里计算了stepSize, 也就是一个step要处理的数据个数

代码片段2

constexpr int ThreadPerSync = 8;
int g = tid / ThreadPerSync;
int ng = nthreads / ThreadPerSync;
index = tid % ThreadPerSync;
flags = 0;
if (g == 0) {
  if (index < nrecv) flags |= RoleWaitRecv;
  if (index == nrecv) flags |= RoleInput;
} else if (g == 1) {
  if (index < nsend) flags |= RoleWaitSend;
  if (index == nsend) flags |= RoleOutput;
} else if (g == ng - 2) {
  if (index < nrecv) flags |= RolePostRecv;
} else if (g == ng - 1) {
  if (index < nsend) flags |= RolePostSend;
}

int peer = 0;
if (flags & (RoleWaitRecv|RolePostRecv)) peer = recvPeers[index];
if (flags & (RoleWaitSend|RolePostSend)) peer = sendPeers[index];