上节我们讲到all-reduce通信kernel在实际执行过程中会被拆解成多个send/recv步骤, 也就是下面这张图:
这里的Send, recvReduceSend等操作在NCCL里面就是基本的数据传输Op(这里列的只有一部分, 后面会讲到所有的), 这些op都是primitives这个类的成员函数, 而primitives这个泛型类的实例里保存着实际通信过程使用到的资源, 通信的行为; 所以我们今天主要讲解这个类.
这个类的代码位于src/collectives/device/
目录下, 一共有三个, 分别是simple, ll(low lantency), ll128三种协议的实现, 这里我们以simple协议入手(其他两种大同小异), 看下prims_simple.h
这个文件
我们先看下代码, 参数及变量的含义已经标注了:
template<typename T, typename RedOp, typename Fan, int Direct,
int SlicePerChunk, int StepPerSlice, int Unroll>
class Primitives<
T, // 通信的数据类型
RedOp, // 规约操作(如Sum ,Max 等), 用于all-reduce这些需要计算的原语
Fan, // 类似"树"结构中"度"的概念, 表示GPU的发送和接收的数量,
//对于ring算法, 只有一对sender/recver, 所以Fan::MaxSend=Fan::MaxRecv = 1 , 而对于tree算法则为2
Direct, // 表示数据写入缓冲区, 还是直接写到输入/输出;
ProtoSimple<SlicePerChunk, // 初始化时确定, 表示每个chunk的slice数量
StepPerSlice, // 初始化确定, 每个slice的step数量
Unroll> // 是否循环展开
> {
static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend; // 与上面Fan的概念类似
static constexpr int Input=0, Output=1; // 表示是否是有输入输出
static constexpr int RoleInput = 0x01, // 这些都是flag, 其中几个Role是个重要的概念, 后面会详细讲;
RoleOutput = 0x02,
RoleWaitRecv = 0x04,
RoleWaitSend = 0x08,
RolePostSend = 0x10,
RolePostRecv = 0x20,
Aborted = 0x40,
PtrsFifoEnabled = 0x80,
SizesFifoEnabled = 0x100,
DirectWrite = 0x200,
DirectRead = 0x400,
ThreadsSynced = 0x800;
const int tid; // 每个线程都会创建Primitives实例, 所以这里表示线程id
int nthreads; // 总线程数
int nworkers; // 工作线程数, 由于有同步线程的存在, 所以nworkers <= nthreads
const int stepSize; // 每个step的大小,
Fan fan;
int index; // Peer index I'm responsible for
int flags;
int group; // 比较重要的概念, 后面和Role一起讲
uint64_t step;
union {
void **connPtrsFifoPtr; // (flags & PtrsFifoEnabled)
T *userBuff; // (flags & (RoleInput|RoleOutput))
T *connEltsFifo; // !(flags & (PtrsFifoEnabled|RoleInput|RoleOutput))
}; // 根据flag决定获取的数据地址是什么类型的
union {
int volatile *connSizesFifoPtr; // (flags & SizesFifoEnabled)
T *directBuff; // !(flags & SizesFifoEnabled)
};
uint64_t volatile *connStepPtr; // 指向peer的buffer的head, 所以类型是volatile
uint64_t connStepCache; // Cache last seen value of (*connStepPtr)
上面的标注做一个简单的了解, 后面遇到也会结合上下文讲解, 更容易理解
我们了解了大概含义就可以看看all-reduce中是怎么用的了:
Primitives<T, RedOp, FanSymmetric<1>, 1, Proto> prims
这里有两个信息: 因为是ring算法, 所以Fan设置为1; 启用了Direct
接下来看下初始化
__device__ Primitives(
int tid, int nthreads, int const *recvPeers, int const *sendPeers,
void const *inputBuf, void *outputBuf, uint64_t redOpArg, uint32_t group=0, struct ncclWorkElem* e = nullptr
):
tid(tid),
stepSize(ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T)) {
// For send operations, we need an extra warp to overlap the threadfence and the copy
this->nthreads = nthreads;
this->nworkers = nthreads - (MaxSend > 0 && nthreads-WARP_SIZE >= 64 ? WARP_SIZE : 0);
this->group = group & (uint16_t)0xFFFF;
int connIndex = group >> 16;
int nrecv=0, nsend=0;
while (nrecv < MaxRecv && recvPeers[nrecv] != -1) nrecv++;
while (nsend < MaxSend && sendPeers[nsend] != -1) nsend++;
this->fan = Fan(nrecv, nsend);
recvPeers
和sendPeers
表示当前rank的接收rank及发送rank的buffer的地址; 而inputBuf
和outputBuf
则表示用户给的输入输出, 是在当前rank上的地址.group
和connIndex
不用管, all-reduce中=0, 暂时用不到constexpr int ThreadPerSync = 8;
int g = tid / ThreadPerSync;
int ng = nthreads / ThreadPerSync;
index = tid % ThreadPerSync;
flags = 0;
if (g == 0) {
if (index < nrecv) flags |= RoleWaitRecv;
if (index == nrecv) flags |= RoleInput;
} else if (g == 1) {
if (index < nsend) flags |= RoleWaitSend;
if (index == nsend) flags |= RoleOutput;
} else if (g == ng - 2) {
if (index < nrecv) flags |= RolePostRecv;
} else if (g == ng - 1) {
if (index < nsend) flags |= RolePostSend;
}
int peer = 0;
if (flags & (RoleWaitRecv|RolePostRecv)) peer = recvPeers[index];
if (flags & (RoleWaitSend|RolePostSend)) peer = sendPeers[index];