tensorflow学习笔记(2)线性回归-20240507

发布于:2024-05-08 ⋅ 阅读:(27) ⋅ 点赞:(0)

通过调用Tensorflow计算梯度下降的函数tf.train.GradientDescentOptimizer来实现优化。

代码如下:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#程序作用:
#线性回归:通过调用Tensorflow计算梯度下降的函数tr.train.GradientDescentOptimizer来实现优化。

import os
import  tensorflow as tf
import numpy as np

tf.compat.v1.disable_eager_execution()  #为了tensorflow2.x支持placeholder等

trX=np.linspace(-1,1,101)
trY=2*trX+np.random.randn(*trX.shape)*0.33  #创建一些线性值附近的随机值

#X = tf.placeholder("float") 
#Y = tf.placeholder("float")

X = tf.compat.v1.placeholder("float") 
Y = tf.compat.v1.placeholder("float")
#X = tf.Variable("float")
#Y = tf.Variable("float")

def model(X,w):
	return tf.multiply(X, w)   # X*w线性求值,非常简单

w=tf.Variable(0.0,name="weights")
y_model=model(X,w)

cost=tf.square(Y-y_model)  #用平方误差作为优化目标

#train_op=tf.train.GradientDescentOptimizer(0.01).minimize(cost)  #梯度下降优化
train_op=tf.compat.v1.train.GradientDescentOptimizer(0.01).minimize(cost)  #梯度下降优化


# 开始创建session干活!
with tf.compat.v1.Session() as sess:
	# 首先需要初始化全局变量,这是Tensorflow的要求
	tf.compat.v1.global_variables_initializer().run()

	for i in range(100):
		for(x,y) in zip(trX,trY):
			sess.run(train_op,feed_dict={X:x,Y:y})

	print(sess.run(w))

运行报错:

提示有几个函数只支持tensorflow1.x的,我的tensorflow版本是2.x版本,所以几个函数都用不了。

2.x废弃的函数:

  1. Placeholder
  2. GradientDescentOptimizer
  3. Session
  4. global_variables_initializer

解决办法:

Import后面加入:

tf.comtmpat.v1.disable_eager_execution() 

同时在这几个函数前面加上限制tensorflow1的限制

(1)X = tf.placeholder("float") 改为

X = tf.compat.v1.placeholder("float")

(2)tf.train.GradientDescentOptimizer(0.01) 改为

tf.compat.v1.train.GradientDescentOptimizer(0.01)

(3)tf.Session()改为

tf.compat.v1.Session()

(4)tf.global_variables_initializer().run()改为

tf.compat.v1.global_variables_initializer().run()

修改后运行正确,运行结果为接近2的值,因为随机数产生的,所以每次运行结果不一样

xxx tensorflow % ./tf_test2.py

2024-05-07 15:37:48.829383: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:388] MLIR V1 optimization pass is not enabled

2.0609348

xxx tensorflow % ./tf_test2.py

2024-05-07 15:46:52.645850: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:388] MLIR V1 optimization pass is not enabled

1.8316954


网站公告

今日签到

点亮在社区的每一天
去签到