目录
Python实例题
题目
基于 TensorFlow 的图像识别与分类系统
问题描述
开发一个基于 TensorFlow 的图像识别与分类系统,包含以下功能:
- 图像分类模型:基于预训练模型的图像分类器
- 数据处理与增强:图像预处理和数据增强
- 模型训练与评估:自定义数据集上的模型训练
- API 服务:提供图像识别的 RESTful API
- 前端界面:用户上传图像并获取分类结果
解题思路
- 使用 TensorFlow 和 Keras 构建深度学习模型
- 基于预训练模型(如 ResNet、VGG、EfficientNet)进行迁移学习
- 设计数据处理和增强管道
- 使用 Flask 或 FastAPI 构建 API 服务
- 开发前端界面实现图像上传和结果展示
关键代码框架
import tensorflow as tf
from tensorflow.keras.applications import ResNet50, EfficientNetB0
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import numpy as np
import os
from flask import Flask, request, jsonify, render_template
from PIL import Image
import io
import base64
# 配置参数
IMAGE_SIZE = (224, 224)
BATCH_SIZE = 32
NUM_CLASSES = 10 # 根据实际数据集调整
EPOCHS = 50
BASE_MODEL = 'efficientnet' # 可选 'resnet' 或 'efficientnet'
# 创建数据增强和预处理
def create_data_generators(train_dir, val_dir):
# 训练数据生成器(包含数据增强)
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
# 验证数据生成器(仅缩放)
val_datagen = ImageDataGenerator(rescale=1./255)
# 生成训练数据
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=IMAGE_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical'
)
# 生成验证数据
val_generator = val_datagen.flow_from_directory(
val_dir,
target_size=IMAGE_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical'
)
return train_generator, val_generator
# 构建模型
def build_model(input_shape, num_classes, base_model_type='efficientnet'):
# 选择基础模型
if base_model_type == 'resnet':
base_model = ResNet50(
weights='imagenet',
include_top=False,
input_shape=input_shape
)
else: # efficientnet
base_model = EfficientNetB0(
weights='imagenet',
include_top=False,
input_shape=input_shape
)
# 冻结基础模型的所有层
for layer in base_model.layers:
layer.trainable = False
# 添加自定义层
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(num_classes, activation='softmax')(x)
# 构建最终模型
model = Model(inputs=base_model.input, outputs=predictions)
# 编译模型
model.compile(
optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy']
)
return model
# 模型微调(解冻部分层)
def fine_tune_model(model, num_layers_to_unfreeze=20):
# 解冻最后几层
for layer in model.layers[-num_layers_to_unfreeze:]:
layer.trainable = True
# 重新编译模型,使用较低的学习率
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
loss='categorical_crossentropy',
metrics=['accuracy']
)
return model
# 训练模型
def train_model(model, train_generator, val_generator, epochs=EPOCHS, model_path='model.h5'):
# 设置回调函数
callbacks = [
EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
ModelCheckpoint(model_path, monitor='val_accuracy', save_best_only=True)
]
# 训练模型
history = model.fit(
train_generator,
steps_per_epoch=train_generator.samples // BATCH_SIZE,
validation_data=val_generator,
validation_steps=val_generator.samples // BATCH_SIZE,
epochs=epochs,
callbacks=callbacks
)
return history, model
# 预测函数
def predict_image(model, image_path=None, image_bytes=None, class_names=None):
# 从文件路径或字节数据加载图像
if image_path:
img = Image.open(image_path).convert('RGB')
elif image_bytes:
img = Image.open(io.BytesIO(image_bytes)).convert('RGB')
else:
raise ValueError("必须提供图像路径或图像字节数据")
# 调整图像大小
img = img.resize(IMAGE_SIZE)
# 转换为numpy数组并归一化
img_array = np.array(img) / 255.0
img_array = np.expand_dims(img_array, axis=0)
# 预测
predictions = model.predict(img_array)
predicted_class = np.argmax(predictions[0])
confidence = np.max(predictions[0])
# 获取类别名称
if class_names and predicted_class < len(class_names):
class_name = class_names[predicted_class]
else:
class_name = f"Class {predicted_class}"
return {
"class": class_name,
"confidence": float(confidence),
"all_predictions": predictions.tolist()[0]
}
# 创建Flask应用
app = Flask(__name__)
# 加载模型和类别名称
model = None
class_names = None
@app.before_first_request
def load_model_and_classes():
global model, class_names
# 加载训练好的模型
model = tf.keras.models.load_model('model.h5')
# 加载类别名称(从训练数据生成或手动定义)
if os.path.exists('class_names.txt'):
with open('class_names.txt', 'r') as f:
class_names = [line.strip() for line in f.readlines()]
@app.route('/')
def index():
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
try:
# 获取上传的图像
file = request.files['image']
if not file:
return jsonify({"error": "未提供图像文件"}), 400
# 读取图像数据
image_bytes = file.read()
# 进行预测
result = predict_image(model, image_bytes=image_bytes, class_names=class_names)
return jsonify(result)
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route('/train', methods=['POST'])
def train():
try:
# 获取训练配置
data = request.json
train_dir = data.get('train_dir', 'data/train')
val_dir = data.get('val_dir', 'data/val')
epochs = data.get('epochs', EPOCHS)
base_model = data.get('base_model', BASE_MODEL)
# 创建数据生成器
train_generator, val_generator = create_data_generators(train_dir, val_dir)
# 构建模型
model = build_model((*IMAGE_SIZE, 3), train_generator.num_classes, base_model)
# 训练模型
history, model = train_model(model, train_generator, val_generator, epochs)
# 保存类别名称
class_names = list(train_generator.class_indices.keys())
with open('class_names.txt', 'w') as f:
f.write('\n'.join(class_names))
return jsonify({
"message": "模型训练完成",
"classes": class_names
})
except Exception as e:
return jsonify({"error": str(e)}), 500
# 前端模板 (index.html)
<!DOCTYPE html>
<html>
<head>
<title>图像分类系统</title>
<style>
body {
font-family: Arial, sans-serif;
max-width: 800px;
margin: 0 auto;
padding: 20px;
text-align: center;
}
.container {
background-color: #f5f5f5;
padding: 20px;
border-radius: 10px;
box-shadow: 0 0 10px rgba(0,0,0,0.1);
}
h1 {
color: #333;
}
.upload-area {
margin: 20px 0;
}
.upload-btn {
background-color: #4CAF50;
color: white;
padding: 10px 20px;
border: none;
border-radius: 5px;
cursor: pointer;
}
.upload-btn:hover {
background-color: #45a049;
}
.result-area {
margin-top: 20px;
padding: 15px;
background-color: #fff;
border-radius: 5px;
min-height: 100px;
}
.image-preview {
max-width: 100%;
height: auto;
margin: 20px 0;
border-radius: 5px;
}
</style>
</head>
<body>
<div class="container">
<h1>图像分类系统</h1>
<div class="upload-area">
<input type="file" id="imageUpload" accept="image/*" style="display: none;">
<button class="upload-btn" onclick="document.getElementById('imageUpload').click()">
选择图像
</button>
<button class="upload-btn" id="predictBtn" onclick="predictImage()" disabled>
预测
</button>
</div>
<div>
<img id="imagePreview" class="image-preview" src="" alt="图像预览">
</div>
<div class="result-area" id="resultArea">
<p>请上传一张图像进行分类</p>
</div>
</div>
<script>
let selectedImage = null;
document.getElementById('imageUpload').addEventListener('change', function(e) {
if (this.files && this.files[0]) {
const reader = new FileReader();
reader.onload = function(e) {
document.getElementById('imagePreview').src = e.target.result;
document.getElementById('resultArea').innerHTML = '<p>图像已加载,请点击预测按钮</p>';
document.getElementById('predictBtn').disabled = false;
selectedImage = this.files[0];
}.bind(this);
reader.readAsDataURL(this.files[0]);
}
});
function predictImage() {
if (!selectedImage) {
alert('请先选择一张图像');
return;
}
const formData = new FormData();
formData.append('image', selectedImage);
document.getElementById('resultArea').innerHTML = '<p>正在预测,请稍候...</p>';
fetch('/predict', {
method: 'POST',
body: formData
})
.then(response => response.json())
.then(data => {
if (data.error) {
document.getElementById('resultArea').innerHTML = `<p>错误: ${data.error}</p>`;
} else {
let resultHTML = `<h3>预测结果</h3>`;
resultHTML += `<p>类别: ${data.class}</p>`;
resultHTML += `<p>置信度: ${(data.confidence * 100).toFixed(2)}%</p>`;
document.getElementById('resultArea').innerHTML = resultHTML;
}
})
.catch(error => {
document.getElementById('resultArea').innerHTML = `<p>错误: ${error.message}</p>`;
});
}
</script>
</body>
</html>
# 训练脚本示例
if __name__ == "__main__":
# 创建数据生成器
train_generator, val_generator = create_data_generators('data/train', 'data/val')
# 构建模型
model = build_model((*IMAGE_SIZE, 3), train_generator.num_classes, BASE_MODEL)
# 训练模型
print("开始训练基础模型...")
history, model = train_model(model, train_generator, val_generator, EPOCHS, 'base_model.h5')
# 模型微调
print("开始微调模型...")
model = fine_tune_model(model)
history, model = train_model(model, train_generator, val_generator, EPOCHS//2, 'fine_tuned_model.h5')
# 保存类别名称
class_names = list(train_generator.class_indices.keys())
with open('class_names.txt', 'w') as f:
f.write('\n'.join(class_names))
print("模型训练完成!")
难点分析
- 数据预处理:设计合理的图像增强和预处理策略
- 模型选择与调优:选择合适的预训练模型并进行有效微调
- 计算资源优化:在有限资源下高效训练大型模型
- API 设计:设计稳定可靠的图像识别 API 接口
- 前端交互:实现流畅的图像上传和结果展示界面
扩展方向
- 添加多类别分类支持
- 实现目标检测功能
- 添加模型解释和可视化
- 集成摄像头实时识别
- 部署到云服务平台