基于改进扩散模型与注意力机制的影像到转基因数据预测系统

发布于:2025-07-02 ⋅ 阅读:(26) ⋅ 点赞:(0)

基于改进扩散模型与注意力机制的影像到转基因数据预测系统

1. 项目概述

本系统利用改进的扩散模型结合注意力机制,从医学影像中预测转基因数据。系统采用PyTorch框架实现,包含数据预处理、模型架构、训练流程和评估指标等完整模块。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.metrics import mean_squared_error, r2_score
from tqdm import tqdm
import os
import math
import random
from datetime import datetime
import argparse

# 设置随机种子确保可复现性
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

2. 数据加载与预处理

2.1 数据集类实现

class GeneExpressionDataset(Dataset):
    """从影像预测转基因表达数据的自定义数据集"""
    
    def __init__(self, img_dir, gene_file, transform=None, target_transform=None):
        """
        Args:
            img_dir (string): 包含影像的目录路径
            gene_file (string): 包含转基因表达数据的CSV文件路径
            transform (callable, optional): 影像变换函数
            target_transform (callable, optional): 目标值变换函数
        """
        self.img_dir = img_dir
        self.gene_data = pd.read_csv(gene_file)
        self.transform = transform
        self.target_transform = target_transform
        
        # 验证数据完整性
        self._validate_data()
        
    def _validate_data(self):
        """验证影像文件与基因数据的对应关系"""
        img_ids = [f.split('.')[0] for f in os.listdir(self.img_dir) 
        gene_ids = self.gene_data['sample_id'].tolist()
        
        missing_img = set(gene_ids) - set(img_ids)
        missing_gene = set(img_ids) - set(gene_ids)
        
        if missing_img:
            print(f"警告: {
     len(missing_img)}个基因数据缺少对应影像")
        if missing_gene:
            print(f"警告: {
     len(missing_gene)}个影像缺少对应基因数据")
        
        # 仅保留同时有影像和基因数据的样本
        valid_ids = set(gene_ids) & set(img_ids)
        self.gene_data = self.gene_data[self.gene_data['sample_id'].isin(valid_ids)]
    
    def __len__(self):
        return len(self.gene_data)
    
    def __getitem__(self, idx):
        sample_id = self.gene_data.iloc[idx]['sample_id']
        img_path = os.path.join(self.img_dir, f"{
     sample_id}.png")
        image = Image.open(img_path).convert('RGB')
        
        # 提取转基因表达数据 (跳过sample_id列)
        gene_values = self.gene_data.iloc[idx, 1:].values.astype(np.float32)
        
        if self.transform:
            image = self.transform(image)
            
        if self.target_transform:
            gene_values = self.target_transform(gene_values)
            
        return image, gene_values, sample_id

# 数据增强变换
def get_transforms(img_size=224):
    """获取训练和验证的数据变换"""
    train_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    return train_transform, val_transform

# 目标值标准化
class GeneNormalizer:
    """转基因表达数据的标准化器"""
    
    def __init__(self, gene_data):
        self.means = gene_data.mean(axis=0)
        self.stds = gene_data.std(axis=0)
        # 避免除以零
        self.stds[self.stds == 0] = 1.0
        
    def __call__(self, gene_values):
        return (gene_values - self.means) / self.stds
    
    def inverse_transform(self, normalized_values):
        return normalized_values * self.stds + self.means

3. 模型架构

3.1 改进扩散模型核心组件

# 时间步嵌入层
class SinusoidalPositionEmbeddings(nn.Module):
    """扩散模型的时间步嵌入"""
    
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        
    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

# 带注意力机制的残差块
class AttentionResidualBlock(nn.Module):
    """带有注意力机制的残差块"""
    
    def __init__(self, in_channels, out_channels, time_emb_dim, dropout=0.1):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_channels)
        
        # 第一个卷积块
        self.block1 = nn.Sequential(
            nn.GroupNorm(8, in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        )
        
        # 第二个卷积块
        self.block2 = nn.Sequential(
            nn.GroupNorm(8, out_channels),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        )
        
        # 跳跃连接
        if in_channels != out_channels:
            self.residual = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        else:
            self.residual = nn.Identity()
            
        # 注意力机制
        self.attention = nn.Sequential(
            nn.GroupNorm(8, out_channels),
            nn.SiLU(),
            nn.Conv2d(out_channels, out_channels//8, 1),
            nn.Conv2d(out_channels//8, out_channels//8, 3, padding=1),
            nn.Conv2d(out_channels//8, out_channels, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x, t):
        # 时间嵌入
        t_emb = self.time_mlp(t)
        t_emb = t_emb[(..., ) + (None, ) * 2]  # 添加两个维度以匹配x的形状
        
        # 第一个块
        h = self.block1(x)
        h = h + t_emb
        
        # 第二个块
        h = self.block2(h)
        
        # 跳跃连接
        x_res = self.residual(x)
        h = h + x_res
        
        # 注意力机制
        attn = self.attention(h)
        h = h * attn
        
        return h

# 改进的UNet架构
class ImprovedUNet(nn.Module):
    """带有注意力机制和改进结构的UNet模型"""
    
    def __init__(self, in_channels=3, out_channels=1, time_emb_dim=128, 
                 init_channels=64, num_blocks=2, channel_mults=(1, 2, 4, 8), 
                 dropout=0.1):
        super().__init__()
        
        # 时间嵌入层
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim)
        )
        
        # 初始卷积
        self.init_conv = nn.Conv2d(in_channels, init_channels, kernel_size=3, padding=1)
        
        # 下采样路径
        self.down_blocks = nn.ModuleList()
        self.down_attentions = nn.ModuleList()
        self.down_pools = nn.ModuleList()
        
        channels = [init_channels]
        now_channels = init_channels
        
        # 构建下采样路径
        for i, mult in enumerate(channel_mults):
            out_channels = init_channels * mult
            for _ in range(num_blocks):
                self.down_blocks.append(
                    AttentionResidualBlock(now_channels, out_channels, time_emb_dim, dropout)
                )
                self.down_attentions.append(SelfAttention(out_channels))
                now_channels = out_channels
                channels.append(now_channels)
            if i != len(channel_mults) - 1:
                self.down_pools.append(nn.Conv2d(now_channels, now_channels, kernel_size=3, stride=2, padding=1))
        
        # 瓶颈层
        self.bottleneck = nn.ModuleList([
            AttentionResidualBlock(now_channels, now_channels, time_emb_dim, dropout),
            SelfAttention(now_channels),
            AttentionResidualBlock(now_channels, now_channels, time_emb_dim, dropout)
        ])
        
        # 上采样路径
        self.up_blocks = nn.ModuleList()
        self.up_attentions = nn.ModuleList()
        
        for i, mult in reversed(list(enumerate(channel_mults))):
            out_channels = init_channels * mult
            for j in range(num_blocks + 1):
                self.up_blocks.append(
                    AttentionResidualBlock(channels.pop() + now_channels, out_channels, time_emb_dim, dropout)
                )
                self.up_attentions.append(SelfAttention(out_channels))
                now_channels = out_channels
            if i != 0:
                self.up_blocks.append(nn.ConvTranspose2d(now_channels, now_channels, kernel_size=2, stride=2))
        
        # 最终输出层
        self.final_conv = nn.Sequential(
            nn.GroupNorm(8, now_channels),
            nn.SiLU(),
            nn.Conv2d(now_channels, out_channels, kernel_size=1)
        )
    
    def forward(self, x, t):
        # 时间嵌入
        t_emb = self.time_mlp(t)
        
        # 初始卷积
        x