Wan2.1 图生视频模型内部协作流程
flyfish
Wan2.1作为一个多模态生成模型,其内部涉及多个子模型的协同工作。
1. 模型架构概览
Wan2.1主要由以下核心组件构成:
- 文本编码器:基于T5的文本理解模型,将prompt转换为语义向量
- 图像编码器:基于DiT的图像理解模型,提取输入图像的视觉特征
- 时空UNet:核心生成模型,基于文本和图像条件生成视频序列
- 帧间对齐模块:确保生成的视频帧之间具有时间连贯性
- 上采样模块:将低分辨率视频提升到目标分辨率
2. 数据流向与协作流程
输入文本(prompt) 输入图像(img)
│ │
▼ ▼
┌─────────────┐ ┌─────────────┐
│ T5文本编码器 │ │ DiT图像编码器│
│ 生成文本嵌入 │ │ 生成图像特征 │
└──────┬──────┘ └──────┬──────┘
│ │
▼ ▼
┌─────────────────────────────┐
│ 交叉注意力融合 │
└─────────────────────────────┘
│
▼
┌─────────────────────────────┐
│ 时空UNet生成器 │
│ ┌───────────────────────┐ │
│ │ 1. 噪声调度器 │ │
│ │ 2. 多帧联合去噪 │ │
│ │ 3. 帧间对齐 │ │
│ └───────────────────────┘ │
└─────────────────────────────┘
│
▼
┌─────────────────────────────┐
│ 超分辨率模块 │
│ ┌───────────────────────┐ │
│ │ 1. 低分辨率视频 │ │
│ │ 2. 渐进式上采样 │ │
│ │ 3. 细节增强 │ │
│ └───────────────────────┘ │
└─────────────────────────────┘
│
▼
┌─────────────────────────────┐
│ 输出视频张量 │
└─────────────────────────────┘
3. 伪代码
3.1 文本编码阶段
# T5文本编码器处理流程
text_embeddings = t5_model.encode(prompt, max_length=77) # 编码为77×768的文本嵌入
text_embeddings = text_projection(text_embeddings) # 投影到模型内部维度
- 功能:将自然语言描述转换为模型可理解的语义向量
- 优化:使用CLIP文本编码器的变种,增强多模态对齐能力
3.2 图像编码阶段
# DiT图像编码器处理流程
image_features = dit_model.encode(img) # 提取图像特征
image_features = image_projection(image_features) # 投影到模型内部维度
image_features = spatial_pooling(image_features) # 空间池化,获取全局特征
- 功能:提取输入图像的视觉特征,作为视频生成的基础
- 优化:使用预训练的DiT-XL/2模型,增强图像理解能力
3.3 多模态融合阶段
# 文本和图像特征融合
conditioning = cross_attention(text_embeddings, image_features) # 交叉注意力机制
conditioning = time_embedding(conditioning, timestep) # 结合时间步嵌入
- 技术:使用Transformer架构的交叉注意力机制
- 作用:将文本语义和图像特征融合为统一的条件表示
3.4 视频生成阶段
# 时空UNet生成过程
noise = torch.randn(batch_size, channels, frames, height, width).to(device) # 初始噪声
# 扩散过程(反向去噪)
for t in reversed(range(num_timesteps)):
timestep_emb = get_timestep_embedding(t) # 当前时间步嵌入
# 预测噪声
noise_pred = unet(
x=noise,
timestep=timestep_emb,
encoder_hidden_states=conditioning
)
# 应用噪声预测更新样本
noise = p_sample(noise, noise_pred, t) # 基于预测噪声更新样本
video = noise # 最终去噪结果即为生成的视频
- 核心技术:
- 时空UNet架构:同时处理空间和时间维度
- 扩散模型:通过逐步去噪生成高质量视频
- 帧间注意力机制:确保视频帧之间的连贯性
3.5 超分辨率阶段
# 视频超分辨率过程
low_res_video = video # 从UNet输出的低分辨率视频
# 渐进式上采样
for i in range(num_upscale_steps):
low_res_video = upsampler_module[i](low_res_video) # 逐级上采样
high_res_video = detail_enhancer(low_res_video) # 细节增强
技术:
- 级联上采样模块:逐步提升视频分辨率
- 残差连接:保留细节信息
- 对抗训练:增强视觉真实性
级联上采样模块中的残差连接与对抗训练
一、级联上采样模块的核心作用
级联上采样模块是视频超分辨率(Video Super-Resolution)的关键组件,其设计目标是将低分辨率视频(如256×256)逐步提升至高分辨率(如1024×1024),同时保持时间维度的连贯性。
核心逻辑:通过多个上采样层的级联(如4级联),每次将分辨率翻倍(×2),最终达到目标尺寸。
二、残差连接(Residual Connection)
1. 技术原理
残差连接是深度学习中的一种架构设计,允许输入直接跳过若干层到达输出,数学表达为:
输出 = 输入 + 非线性变换(输入)
这种设计解决了深层网络的“梯度消失”问题,并能保留原始输入的细节信息。
2. 在视频超分辨率中的作用
细节保留机制:
低分辨率视频中包含的高频细节(如边缘、纹理)在传统上采样中容易丢失,残差连接通过直接传递原始特征,让网络专注于学习“残差信息”(即低分辨率到高分辨率的差异),从而保留原始细节。网络优化:
级联上采样模块通常包含多层卷积,残差连接使梯度能更有效地反向传播,支持更深的网络结构,提升超分辨率质量。
3. 典型结构示例
低分辨率特征 ──┐
▼
卷积层1 ──┐
▼ ┌──────────┐
卷积层2 ───→─┤ 加法操作 │─→ 高分辨率特征
▲ └──────────┘
┘
低分辨率特征 ──┘
三、对抗训练(Adversarial Training)
1. 技术原理
对抗训练源于生成对抗网络(GAN),通过生成器(Generator)和判别器(Discriminator)的博弈提升生成质量:
- 生成器:尝试生成逼真的高分辨率视频
- 判别器:区分生成视频与真实视频
两者相互对抗,最终生成器的输出趋近真实。
2. 在视频超分辨率中的作用
视觉真实性增强:
传统上采样方法(如双三次插值)生成的视频可能模糊或出现伪影,对抗训练通过判别器的监督,迫使生成器学习真实视频的纹理、色彩分布和动态特性,提升视觉真实性。感知质量优化:
判别器通常基于预训练的视觉模型(如VGG)设计,关注人类感知层面的质量(如语义一致性、结构合理性),而非单纯的像素级匹配。
3. 损失函数设计
- 对抗损失:
L_GAN = -E[log(D(G(z)))]
(生成器试图让判别器误判) - 内容损失:
L_content = ||VGG(G(z)) - VGG(x_real)||_2
(特征空间匹配真实视频) - 总损失:
L_total = L_content + λ*L_GAN
(λ为权衡系数)
四、级联上采样+残差连接+对抗训练的协同效应
低分辨率视频 ──→ 级联上采样模块(含残差连接) ──→ 高分辨率视频 ──→ 对抗训练优化
(逐步提升分辨率并保留细节) (增强视觉真实性)
分辨率提升路径:
64×64 → 128×128 → 256×256 → 512×512 → 1024×1024
(4级联)细节保留机制:
每级上采样中的残差连接确保前一级的细节不丢失,例如在从256×256到512×512的过程中,残差连接传递256×256的边缘信息,避免放大后模糊。真实性增强:
对抗训练使最终的1024×1024视频在纹理(如发丝、布料)和动态(如水流、烟雾)上更接近真实世界视频。