BERT的中文问答系统18

发布于:2024-10-16 ⋅ 阅读:(143) ⋅ 点赞:(0)
import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import logging
from difflib import SequenceMatcher
from datetime import datetime
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import re
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import ReduceLROnPlateau

# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))

# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)

def setup_logging():
    log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d/%H-%M-%S/羲和.txt'))
    os.makedirs(os.path.dirname(log_file), exist_ok=True)
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )

setup_logging()

# 文本清洗函数
def clean_text(text):
    # 去除特殊字符
    text = re.sub(r'[^\w\s]', '', text)
    # 转换为小写
    text = text.lower()
    return text

# 数据集类
class XihuaDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(file_path)

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for i, item in enumerate(reader):
                    try:
                        if self.validate_item(item):
                            data.append(item)
                    except jsonlines.jsonlines.InvalidLineError as e:
                        logging.warning(f"跳过无效行 {i + 1}: {e}")
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                try:
                    data = [item for item in json.load(f) if self.validate_item(item)]
                except json.JSONDecodeError as e:
                    logging.warning(f"跳过无效文件 {file_path}: {e}")
        return data

    def validate_item(self, item):
        required_keys = ['question', 'xihe_answers', 'ling_answers']
        if all(key in item for key in required_keys):
            return True
        logging.warning(f"跳过无效项: 缺少必要键 {required_keys}")
        return False

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = clean_text(item['question'])

        # 确保 xihe_answers 和 ling_answers 都有值
        xihe_answer = item.get('xihe_answers', [])
        ling_answer = item.get('ling_answers', [])

        if not xihe_answer and ling_answer:
            xihe_answer = ling_answer
        elif not ling_answer and xihe_answer:
            ling_answer = xihe_answer

        xihe_answer = clean_text(xihe_answer[0]) if xihe_answer else ""
        ling_answer = clean_text(ling_answer[0]) if ling_answer else ""

        try:
            inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            xihe_inputs = self.tokenizer(xihe_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            ling_inputs = self.tokenizer(ling_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        except Exception as e:
            logging.warning(f"跳过无效项 {idx}: {e}")
            return self.__getitem__((idx + 1) % len(self.data))

        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'xihe_input_ids': xihe_inputs['input_ids'].squeeze(),
            'xihe_attention_mask': xihe_inputs['attention_mask'].squeeze(),
            'ling_input_ids': ling_inputs['input_ids'].squeeze(),
            'ling_attention_mask': ling_inputs['attention_mask'].squeeze(),
            'xihe_answer': xihe_answer,
            'ling_answer': ling_answer
        }

# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):
    dataset = XihuaDataset(file_path, tokenizer, max_length)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 模型定义
class XihuaModel(torch.nn.Module):
    def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):
        super(XihuaModel, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_name)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

# 训练函数
def train(model, data_loader, optimizer, criterion, device, scheduler=None, scaler=None):
    model.train()
    total_loss = 0.0
    losses = []
    for batch in data_loader:
        try:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            xihe_input_ids = batch['xihe_input_ids'].to(device)
            xihe_attention_mask = batch['xihe_attention_mask'].to(device)
            ling_input_ids = batch['ling_input_ids'].to(device)
            ling_attention_mask = batch['ling_attention_mask'].to(device)

            optimizer.zero_grad()
            with autocast():
                xihe_logits = model(xihe_input_ids, xihe_attention_mask)
                ling_logits = model(ling_input_ids, ling_attention_mask)

                xihe_labels = torch.ones(xihe_logits.size(0), 1).to(device)
                ling_labels = torch.zeros(ling_logits.size(0), 1).to(device)

                loss = criterion(xihe_logits, xihe_labels) + criterion(ling_logits, ling_labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            if scheduler:
                scheduler.step(loss)

            total_loss += loss.item()
            losses.append(loss.item())
        except Exception as e:
            logging.warning(f"跳过无效批次: {e}")

    return total_loss / len(data_loader), losses

# 主训练函数
def main_train(retrain=False):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device: {device}')

    tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
    model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(device)

    if retrain:
        model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=device, weights_only=True))

    optimizer = optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-5)
    criterion = torch.nn.BCEWithLogitsLoss()
    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.1, verbose=True)
    scaler = GradScaler()

    train_data_loader = get_data_loader(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'), tokenizer, batch_size=8, max_length=128)

    num_epochs = 3
    for epoch in range(num_epochs):
        train_loss, losses = train(model, train_data_loader, optimizer, criterion, device, scheduler, scaler)
        logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.8f}')
        plot_losses(losses)

    torch.save(model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))
    logging.info("模型训练完成并保存")

# 绘制损失图
def plot_losses(losses):
    fig, ax = plt.subplots()
    ax.plot(losses)
    ax.set_xlabel('Batch')
    ax.set_ylabel('Loss')
    ax.set_title('Training Loss')
    canvas = FigureCanvasTkAgg(fig, master=root)
    canvas.draw()
    canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=1)

# GUI界面
class XihuaChatbotGUI:
    def __init__(self, root):
        self.root = root
        self.root.title("羲和聊天机器人")

        self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)

        model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
        if not os.path.exists(model_path):
            messagebox.showinfo("模型未找到", "未找到现有模型,将开始训练新的模型")
            self.train_model()
        else:
            self.load_model()
            self.model.eval()

        # 加载训练数据集以便在获取答案时使用
        self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))

        self.history = []  # 历史记录

        self.create_widgets()

    def create_widgets(self):
        # 问题输入框
        self.question_frame = tk.Frame(self.root)
        self.question_frame.pack(pady=10)

        self.question_label = tk.Label(self.question_frame, text="问题:")
        self.question_label.pack(side=tk.LEFT)

        self.question_entry = tk.Entry(self.question_frame, width=50)
        self.question_entry.pack(side=tk.LEFT)

        self.answer_button = tk.Button(self.question_frame, text="获取回答", command=self.get_answer)
        self.answer_button.pack(side=tk.LEFT)

        # 回答显示区
        self.answer_frame = tk.Frame(self.root)
        self.answer_frame.pack(pady=10)

        self.answer_label = tk.Label(self.answer_frame, text="回答:")
        self.answer_label.pack()

        self.answer_text = tk.Text(self.answer_frame, height=10, width=50)
        self.answer_text.pack()

        # 历史记录
        self.history_frame = tk.Frame(self.root)
        self.history_frame.pack(pady=10)

        self.history_label = tk.Label(self.history_frame, text="历史记录:")
        self.history_label.pack()

        self.history_text = tk.Text(self.history_frame, height=10, width=50)
        self.history_text.pack()

        # 训练模式
        self.train_mode_frame = tk.Frame(self.root)
        self.train_mode_frame.pack(pady=10)

        self.train_mode_var = tk.BooleanVar()
        self.train_mode_checkbutton = tk.Checkbutton(self.train_mode_frame, text="继续训练现有模型", variable=self.train_mode_var)
        self.train_mode_checkbutton.pack(side=tk.LEFT)

        self.train_button = tk.Button(self.train_mode_frame, text="训练模型", command=self.train_model)
        self.train_button.pack(side=tk.LEFT)

        self.retrain_button = tk.Button(self.train_mode_frame, text="重新训练模型", command=lambda: self.train_model(retrain=True))
        self.retrain_button.pack(side=tk.LEFT)

        # 图标
        self.xihe_icon = tk.PhotoImage(file=os.path.join(PROJECT_ROOT, 'icons/xihe.png'))
        self.ling_icon = tk.PhotoImage(file=os.path.join(PROJECT_ROOT, 'icons/ling.png'))

        # 进度条
        self.progress = ttk.Progressbar(self.root, orient='horizontal', mode='determinate')
        self.progress.pack(fill=tk.X, padx=10, pady=10)

        # 状态信息
        self.status_label = tk.Label(self.root, text="")
        self.status_label.pack()

    def get_answer(self):
        question = self.question_entry.get()
        if not question:
            messagebox.showwarning("输入错误", "请输入问题")
            return

        inputs = self.tokenizer(clean_text(question), return_tensors='pt', padding='max_length', truncation=True, max_length=128)
        with torch.no_grad():
            input_ids = inputs['input_ids'].to(self.device)
            attention_mask = inputs['attention_mask'].to(self.device)
            logits = self.model(input_ids, attention_mask)
        
        if logits.item() > 0:
            answer_type = "羲和回答"
            icon = self.xihe_icon
        else:
            answer_type = "零回答"
            icon = self.ling_icon

        specific_answer = self.get_specific_answer(question, answer_type)

        self.answer_text.delete(1.0, tk.END)
        self.answer_text.insert(tk.END, f"{answer_type}\n{specific_answer}")
        self.answer_text.image_create(tk.END, image=icon)

        # 记录历史
        self.history.append((question, specific_answer))
        self.update_history()

    def update_history(self):
        self.history_text.delete(1.0, tk.END)
        for q, a in self.history:
            self.history_text.insert(tk.END, f"问题: {q}\n回答: {a}\n\n")

    def get_specific_answer(self, question, answer_type):
        # 使用模糊匹配查找最相似的问题
        best_match = None
        best_ratio = 0.0
        for item in self.data:
            ratio = SequenceMatcher(None, clean_text(question), clean_text(item['question'])).ratio()
            if ratio > best_ratio:
                best_ratio = ratio
                best_match = item

        if best_match:
            xihe_answer = best_match.get('xihe_answers', [])
            ling_answer = best_match.get('ling_answers', [])

            if not xihe_answer and ling_answer:
                xihe_answer = ling_answer
            elif not ling_answer and xihe_answer:
                ling_answer = xihe_answer

            if answer_type == "羲和回答":
                return xihe_answer[0] if xihe_answer else ling_answer[0]
            else:
                return ling_answer[0] if ling_answer else xihe_answer[0]
        return "这个我也不清楚,你问问零吧"

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for i, item in enumerate(reader):
                    try:
                        if self.validate_item(item):
                            data.append(item)
                    except jsonlines.jsonlines.InvalidLineError as e:
                        logging.warning(f"跳过无效行 {i + 1}: {e}")
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                try:
                    data = [item for item in json.load(f) if self.validate_item(item)]
                except json.JSONDecodeError as e:
                    logging.warning(f"跳过无效文件 {file_path}: {e}")
        return data

    def validate_item(self, item):
        required_keys = ['question', 'xihe_answers', 'ling_answers']
        if all(key in item for key in required_keys):
            return True
        logging.warning(f"跳过无效项: 缺少必要键 {required_keys}")
        return False

    def load_model(self):
        model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
        if os.path.exists(model_path):
            try:
                self.model.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=True))
                logging.info("加载现有模型")
            except Exception as e:
                logging.error(f"加载模型失败: {e}")
                messagebox.showerror("加载失败", f"加载模型失败: {e}")
        else:
            logging.info("没有找到现有模型,将使用预训练模型")

    def train_model(self, retrain=False):
        file_path = filedialog.askopenfilename(filetypes=[("JSONL files", "*.jsonl"), ("JSON files", "*.json")])
        if not file_path:
            messagebox.showwarning("文件选择错误", "请选择一个有效的数据文件")
            return

        try:
            dataset = XihuaDataset(file_path, self.tokenizer)
            data_loader = DataLoader(dataset, batch_size=8, shuffle=True)
            
            # 加载已训练的模型权重
            if retrain or self.train_mode_var.get():
                self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device, weights_only=True))
                self.model.to(self.device)
                self.model.train()

            optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5, weight_decay=1e-5)
            criterion = torch.nn.BCEWithLogitsLoss()
            scheduler = ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.1, verbose=True)
            scaler = GradScaler()
            num_epochs = 3
            for epoch in range(num_epochs):
                self.status_label.config(text=f"正在训练 Epoch {epoch+1}/{num_epochs}")
                self.root.update_idletasks()
                train_loss, losses = train(self.model, data_loader, optimizer, criterion, self.device, scheduler, scaler)
                logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')
                plot_losses(losses)
                self.progress['value'] = (epoch + 1) / num_epochs * 100
                self.root.update_idletasks()

            torch.save(self.model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))
            logging.info("模型训练完成并保存")
            messagebox.showinfo("训练完成", "模型训练完成并保存")
            self.status_label.config(text="训练完成")
        except Exception as e:
            logging.error(f"模型训练失败: {e}")
            messagebox.showerror("训练失败", f"模型训练失败: {e}")
            self.status_label.config(text="训练失败")

# 主函数
if __name__ == "__main__":
    # 启动GUI
    root = tk.Tk()
    app = XihuaChatbotGUI(root)
    root.mainloop()

优化点
数据预处理:

文本清洗: 去除特殊字符、统一文本格式、转换为小写等。
数据增强: 通过同义词替换、句子重组等方式增加数据多样性(这部分可以在后续版本中实现)。
数据标准化: 统一文本长度、词汇表等。
模型训练:

学习率调度: 使用 ReduceLROnPlateau 策略,当验证损失不再下降时降低学习率。
早停机制: 当验证损失不再下降时提前停止训练(可以通过添加验证集来实现)。
正则化: 使用 L1 或 L2 正则化防止过拟合。
混合精度训练: 使用 torch.cuda.amp 提高训练速度和效率。
GUI界面:

历史记录: 记录用户的输入和模型的回答,方便用户查看历史对话。
多语言支持: 目前代码主要支持中文,可以在后续版本中增加多语言支持。
现代GUI库: 考虑使用更现代的GUI库,如 PyQt 或 Kivy,以提高界面的美观性和易用性。
进一步优化建议
数据增强:

同义词替换: 使用同义词词典或词向量模型进行同义词替换。
句子重组: 使用句法树或序列到序列模型进行句子重组。
多语言支持:

多语言模型: 使用多语言预训练模型,如 bert-base-multilingual-cased。
语言检测: 使用语言检测模型,自动识别用户输入的语言并选择合适的模型进行处理。
现代GUI库:

PyQt: 使用 PyQt 库构建更现代、更美观的用户界面。
Kivy: 使用 Kivy 库构建跨平台的用户界面。
希望这些优化和建议对你有帮助!如果有任何问题或需要进一步的解释,请随时提问。


网站公告

今日签到

点亮在社区的每一天
去签到