迁移学习和微调(tensorflow)

发布于:2023-01-22 ⋅ 阅读:(12) ⋅ 点赞:(0) ⋅ 评论:(0)
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras
import matplotlib.pyplot as plt
from tensorflow.keras import layers



#加载数据集tf_flowers
tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "tf_flowers",
    split=["train[:70%]", "train[75%:85%]", "train[85%:95%]"],
    as_supervised=True,  # Include labels
)

#我们将图像的大小调整为 160x160:
size = (160, 160)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))

#此外,我们对数据进行批处理并使用缓存和预提取来优化加载速度。
batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

"""
当您没有较大的图像数据集时,通过将随机但现实的转换(例如随机水平翻转或小幅随机旋转)
应用于训练图像来人为引入样本多样性是一种良好的做法。
这有助于使模型暴露于训练数据的不同方面,同时减慢过拟合的速度。
"""

data_augmentation = keras.Sequential(

    [
        layers.experimental.preprocessing.RandomFlip("horizontal"),
        layers.experimental.preprocessing.RandomRotation(0.1),
    ]

)

base_model = keras.applications.MobileNetV2(
    weights="imagenet", # Load weights pre-trained on ImageNet.
    input_shape=(160, 160, 3),
    include_top=False,
) # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(160, 160, 3))
x = data_augmentation(inputs) # Apply random data augmentation
x = keras.applications.mobilenet_v2.preprocess_input(x)

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.

x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.3)(x) # Regularize with dropout
outputs = keras.layers.Dense(5)(x)

model = keras.Model(inputs, outputs)
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'],
)

#训练顶层
initial_epochs = 10
history = model.fit(train_ds, 
                    epochs=initial_epochs,
                    validation_data=validation_ds)

#评估模型
import matplotlib.pyplot as plt
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
%config InlineBackend.figure_format = 'retina'
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')
plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0.2,1.3])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()


# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()
model.compile(
    optimizer=keras.optimizers.Adam(1e-5), # Low learning rate
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'],
)

"""
对整个模型进行一轮微调
最后,我们解冻基础模型,并以较低的学习率端到端地训练整个模型。

重要的是,尽管基础模型变得可训练,
但在构建模型过程中,由于我们在调用该模型时传递了 training=False,因此它仍在推断模式下运行。
这意味着内部的批次归一化层不会更新其批次统计信息。
如果它们更新了这些统计信息,则会破坏该模型到目前为止所学习的表示。
"""
fine_tune_epochs = 10
total_epochs = initial_epochs + fine_tune_epochs
history_fine = model.fit(train_ds,
                         epochs=total_epochs,
                         initial_epoch=history.epoch[-1],
                         validation_data=validation_ds)


loss, accuracy = model.evaluate(test_ds)
print('Test accuracy :', accuracy)

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_4 (InputLayer)        [(None, 160, 160, 3)]     0         
                                                                 
 sequential_1 (Sequential)   (None, 160, 160, 3)       0         
                                                                 
 tf.math.truediv_1 (TFOpLamb  (None, 160, 160, 3)      0         
 da)                                                             
                                                                 
 tf.math.subtract_1 (TFOpLam  (None, 160, 160, 3)      0         
 bda)                                                            
                                                                 
 mobilenetv2_1.00_160 (Funct  (None, 5, 5, 1280)       2257984   
 ional)                                                          
                                                                 
 global_average_pooling2d_1   (None, 1280)             0         
 (GlobalAveragePooling2D)                                        
                                                                 
 dropout_1 (Dropout)         (None, 1280)              0         
                                                                 
 dense_1 (Dense)             (None, 5)                 6405      
                                                                 
=================================================================
Total params: 2,264,389
Trainable params: 6,405
Non-trainable params: 2,257,984
_________________________________________________________________
Epoch 1/10
81/81 [==============================] - 44s 501ms/step - loss: 1.0165 - accuracy: 0.6185 - val_loss: 0.5285 - val_accuracy: 0.8004
Epoch 2/10
81/81 [==============================] - 36s 441ms/step - loss: 0.5426 - accuracy: 0.8062 - val_loss: 0.4269 - val_accuracy: 0.8439
Epoch 3/10
81/81 [==============================] - 35s 437ms/step - loss: 0.4431 - accuracy: 0.8392 - val_loss: 0.3853 - val_accuracy: 0.8748
Epoch 4/10
81/81 [==============================] - 36s 445ms/step - loss: 0.4129 - accuracy: 0.8517 - val_loss: 0.3681 - val_accuracy: 0.8730
Epoch 5/10
81/81 [==============================] - 35s 437ms/step - loss: 0.3719 - accuracy: 0.8630 - val_loss: 0.3583 - val_accuracy: 0.8693
Epoch 6/10
81/81 [==============================] - 35s 438ms/step - loss: 0.3412 - accuracy: 0.8762 - val_loss: 0.3498 - val_accuracy: 0.8802
Epoch 7/10
81/81 [==============================] - 35s 434ms/step - loss: 0.3313 - accuracy: 0.8817 - val_loss: 0.3421 - val_accuracy: 0.8784
Epoch 8/10
81/81 [==============================] - 35s 429ms/step - loss: 0.3177 - accuracy: 0.8817 - val_loss: 0.3413 - val_accuracy: 0.8766
Epoch 9/10
81/81 [==============================] - 36s 442ms/step - loss: 0.2999 - accuracy: 0.8879 - val_loss: 0.3379 - val_accuracy: 0.8820
Epoch 10/10
81/81 [==============================] - 35s 434ms/step - loss: 0.2849 - accuracy: 0.8930 - val_loss: 0.3368 - val_accuracy: 0.8838

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_4 (InputLayer)        [(None, 160, 160, 3)]     0         
                                                                 
 sequential_1 (Sequential)   (None, 160, 160, 3)       0         
                                                                 
 tf.math.truediv_1 (TFOpLamb  (None, 160, 160, 3)      0         
 da)                                                             
                                                                 
 tf.math.subtract_1 (TFOpLam  (None, 160, 160, 3)      0         
 bda)                                                            
                                                                 
 mobilenetv2_1.00_160 (Funct  (None, 5, 5, 1280)       2257984   
 ional)                                                          
                                                                 
 global_average_pooling2d_1   (None, 1280)             0         
 (GlobalAveragePooling2D)                                        
                                                                 
 dropout_1 (Dropout)         (None, 1280)              0         
                                                                 
 dense_1 (Dense)             (None, 5)                 6405      
                                                                 
=================================================================
Total params: 2,264,389
Trainable params: 2,230,277
Non-trainable params: 34,112
_________________________________________________________________
Epoch 10/20
81/81 [==============================] - 123s 1s/step - loss: 0.2722 - accuracy: 0.8926 - val_loss: 0.3170 - val_accuracy: 0.8802
Epoch 11/20
81/81 [==============================] - 107s 1s/step - loss: 0.2504 - accuracy: 0.9066 - val_loss: 0.2973 - val_accuracy: 0.8857
Epoch 12/20
81/81 [==============================] - 106s 1s/step - loss: 0.2193 - accuracy: 0.9206 - val_loss: 0.3110 - val_accuracy: 0.8838
Epoch 13/20
81/81 [==============================] - 107s 1s/step - loss: 0.2034 - accuracy: 0.9237 - val_loss: 0.2772 - val_accuracy: 0.8911
Epoch 14/20
81/81 [==============================] - 107s 1s/step - loss: 0.1771 - accuracy: 0.9401 - val_loss: 0.2645 - val_accuracy: 0.9111
Epoch 15/20
81/81 [==============================] - 106s 1s/step - loss: 0.1667 - accuracy: 0.9424 - val_loss: 0.2609 - val_accuracy: 0.9111
Epoch 16/20
81/81 [==============================] - 107s 1s/step - loss: 0.1602 - accuracy: 0.9408 - val_loss: 0.2663 - val_accuracy: 0.9165
Epoch 17/20
81/81 [==============================] - 106s 1s/step - loss: 0.1391 - accuracy: 0.9475 - val_loss: 0.2469 - val_accuracy: 0.9147
Epoch 18/20
81/81 [==============================] - 107s 1s/step - loss: 0.1203 - accuracy: 0.9603 - val_loss: 0.2470 - val_accuracy: 0.9129
Epoch 19/20
81/81 [==============================] - 106s 1s/step - loss: 0.1134 - accuracy: 0.9548 - val_loss: 0.2566 - val_accuracy: 0.9038
Epoch 20/20
81/81 [==============================] - 107s 1s/step - loss: 0.1055 - accuracy: 0.9611 - val_loss: 0.2639 - val_accuracy: 0.9074

18/18 [==============================] - 6s 279ms/step - loss: 0.2469 - accuracy: 0.9164
Test accuracy : 0.9163636565208435