0x0. 前言

我们知道Flash Attention基于一系列给定大小的Block做的块级算子融合。

考虑到注意力图的稀疏性,我们通常可以忽略掉注意力本身就已忽略掉的部分,这些部分虽然一般是相当细粒度的,但是依旧呈现出聚集趋势,我们可以用更粗粒度的块稀疏mask来近似他们。

近似出这个mask就可以站在FA的功劳簿上水一大批论文了(bushi

基于Triton官方例子我们可以将其写出来。

0x1. 问题定义

我的实现不是Top-K Block Sparse Mask (BSM),而是自定义异构序列拼接后的BSM,我叫他交错掩码(Interlaced Mask, IM)。

通常可以理解为,一组多模态序列,有文本(T)、视觉(V)、音频(A),形状是[B, H, S_m, D],其中 m \in {T, V, A}。在序列维度上拼接,三个模态的拓扑是一个3 * 3的粗拓扑。外部可以定义这个拓扑,我们叫他模态级拓扑(Modal-wise Topology)。比较常见的形状如下:

test_topology = [
		[0, 1, 0],
		[0, 0, 1],
		[1, 0, 0]
]

序列长度则如下定义:

mm_seq_seg = (50, 375, 500) # 假设就是(T, V, A)

基于模态级拓扑和序列长度可以生成稠密掩码(Dense Mask,DM)。为什么要生成DM?因为自定义的拓扑与序列长度,不同于Top-K 估算的稀疏图,存在大量无法整除计算块大小的情况。如果不将DM计入,则在计算注意力图的时候,必然出错。当然,这种情况下会存在一些计算浪费,但这就是自定义的坏处。

def generate_interlaced_mask(topology: List[List[int]], 
                             mm_seq_seg: List[int]) -> torch.Tensor:
    seq_len = sum(mm_seq_seg)
    mask = torch.zeros((seq_len, seq_len), dtype=torch.bool)
    for row_idx, b_row in enumerate(topology):
        for col_idx, b_col_elem in enumerate(b_row):
            if b_col_elem == 1:
                row_start = sum(mm_seq_seg[:row_idx])
                row_end = row_start + mm_seq_seg[row_idx]
                col_start = sum(mm_seq_seg[:col_idx])
                col_end = col_start + mm_seq_seg[col_idx]
                mask[row_start:row_end, col_start:col_end] = True
    return mask

相应的,我们需要实现基于DM的BSM,这个BSM和DM配合在最后的Kernel中完成高效且正确的计算。在这里,我直接把算法类给出来,方便大家看。

class TritonMBSAttnKernel(BaseMBSAttnKernel):
    def __init__(self, topology: List[List[int]], mm_seq_seg: Tuple[int, int, int], 
                 causal_mask: bool, dropout: float):
        super().__init__(topology, mm_seq_seg, causal_mask, dropout)
        self.dense_mask = generate_interlaced_mask(self.topology, self.mm_seq_seg)
        self.block_sparse_config = BlockSparseConfig(min_seq_len=min(self.mm_seq_seg))
        self.block_sparse_mask = self.__generate_block_sparse_mask()
        self.block_sparse_attn_fn = block_sparse_attention.apply
        
    def __generate_block_sparse_mask(self):
        # NOTE: Not using Top-K Sparse Mask
        block_size = self.block_sparse_config.block_size
        block_sparse_mask = self.dense_mask.float()
        
        seq_len = self.dense_mask.shape[0]
        pad_len = (block_size - seq_len % block_size) % block_size
        if pad_len > 0:
            block_sparse_mask = F.pad(
                block_sparse_mask, (0, pad_len, 0, pad_len), 
                mode="constant", value=False)

        seq_len = block_sparse_mask.shape[0]
        num_blocks = seq_len // block_size
        block_sparse_mask = block_sparse_mask.reshape(
            num_blocks, block_size, num_blocks, block_size
        ).any(dim=(1, 3))
        return block_sparse_mask

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        q, k, v = check_maybe_fix_all_inputs(q, k, v, self.block_sparse_config)
        self.block_sparse_mask = self.block_sparse_mask.to(q.device)
        self.dense_mask = self.dense_mask.to(q.device)
        return self.block_sparse_attn_fn(
            q, k, v, self.block_sparse_mask, self.dense_mask, 1.0 / (q.shape[-1] ** 0.5), 
            self.block_sparse_config.block_size
        )

对应的底层算子函数类如下:

class block_sparse_attention(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, 
                block_sparse_mask: torch.Tensor, dense_mask: torch.Tensor, 
                softmax_scale: float, block_size: int):
        return block_sparse_attn_fwd(
            ctx, q, k, v, block_sparse_mask, dense_mask, softmax_scale, block_size)

    @staticmethod
    def backward(ctx, do):
        # TODO: Implement backward pass
        raise NotImplementedError("It does not support gradient propagation yet")

现阶段只实现了forward,backward得到后面再说。

0x2. 数据准备

BLOCK_M = BLOCK_N = block_size
GRID = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1]) # (S // BLOCK_M + 1, B * H)
HEAD_DIM = q.shape[-1]
N_CTX = k.shape[2]
PAST_LEN = N_CTX - q.shape[2]
N_HEADS = q.shape[1]