Keras/TensorFlow 中 `fit()` 方法参数详细说明

发布于:2025-09-04 ⋅ 阅读:(25) ⋅ 点赞:(0)

Keras/TensorFlow 中 fit() 方法参数详细说明

Keras/TensorFlow 中的 fit() 方法是训练神经网络的核心API,提供了丰富的参数来控制训练过程。以下是所有参数的详细说明:

一、基础参数

1. x/y

  • 作用:输入数据和目标数据
  • 类型
    • NumPy数组
    • TensorFlow张量
    • 字典(用于具名输入)
    • tf.data数据集
  • 示例
    model.fit(x=train_images, y=train_labels)
    

2. batch_size

  • 作用:每个梯度更新的样本数
  • 类型:整数或None
  • 默认值:32
  • 注意
    • 如果使用数据集对象并且指定了steps_per_epoch,则不需要设置
    • 典型值:16/32/64/128/256

3. epochs

  • 作用:训练轮次数
  • 类型:整数
  • 默认值:1
  • 示例
    model.fit(..., epochs=50)
    

4. verbose

  • 作用:控制训练过程输出的详细程度
  • 类型:整数
  • 可选值
    • 0:静默模式
    • 1:进度条(默认)
    • 2:每个epoch一行输出

二、验证相关参数

5. validation_split

  • 作用:从训练数据中分出部分作为验证集的比例
  • 类型:0-1之间的浮点数
  • 默认值:0.0(不使用)
  • 示例
    model.fit(..., validation_split=0.2)  # 使用20%数据作为验证集
    

6. validation_data

  • 作用:手动指定验证数据集
  • 类型:与x/y相同的格式
  • 优先级:高于validation_split
  • 示例
    model.fit(..., validation_data=(val_images, val_labels))
    

7. validation_freq

  • 作用:指定每隔多少epoch进行一次验证
  • 类型:整数或列表
  • 默认值:1(每个epoch都验证)
  • 示例
    model.fit(..., validation_freq=3)  # 每3个epoch验证一次
    

三、数据相关参数

8. shuffle

  • 作用:是否在每个epoch前打乱数据
  • 类型:布尔值
  • 默认值True
  • 注意:使用tf.data数据集时优先使用数据集自身的shuffle操作

9. class_weight

  • 作用:为不同类别分配权重(用于不平衡数据集)
  • 类型:字典
  • 示例
    model.fit(..., class_weight={0: 1., 1: 0.5})  # 类别1的权重是类别0的一半
    

10. sample_weight

  • 作用:为每个样本分配权重
  • 类型:NumPy数组
  • 示例
    weights = np.array([1.0, 1.5])  # 第二个样本权重更大
    model.fit(..., sample_weight=weights)
    

11. initial_epoch

  • 作用:从指定epoch开始训练(用于恢复训练)
  • 类型:整数
  • 默认值:0
  • 示例
    model.fit(..., initial_epoch=10)  # 从第10个epoch开始
    

四、回调与控制参数

12. callbacks

  • 作用:训练过程中执行的回调函数列表
  • 类型:列表
  • 常见回调
    • EarlyStopping - 早停
    • ModelCheckpoint - 保存模型
    • TensorBoard - 可视化
    • LearningRateScheduler - 学习率调整
  • 示例
    callbacks = [
        tf.keras.callbacks.EarlyStopping(patience=3),
        tf.keras.callbacks.ModelCheckpoint('model.h5')
    ]
    model.fit(..., callbacks=callbacks)
    

五、高级参数

13. steps_per_epoch

  • 作用:每个epoch执行的batch步数
  • 类型:整数
  • 默认值None(自动计算:样本数/batch_size)
  • 适用场景
    • 使用无限数据集时必需指定
    • 部分数据集训练

14. validation_steps

  • 作用:验证时使用的batch步数
  • 类型:整数
  • 适用场景
    • 验证数据为无限数据集时必需指定

15. max_queue_size

  • 作用:生成器队列的最大大小
  • 类型:整数
  • 默认值:10
  • 适用场景:使用Python生成器作为输入时

16. workers

  • 作用:生成器预处理的最大进程数
  • 类型:整数
  • 默认值:1

17. use_multiprocessing

  • 作用:是否使用多进程处理数据
  • 类型:布尔值
  • 默认值False
  • 注意:设置True可能导致性能下降

六、实际使用示例

# 完整参数示例
history = model.fit(
    x=train_images,
    y=train_labels,
    batch_size=64,
    epochs=100,
    verbose=1,
    callbacks=[
        tf.keras.callbacks.EarlyStopping(patience=5),
        tf.keras.callbacks.ReduceLROnPlateau(factor=0.1, patience=3)
    ],
    validation_data=(val_images, val_labels),
    validation_freq=2,
    shuffle=True,
    class_weight={0: 1.0, 1: 2.0},  # 假设类别1更重要
    initial_epoch=0,
    steps_per_epoch=None,
    validation_steps=None,
    max_queue_size=10,
    workers=4,
    use_multiprocessing=False
)

七、返回值

fit() 方法返回 History 对象,包含:

  • history.history:字典,包含训练过程中的loss和metrics记录
  • history.epoch:完成的epoch列表
  • history.params:训练参数
  • history.model:对应的模型对象
# 使用训练历史
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Val'], loc='upper left')
plt.show()

网站公告

今日签到

点亮在社区的每一天
去签到