深度学习 --- 基于AlexNet的花卉识别

发布于:2025-08-14 ⋅ 阅读:(22) ⋅ 点赞:(0)

深度学习 — 基于AlexNet的花卉识别


一,构建模型

model.py

import torch
import torch.nn as nn

class AlexNet(nn.Module):
    def __init__(self, num_classes=5, init_weights=False):
        super(AlexNet, self).__init__()
        self._config = {'num_classes': num_classes}
        self.features = nn.Sequential(
            # 卷积层,分别是
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        # 自适应层
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))

        # 全连接层
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, self._config['num_classes']),
        )
        if init_weights:
            self._initialize_weights()

        #前向传播
    def forward(self, x):
        x = self.features(x)#卷积层
        x = self.avgpool(x)#适应层
        x = torch.flatten(x, 1)#展平层
        x = self.classifier(x)#全连接层
        return x
        
    # 权重初始化
    def _initialize_weights(self):
        for m in self.modules():#遍历所有模块
            if isinstance(m, nn.Conv2d):#判断是否是卷积层
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)#零偏置
            elif isinstance(m, nn.Linear):#判断是否是全连接层
                nn.init.normal_(m.weight, 0, 0.01)#正太分布
                nn.init.constant_(m.bias, 0)#零偏置


      

二,模型训练

train.py

import os
import sys
import json
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm
from model import AlexNet

def main():
    # 使用GPU
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train":
            transforms.Compose([
                transforms.RandomResizedCrop(224),  # 随机裁剪 224*224
                transforms.RandomHorizontalFlip(),  # 随机翻转 水平方向随机翻转进行数据增强
                transforms.ToTensor(),  # 转化为Tensor
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ]),
        "val":
            transforms.Compose([
                transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
    }

    # 数据集路径
    data_root = os.path.abspath(os.path.join(os.getcwd(), "dataset"))
    image_path = os.path.join(data_root, "flower_photos")  # flower data set path

    # 加载整个数据集
    dataset = datasets.ImageFolder(root=image_path, transform=data_transform["train"])
    train_num = len(dataset)  # 数据集总图片数

    # 字典,类别:索引{'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = dataset.class_to_idx  # 去获取分类名称所对应的索引
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # 写入json文件
    json_str = json.dumps(cla_dict, indent=4)
    with open('flower_AlexNet/class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    # 划分训练集和验证集
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, validate_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

    batch_size = 32

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=4, shuffle=False, num_workers=0)

    print("using {} images for training, {} images for validation.".format(train_size, val_size))

    net = AlexNet(num_classes=5, init_weights=True)
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0002)

    epochs = 10
    save_path = './flower_AlexNet/AlexNet.pth'  # 保存网络的路径
    best_acc = 0.0  # 定义这个参数是为了在后边训练网络中保存准确率最高的那次模型
    train_steps = len(train_loader)

    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)

        # validate
        net.eval()
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_size
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' % (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')

if __name__ == '__main__':
    main()

三,模型预测

predict.py

import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import AlexNet
import os

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(  # 依然是对数据先进行预处理
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # load image
    img_path = "./dataset/flower_photos/dandelion/8223949_2928d3f6f6_n.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)  # 直接使用PIL库载入一张图像

    plt.imshow(img)  # 简单展示一下这张图片
    # [N, C, H, W]
    img = data_transform(img)  # 对图片进行预处理
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)  # 扩充一个维度,添加一个batch维度

    # read class_indict
    json_path = './class_indices.json'  # 读取json文件,也就是索引对应的类别名称
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)  # 对json文件进行解码,解码成我们所需要的字典

    # create model
    model = AlexNet(num_classes=5).to(device)  # 初始化我们的网络

    # load model weights
    weights_path = "./flower_AlexNet/AlexNet.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path))  # 载入我们的网络模型

    model.eval()  # 进入eval模式,没有dropout的那个
    with torch.no_grad():  # 不跟踪变量的损失梯度
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()  # 将数据通过model进行正向传播得到输出
        # squeeze将输出进行压缩,把第一个维度的batch压缩掉了
        predict = torch.softmax(output, dim=0)  # softmax得到概率分布
        predict_cla = torch.argmax(predict).numpy()  # 概率最大处所对应的索引值

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    # 打印预测名称,已经对应类别的概率
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()


if __name__ == '__main__':
    main()

四,目录及数据集地址

目录与数据集为同一级

4.1 目录

在这里插入图片描述

4.2 数据集

通过网盘分享的文件:
链接: https://pan.baidu.com/s/11K54YfxGA0GOeC4G5FtfTg?pwd=rgun


网站公告

今日签到

点亮在社区的每一天
去签到