1、绪论
在TensorFlow中,fit() 方法是一个用于训练模型的便捷函数,它封装了训练循环(training loop)的许多常见步骤,如前向传播(forward pass)、计算损失(loss)、反向传播(backward pass)以及模型权重的更新。然而,有时候你可能想要对训练过程进行更细粒度的控制,或者添加一些自定义的步骤。
 当进行监督学习时,可以使用fit()方法,并且一切都会顺利进行。
但是,当需要控制每一个小细节时,就可以完全从头开始编写自己的训练循环。
但如果需要一个自定义的训练算法,但又想从fit()的便捷功能中受益,比如回调(callbacks)、内置的分布支持(built-in distribution support)或步骤融合(step fusing)时,又该如何呢?
Keras的一个核心原则是复杂性的逐步展现。总是能够逐步深入到更低级别的工作流程中。如果高级功能不完全符合你的测试用例,也不会突然陷入困境。我们可以在保留相应级别的高级便利性的同时,对细节获得更多的控制权。
当需要自定义fit()的行为时,我们应该重写Model类的训练步骤函数。这是fit()在处理每一批数据时调用的函数。然后你就可以像平常一样调用fit()——而它将会运行你自己的学习算法。
请注意,这种模式并不会阻止你使用函数式API构建模型。无论你是构建Sequential模型、函数式API模型还是子类化模型,都可以采用这种方法。
2、准备工作
#2.1 基础设置
开始操作前请按照如下进行基础设置
import os
# This guide can only be run with the TF backend.
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
import keras
from keras import layers
import numpy as np
2.2 操作示例
以下是一个使用TensorFlow来自定义fit()的示例:
首先需要创建一个新的类,该类继承自keras.Model。
 然后重写train_step(self, data)这个方法。
 之后我们返回数据字典,该字典将度量指标名称(包括损失)映射到它们的当前值。
 输入参数data是传递给fit方法的训练数据:
- 如果你通过调用fit(x, y, ...)传递NumPy数组,那么data将是元组(x, y)
- 如果你通过调用fit(dataset, ...)传递一个tf.data.Dataset,那么data将是dataset在每个批次中产生的数据。
在train_step()方法的主体中,我们实现了一个常规的训练更新过程,类似于你已经熟悉的过程。重要的是,我们通过self.compute_loss()计算损失,该方法封装了在compile()方法中传递的损失函数。
类似地,我们调用metric.update_state(y, y_pred)来更新在compile()方法中传递的度量指标的状态,并在最后通过self.metrics查询结果来获取它们的当前值。
class CustomModel(keras.Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compute_loss(y=y, y_pred=y_pred)
        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply(gradients, trainable_vars)
        # Update metrics (includes the metric that tracks the loss)
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}
运行代码,看看输出结果
# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
# Just use `fit` as usual
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=3)
Epoch 1/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - mae: 0.5089 - loss: 0.3778   
Epoch 2/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 318us/step - mae: 0.3986 - loss: 0.2466
Epoch 3/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 372us/step - mae: 0.3848 - loss: 0.2319
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1699222602.443035       1 device_compiler.h:187] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
<keras.src.callbacks.history.History at 0x2a5599f00>
3、底层操作方法
在操作过程中,可以在compile()方法中省略损失函数的传递,而是在train_step中手动完成所有操作。对于度量指标(metrics)也是如此。
以下是一个更加底层操作的示例,它仅使用compile()方法来配置优化器:
首先,我们在__init__()方法中创建度量指标实例来跟踪损失和平均绝对误差(MAE)分数。
然后,我们实现一个自定义的train_step(),更新这些度量指标的状态(通过调用它们的update_state()方法),接着查询它们(通过result()方法)来返回当前平均值,以便进度条显示并传递给任何回调函数。
请注意,在每个epoch之间,我们需要调用度量指标的reset_states()方法!否则,调用result()会返回从训练开始以来的平均值,而我们通常处理的是每个epoch的平均值。幸运的是,框架可以为我们完成这一操作:只需在模型的metrics属性中列出你希望重置的任何度量指标对象。在每个fit() epoch的开始或调用evaluate()时,模型会自动调用这些对象上的reset_states()方法。
class CustomModel(keras.Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_tracker = keras.metrics.Mean(name="loss")
        self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")
        self.loss_fn = keras.losses.MeanSquaredError()
    def train_step(self, data):
        x, y = data
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute our own loss
            loss = self.loss_fn(y, y_pred)
        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply(gradients, trainable_vars)
        # Compute our own metrics
        self.loss_tracker.update_state(loss)
        self.mae_metric.update_state(y, y_pred)
        return {
            "loss": self.loss_tracker.result(),
            "mae": self.mae_metric.result(),
        }
    @property
    def metrics(self):
        # We list our `Metric` objects here so that `reset_states()` can be
        # called automatically at the start of each epoch
        # or at the start of `evaluate()`.
        return [self.loss_tracker, self.mae_metric]
# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
# We don't pass a loss or metrics here.
model.compile(optimizer="adam")
# Just use `fit` as usual -- you can use callbacks, etc.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=5)
Epoch 1/5
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 4.0292 - mae: 1.9270
Epoch 2/5
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 385us/step - loss: 2.2155 - mae: 1.3920
Epoch 3/5
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 336us/step - loss: 1.1863 - mae: 0.9700
Epoch 4/5
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 373us/step - loss: 0.6510 - mae: 0.6811
Epoch 5/5
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 330us/step - loss: 0.4059 - mae: 0.5094
<keras.src.callbacks.history.History at 0x2a7a02860>
3.1 样本权重(sample_weight)和分类权(class_weight)
你可能已经注意到我们之前的基本示例并没有提到样本权重。如果你想要在模型训练时支持 fit() 方法的 sample_weight 和 class_weight 参数,你只需要做以下事情:
- 从数据参数中解包 sample_weight。
- 将其传递给 compute_loss和update_state方法(当然,如果你没有依赖compile()方法来处理损失和度量指标,你也可以手动应用它)。
class CustomModel(keras.Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        if len(data) == 3:
            x, y, sample_weight = data
        else:
            sample_weight = None
            x, y = data
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value.
            # The loss function is configured in `compile()`.
            loss = self.compute_loss(
                y=y,
                y_pred=y_pred,
                sample_weight=sample_weight,
            )
        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply(gradients, trainable_vars)
        # Update the metrics.
        # Metrics are configured in `compile()`.
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred, sample_weight=sample_weight)
        # Return a dict mapping metric names to current value.
        # Note that it will include the loss (tracked in self.metrics).
        return {m.name: m.result() for m in self.metrics}
# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
# You can now use sample_weight argument
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
sw = np.random.random((1000, 1))
model.fit(x, y, sample_weight=sw, epochs=3)
Epoch 1/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - mae: 0.4228 - loss: 0.1420
Epoch 2/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 449us/step - mae: 0.3751 - loss: 0.1058
Epoch 3/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 337us/step - mae: 0.3478 - loss: 0.0951
<keras.src.callbacks.history.History at 0x2a7491780>
3.2 提供自定义的评估步骤
在构建自定义的Keras模型时,如果你希望为model.evaluate()方法提供自己的评估步骤,你可以通过覆盖test_step方法来实现。test_step定义了模型在评估模式(即不更新模型权重)下如何处理一批数据。通过覆盖这个方法,你可以自定义模型在评估时的行为,比如计算特定的度量指标或执行额外的数据检查。
class CustomModel(keras.Model):
    def test_step(self, data):
        # Unpack the data
        x, y = data
        # Compute predictions
        y_pred = self(x, training=False)
        # Updates the metrics tracking the loss
        loss = self.compute_loss(y=y, y_pred=y_pred)
        # Update the metrics.
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred)
        # Return a dict mapping metric names to current value.
        # Note that it will include the loss (tracked in self.metrics).
        return {m.name: m.result() for m in self.metrics}
# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(loss="mse", metrics=["mae"])
# Evaluate with our custom test_step
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.evaluate(x, y)
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 927us/step - mae: 0.8518 - loss: 0.9166
[0.912325382232666, 0.8567370176315308]
4、总结:一个端到端的GAN示例
本示例考虑以下网络:
一个生成器网络,用于生成28x28x1的图像。
 一个判别器网络,用于将28x28x1的图像分为两类(“假的”和“真的”)。
 每个网络各有一个优化器。
 一个损失函数来训练判别器。
# Create the discriminator
discriminator = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.GlobalMaxPooling2D(),
        layers.Dense(1),
    ],
    name="discriminator",
)
# Create the generator
latent_dim = 128
generator = keras.Sequential(
    [
        keras.Input(shape=(latent_dim,)),
        # We want to generate 128 coefficients to reshape into a 7x7x128 map
        layers.Dense(7 * 7 * 128),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Reshape((7, 7, 128)),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
    ],
    name="generator",
)
以下是一个功能完整的GAN类,它重写了compile()方法以使用自己的签名,并在train_step中仅用17行代码实现了整个GAN算法:
class GAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super().__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_loss_tracker = keras.metrics.Mean(name="d_loss")
        self.g_loss_tracker = keras.metrics.Mean(name="g_loss")
        self.seed_generator = keras.random.SeedGenerator(1337)
    @property
    def metrics(self):
        return [self.d_loss_tracker, self.g_loss_tracker]
    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super().compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn
    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]
        # Sample random points in the latent space
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = keras.random.normal(
            shape=(batch_size, self.latent_dim), seed=self.seed_generator
        )
        # Decode them to fake images
        generated_images = self.generator(random_latent_vectors)
        # Combine them with real images
        combined_images = tf.concat([generated_images, real_images], axis=0)
        # Assemble labels discriminating real from fake images
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )
        # Add random noise to the labels - important trick!
        labels += 0.05 * keras.random.uniform(
            tf.shape(labels), seed=self.seed_generator
        )
        # Train the discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply(grads, self.discriminator.trainable_weights)
        # Sample random points in the latent space
        random_latent_vectors = keras.random.normal(
            shape=(batch_size, self.latent_dim), seed=self.seed_generator
        )
        # Assemble labels that say "all real images"
        misleading_labels = tf.zeros((batch_size, 1))
        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(random_latent_vectors))
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply(grads, self.generator.trainable_weights)
        # Update metrics and return their value.
        self.d_loss_tracker.update_state(d_loss)
        self.g_loss_tracker.update_state(g_loss)
        return {
            "d_loss": self.d_loss_tracker.result(),
            "g_loss": self.g_loss_tracker.result(),
        }
测试输出结果
# Prepare the dataset. We use both the training & test MNIST digits.
batch_size = 64
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)
# To limit the execution time, we only train on 100 batches. You can train on
# the entire dataset. You will need about 20 epochs to get nice results.
gan.fit(dataset.take(100), epochs=1)
100/100 ━━━━━━━━━━━━━━━━━━━━ 51s 500ms/step - d_loss: 0.5645 - g_loss: 0.7434
<keras.src.callbacks.history.History at 0x14a4f1b10>
TensorFlow 提供了高级的API,如 fit() 方法,用于简化模型的训练过程。然而,当标准的 fit() 方法不满足特定需求时,例如需要自定义训练循环、调整学习率、实施早期停止等,TensorFlow 也允许用户编写自定义的 fit() 函数。
 自定义 fit() 函数在TensorFlow中意味着完全控制模型的训练过程。这通常涉及到以下步骤:
- 初始化模型:首先,你需要定义并初始化你的模型,这通常包括生成器和判别器(在GAN的情况下)或其他类型的神经网络。 
- 准备数据:你需要准备训练数据,并确保它以TensorFlow可以处理的格式进行加载和批处理。 
- 设置优化器和损失函数:选择适合你的模型和学习任务的优化器和损失函数。 
- 编写训练循环:这是自定义 - fit()函数的核心部分。你需要编写一个循环,该循环会迭代训练数据,计算损失,反向传播梯度,并更新模型权重。
- 监控和评估:在训练过程中,你可能想要监控诸如损失、准确率等指标,并在每个epoch或一批数据后评估模型的性能。 
- 实现早期停止和其他回调:你可以添加自定义的回调函数,如早期停止,以防止过拟合。 
- 保存和加载模型:在训练完成后,你可能想要保存模型以供将来使用,并在需要时加载它。