一、论文 

 1
题目Agent Attention: On the Integration of Softmax
and Linear Attention
地址:https:///pdf/2312.08874

二、概要总结 

 1

代理注意力动机的示意图。(a) 在Softmax注意力中,每个查询聚合来自所有特征的信息,导致二次复杂度。(b) 利用注意力权重之间的冗余,代理注意力使用少量代理标记作为查询的“代理”,捕获来自所有特征的多样化语义信息,然后将其呈现给每个查询。注意力权重来自DeiT-T和Agent-DeiT-T。

  • 现代Transformer模型通常采用Softmax自注意力机制,其计算复杂度与token数量的平方成正比,这在视觉任务中可能导致计算需求过高。

  •  为了解决这一问题,研究者们尝试通过设计高效的注意力模式来降低计算复杂度,但这些方法往往牺牲了模型的全局感受野和长距离关系建模能力。

  • 为了在计算效率和表示能力之间取得平衡,论文提出了一种新的注意力范式——Agent Attention。

三、方法 

 1

1. Agent Attention的提出:提出了一种新的注意力范式——Agent Attention,通过引入一组代理令牌(agent tokens)A,使得查询令牌(query tokens)Q能够通过代理令牌聚合来自键(keys)和值(values)的信息,然后将信息广播回每个查询令牌。

2. Agent Attention的计算过程:Agent Attention可以表示为四元组(Q, A, K, V),其中A作为查询进行第一次Softmax注意力计算以聚合代理特征VA,然后作为键进行第二次Softmax注意力计算以将全局信息广播给每个查询令牌,形成最终输出。

3. Agent Attention与Softmax和线性注意力的关系: Agent Attention实际上是Softmax注意力和线性注意力的一种优雅整合,它既保留了Softmax注意力的全局上下文建模能力,又具有线性注意力的高效计算复杂度。

四、实验分析 

1. 图像分类结果:Agent Attention在ImageNet-1K分类任务中显著提升了各种视觉Transformer模型的性能,特别是在高分辨率场景中表现突出。

2. 目标检测结果: 在COCO数据集上的目标检测实验表明,Agent Attention能够一致地增强不同配置下的性能。

3. 语义分割结果: 在ADE 20K数据集上的语义分割实验显示,Agent Attention与各种分割模型兼容,并且能够持续提升性能。

【ECCV 2024】新注意力范式——Agent Attention,整合Softmax与线性注意力

4. 稳定扩散结果: Agent Attention在稳定扩散模型中的应用显著加速了图像生成过程,并在无需额外训练的情况下提高了生成质量。

5. 大感受野和高分辨率结果: Agent Attention能够利用全局感受野,同时保持与Softmax注意力相同的计算复杂度,这在高分辨率场景中尤为重要。

五、代码 

 1
import torchimport torch.nn as nnfrom timm.models.layers import trunc_normal_
# 论文:Agent Attention: On the Integration of Softmax and Linear Attention# 论文地址:https:///pdf/2312.08874# 微信公众号:AI缝合术'''2024年全网最全即插即用模块,全部免费!包含各种卷积变种、最新注意力机制、特征融合模块、上下采样模块,适用于人工智能(AI)、深度学习、计算机视觉(CV)领域,适用于图像分类、目标检测、实例分割、语义分割、单目标跟踪(SOT)、多目标跟踪(MOT)、红外与可见光图像融合跟踪(RGBT)、图像去噪、去雨、去雾、去模糊、超分等任务,模块库持续更新中......https://github.com/AIFengheshu/Plug-play-modules'''class AgentAttention(nn.Module): def __init__(self, dim, num_patches, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, agent_num=49, **kwargs): super().__init__() assert dim % num_heads == 0, f'dim {dim} should be divided by num_heads {num_heads}.'
self.dim = dim self.num_patches = num_patches window_size = (int(num_patches ** 0.5), int(num_patches ** 0.5)) self.window_size = window_size self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias) self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop)
self.sr_ratio = sr_ratio if sr_ratio > 1: self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) self.norm = nn.LayerNorm(dim)
self.agent_num = agent_num self.dwc = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=(3, 3), padding=1, groups=dim) self.an_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7)) self.na_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7)) self.ah_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, window_size[0] // sr_ratio, 1)) self.aw_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, 1, window_size[1] // sr_ratio)) self.ha_bias = nn.Parameter(torch.zeros(1, num_heads, window_size[0], 1, agent_num)) self.wa_bias = nn.Parameter(torch.zeros(1, num_heads, 1, window_size[1], agent_num)) trunc_normal_(self.an_bias, std=.02) trunc_normal_(self.na_bias, std=.02) trunc_normal_(self.ah_bias, std=.02) trunc_normal_(self.aw_bias, std=.02) trunc_normal_(self.ha_bias, std=.02) trunc_normal_(self.wa_bias, std=.02) pool_size = int(agent_num ** 0.5) self.pool = nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size)) self.softmax = nn.Softmax(dim=-1)
def forward(self, x, H, W): b, n, c = x.shape num_heads = self.num_heads head_dim = c // num_heads q = self.q(x)
if self.sr_ratio > 1: x_ = x.permute(0, 2, 1).reshape(b, c, H, W) x_ = self.sr(x_).reshape(b, c, -1).permute(0, 2, 1) x_ = self.norm(x_) kv = self.kv(x_).reshape(b, -1, 2, c).permute(2, 0, 1, 3) else: kv = self.kv(x).reshape(b, -1, 2, c).permute(2, 0, 1, 3) k, v = kv[0], kv[1]
agent_tokens = self.pool(q.reshape(b, H, W, c).permute(0, 3, 1, 2)).reshape(b, c, -1).permute(0, 2, 1) q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3) k = k.reshape(b, n // self.sr_ratio ** 2, num_heads, head_dim).permute(0, 2, 1, 3) v = v.reshape(b, n // self.sr_ratio ** 2, num_heads, head_dim).permute(0, 2, 1, 3) agent_tokens = agent_tokens.reshape(b, self.agent_num, num_heads, head_dim).permute(0, 2, 1, 3)
kv_size = (self.window_size[0] // self.sr_ratio, self.window_size[1] // self.sr_ratio) position_bias1 = nn.functional.interpolate(self.an_bias, size=kv_size, mode='bilinear') position_bias1 = position_bias1.reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1) position_bias2 = (self.ah_bias + self.aw_bias).reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1) position_bias = position_bias1 + position_bias2 agent_attn = self.softmax((agent_tokens * self.scale) @ k.transpose(-2, -1) + position_bias) agent_attn = self.attn_drop(agent_attn) agent_v = agent_attn @ v
agent_bias1 = nn.functional.interpolate(self.na_bias, size=self.window_size, mode='bilinear') agent_bias1 = agent_bias1.reshape(1, num_heads, self.agent_num, -1).permute(0, 1, 3, 2).repeat(b, 1, 1, 1) agent_bias2 = (self.ha_bias + self.wa_bias).reshape(1, num_heads, -1, self.agent_num).repeat(b, 1, 1, 1) agent_bias = agent_bias1 + agent_bias2 q_attn = self.softmax((q * self.scale) @ agent_tokens.transpose(-2, -1) + agent_bias) q_attn = self.attn_drop(q_attn) x = q_attn @ agent_v
x = x.transpose(1, 2).reshape(b, n, c) v = v.transpose(1, 2).reshape(b, H // self.sr_ratio, W // self.sr_ratio, c).permute(0, 3, 1, 2) if self.sr_ratio > 1: v = nn.functional.interpolate(v, size=(H, W), mode='bilinear') x = x + self.dwc(v).permute(0, 2, 3, 1).reshape(b, n, c)
x = self.proj(x) x = self.proj_drop(x) return x
if __name__ == '__main__': dim = 64 num_patches = 49
block = AgentAttention(dim=dim, num_patches=num_patches)
H, W = 7, 7 x = torch.rand(1, num_patches, dim)
# Forward pass output = block(x, H, W) print(f'Input size: {x.size()}')    print(f'Output size: {output.size()}')

便捷下载

https://github.com/AIFengheshu/Plug-play-modules

AI缝合术