15 - 多模态大语言模型 — 图文 “牵线” 系统 “成长记”:借 CLIP 练本领,从图像与文字里精准 “搭鹊桥” 的全过程 (呆瓜版 - 2 号)

发布于:2025-08-01 ⋅ 阅读:(23) ⋅ 点赞:(0)

目录

1、基础:它到底是个啥?

1. 1、一句话理解核心

1.2、 为啥厉害?

1.3、怎么发展来的?

2、架构:它的 “身体构造” 是啥样的?

2.1、视觉语言模型架构:让 AI “看懂” 世界的核心系统

2.1.1、双塔模型(如 CLIP)   

2.1.2、交叉注意力模型(如 BLIP-2)  

2.1.3、端到端模型(如 Flamingo)

2.1.4、轻量级模型(如 Fuyu-8B)

2.2、语音语言模型架构:让 AI “听懂” 声音的核心系统

2.2.1、语音特征提取(MFCC)

2.2.2、序列对齐(CTC 损失)

2.2.3、端到端模型(如 Whisper)

2.3、多模态大语言模型架构:让 AI “感知” 世界的超级系统

2.3.1、模态编码器

2.3.2、连接器(Connector)

2.3.3、大型语言模型(LLM)

3、多模态架构的核心公式与训练流程

3.1、跨模态对齐公式

3.2、训练流程

3.3、训练:怎么让它变聪明的?

4、架构对比与选择指南

5、应用:它能帮我们做啥?

6、多模态LLM模型(图像-文本生成)(简化版)

7、多模态LLM模型(图像-文本生成+问答系统)(简化版)


1、基础:它到底是个啥?

1. 1、一句话理解核心

普通大模型(比如 ChatGPT)只能处理文字,而多模态大语言模型(简称 “多模态 LLM”)能同时 “看懂图、听懂声、读得懂字”,还能用文字回答你所有问题。比如你给它一张电路图,它能直接告诉你 “这里接反了会短路”;给它一段机器运转的声音,它能说 “轴承快坏了,得换”。

1.2、 为啥厉害?

以前的 AI 是 “偏科生”:有的只能看图(比如识别图片里的猫),有的只能处理文字(比如写作文),但多模态 LLM 是 “全能选手”—— 它用语言把所有信息打通了。就像人既会看路标(图像),又会读路牌(文字),还能跟人打听路(语言),最后找到目的地,而不是只认其中一种。

1.3、怎么发展来的?

  • 先有 “文字学霸”:比如 GPT-3、Llama,只会处理文字,逻辑推理超强但 “看不见东西”。
  • 再加上 “图像 / 声音翻译官”:比如 CLIP 能把图片转成文字能懂的 “密码”,让文字学霸能 “间接看图”。
  • 最后合体:把 “翻译官” 和 “文字学霸” 绑在一起,就成了多模态 LLM,比如 GPT-4V、Llava 这些。

2、架构:它的 “身体构造” 是啥样的?

2.1、视觉语言模型架构:让 AI “看懂” 世界的核心系统

视觉语言模型(如 CLIP、BLIP-2)的核心是将图像和文字映射到同一语义空间,实现跨模态理解。其架构通常包含三个模块:

2.1.1、双塔模型(如 CLIP)   

  • 架构原理: 独立的图像编码器(如 ResNet)和文本编码器(如 Transformer)分别处理图片和文字,通过对比学习将两者特征投影到同一向量空间。
    • 关键公式:对比损失函数L_{\text{CLIP}} = -\frac{1}{2N}\sum_{i=1}^N \left[ \log\frac{e^{\text{sim}(v_i,t_i)/\tau}}{\sum_{j=1}^N e^{\text{sim}(v_i,t_j)/\tau}} + \log\frac{e^{\text{sim}(v_i,t_i)/\tau}}{\sum_{j=1}^N e^{\text{sim}(v_j,t_i)/\tau}} \right]

    • 其中,\text{sim}为余弦相似度,\tau是温度参数,控制相似度分布的平滑程度。

  • 训练步骤
    1. 输入图片和对应文本,分别编码为特征向量v_it_i
    2. 计算所有图片 - 文本对的相似度矩阵,最大化正确对的相似度,最小化错误对的相似度。
  • 应用场景:图片检索(如从百万张图中找出 “戴红帽子的猫”)

2.1.2、交叉注意力模型(如 BLIP-2)  

  • 架构原理: 引入Query-Former 模块,通过跨模态注意力机制让图像和文本特征直接交互。
    • 关键公式:跨模态注意力\text{Attention}(Q,K,V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V其中,Q 来自文本,K 和 V 来自图像,通过多头注意力实现深度融合。
  • 训练步骤
    1. 冻结图像编码器(如 CLIP 的 ViT),仅微调 Query-Former 和 LLM。
    2. 输入图像和文本,Query-Former 生成融合特征,LLM 生成回答(如 “图中猫在做什么”)。
  • 应用场景:视觉问答(VQA)、图文生成(如根据图片写故事)。

2.1.3、端到端模型(如 Flamingo)

  • 架构原理: 冻结视觉编码器,仅微调语言模型(如 Llama),通过视觉提示引导语言模型生成答案。
    • 关键设计
      • 视觉特征直接输入 LLM 的 Transformer 层,无需独立编码器。
      • 采用 “视觉 token”(如<image>标签)标记输入中的图像部分。
  • 训练步骤
    1. 预训练阶段:在海量图文数据上对齐视觉和语言特征。
    2. 微调阶段:针对特定任务(如医学影像分析),用领域数据训练语言模型。
  • 应用场景:实时图像标注(如直播中自动生成字幕)。

2.1.4、轻量级模型(如 Fuyu-8B)

  • 架构原理: 摒弃传统图像编码器,直接将图像分块后通过线性投影输入 Transformer 解码器。
    • 关键公式:图像分块投影\text{patch}_i = \text{Linear}(x_{i \times patch\_size}) 其中,x是原始图像,\text{patch}_i是第 i 个图像块的特征向量。
  • 训练步骤
    1. 将图像切分为 16x16 的小块,每个块线性投影到文本模型的维度。
    2. 与文本 token 混合输入解码器,联合训练生成回答。
  • 应用场景:边缘设备(如手机)的实时图像问答。

2.2、语音语言模型架构:让 AI “听懂” 声音的核心系统

语音语言模型(如 Whisper、DeepSpeech)的核心是将语音信号转化为文字序列,其架构通常包含三个模块:

2.2.1、语音特征提取(MFCC)

  • 步骤解析
    1. 预加重:提升高频信号,公式为y[n] = x[n] - \mu x[n-1] , \mu \approx 0.97
    2. 分帧加窗:将语音切分为 20-30ms 的帧,加汉明窗减少边界效应。
    3. FFT 变换:将时域信号转为频域,得到功率谱。
    4. 梅尔滤波:通过三角形滤波器组提取人耳敏感的频率特征。
    5. DCT 变换:将梅尔谱转换为倒谱系数(MFCC),去除冗余信息。
  • 输出结果:每帧生成 12-16 维 MFCC 特征,叠加能量、一阶 / 二阶差分,共 40 维左右。

2.2.2、序列对齐(CTC 损失)

  • 架构原理: 解决语音和文本的时序不对齐问题,通过动态规划计算路径概率。
    • 关键公式:CTC 损失函数L = -\log \sum_{\pi \in \text{Align}(y)} \prod_{t=1}^T p_t(\pi_t) 其中,\pi是对齐路径,p_t(\pi_t)是时刻 t 输出字符\pi_t的概率。
  • 训练步骤
    1. 输入 MFCC 特征序列,通过 RNN 或 CNN 生成预测概率矩阵。
    2. 使用 CTC 算法计算所有可能对齐路径的概率之和,最大化正确路径的概率。

2.2.3、端到端模型(如 Whisper)

  • 架构原理: 基于 Transformer 的编码器 - 解码器架构,直接输入音频波形生成文本。
    • 关键设计
      • 编码器:将 30 秒音频转为 80 维 log-Mel 频谱,输入多层 Transformer。
      • 解码器:在文本生成时引入交叉注意力,融合音频编码和历史文本。
  • 训练步骤
    1. 预训练:在 68 万小时多语言音频上训练,支持 99 种语言。
    2. 微调:针对特定领域(如医疗)优化转录准确率。
  • 应用场景:实时语音转写(如会议记录)、跨语言翻译(如法语→英语)。

2.3、多模态大语言模型架构:让 AI “感知” 世界的超级系统

多模态大语言模型(如 GPT-4V、Llama 4 Maverick)的核心是整合视觉、语音、文本多模态信息,实现复杂推理。其架构通常包含四个模块:

2.3.1、模态编码器

  • 功能:将图像、语音等非文本信息转化为特征向量。
  • 技术方案
    • 图像:CLIP、Swin Transformer(如 GPT-4V)。
    • 语音:MFCC+Transformer(如 Whisper)。
    • 文本:Llama、Qwen(如 Qwen-VL)。

2.3.2、连接器(Connector)

  • 功能:统一不同模态的特征格式,便于 LLM 处理。
  • 技术方案
    • 线性投影:将图像 / 语音特征调整为与文本 token 相同维度(如 Fuyu-8B)。
    • 跨模态注意力:在 Transformer 层引入图像 - 文本交互(如 BLIP-2)。

2.3.3、大型语言模型(LLM)

  • 功能:作为 “大脑” 进行跨模态推理和生成。
  • 技术方案
    • 参数量:通常为 7B-400B(如 Llama 4 Maverick 的 400B 参数)。
    • 架构:混合专家(MoE)、稀疏注意力(如 DeepSeek-V3)。
  1. 生成器(可选)

    • 功能:输出非文本模态(如图像、视频)。
    • 技术方案
      • 图像生成:扩散模型(如 Stable Diffusion),基于 LLM 输出的文本描述生成图片。
      • 视频生成:Transformer + 时空注意力,生成连贯视频序列。

3、多模态架构的核心公式与训练流程

3.1、跨模态对齐公式

  • 对比学习(CLIP):
  • 掩码语言建模

3.2、训练流程

  1. 预训练阶段
    • 多模态数据构建:爬取图文对、语音 - 文本对(如 SBU 数据集的 50 万图文对)。
    • 特征对齐:通过对比学习或掩码建模,让模型理解跨模态关联。
  2. 微调阶段
    • 领域数据注入:如医疗影像 + 诊断报告,提升特定任务准确率(如 BakLLaVA-1 的 92% 诊断率)。
    • 指令微调:设计多模态指令(如 “根据 X 光片诊断肺炎风险”),引导模型生成符合人类逻辑的回答。
  3. 优化技术
    • 混合专家(MoE):减少训练成本,如 Llama 4 通过 MoE 实现 400B 参数高效训练。
    • 模型量化:将参数压缩至 4-bit/8-bit,如 Llama 4 Scout 支持单卡部署。

3.3、训练:怎么让它变聪明的?

就像教一个小孩 “认识世界”,分三步:

1. 先学 “基础知识”(预训练)
给它喂海量 “图文配对” 的资料:比如 “猫的图片 +‘这是一只猫’”“汽车图片 +‘四个轮子的交通工具’”。
目的是让它知道 “图片里的内容和文字说的是一回事”,就像小孩看绘本,把图画和文字对应起来。

2. 再练 “具体技能”(微调)
针对具体任务 “补课”:比如想让它看懂 X 光片,就专门喂 “X 光片 + 医生诊断文字” 的资料;想让它讲题,就喂 “数学题图片 + 解题步骤”。
这一步就像学生上完基础课,再去学 “物理、化学” 等专业课。

3. 关键技巧:让它 “不瞎猜”
训练时故意 “藏起一部分信息” 让它猜:比如盖住图片的一半让它补全,或者遮住文字的几个字让它填。这样能逼它更认真地 “看” 和 “想”,减少胡说八道(专业叫 “减少幻觉”)。

4、架构对比与选择指南

架构类型 代表模型 核心优势 适用场景 参数量范围
双塔模型 CLIP 轻量、高检索效率 图片 / 文本匹配 400M-10B
交叉注意力 BLIP-2 复杂推理、多模态生成 视觉问答、图文生成 13B-65B
端到端 Flamingo 高效适配、低延迟 实时交互、边缘设备 7B-30B
混合专家(MoE) Llama 4 Maverick 高性能、稀疏计算 科学研究、工业级推理 100B-400B
轻量级 Fuyu-8B 低功耗、单卡部署 手机、物联网设备 8B-16B

5、应用:它能帮我们做啥?

生活里到处都能用,举几个接地气的例子:

  • 看病:给医生当助手,拍张 X 光片,它能立刻标出 “这里可能有炎症”,再结合病历文字,提醒医生重点检查。
  • 学习:学生拍一张数学题图片,它不光给答案,还能用文字讲 “第一步为什么要这么算”,比课本好懂。
  • 干活:工厂里拍张零件照片,它能说 “这个螺丝松了,会导致机器异响”,工人不用自己盯着看半天。
  • 日常:旅游时拍张外语路标,它能翻译文字,还能告诉你 “往前走 300 米有地铁站”(结合图片里的箭头)。

特别说明:训练度和数据集不够,结果存在问题,主要用于理解知识

6、多模态LLM模型(图像-文本生成)(简化版)

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import GPT2LMHeadModel, GPT2Tokenizer, get_linear_schedule_with_warmup
import os
from PIL import Image
import json
import requests
from io import BytesIO
import random
import numpy as np
from tqdm import tqdm
import warnings
import matplotlib.pyplot as plt  # 用于绘图
from matplotlib.table import Table  # 用于生成对比表格
import seaborn as sns  # 美化图表
sns.set_style("whitegrid")

# ---------------------------- 核心修复1:设置Matplotlib支持中文显示 ----------------------------
plt.rcParams["font.family"] = ["SimHei"]  # 支持中文的字体
plt.rcParams["axes.unicode_minus"] = False  # 解决负号显示问题


# ---------------------------- 消除其他警告配置 ----------------------------
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
warnings.filterwarnings("ignore", category=UserWarning, message="The parameter 'pretrained' is deprecated")
warnings.filterwarnings("ignore", category=UserWarning, message="Arguments other than a weight enum or `None` for 'weights' are deprecated")


# ---------------------------- 数据集类定义(平衡样本分布) ----------------------------
class BalancedDemoDataset(Dataset):
    """平衡的演示数据集(确保各类别样本数量均等)"""
    def __init__(self, img_size=224, max_text_length=512):
        self.categories = [
            {"name": "cat", "url": "https://picsum.photos/seed/cat1/500/300"},
            {"name": "dog", "url": "https://picsum.photos/seed/dog1/500/300"},
            {"name": "bird", "url": "https://picsum.photos/seed/bird1/500/300"},
            {"name": "city", "url": "https://picsum.photos/seed/city1/500/300"},
            {"name": "mountain", "url": "https://picsum.photos/seed/mountains1/500/300"},
            {"name": "beach", "url": "https://picsum.photos/seed/beach1/500/300"},
            {"name": "forest", "url": "https://picsum.photos/seed/forest1/500/300"},
            {"name": "library", "url": "https://picsum.photos/seed/library1/500/300"},
            {"name": "restaurant", "url": "https://picsum.photos/seed/restaurant1/500/300"},
            {"name": "airport", "url": "https://picsum.photos/seed/airport1/500/300"}
        ]

        # 为每个类别生成5个不同描述(更具体,避免模糊)
        self.data = []
        for cat in self.categories:
            base_descriptions = [
                f"A {cat['name']} scene with typical features",
                f"The {cat['name']} showing natural details",
                f"An image of {cat['name']} with clear views",
                f"View of {cat['name']} in daylight",
                f"Close-up of {cat['name']} key elements"
            ]
            for desc in base_descriptions:
                self.data.append({
                    "image_url": cat["url"],
                    "text": desc,
                    "category": cat["name"]
                })

        self.img_size = img_size
        self.max_text_length = max_text_length

        # 图像预处理
        self.image_transform = transforms.Compose([
            transforms.Resize((self.img_size, self.img_size)),
            transforms.RandomHorizontalFlip(p=0.3),
            transforms.RandomRotation(5),
            transforms.ColorJitter(brightness=0.1, contrast=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        # 初始化tokenizer(左padding,明确设置pad_token)
        self.text_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        self.text_tokenizer.pad_token = self.text_tokenizer.eos_token
        self.text_tokenizer.padding_side = "left"

        self.cached_images = {}


    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        if item["image_url"] not in self.cached_images:
            try:
                response = requests.get(item["image_url"], timeout=10)
                image = Image.open(BytesIO(response.content)).convert("RGB")
                self.cached_images[item["image_url"]] = image
            except Exception as e:
                print(f"Image download failed for {item['category']}: {e}, using random image")
                color_map = {
                    "cat": (255, 200, 200), "dog": (200, 255, 200), "bird": (200, 200, 255),
                    "city": (255, 255, 200), "mountain": (200, 255, 255), "beach": (255, 220, 180),
                    "forest": (180, 255, 180), "library": (220, 220, 220),
                    "restaurant": (255, 180, 180), "airport": (200, 200, 200)
                }
                color = color_map.get(item["category"], (255, 255, 255))
                image = Image.new('RGB', (self.img_size, self.img_size), color=color)
                self.cached_images[item["image_url"]] = image

        image = self.cached_images[item["image_url"]]
        image_tensor = self.image_transform(image)
        text_tokens = self.text_tokenizer(
            item["text"],
            padding="max_length",
            truncation=True,
            max_length=self.max_text_length,
            return_tensors="pt"
        )

        return {
            "image": image_tensor,
            "input_ids": text_tokens["input_ids"].squeeze(0),
            "attention_mask": text_tokens["attention_mask"].squeeze(0),
            "text": item["text"],
            "category": item["category"]
        }


# ---------------------------- 模型架构(保持不变) ----------------------------
class ImageEncoder(nn.Module):
    def __init__(self, output_dim=768):
        super().__init__()
        self.base_model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        modules = list(self.base_model.children())[:-1]
        self.feature_extractor = nn.Sequential(*modules)
        self.projection = nn.Sequential(
            nn.Linear(2048, 1024),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(1024, output_dim)
        )

    def forward(self, images):
        features = self.feature_extractor(images).squeeze(-1).squeeze(-1)
        return self.projection(features)


class CrossModalFusion(nn.Module):
    def __init__(self, hidden_dim=768, num_heads=8):
        super().__init__()
        self.text_norm = nn.LayerNorm(hidden_dim)
        self.image_norm = nn.LayerNorm(hidden_dim)
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_dim, num_heads=num_heads, batch_first=True
        )
        self.fusion = nn.Linear(hidden_dim * 2, hidden_dim)
        self.activation = nn.GELU()

    def forward(self, text_features, image_features):
        batch_size, seq_len, _ = text_features.shape
        image_features = self.image_norm(image_features)
        image_expanded = image_features.unsqueeze(1).expand(-1, seq_len, -1)

        text_attn, _ = self.attention(
            query=text_features, key=image_expanded, value=image_expanded
        )
        text_attn = self.text_norm(text_features + text_attn)
        fused = self.activation(self.fusion(torch.cat([text_features, text_attn], dim=-1)))
        return fused


class MultimodalLLM(nn.Module):
    def __init__(self, hidden_dim=768):
        super().__init__()
        self.text_encoder = GPT2LMHeadModel.from_pretrained("gpt2")
        for param in list(self.text_encoder.parameters())[:3]:
            param.requires_grad = False

        self.image_encoder = ImageEncoder(output_dim=hidden_dim)
        self.cross_modal_fusion = CrossModalFusion(hidden_dim=hidden_dim)
        self.final_norm = nn.LayerNorm(hidden_dim)

    def forward(self, images, input_ids, attention_mask=None):
        image_features = self.image_encoder(images)
        text_outputs = self.text_encoder.transformer(
            input_ids=input_ids, attention_mask=attention_mask
        )
        text_features = text_outputs.last_hidden_state

        fused_features = self.cross_modal_fusion(text_features, image_features)
        fused_features = self.final_norm(fused_features)
        return self.text_encoder.lm_head(fused_features)


# ---------------------------- 训练与生成函数(优化生成策略) ----------------------------
def train_model(model, train_loader, val_loader, epochs=10, lr=5e-5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index=50256)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=len(train_loader), num_training_steps=epochs*len(train_loader)
    )

    train_losses = []
    val_losses = []

    for epoch in range(epochs):
        model.train()
        total_train_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for batch in progress_bar:
            images = batch["image"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)

            outputs = model(images, input_ids, attention_mask)
            shift_logits = outputs[..., :-1, :].contiguous()
            shift_labels = input_ids[..., 1:].contiguous()
            loss = criterion(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

            total_train_loss += loss.item()
            progress_bar.set_postfix(loss=f"{loss.item():.4f}")

        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                images = batch["image"].to(device)
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                outputs = model(images, input_ids, attention_mask)
                shift_logits = outputs[..., :-1, :].contiguous()
                shift_labels = input_ids[..., 1:].contiguous()
                total_val_loss += criterion(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).item()

        avg_train = total_train_loss / len(train_loader)
        avg_val = total_val_loss / len(val_loader)
        train_losses.append(avg_train)
        val_losses.append(avg_val)
        print(f"Epoch {epoch+1} | Train Loss: {avg_train:.4f} | Val Loss: {avg_val:.4f}")

    return model, train_losses, val_losses


def generate_text(model, image, tokenizer, category, max_length=60):
    """核心修复2:传递attention_mask,优化生成策略"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    # 更具体的引导提示(避免模型生成列表)
    prompt = f"Describe the {category} image in detail: "
    # 生成input_ids和attention_mask(解决警告)
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        padding="max_length",
        max_length=len(prompt) + 5,  # 足够容纳提示词
        truncation=True
    )
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)  # 传递注意力掩码

    # 提取图像特征
    image = image.unsqueeze(0).to(device)
    with torch.no_grad():
        image_features = model.image_encoder(image)

    # 核心修复3:优化生成参数(减少重复,提高相关性)
    output = model.text_encoder.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,  # 传入掩码,消除警告
        max_length=max_length,
        temperature=0.5,  # 降低随机性,更聚焦输入
        num_beams=3,
        no_repeat_ngram_size=3,  # 避免3字词重复
        early_stopping=True,
        encoder_hidden_states=image_features.unsqueeze(1)
    )
    # 解码并移除提示词
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    return generated_text.replace(prompt, "").strip()


# ---------------------------- 对比图生成(修复中文显示) ----------------------------
def plot_loss_curves(train_losses, val_losses):
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, len(train_losses)+1), train_losses, label="训练损失", marker='o')
    plt.plot(range(1, len(val_losses)+1), val_losses, label="验证损失", marker='s')
    plt.xlabel("轮次")
    plt.ylabel("损失值")
    plt.title("训练与验证损失对比")
    plt.legend()
    plt.grid(alpha=0.3)
    plt.savefig("loss_curves.png")
    print("Loss对比图已保存为 loss_curves.png")
    plt.close()


def generate_results_table(results):
    plt.figure(figsize=(12, 9))
    ax = plt.gca()
    ax.axis('off')

    table = Table(ax, bbox=[0, 0, 1, 1])
    # 表头(中文显示正常)
    table.add_cell(0, 0, 0.1, 0.1, text="类别", loc='center', facecolor='lightgray')
    table.add_cell(0, 1, 0.3, 0.1, text="原始文本", loc='center', facecolor='lightgray')
    table.add_cell(0, 2, 0.6, 0.1, text="生成文本", loc='center', facecolor='lightgray')

    # 添加内容
    for i, res in enumerate(results[:8]):
        table.add_cell(i+1, 0, 0.1, 0.15, text=res["category"], loc='center')  # 增加行高,避免文本溢出
        table.add_cell(i+1, 1, 0.3, 0.15, text=res["original"], loc='left')
        table.add_cell(i+1, 2, 0.6, 0.15, text=res["generated"], loc='left')

    ax.add_table(table)
    plt.savefig("results_table.png", bbox_inches='tight')
    print("结果对比表已保存为 results_table.png")
    plt.close()


# ---------------------------- 主函数 ----------------------------
if __name__ == "__main__":
    print("准备平衡数据集...")
    full_dataset = BalancedDemoDataset()
    train_size = int(0.8 * len(full_dataset))
    train_dataset, val_dataset = random_split(full_dataset, [train_size, len(full_dataset)-train_size])

    batch_size = 4 if torch.cuda.is_available() else 1
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    print("训练模型...")
    model = MultimodalLLM()
    model, train_losses, val_losses = train_model(
        model, train_loader, val_loader, epochs=12
    )
    torch.save(model.state_dict(), "optimized_model.pth")

    plot_loss_curves(train_losses, val_losses)

    print("生成测试结果...")
    tokenizer = full_dataset.text_tokenizer
    results = []
    for category in [cat["name"] for cat in full_dataset.categories]:
        sample_idx = next(i for i, item in enumerate(full_dataset.data) if item["category"] == category)
        sample = full_dataset[sample_idx]
        generated = generate_text(model, sample["image"], tokenizer, category)
        results.append({
            "category": category,
            "original": sample["text"],
            "generated": generated
        })

    generate_results_table(results)

    print("\n部分生成结果:")
    for res in results[:5]:
        print(f"\n类别: {res['category']}")
        print(f"原始文本: {res['original']}")
        print(f"生成文本: {res['generated']}")

7、多模态LLM模型(图像-文本生成+问答系统)(简化版)

# 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms  # 视觉模型与图像预处理工具
from torch.utils.data import Dataset, DataLoader, random_split  # 数据加载与划分
from transformers import GPT2LMHeadModel, GPT2Tokenizer, get_linear_schedule_with_warmup  # 文本模型与工具
import os
from PIL import Image  # 图像处理库
import requests  # 网络请求(下载图像)
from io import BytesIO  # 内存中处理二进制数据
import numpy as np
from tqdm import tqdm  # 进度条显示
import warnings  # 警告处理
import matplotlib.pyplot as plt  # 可视化工具
from matplotlib.table import Table  # 生成结果表格
import seaborn as sns  # 美化图表

sns.set_style("whitegrid")  # 设置图表风格

# ---------------------------- 配置环境(解决中文显示与警告问题) ----------------------------
# 设置支持中文的字体,解决图表中文乱码
plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC", "Arial Unicode MS"]
plt.rcParams["axes.unicode_minus"] = False  # 解决负号显示异常

# 过滤无关警告,保持输出简洁
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
warnings.filterwarnings("ignore", category=UserWarning, message="The parameter 'pretrained' is deprecated")
warnings.filterwarnings("ignore", category=UserWarning,
                        message="Arguments other than a weight enum or `None` for 'weights' are deprecated")


# ---------------------------- 多模态数据集类(核心数据处理) ----------------------------
class BalancedMultimodalDataset(Dataset):
    """
    平衡的多模态数据集:包含图像、描述文本和英文问答对
    特点:10个类别,每个类别样本数量均等,避免模型偏向某类数据
    """

    def __init__(self, img_size=224, max_text_length=128):
        # 定义10个类别及其对应的图像URL(使用picsum生成可重复的随机图像)
        self.categories = [
            {"name": "cat", "url": "https://picsum.photos/seed/cat1/500/300"},
            {"name": "dog", "url": "https://picsum.photos/seed/dog1/500/300"},
            {"name": "bird", "url": "https://picsum.photos/seed/bird1/500/300"},
            {"name": "city", "url": "https://picsum.photos/seed/city1/500/300"},
            {"name": "mountain", "url": "https://picsum.photos/seed/mountains1/500/300"},
            {"name": "beach", "url": "https://picsum.photos/seed/beach1/500/300"},
            {"name": "forest", "url": "https://picsum.photos/seed/forest1/500/300"},
            {"name": "library", "url": "https://picsum.photos/seed/library1/500/300"},
            {"name": "restaurant", "url": "https://picsum.photos/seed/restaurant1/500/300"},
            {"name": "airport", "url": "https://picsum.photos/seed/airport1/500/300"}
        ]

        # 构建数据集:每个类别包含5种描述和3组问答对
        self.data = []
        for cat in self.categories:
            # 为每个类别生成5种不同的图像描述(增强数据多样性)
            descriptions = [
                f"A {cat['name']} scene with typical features",
                f"The {cat['name']} showing natural details",
                f"An image of {cat['name']} with clear views",
                f"View of {cat['name']} in daylight",
                f"Close-up of {cat['name']} key elements"
            ]

            # 为每个类别设计3组英文问答对(覆盖不同类型的问题)
            qa_pairs = [
                {
                    "question": f"What is the main subject of this {cat['name']} image?",  # 主体识别
                    "answer": f"The main subject is a {cat['name']}."
                },
                {
                    "question": f"What features are typical of this {cat['name']}?",  # 特征描述
                    "answer": f"Typical features include {cat['name']}-specific characteristics."
                },
                {
                    "question": f"Where might this {cat['name']} be located?",  # 位置推测
                    "answer": f"This {cat['name']} might be located in its natural environment."
                }
            ]

            # 组合描述和问答对,生成最终数据集
            for desc in descriptions:
                for qa in qa_pairs:
                    self.data.append({
                        "image_url": cat["url"],  # 图像URL
                        "description": desc,  # 图像描述
                        "question": qa["question"],  # 问题
                        "answer": qa["answer"],  # 答案
                        "category": cat["name"]  # 类别标签
                    })

        self.img_size = img_size  # 图像统一尺寸
        self.max_text_length = max_text_length  # 文本最大长度(防止输入过长)

        # 图像预处理管道(含数据增强)
        self.image_transform = transforms.Compose([
            transforms.Resize((self.img_size, self.img_size)),  # 缩放至指定尺寸
            transforms.RandomHorizontalFlip(p=0.3),  # 30%概率水平翻转(数据增强)
            transforms.RandomRotation(5),  # 随机旋转±5度(增强视角鲁棒性)
            transforms.ColorJitter(brightness=0.1, contrast=0.1),  # 微调亮度和对比度
            transforms.ToTensor(),  # 转换为Tensor格式(通道×高度×宽度)
            transforms.Normalize(  # 标准化(使用ImageNet的均值和标准差)
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

        # 初始化文本分词器(适配GPT2模型)
        self.text_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        self.text_tokenizer.pad_token = self.text_tokenizer.eos_token  # 使用终止符作为填充符
        self.text_tokenizer.padding_side = "left"  # 左填充(适合自回归模型)

        self.cached_images = {}  # 缓存已下载的图像(避免重复网络请求)

    def __len__(self):
        """返回数据集样本总数"""
        return len(self.data)

    def __getitem__(self, idx):
        """获取单个样本:图像+文本+标签"""
        item = self.data[idx]

        # 下载并缓存图像(若未缓存)
        if item["image_url"] not in self.cached_images:
            try:
                # 尝试下载图像
                response = requests.get(item["image_url"], timeout=10)
                image = Image.open(BytesIO(response.content)).convert("RGB")  # 转换为RGB格式
                self.cached_images[item["image_url"]] = image
            except Exception as e:
                # 下载失败时,生成与类别相关的纯色图(避免程序崩溃)
                print(f"图像下载失败({item['category']}): {e},使用替代图")
                # 为每个类别分配独特颜色(便于调试)
                color_map = {
                    "cat": (255, 200, 200), "dog": (200, 255, 200), "bird": (200, 200, 255),
                    "city": (255, 255, 200), "mountain": (200, 255, 255), "beach": (255, 220, 180),
                    "forest": (180, 255, 180), "library": (220, 220, 220),
                    "restaurant": (255, 180, 180), "airport": (200, 200, 200)
                }
                color = color_map.get(item["category"], (255, 255, 255))  # 默认白色
                image = Image.new('RGB', (self.img_size, self.img_size), color=color)
                self.cached_images[item["image_url"]] = image

        # 预处理图像
        image = self.cached_images[item["image_url"]]
        image_tensor = self.image_transform(image)

        # 预处理文本(将问答对转换为模型输入格式)
        input_text = f"Question: {item['question']} Answer: {item['answer']}"  # 拼接问题和答案
        text_tokens = self.text_tokenizer(
            input_text,
            padding="max_length",  # 填充至最大长度
            truncation=True,  # 超长则截断
            max_length=self.max_text_length,
            return_tensors="pt"  # 返回PyTorch张量
        )

        return {
            "image": image_tensor,  # 预处理后的图像张量
            "input_ids": text_tokens["input_ids"].squeeze(0),  # 文本ID序列(去除batch维度)
            "attention_mask": text_tokens["attention_mask"].squeeze(0),  # 注意力掩码(1表示有效token)
            "question": item["question"],  # 原始问题(用于测试)
            "answer": item["answer"],  # 原始答案(用于对比)
            "category": item["category"]  # 类别标签
        }


# ---------------------------- 多模态模型架构(核心组件) ----------------------------
class ImageEncoder(nn.Module):
    """
    图像编码器:将图像转换为与文本兼容的特征向量
    输入:图像(3×224×224)
    输出:特征向量(768维)
    """

    def __init__(self, output_dim=768):
        super().__init__()
        # 使用预训练的ResNet50作为基础特征提取器
        self.base_model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

        # 移除最后一层全连接层(保留卷积特征提取部分)
        # ResNet50的最后一层是fc层,输出1000类,这里只保留前面的特征提取部分
        feature_extractor_modules = list(self.base_model.children())[:-1]
        self.feature_extractor = nn.Sequential(*feature_extractor_modules)

        # 投影层:将ResNet输出的2048维特征映射到768维(与文本特征维度一致)
        self.projection = nn.Sequential(
            nn.Linear(2048, 1024),  # 降维至1024
            nn.GELU(),  # 高斯误差线性单元(比ReLU更平滑)
            nn.Dropout(0.2),  # Dropout层(防止过拟合)
            nn.Linear(1024, output_dim)  # 最终投影至768维
        )

    def forward(self, images):
        """前向传播:图像→特征向量"""
        # 提取卷积特征:ResNet50输出为[batch_size, 2048, 1, 1]
        conv_features = self.feature_extractor(images)
        # 展平为[batch_size, 2048]
        flattened_features = conv_features.squeeze(-1).squeeze(-1)
        # 投影至768维
        return self.projection(flattened_features)


class CrossModalFusion(nn.Module):
    """
    跨模态融合模块:实现文本特征与图像特征的交互
    核心:通过注意力机制让文本关注图像的关键信息
    """

    def __init__(self, hidden_dim=768, num_heads=8):
        super().__init__()
        self.text_norm = nn.LayerNorm(hidden_dim)  # 文本特征归一化
        self.image_norm = nn.LayerNorm(hidden_dim)  # 图像特征归一化

        # 多头注意力机制(并行处理多个特征子空间)
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_dim,  # 特征维度(768)
            num_heads=num_heads,  # 注意力头数(8,768/8=96,每个头处理96维)
            batch_first=True  # 输入格式为[batch, seq_len, dim]
        )

        # 特征融合层:将文本特征与注意力输出融合
        self.fusion = nn.Linear(hidden_dim * 2, hidden_dim)  # 1536→768
        self.activation = nn.GELU()  # 激活函数

    def forward(self, text_features, image_features):
        """
        输入:
            text_features: [batch_size, seq_len, hidden_dim](文本序列特征)
            image_features: [batch_size, hidden_dim](图像全局特征)
        输出:
            fused_features: [batch_size, seq_len, hidden_dim](融合特征)
        """
        batch_size, seq_len, _ = text_features.shape  # 获取文本序列长度

        # 图像特征归一化并扩展至序列长度(每个文本token都能关注图像)
        image_features = self.image_norm(image_features)
        # 扩展为[batch_size, seq_len, hidden_dim]
        image_expanded = image_features.unsqueeze(1).expand(-1, seq_len, -1)

        # 文本特征通过注意力关注图像特征(交叉注意力)
        # query=文本特征,key=图像特征,value=图像特征
        text_attn, _ = self.attention(
            query=text_features,
            key=image_expanded,
            value=image_expanded
        )

        # 残差连接+层归一化(缓解梯度消失,加速训练)
        text_attn = self.text_norm(text_features + text_attn)

        # 融合原始文本特征和注意力增强特征
        fused = self.activation(self.fusion(torch.cat([text_features, text_attn], dim=-1)))
        return fused


class MultimodalLLM(nn.Module):
    """
    多模态大语言模型:整合图像编码器、文本编码器和跨模态融合模块
    功能:根据图像生成描述文本,或回答与图像相关的问题
    """

    def __init__(self, hidden_dim=768):
        super().__init__()
        # 文本编码器(基于GPT2,预训练语言模型)
        self.text_encoder = GPT2LMHeadModel.from_pretrained("gpt2")
        # 冻结前3层参数(减少训练量,保留预训练语言知识)
        for param in list(self.text_encoder.parameters())[:3]:
            param.requires_grad = False

        # 图像编码器(见上文)
        self.image_encoder = ImageEncoder(output_dim=hidden_dim)
        # 跨模态融合模块(见上文)
        self.cross_modal_fusion = CrossModalFusion(hidden_dim=hidden_dim)
        self.final_norm = nn.LayerNorm(hidden_dim)  # 最终归一化层

    def forward(self, images, input_ids, attention_mask=None):
        """
        前向传播流程:
        1. 提取图像特征
        2. 提取文本特征
        3. 跨模态融合
        4. 生成预测结果
        """
        # 1. 图像特征提取
        image_features = self.image_encoder(images)  # [batch_size, hidden_dim]

        # 2. 文本特征提取(通过GPT2的Transformer层)
        text_outputs = self.text_encoder.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        text_features = text_outputs.last_hidden_state  # [batch_size, seq_len, hidden_dim]

        # 3. 跨模态融合(文本特征+图像特征)
        fused_features = self.cross_modal_fusion(text_features, image_features)
        fused_features = self.final_norm(fused_features)  # 归一化

        # 4. 通过GPT2的语言模型头生成下一个token的概率分布
        return self.text_encoder.lm_head(fused_features)  # [batch, seq_len, vocab_size]


# ---------------------------- 训练与生成函数(模型应用) ----------------------------
def train_model(model, train_loader, val_loader, epochs=10, lr=5e-5):
    """
    训练多模态模型
    参数:
        model: 待训练的模型
        train_loader: 训练数据加载器
        val_loader: 验证数据加载器
        epochs: 训练轮次
        lr: 学习率
    返回:
        model: 训练好的模型
        train_losses: 训练损失曲线
        val_losses: 验证损失曲线
    """
    # 选择计算设备(GPU优先)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)  # 模型移至设备

    # 优化器(AdamW:带权重衰减的Adam,减轻过拟合)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    # 损失函数(交叉熵损失,忽略填充token的损失)
    # 50256是GPT2的eos_token_id(即填充符)
    criterion = nn.CrossEntropyLoss(ignore_index=50256)
    # 学习率调度器(线性预热+衰减)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=len(train_loader),  # 预热步数=1个epoch的迭代次数
        num_training_steps=epochs * len(train_loader)  # 总训练步数
    )

    train_losses = []  # 记录训练损失
    val_losses = []  # 记录验证损失

    for epoch in range(epochs):
        # 训练阶段
        model.train()  # 开启训练模式(启用dropout等)
        total_train_loss = 0
        # 进度条显示训练过程
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")

        for batch in progress_bar:
            # 数据移至设备
            images = batch["image"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)

            # 前向传播:获取模型输出
            outputs = model(images, input_ids, attention_mask)

            # 计算损失(预测下一个token)
            # 输出和标签都偏移一位(预测第i+1个token,基于第1..i个token)
            shift_logits = outputs[..., :-1, :].contiguous()  # 预测序列
            shift_labels = input_ids[..., 1:].contiguous()  # 目标序列
            loss = criterion(
                shift_logits.view(-1, shift_logits.size(-1)),  # 展平为[batch*(seq_len-1), vocab_size]
                shift_labels.view(-1)  # 展平为[batch*(seq_len-1)]
            )

            # 反向传播与参数更新
            optimizer.zero_grad()  # 梯度清零
            loss.backward()  # 计算梯度
            # 梯度裁剪(防止梯度爆炸,大模型训练必备)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()  # 更新参数
            scheduler.step()  # 更新学习率

            total_train_loss += loss.item()
            # 显示当前批次损失
            progress_bar.set_postfix(loss=f"{loss.item():.4f}")

        # 计算平均训练损失
        avg_train_loss = total_train_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        # 验证阶段(不更新参数)
        model.eval()  # 开启评估模式(关闭dropout等)
        total_val_loss = 0
        with torch.no_grad():  # 禁用梯度计算(节省内存)
            for batch in val_loader:
                # 数据移至设备
                images = batch["image"].to(device)
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)

                # 前向传播
                outputs = model(images, input_ids, attention_mask)
                # 计算损失(同训练阶段)
                shift_logits = outputs[..., :-1, :].contiguous()
                shift_labels = input_ids[..., 1:].contiguous()
                total_val_loss += criterion(
                    shift_logits.view(-1, shift_logits.size(-1)),
                    shift_labels.view(-1)
                ).item()

        # 计算平均验证损失
        avg_val_loss = total_val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        print(f"Epoch {epoch + 1} | 训练损失: {avg_train_loss:.4f} | 验证损失: {avg_val_loss:.4f}")

    return model, train_losses, val_losses


def generate_description(model, image, tokenizer, category, max_new_tokens=40):
    """
    根据图像生成描述文本
    参数:
        model: 训练好的模型
        image: 预处理后的图像
        tokenizer: 文本分词器
        category: 图像类别(用于提示词)
        max_new_tokens: 最大新增token数
    返回:
        生成的描述文本
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()  # 评估模式

    # 构建提示词(引导模型生成与类别相关的描述)
    prompt = f"Describe the {category} image in detail: "

    # 编码提示词(不填充,保留原始长度)
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        padding="do_not_pad",
        truncation=False
    )
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)

    # 提取图像特征
    image = image.unsqueeze(0).to(device)  # 增加batch维度
    with torch.no_grad():
        image_features = model.image_encoder(image)  # [1, hidden_dim]

    # 生成文本(核心步骤)
    output = model.text_encoder.generate(
        input_ids=input_ids,  # 提示词ID
        attention_mask=attention_mask,  # 注意力掩码
        max_new_tokens=max_new_tokens,  # 最多生成40个新token
        temperature=0.6,  # 温度参数(控制随机性,值越小越确定)
        num_beams=3,  # Beam搜索宽度(保留3个最优候选)
        no_repeat_ngram_size=2,  # 禁止2-gram重复(减少冗余)
        early_stopping=True,  # 生成终止符时停止
        encoder_hidden_states=image_features.unsqueeze(1)  # 传入图像特征(关键)
    )

    # 解码并清理生成的文本(去除特殊字符和提示词)
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    # 替换非-breaking空格为普通空格,去除提示词
    return generated_text.replace(u'\xa0', ' ').replace(prompt, "").strip()


def answer_question(model, image, question, tokenizer, max_new_tokens=40):
    """
    基于图像回答英文问题
    参数:
        model: 训练好的模型
        image: 预处理后的图像
        question: 英文问题
        tokenizer: 文本分词器
        max_new_tokens: 最大新增token数
    返回:
        生成的答案
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()  # 评估模式

    # 构建问答格式的提示词
    prompt = f"Question: {question} Answer: "

    # 编码提示词
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        padding="do_not_pad",
        truncation=False
    )
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)

    # 提取图像特征
    image = image.unsqueeze(0).to(device)
    with torch.no_grad():
        image_features = model.image_encoder(image)

    # 生成答案
    output = model.text_encoder.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_new_tokens=max_new_tokens,
        temperature=0.5,  # 更低的温度(回答需更确定)
        num_beams=3,
        no_repeat_ngram_size=3,  # 禁止3-gram重复(进一步减少冗余)
        early_stopping=True,
        encoder_hidden_states=image_features.unsqueeze(1)
    )

    # 解码并清理答案
    generated_answer = tokenizer.decode(output[0], skip_special_tokens=True)
    return generated_answer.replace(u'\xa0', ' ').replace(prompt, "").strip()


# ---------------------------- 可视化与评估函数 ----------------------------
def plot_loss_curves(train_losses, val_losses):
    """绘制训练和验证损失曲线,评估模型训练效果"""
    plt.figure(figsize=(10, 5))
    # 绘制训练损失
    plt.plot(range(1, len(train_losses) + 1), train_losses, label="训练损失", marker='o')
    # 绘制验证损失
    plt.plot(range(1, len(val_losses) + 1), val_losses, label="验证损失", marker='s')
    plt.xlabel("训练轮次")
    plt.ylabel("损失值")
    plt.title("训练与验证损失对比")
    plt.legend()
    plt.grid(alpha=0.3)  # 网格线(增强可读性)
    plt.savefig("loss_curves.png", bbox_inches='tight')  # 保存图像
    print("损失对比图已保存为 loss_curves.png")
    plt.close()


def generate_results_table(desc_results, qa_results):
    """生成结果对比表格(包含描述和问答结果)"""
    plt.figure(figsize=(14, 10))
    ax = plt.gca()
    ax.axis('off')  # 关闭坐标轴

    # 创建表格
    table = Table(ax, bbox=[0, 0, 1, 1])  # 表格占满整个图
    # 添加表头
    table.add_cell(0, 0, 0.1, 0.1, text="类别", loc='center', facecolor='lightgray')
    table.add_cell(0, 1, 0.25, 0.1, text="生成描述", loc='center', facecolor='lightgray')
    table.add_cell(0, 2, 0.3, 0.1, text="问题", loc='center', facecolor='lightgray')
    table.add_cell(0, 3, 0.35, 0.1, text="生成答案", loc='center', facecolor='lightgray')

    # 填充表格内容(前6个类别)
    for i in range(min(6, len(desc_results))):
        desc = desc_results[i]
        qa = qa_results[i]

        # 清理文本中的特殊字符
        clean_desc = desc["generated"].replace(u'\xa0', ' ')
        clean_question = qa["question"].replace(u'\xa0', ' ')
        clean_answer = qa["answer"].replace(u'\xa0', ' ')

        # 添加单元格内容
        table.add_cell(i + 1, 0, 0.1, 0.15, text=desc["category"], loc='center')
        table.add_cell(i + 1, 1, 0.25, 0.15, text=clean_desc, loc='left')
        table.add_cell(i + 1, 2, 0.3, 0.15, text=clean_question, loc='left')
        table.add_cell(i + 1, 3, 0.35, 0.15, text=clean_answer, loc='left')

    ax.add_table(table)
    plt.savefig("results_table.png", bbox_inches='tight')  # 保存表格
    print("结果对比表已保存为 results_table.png")
    plt.close()


# ---------------------------- 主函数(完整流程执行) ----------------------------
if __name__ == "__main__":
    # 1. 准备数据集
    print("准备多模态数据集(含问答)...")
    full_dataset = BalancedMultimodalDataset()

    # 划分训练集(80%)和验证集(20%)
    train_size = int(0.8 * len(full_dataset))
    train_dataset, val_dataset = random_split(
        full_dataset, [train_size, len(full_dataset) - train_size]
    )

    # 创建数据加载器(批量加载数据)
    batch_size = 4 if torch.cuda.is_available() else 1  # GPU可用时使用更大批量
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=0
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=0
    )

    # 2. 训练模型
    print("开始训练多模态模型(支持问答)...")
    model = MultimodalLLM()
    model, train_losses, val_losses = train_model(
        model, train_loader, val_loader, epochs=10  # 训练10个轮次
    )
    # 保存训练好的模型权重
    torch.save(model.state_dict(), "multimodal_qa_model.pth")

    # 3. 绘制损失曲线(评估训练效果)
    plot_loss_curves(train_losses, val_losses)

    # 4. 生成测试结果(描述+问答)
    print("生成测试结果...")
    tokenizer = full_dataset.text_tokenizer
    desc_results = []  # 存储描述生成结果
    qa_results = []  # 存储问答结果

    # 每个类别选1个样本测试
    categories = list(set(item["category"] for item in full_dataset.data))  # 去重类别
    for category in categories[:6]:  # 测试前6个类别
        # 找到该类别的样本索引
        sample_idx = next(
            i for i, item in enumerate(full_dataset.data) if item["category"] == category
        )
        sample = full_dataset[sample_idx]  # 获取样本

        # 生成图像描述
        description = generate_description(
            model, sample["image"], tokenizer, category
        )
        desc_results.append({"category": category, "generated": description})

        # 生成问答结果
        question = sample["question"]  # 原始问题
        answer = answer_question(model, sample["image"], question, tokenizer)
        qa_results.append({
            "category": category,
            "question": question,
            "answer": answer
        })

    # 5. 生成结果对比表
    generate_results_table(desc_results, qa_results)

    # 6. 打印部分结果(展示效果)
    print("\n英文问答示例:")
    for i in range(3):
        print(f"\n示例 {i + 1}:")
        print(f"类别: {qa_results[i]['category']}")
        print(f"问题: {qa_results[i]['question']}")
        print(f"生成答案: {qa_results[i]['answer']}")

 


网站公告

今日签到

点亮在社区的每一天
去签到