一、论文
and Linear Attention
二、概要总结
代理注意力动机的示意图。(a) 在Softmax注意力中,每个查询聚合来自所有特征的信息,导致二次复杂度。(b) 利用注意力权重之间的冗余,代理注意力使用少量代理标记作为查询的“代理”,捕获来自所有特征的多样化语义信息,然后将其呈现给每个查询。注意力权重来自DeiT-T和Agent-DeiT-T。
-
现代Transformer模型通常采用Softmax自注意力机制,其计算复杂度与token数量的平方成正比,这在视觉任务中可能导致计算需求过高。
-
为了解决这一问题,研究者们尝试通过设计高效的注意力模式来降低计算复杂度,但这些方法往往牺牲了模型的全局感受野和长距离关系建模能力。
-
为了在计算效率和表示能力之间取得平衡,论文提出了一种新的注意力范式——Agent Attention。
三、方法
1. Agent Attention的提出:提出了一种新的注意力范式——Agent Attention,通过引入一组代理令牌(agent tokens)A,使得查询令牌(query tokens)Q能够通过代理令牌聚合来自键(keys)和值(values)的信息,然后将信息广播回每个查询令牌。
2. Agent Attention的计算过程:Agent Attention可以表示为四元组(Q, A, K, V),其中A作为查询进行第一次Softmax注意力计算以聚合代理特征VA,然后作为键进行第二次Softmax注意力计算以将全局信息广播给每个查询令牌,形成最终输出。
3. Agent Attention与Softmax和线性注意力的关系: Agent Attention实际上是Softmax注意力和线性注意力的一种优雅整合,它既保留了Softmax注意力的全局上下文建模能力,又具有线性注意力的高效计算复杂度。
四、实验分析
1. 图像分类结果:Agent Attention在ImageNet-1K分类任务中显著提升了各种视觉Transformer模型的性能,特别是在高分辨率场景中表现突出。
2. 目标检测结果: 在COCO数据集上的目标检测实验表明,Agent Attention能够一致地增强不同配置下的性能。
3. 语义分割结果: 在ADE 20K数据集上的语义分割实验显示,Agent Attention与各种分割模型兼容,并且能够持续提升性能。
4. 稳定扩散结果: Agent Attention在稳定扩散模型中的应用显著加速了图像生成过程,并在无需额外训练的情况下提高了生成质量。
5. 大感受野和高分辨率结果: Agent Attention能够利用全局感受野,同时保持与Softmax注意力相同的计算复杂度,这在高分辨率场景中尤为重要。
五、代码
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_
# 论文:Agent Attention: On the Integration of Softmax and Linear Attention
# 论文地址:https:///pdf/2312.08874
# 微信公众号:AI缝合术
'''
2024年全网最全即插即用模块,全部免费!包含各种卷积变种、最新注意力机制、特征融合模块、上下采样模块,
适用于人工智能(AI)、深度学习、计算机视觉(CV)领域,适用于图像分类、目标检测、实例分割、语义分割、
单目标跟踪(SOT)、多目标跟踪(MOT)、红外与可见光图像融合跟踪(RGBT)、图像去噪、去雨、去雾、去模糊、超分等任务,
模块库持续更新中......
https://github.com/AIFengheshu/Plug-play-modules
'''
class AgentAttention(nn.Module):
def __init__(self, dim, num_patches, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,
sr_ratio=1, agent_num=49, **kwargs):
super().__init__()
assert dim % num_heads == 0, f'dim {dim} should be divided by num_heads {num_heads}.'
self.dim = dim
self.num_patches = num_patches
window_size = (int(num_patches ** 0.5), int(num_patches ** 0.5))
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
self.agent_num = agent_num
self.dwc = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=(3, 3), padding=1, groups=dim)
self.an_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7))
self.na_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7))
self.ah_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, window_size[0] // sr_ratio, 1))
self.aw_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, 1, window_size[1] // sr_ratio))
self.ha_bias = nn.Parameter(torch.zeros(1, num_heads, window_size[0], 1, agent_num))
self.wa_bias = nn.Parameter(torch.zeros(1, num_heads, 1, window_size[1], agent_num))
trunc_normal_(self.an_bias, std=.02)
trunc_normal_(self.na_bias, std=.02)
trunc_normal_(self.ah_bias, std=.02)
trunc_normal_(self.aw_bias, std=.02)
trunc_normal_(self.ha_bias, std=.02)
trunc_normal_(self.wa_bias, std=.02)
pool_size = int(agent_num ** 0.5)
self.pool = nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, H, W):
b, n, c = x.shape
num_heads = self.num_heads
head_dim = c // num_heads
q = self.q(x)
if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(b, c, H, W)
x_ = self.sr(x_).reshape(b, c, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = self.kv(x_).reshape(b, -1, 2, c).permute(2, 0, 1, 3)
else:
kv = self.kv(x).reshape(b, -1, 2, c).permute(2, 0, 1, 3)
k, v = kv[0], kv[1]
agent_tokens = self.pool(q.reshape(b, H, W, c).permute(0, 3, 1, 2)).reshape(b, c, -1).permute(0, 2, 1)
q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
k = k.reshape(b, n // self.sr_ratio ** 2, num_heads, head_dim).permute(0, 2, 1, 3)
v = v.reshape(b, n // self.sr_ratio ** 2, num_heads, head_dim).permute(0, 2, 1, 3)
agent_tokens = agent_tokens.reshape(b, self.agent_num, num_heads, head_dim).permute(0, 2, 1, 3)
kv_size = (self.window_size[0] // self.sr_ratio, self.window_size[1] // self.sr_ratio)
position_bias1 = nn.functional.interpolate(self.an_bias, size=kv_size, mode='bilinear')
position_bias1 = position_bias1.reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1)
position_bias2 = (self.ah_bias + self.aw_bias).reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1)
position_bias = position_bias1 + position_bias2
agent_attn = self.softmax((agent_tokens * self.scale) @ k.transpose(-2, -1) + position_bias)
agent_attn = self.attn_drop(agent_attn)
agent_v = agent_attn @ v
agent_bias1 = nn.functional.interpolate(self.na_bias, size=self.window_size, mode='bilinear')
agent_bias1 = agent_bias1.reshape(1, num_heads, self.agent_num, -1).permute(0, 1, 3, 2).repeat(b, 1, 1, 1)
agent_bias2 = (self.ha_bias + self.wa_bias).reshape(1, num_heads, -1, self.agent_num).repeat(b, 1, 1, 1)
agent_bias = agent_bias1 + agent_bias2
q_attn = self.softmax((q * self.scale) @ agent_tokens.transpose(-2, -1) + agent_bias)
q_attn = self.attn_drop(q_attn)
x = q_attn @ agent_v
x = x.transpose(1, 2).reshape(b, n, c)
v = v.transpose(1, 2).reshape(b, H // self.sr_ratio, W // self.sr_ratio, c).permute(0, 3, 1, 2)
if self.sr_ratio > 1:
v = nn.functional.interpolate(v, size=(H, W), mode='bilinear')
x = x + self.dwc(v).permute(0, 2, 3, 1).reshape(b, n, c)
x = self.proj(x)
x = self.proj_drop(x)
return x
if __name__ == '__main__':
dim = 64
num_patches = 49
block = AgentAttention(dim=dim, num_patches=num_patches)
H, W = 7, 7
x = torch.rand(1, num_patches, dim)
# Forward pass
output = block(x, H, W)
print(f'Input size: {x.size()}')
print(f'Output size: {output.size()}')
便捷下载
https://github.com/AIFengheshu/Plug-play-modules
AI缝合术