摘要:从机制上论证:Mamba(SSM)更适合“长序列 + 自回归”的任务,而大多数视觉识别并不同时满足这两点,因此在 ImageNet 分类中不需要 SSM;同时提出对检测/分割这类长序列但非自回归任务,SSM 仍可能有价值。实证上验证:把视觉 Mamba 的核心
题目: MambaOut: Do WeReally Need Mamba for Vision?
论文地址:https://arxiv.org/pdf/2405.07992v1
从机制上论证:Mamba(SSM)更适合“长序列 + 自回归”的任务,而大多数视觉识别并不同时满足这两点,因此在 ImageNet 分类中不需要 SSM;同时提出对检测/分割这类长序列但非自回归任务,SSM 仍可能有价值。
实证上验证:把视觉 Mamba 的核心 SSM 拿掉,堆叠成“Gated CNN”块的系列模型——MambaOut,在 ImageNet 上全面超过多种视觉 Mamba 模型;而在 COCO/ADE20K 上略逊于最强的视觉 Mamba,印证上述假设
整体结构
MambaOut 是一个四阶段分层主干网络:输入先经轻量 Stem 与下采样进入各层级,每层由若干 Gated CNN 块组成;块内先做归一化与两支全连接门控,再在一部分通道上施加 7×7 深度卷积完成 token mixing,最后残差相加;跨阶段通过 3×3/stride=2 进行分辨率递降,网络末端用 GAP+Norm+MLP 输出分类结果MambaOut 模型基于 Gated CNN block 构建,而不是 Mamba block。
整体采用 ResNet 风格的 4-stage 分层架构 :
Stem:3×3 卷积 + Norm + 激活
Stage 1–4:堆叠 Gated CNN block,并在阶段间下采样
Classifier head:全局平均池化 + Norm + MLP
区别点 :Mamba block = Gated CNN + SSM,而 MambaOut 移除了 SSM,仅保留 Gated CNN
GatedCNNBlock 作为即插即用模块,适用于图像分类等非长序列任务,它通过卷积 + 门控机制实现高效的特征混合,在保证性能的同时减少复杂性,是 SSM 的轻量化替代方案
适用场景:
图像分类 :在不需要长序列建模和自回归的任务(如 ImageNet 分类)中,GatedCNNBlock 足以提供优秀的性能,甚至超过引入 SSM 的 Vision Mamba。
轻量化/高效模型 :相比引入 SSM 的 Mamba block,GatedCNNBlock 更简单、计算更高效,适合在移动端或算力受限环境中作为 backbone。
卷积替代注意力的结构 :在希望避免注意力机制高计算复杂度的情况下,GatedCNNBlock 可作为高效 token mixer 使用
模块作用:
核心作用 :作为 token mixer,用深度可分离卷积来建模局部与部分全局依赖关系。
结构优势 :
保留了卷积的归纳偏置(locality、平移不变性),适合视觉任务。
通过 gating(门控机制)增强特征选择能力,使得模型能够自动控制不同通道/区域的信息流。
避免了 SSM 带来的 因果约束 & 长序列建模开销 ,在短序列视觉任务(如分类)中更高效
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_, DropPath
class GatedCNNBlock(nn.Module):
r""" Our implementation of Gated CNN Block: https://arxiv.org/pdf/1612.08083
Args:
conv_ratio: control the number of channels to conduct depthwise convolution.
Conduct convolution on partial channels can improve practical efficiency.
The idea of partial channels is from ShuffleNet V2 (https://arxiv.org/abs/1807.11164) and
also used by InceptionNeXt (https://arxiv.org/abs/2303.16900) and FasterNet (https://arxiv.org/abs/2303.03667)
"""
def __init__(self, dim, expansion_ratio=8/3, kernel_size=7, conv_ratio=1.0,
norm_layer=partial(nn.LayerNorm,eps=1e-6),
act_layer=nn.GELU,
drop_path=0.,
**kwargs):
super.__init__
self.norm = norm_layer(dim)
hidden = int(expansion_ratio * dim)
self.fc1 = nn.Linear(dim, hidden * 2)
self.act = act_layer
conv_channels = int(conv_ratio * dim)
self.split_indices = (hidden, hidden - conv_channels, conv_channels)
self.conv = nn.Conv2d(conv_channels, conv_channels, kernel_size=kernel_size, padding=kernel_size//2, groups=conv_channels)
self.fc2 = nn.Linear(hidden, dim)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity
def forward(self, x):
shortcut = x # [B, H, W, C]
x = self.norm(x)
g, i, c = torch.split(self.fc1(x), self.split_indices, dim=-1)
c = c.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
c = self.conv(c)
0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C]
x = self.fc2(self.act(g) * torch.cat((i, c), dim=-1))
x = self.drop_path(x)
return x + shortcut
if __name__ == "__main__":
batch_size = 1 # Batch size
channels = 32 # 输入通道数
height = 256 # 输入图像高度
width = 256 # 输入图像宽度
# 创建一个模拟输入张量,形状为 (batch_size, height, width, channels)
x = torch.randn(batch_size, height, width, channels)
# 初始化 GatedCNNBlock 模块
model = GatedCNNBlock(dim=channels, expansion_ratio=8/3, kernel_size=7, conv_ratio=1.0, drop_path=0.1)
print(model)
# 前向传播
output = model(x)
# 打印输入和输出张量的形状
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
更多分析可见原文
来源:寂寞的咖啡