《昇思 25 天学习打卡营第 12 天 | Vision Transformer(ViT)图像分类 》
活动地址:https://xihe.mindspore.cn/events/mindspore-training-camp
签名:Sam9029
今天有点忙,内容不能仔细看了,但是上次的 ResNet网络图像分类和这次的
Vision Transformer 是不同的技术栈, 对同一类问题的解决方案
Vision Transformer(ViT)简介
Vision Transformer(ViT)是一种结合了自然语言处理中流行的Transformer模型和计算机视觉任务的深度学习模型。
它不依赖于传统的卷积神经网络(CNN)结构,而是直接使用Transformer的Encoder部分来处理图像数据。
环境准备与数据读取
在开始之前,确保安装了Python和MindSpore。MindSpore是一个开源的深度学习框架,由华为提供。
!pip install mindspore==2.2.14 -i https://pypi.mirrors.ustc.edu.cn/simple
接着,下载并准备数据集。本案例使用的数据集是ImageNet的一个子集。
from download import download
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
path = download(dataset_url, "./", kind="zip", replace=True)
数据增强
使用MindSpore的数据增强功能来准备训练数据。
from mindspore.dataset.vision import transforms
trans_train = [
transforms.RandomCropDecodeResize(size=224),
transforms.RandomHorizontalFlip(prob=0.5),
transforms.Normalize(mean=mean, std=std),
transforms.HWC2CHW()
]
ViT模型结构
ViT模型的核心是Transformer Encoder,它通过将图像分割成多个小块(Patch),然后将这些小块视为序列元素进行处理。
Patch Embedding
图像首先被分割成小块,每个小块被拉伸成一维向量。
class PatchEmbedding(nn.Cell):
def __init__(self, image_size=224, patch_size=16, embed_dim=768, input_channels=3):
super(PatchEmbedding, self).__init__()
self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
Transformer Encoder
Transformer Encoder由多头注意力(Multi-Head Attention)和Feed Forward网络组成。
class TransformerEncoder(nn.Cell):
def __init__(self, dim, num_layers, num_heads, mlp_dim, keep_prob=1.0):
super(TransformerEncoder, self).__init__()
self.layers = nn.SequentialCell([
ResidualCell(Attention(dim, num_heads, keep_prob)),
ResidualCell(FeedForward(dim, mlp_dim, keep_prob))
])
模型训练
设置损失函数、优化器和回调函数,然后开始训练模型。
from mindspore.train import Model, LossMonitor, CheckpointConfig, ModelCheckpoint
# define model, loss function, optimizer, and callbacks
# ...
# start training
model.train(epoch_size, dataset_train, callbacks=[ckpt_callback, LossMonitor(), TimeMonitor()])
模型验证与推理
使用验证集评估模型性能,并通过推理来测试模型对新数据的预测能力。
# evaluate model
result = model.eval(dataset_val)
# inference
for image in dataset_infer.create_dict_iterator(output_numpy=True):
prob = model.predict(ms.Tensor(image["image"]))
label = np.argmax(prob.asnumpy(), axis=1)
print(f"Predicted label: {label}")
思考
通过学习ViT模型,我了解到深度学习中对于图像分类这一类问题是由不同的解决方案的。
ResNet 是 采用 全卷积网络
ViT 是 采用 Transformer模型
ViT模型突破了传统CNN的限制,展示了Transformer结构在图像处理上的巨大潜力。
然而,ViT模型的训练和推理需要大量的计算资源。对于资源有限的小白来说,可能需要寻找简化模型或者使用预训练模型的方法来降低门槛。