用Tensorflow进行线性回归和逻辑回归(七)

发布于:2025-06-29 ⋅ 阅读:(22) ⋅ 点赞:(0)

7.用Minibatches有效的计算梯度

有个问题是计算∇W会很慢。隐式地, ∇W取决于损失函数ℒ。因为 ℒ取决于整个数据集,对于大数据集计算∇W 会很慢。实际上人们使用数据集分部评估的∇W ,分部的数据集称为minibatch。每个minibatch的大小为 50–100。在深度学习里minibatch的大小是超参数。 步长α是另一个超参数。深度学习算法有很多超参数。这些超参数不是通过随机梯度下降学习到的。可学习参数与超参数的关系是深度学习的弱点和优点。超参数的出现可以更多的利用专家的直觉,而可学习参数则数据本身说话。然而这种灵活性很快会成为缺点。对超参数的行为不够理解会阻碍初学习使用深度学习。后面我们会花时间来学习超参数优化。

这一节我们介绍一下epoch。一个epoch是完整的将数据x传递给梯度下降算法一次。更特别的是,一个epoch由许多梯度下降步组成,每个步要审核minibatch给出的数据。例如,假如数据集有1000个数据点,训练的minibatch大小为 50。则一个epoch包含20个梯度下降更新。每个训练的epoch增加模型获得的有用的知道。 数学上,这对应于训练集的损失函数的减少。早的 epochs使损失函数下降很大。这个过程称为learning the prior on that dataset。 后期的 epochs 对应于很小的损失下降,但是通常是后期的epochs才出现有意义的学习。模型通常训练 10–1,000 epochs或直致收敛。但是需要的epochs变不会与数据集的大小成正比。复杂的机器学习算法只需要传递一次数据集。

跟踪损失函数随epochs的下降是理解机器学习过程很有用的捷径。这个图通常称为损失曲线。 (见图3-27)。有经验的实践者只要看一眼损失曲线就能疹断出学习的失败。我们要注意损曲线。特别是,我们会介绍TensorBoard,跟踪损失函数的强大的可视化工具。

                

图 3-27. 模型的损失曲线的例子。注意,你遇到的损失曲线可能没有这么平滑。

8.学习使用TensorFlow

这章一余下部分,我们覆盖使用TensorFlow进行学习机器学习建模型需的必要的概念。我们从介绍玩具数据集开始,并解释如何用python库创建有意义的玩具数据集。 接着我们讨论 TensorFlow的思想,如placeholders,feed dictionaries, name scopes, optimizers, 以及 gradients。下一节告诉你如何用这些概念来训练简间的回归和分类模型。

创建玩具数据集Creating Toy Datasets

这一节我们讨论如何创建有意义的玩具数据集,用来训练监督分类和回归模型。

简单的介绍NumPy

我们重度使用NumPy以定义有意义的玩具数据集。 NumPy 是个Python包可以用来操作张量(在NumPy里称为ndarray)List3-19展示一些基础。

#List3-19. Some examples of basic NumPy usage

>>> import numpy as np

>>> np.zeros((2,2))

array([[ 0., 0.],

[ 0., 0.]])

>>> np.eye(3)

array([[ 1., 0., 0.],

[ 0., 1., 0.],

[ 0., 0., 1.]])

你可以看到 NumPy的 ndarray操作与TensorFlow的张量操作非常相似。这种相似是Tensor‐Flow架构有意设计的。许多TensorFlow工具函数与相似的NumPy函数有相似的参数和形式。 所以我们不深入的介绍NumPy,让读者通过实验使用NumPy。网上有很多的教程资源介绍 NumPy。

为什么玩具数据集很重要?

真实的数据需要经过清洗和预处理变换才能进行学习。

用正态分布添加噪章Adding noise with Gaussians

正态分布广泛的用于噪声模型。如图 Figure 3-28所示,正态分布有不同的均值 μ和标准差σ。

图3-28. 有不同均值 和标准差的正态分布

正态分布记为 N(μ, σ)。

玩具回归数据集

最简单的线性回归是一维的直线。假如我们的数据点 x 是一维的,然后假定真实值标签y通过线性规则产生

y = wx + b

这里 w, b 是可学习参数,必须通过梯度下降从数据中评估得到。为了测试我们可以用TensorFlow学习这些参数,我们产生一个包含直线上的点的人工数据集。为了让学习更有挑战一些,我们给数据集添加一些高斯噪声。

我们书写一下被高斯噪声干扰的直线方程:

y = wx + b + N(0,1)

这里使用标准正态分布噪音。我们可以从这种分布得到人工数据集,使用NumPy。如 List3-20所示。

#List3-20. Using NumPy to sample an artificial dataset

import numpy as np

np.random.seed(456)

import  tensorflow as tf

#tf.set_random_seed(456)

#from matplotlib import rc

#rc('text', usetex=True)

import matplotlib.pyplot as plt

from scipy.stats import pearsonr

from sklearn.metrics import mean_squared_error

# Generate synthetic data

N = 100

w_true = 5

b_true = 2

noise_scale = .1

x_np = np.random.rand(N, 1)

noise = np.random.normal(scale=noise_scale, size=(N, 1))

# Convert shape of y_np to (N,)

y_np = np.reshape(w_true * x_np  + b_true + noise, (-1))

# Save image of the data distribution

plt.scatter(x_np, y_np)

plt.xlabel(r"$x$")

plt.ylabel(r"$y$")

plt.xlim(0, 1)

plt.title("Toy Linear Regression Data, "r"$y = 5x + 2 + N(0, 1)$")

plt.savefig("lr_data.png")

我们用Matplotlib绘制数据集于图Figure 3-29。正如期望的,数据分布在直线上,有一小部分测量误差。

        

图3-29. 玩具回归数据分布作图

List3-21 里的NumPy代码有些技巧。我们使用np.vstack来组合两种不同类型的数据点并与不同的标签关联。(我们用 np.concatenate来组合一维标签)

#List3-21. Sample a toy classification dataset with NumPy

import numpy as np

np.random.seed(456)

import tensorflow as tf

#tf.set_random_seed(456)

import matplotlib.pyplot as plt

from sklearn.metrics import accuracy_score

from scipy.special import logit

# Generate synthetic data

N = 100

# Zeros form a Gaussian centered at (-1, -1)

# epsilon is .1

x_zeros = np.random.multivariate_normal(mean=np.array((-1, -1)), cov=.1*np.eye(2), size=(N//2,))

y_zeros = np.zeros((N//2,))

# Ones form a Gaussian centered at (1, 1)

# epsilon is .1

x_ones = np.random.multivariate_normal(mean=np.array((1, 1)), cov=.1*np.eye(2), size=(N//2,))

y_ones = np.ones((N//2,))

x_np = np.vstack([x_zeros, x_ones])

y_np = np.concatenate([y_zeros, y_ones])

plt.xlabel(r"$x_1$")

plt.ylabel(r"$x_2$")

plt.title("Toy Logistic Regression Data")

plt.scatter(x_zeros[:, 0], x_zeros[:, 1], color="blue")

plt.scatter(x_ones[:, 0], x_ones[:, 1], color="red")

3-30Matplotlib 绘制产生的数据以证实分布是我们期望的。我们看到二个分类的数据几乎完全分开。

图3-30. 玩具分类数据集作图


网站公告

今日签到

点亮在社区的每一天
去签到