通过调用Tensorflow计算梯度下降的函数tf.train.GradientDescentOptimizer来实现优化。
代码如下:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#程序作用:
#线性回归:通过调用Tensorflow计算梯度下降的函数tr.train.GradientDescentOptimizer来实现优化。
import os
import tensorflow as tf
import numpy as np
tf.compat.v1.disable_eager_execution() #为了tensorflow2.x支持placeholder等
trX=np.linspace(-1,1,101)
trY=2*trX+np.random.randn(*trX.shape)*0.33 #创建一些线性值附近的随机值
#X = tf.placeholder("float")
#Y = tf.placeholder("float")
X = tf.compat.v1.placeholder("float")
Y = tf.compat.v1.placeholder("float")
#X = tf.Variable("float")
#Y = tf.Variable("float")
def model(X,w):
return tf.multiply(X, w) # X*w线性求值,非常简单
w=tf.Variable(0.0,name="weights")
y_model=model(X,w)
cost=tf.square(Y-y_model) #用平方误差作为优化目标
#train_op=tf.train.GradientDescentOptimizer(0.01).minimize(cost) #梯度下降优化
train_op=tf.compat.v1.train.GradientDescentOptimizer(0.01).minimize(cost) #梯度下降优化
# 开始创建session干活!
with tf.compat.v1.Session() as sess:
# 首先需要初始化全局变量,这是Tensorflow的要求
tf.compat.v1.global_variables_initializer().run()
for i in range(100):
for(x,y) in zip(trX,trY):
sess.run(train_op,feed_dict={X:x,Y:y})
print(sess.run(w))
运行报错:
提示有几个函数只支持tensorflow1.x的,我的tensorflow版本是2.x版本,所以几个函数都用不了。
2.x废弃的函数:
- Placeholder
- GradientDescentOptimizer
- Session
- global_variables_initializer
解决办法:
Import后面加入:
tf.comtmpat.v1.disable_eager_execution()
同时在这几个函数前面加上限制tensorflow1的限制
(1)X = tf.placeholder("float") 改为
X = tf.compat.v1.placeholder("float")
(2)tf.train.GradientDescentOptimizer(0.01) 改为
tf.compat.v1.train.GradientDescentOptimizer(0.01)
(3)tf.Session()改为
tf.compat.v1.Session()
(4)tf.global_variables_initializer().run()改为
tf.compat.v1.global_variables_initializer().run()
修改后运行正确,运行结果为接近2的值,因为随机数产生的,所以每次运行结果不一样
xxx tensorflow % ./tf_test2.py
2024-05-07 15:37:48.829383: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:388] MLIR V1 optimization pass is not enabled
2.0609348
xxx tensorflow % ./tf_test2.py
2024-05-07 15:46:52.645850: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:388] MLIR V1 optimization pass is not enabled
1.8316954