基于改进扩散模型与注意力机制的影像到转基因数据预测系统
1. 项目概述
本系统利用改进的扩散模型结合注意力机制,从医学影像中预测转基因数据。系统采用PyTorch框架实现,包含数据预处理、模型架构、训练流程和评估指标等完整模块。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.metrics import mean_squared_error, r2_score
from tqdm import tqdm
import os
import math
import random
from datetime import datetime
import argparse
# 设置随机种子确保可复现性
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_seed()
2. 数据加载与预处理
2.1 数据集类实现
class GeneExpressionDataset(Dataset):
"""从影像预测转基因表达数据的自定义数据集"""
def __init__(self, img_dir, gene_file, transform=None, target_transform=None):
"""
Args:
img_dir (string): 包含影像的目录路径
gene_file (string): 包含转基因表达数据的CSV文件路径
transform (callable, optional): 影像变换函数
target_transform (callable, optional): 目标值变换函数
"""
self.img_dir = img_dir
self.gene_data = pd.read_csv(gene_file)
self.transform = transform
self.target_transform = target_transform
# 验证数据完整性
self._validate_data()
def _validate_data(self):
"""验证影像文件与基因数据的对应关系"""
img_ids = [f.split('.')[0] for f in os.listdir(self.img_dir)
gene_ids = self.gene_data['sample_id'].tolist()
missing_img = set(gene_ids) - set(img_ids)
missing_gene = set(img_ids) - set(gene_ids)
if missing_img:
print(f"警告: {
len(missing_img)}个基因数据缺少对应影像")
if missing_gene:
print(f"警告: {
len(missing_gene)}个影像缺少对应基因数据")
# 仅保留同时有影像和基因数据的样本
valid_ids = set(gene_ids) & set(img_ids)
self.gene_data = self.gene_data[self.gene_data['sample_id'].isin(valid_ids)]
def __len__(self):
return len(self.gene_data)
def __getitem__(self, idx):
sample_id = self.gene_data.iloc[idx]['sample_id']
img_path = os.path.join(self.img_dir, f"{
sample_id}.png")
image = Image.open(img_path).convert('RGB')
# 提取转基因表达数据 (跳过sample_id列)
gene_values = self.gene_data.iloc[idx, 1:].values.astype(np.float32)
if self.transform:
image = self.transform(image)
if self.target_transform:
gene_values = self.target_transform(gene_values)
return image, gene_values, sample_id
# 数据增强变换
def get_transforms(img_size=224):
"""获取训练和验证的数据变换"""
train_transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return train_transform, val_transform
# 目标值标准化
class GeneNormalizer:
"""转基因表达数据的标准化器"""
def __init__(self, gene_data):
self.means = gene_data.mean(axis=0)
self.stds = gene_data.std(axis=0)
# 避免除以零
self.stds[self.stds == 0] = 1.0
def __call__(self, gene_values):
return (gene_values - self.means) / self.stds
def inverse_transform(self, normalized_values):
return normalized_values * self.stds + self.means
3. 模型架构
3.1 改进扩散模型核心组件
# 时间步嵌入层
class SinusoidalPositionEmbeddings(nn.Module):
"""扩散模型的时间步嵌入"""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings
# 带注意力机制的残差块
class AttentionResidualBlock(nn.Module):
"""带有注意力机制的残差块"""
def __init__(self, in_channels, out_channels, time_emb_dim, dropout=0.1):
super().__init__()
self.time_mlp = nn.Linear(time_emb_dim, out_channels)
# 第一个卷积块
self.block1 = nn.Sequential(
nn.GroupNorm(8, in_channels),
nn.SiLU(),
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
)
# 第二个卷积块
self.block2 = nn.Sequential(
nn.GroupNorm(8, out_channels),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
)
# 跳跃连接
if in_channels != out_channels:
self.residual = nn.Conv2d(in_channels, out_channels, kernel_size=1)
else:
self.residual = nn.Identity()
# 注意力机制
self.attention = nn.Sequential(
nn.GroupNorm(8, out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels//8, 1),
nn.Conv2d(out_channels//8, out_channels//8, 3, padding=1),
nn.Conv2d(out_channels//8, out_channels, 1),
nn.Sigmoid()
)
def forward(self, x, t):
# 时间嵌入
t_emb = self.time_mlp(t)
t_emb = t_emb[(..., ) + (None, ) * 2] # 添加两个维度以匹配x的形状
# 第一个块
h = self.block1(x)
h = h + t_emb
# 第二个块
h = self.block2(h)
# 跳跃连接
x_res = self.residual(x)
h = h + x_res
# 注意力机制
attn = self.attention(h)
h = h * attn
return h
# 改进的UNet架构
class ImprovedUNet(nn.Module):
"""带有注意力机制和改进结构的UNet模型"""
def __init__(self, in_channels=3, out_channels=1, time_emb_dim=128,
init_channels=64, num_blocks=2, channel_mults=(1, 2, 4, 8),
dropout=0.1):
super().__init__()
# 时间嵌入层
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(time_emb_dim),
nn.Linear(time_emb_dim, time_emb_dim),
nn.SiLU(),
nn.Linear(time_emb_dim, time_emb_dim)
)
# 初始卷积
self.init_conv = nn.Conv2d(in_channels, init_channels, kernel_size=3, padding=1)
# 下采样路径
self.down_blocks = nn.ModuleList()
self.down_attentions = nn.ModuleList()
self.down_pools = nn.ModuleList()
channels = [init_channels]
now_channels = init_channels
# 构建下采样路径
for i, mult in enumerate(channel_mults):
out_channels = init_channels * mult
for _ in range(num_blocks):
self.down_blocks.append(
AttentionResidualBlock(now_channels, out_channels, time_emb_dim, dropout)
)
self.down_attentions.append(SelfAttention(out_channels))
now_channels = out_channels
channels.append(now_channels)
if i != len(channel_mults) - 1:
self.down_pools.append(nn.Conv2d(now_channels, now_channels, kernel_size=3, stride=2, padding=1))
# 瓶颈层
self.bottleneck = nn.ModuleList([
AttentionResidualBlock(now_channels, now_channels, time_emb_dim, dropout),
SelfAttention(now_channels),
AttentionResidualBlock(now_channels, now_channels, time_emb_dim, dropout)
])
# 上采样路径
self.up_blocks = nn.ModuleList()
self.up_attentions = nn.ModuleList()
for i, mult in reversed(list(enumerate(channel_mults))):
out_channels = init_channels * mult
for j in range(num_blocks + 1):
self.up_blocks.append(
AttentionResidualBlock(channels.pop() + now_channels, out_channels, time_emb_dim, dropout)
)
self.up_attentions.append(SelfAttention(out_channels))
now_channels = out_channels
if i != 0:
self.up_blocks.append(nn.ConvTranspose2d(now_channels, now_channels, kernel_size=2, stride=2))
# 最终输出层
self.final_conv = nn.Sequential(
nn.GroupNorm(8, now_channels),
nn.SiLU(),
nn.Conv2d(now_channels, out_channels, kernel_size=1)
)
def forward(self, x, t):
# 时间嵌入
t_emb = self.time_mlp(t)
# 初始卷积
x