ViT实现
Patch to Embedding
首先,对Patch to Embedding这个行为做一个简单理解。
![patch to embedding示意](ViT实现/Patch to Embedding示意图.png)
这一动作是对标Transformer中Word Embedding的,但是是把一张图像分成多个Embedding。Patch to Embedding的过程是可学习的。
具体理解可以参考https://blog.csdn.net/lzzzzzzm/article/details/122902777这篇文章,接下来只说代码实现。
在Patch to Embedding中,一般会分为两种实现方法。
其一,分块+权重矩阵得到Embedding的结果
具体思路如下:
# 每个patch会被映射为一个embedding |
运用unfold进行分块,为什么最后要交换呢?
单纯unfold得到patch的shape会是(batch_size,patch_size*patch_size*channels,patch_nums)
,对图像分块,其实没有了channels的维度,只会有像素数量,交换后可以与权重矩阵做进一步的乘法。
其二,卷积方法得到Embedding的结果。目前使用的就是这种
具体思路如下:
def image2emb_conv(image, kernel, stride): |
总结:从Patch to Embedding的过程能够知道,一张图像分块送入ViT中有一个前提条件,即image的长宽必须能除尽patch_size,在ViT的代码中,长和宽可以使用一个元组来说明,这时image的长必须除尽元组中代表的长,而image的宽必须除尽元组中代表的宽;另外,model_dim即输入到ViT中的Embedding的维数,个人理解model_dim的维数越多,能代表一个patch块中的信息量越大;为了更好的任务效果,这个映射过程在代码实现中实际上是可学习的。
Cls token & Positional Encoding
人为地添加一个与Patch Embedding相同维度的Embedding,该Embedding用于最终分类。
Positional Encoding用于标注patch的位置,具体需要额外了解。
参考文章:https://blog.csdn.net/chumingqian/article/details/124660657