Patch to Embedding

首先,对Patch to Embedding这个行为做一个简单理解。

![patch to embedding示意](ViT实现/Patch to Embedding示意图.png)

这一动作是对标Transformer中Word Embedding的,但是是把一张图像分成多个Embedding。Patch to Embedding的过程是可学习的。

具体理解可以参考https://blog.csdn.net/lzzzzzzm/article/details/122902777这篇文章,接下来只说代码实现。

在Patch to Embedding中,一般会分为两种实现方法。

其一,分块+权重矩阵得到Embedding的结果

具体思路如下:

# 每个patch会被映射为一个embedding
def image2emb_naive(image, patch_size, weight):
# image shape : [batch_size,channels,height,width]
patch = F.unfold(image, kernel_size=patch_size, stride=patch_size).transpose(-1, -2)
patch_embedding = patch @ weight
return patch_embedding

bs, input_channel, image_h, image_w = 1, 3, 8, 8
image = torch.randn((bs, input_channel, image_h, image_w))
patch_size = 4
model_dim = 8
patch_depth = patch_size * patch_size * input_channel
# model_dim是输出通道数目,patch_depth是卷积核面积*输入通道数
weight = torch.randn(patch_depth, model_dim)
# 分块方法得到embedding
patch_embedding_naive = image2emb_naive(image, patch_size, weight)

运用unfold进行分块,为什么最后要交换呢?

单纯unfold得到patch的shape会是(batch_size,patch_size*patch_size*channels,patch_nums),对图像分块,其实没有了channels的维度,只会有像素数量,交换后可以与权重矩阵做进一步的乘法。

其二,卷积方法得到Embedding的结果。目前使用的就是这种

具体思路如下:

def image2emb_conv(image, kernel, stride):
# conv_output shape = [batch_size,model_dim,output_height,output_width],最后把output_height和output_width使用flatten拉伸
conv_output = F.conv2d(image, kernel, stride=stride)
print(conv_output.shape) # shape[1,8,2,2]
bs, oc, oh, ow = conv_output.shape
patch_embedding = conv_output.reshape((bs, oc, oh * ow)).transpose(-1, -2)
return patch_embedding

bs, input_channel, image_h, image_w = 1, 3, 8, 8
image = torch.randn((bs, input_channel, image_h, image_w))
patch_size = 4
model_dim = 8
patch_depth = patch_size * patch_size * input_channel
# model_dim是输出通道数目,patch_depth是卷积核面积*输入通道数
weight = torch.randn(patch_depth, model_dim)

# kernel shape = [model_dim,input_channel,kernel_height,kernel_width]
kernel = weight.transpose(0, 1).reshape((-1, input_channel, patch_size, patch_size))
# 得到kernel的shape应该是[8,3,4,4],image shape为[1,3,8,8] stride=4
patch_embedding_conv = image2emb_conv(image, kernel=kernel, stride=patch_size)

总结:从Patch to Embedding的过程能够知道,一张图像分块送入ViT中有一个前提条件,即image的长宽必须能除尽patch_size,在ViT的代码中,长和宽可以使用一个元组来说明,这时image的长必须除尽元组中代表的长,而image的宽必须除尽元组中代表的宽;另外,model_dim即输入到ViT中的Embedding的维数,个人理解model_dim的维数越多,能代表一个patch块中的信息量越大;为了更好的任务效果,这个映射过程在代码实现中实际上是可学习的。

Cls token & Positional Encoding

人为地添加一个与Patch Embedding相同维度的Embedding,该Embedding用于最终分类。

Positional Encoding用于标注patch的位置,具体需要额外了解。

参考文章:https://blog.csdn.net/chumingqian/article/details/124660657