我们知道Flash Attention基于一系列给定大小的Block做的块级算子融合。
考虑到注意力图的稀疏性,我们通常可以忽略掉注意力本身就已忽略掉的部分,这些部分虽然一般是相当细粒度的,但是依旧呈现出聚集趋势,我们可以用更粗粒度的块稀疏mask来近似他们。
近似出这个mask就可以站在FA的功劳簿上水一大批论文了(bushi
基于Triton官方例子我们可以将其写出来。
我的实现不是Top-K Block Sparse Mask (BSM),而是自定义异构序列拼接后的BSM,我叫他交错掩码(Interlaced Mask, IM)。
通常可以理解为,一组多模态序列,有文本(T)、视觉(V)、音频(A),形状是[B, H, S_m, D],其中 m \in {T, V, A}。在序列维度上拼接,三个模态的拓扑是一个3 * 3的粗拓扑。外部可以定义这个拓扑,我们叫他模态级拓扑(Modal-wise Topology)。比较常见的形状如下:
test_topology = [
[0, 1, 0],
[0, 0, 1],
[1, 0, 0]
]
序列长度则如下定义:
mm_seq_seg = (50, 375, 500) # 假设就是(T, V, A)
基于模态级拓扑和序列长度可以生成稠密掩码(Dense Mask,DM)。为什么要生成DM?因为自定义的拓扑与序列长度,不同于Top-K 估算的稀疏图,存在大量无法整除计算块大小的情况。如果不将DM计入,则在计算注意力图的时候,必然出错。当然,这种情况下会存在一些计算浪费,但这就是自定义的坏处。
def generate_interlaced_mask(topology: List[List[int]],
mm_seq_seg: List[int]) -> torch.Tensor:
seq_len = sum(mm_seq_seg)
mask = torch.zeros((seq_len, seq_len), dtype=torch.bool)
for row_idx, b_row in enumerate(topology):
for col_idx, b_col_elem in enumerate(b_row):
if b_col_elem == 1:
row_start = sum(mm_seq_seg[:row_idx])
row_end = row_start + mm_seq_seg[row_idx]
col_start = sum(mm_seq_seg[:col_idx])
col_end = col_start + mm_seq_seg[col_idx]
mask[row_start:row_end, col_start:col_end] = True
return mask
相应的,我们需要实现基于DM的BSM,这个BSM和DM配合在最后的Kernel中完成高效且正确的计算。在这里,我直接把算法类给出来,方便大家看。
class TritonMBSAttnKernel(BaseMBSAttnKernel):
def __init__(self, topology: List[List[int]], mm_seq_seg: Tuple[int, int, int],
causal_mask: bool, dropout: float):
super().__init__(topology, mm_seq_seg, causal_mask, dropout)
self.dense_mask = generate_interlaced_mask(self.topology, self.mm_seq_seg)
self.block_sparse_config = BlockSparseConfig(min_seq_len=min(self.mm_seq_seg))
self.block_sparse_mask = self.__generate_block_sparse_mask()
self.block_sparse_attn_fn = block_sparse_attention.apply
def __generate_block_sparse_mask(self):
# NOTE: Not using Top-K Sparse Mask
block_size = self.block_sparse_config.block_size
block_sparse_mask = self.dense_mask.float()
seq_len = self.dense_mask.shape[0]
pad_len = (block_size - seq_len % block_size) % block_size
if pad_len > 0:
block_sparse_mask = F.pad(
block_sparse_mask, (0, pad_len, 0, pad_len),
mode="constant", value=False)
seq_len = block_sparse_mask.shape[0]
num_blocks = seq_len // block_size
block_sparse_mask = block_sparse_mask.reshape(
num_blocks, block_size, num_blocks, block_size
).any(dim=(1, 3))
return block_sparse_mask
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
q, k, v = check_maybe_fix_all_inputs(q, k, v, self.block_sparse_config)
self.block_sparse_mask = self.block_sparse_mask.to(q.device)
self.dense_mask = self.dense_mask.to(q.device)
return self.block_sparse_attn_fn(
q, k, v, self.block_sparse_mask, self.dense_mask, 1.0 / (q.shape[-1] ** 0.5),
self.block_sparse_config.block_size
)
对应的底层算子函数类如下:
class block_sparse_attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
block_sparse_mask: torch.Tensor, dense_mask: torch.Tensor,
softmax_scale: float, block_size: int):
return block_sparse_attn_fwd(
ctx, q, k, v, block_sparse_mask, dense_mask, softmax_scale, block_size)
@staticmethod
def backward(ctx, do):
# TODO: Implement backward pass
raise NotImplementedError("It does not support gradient propagation yet")
现阶段只实现了forward,backward得到后面再说。
BLOCK_M = BLOCK_N = block_size
GRID = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1]) # (S // BLOCK_M + 1, B * H)
HEAD_DIM = q.shape[-1]
N_CTX = k.shape[2]
PAST_LEN = N_CTX - q.shape[2]
N_HEADS = q.shape[1]