Wan2.1 图生视频模型内部协作流程

发布于:2025-05-27 ⋅ 阅读:(27) ⋅ 点赞:(0)

Wan2.1 图生视频模型内部协作流程

flyfish

Wan2.1作为一个多模态生成模型,其内部涉及多个子模型的协同工作。

1. 模型架构概览

Wan2.1主要由以下核心组件构成:

  1. 文本编码器:基于T5的文本理解模型,将prompt转换为语义向量
  2. 图像编码器:基于DiT的图像理解模型,提取输入图像的视觉特征
  3. 时空UNet:核心生成模型,基于文本和图像条件生成视频序列
  4. 帧间对齐模块:确保生成的视频帧之间具有时间连贯性
  5. 上采样模块:将低分辨率视频提升到目标分辨率

2. 数据流向与协作流程

输入文本(prompt)     输入图像(img)
      │                  │
      ▼                  ▼
┌─────────────┐    ┌─────────────┐
│ T5文本编码器 │    │ DiT图像编码器│
│ 生成文本嵌入  │    │ 生成图像特征 │
└──────┬──────┘    └──────┬──────┘
       │                  │
       ▼                  ▼
┌─────────────────────────────┐
│        交叉注意力融合        │
└─────────────────────────────┘
                │
                ▼
┌─────────────────────────────┐
│       时空UNet生成器        │
│  ┌───────────────────────┐  │
│  │ 1. 噪声调度器         │  │
│  │ 2. 多帧联合去噪       │  │
│  │ 3. 帧间对齐           │  │
│  └───────────────────────┘  │
└─────────────────────────────┘
                │
                ▼
┌─────────────────────────────┐
│       超分辨率模块          │
│  ┌───────────────────────┐  │
│  │ 1. 低分辨率视频       │  │
│  │ 2. 渐进式上采样       │  │
│  │ 3. 细节增强           │  │
│  └───────────────────────┘  │
└─────────────────────────────┘
                │
                ▼
┌─────────────────────────────┐
│        输出视频张量         │
└─────────────────────────────┘

3. 伪代码

3.1 文本编码阶段
# T5文本编码器处理流程
text_embeddings = t5_model.encode(prompt, max_length=77)  # 编码为77×768的文本嵌入
text_embeddings = text_projection(text_embeddings)  # 投影到模型内部维度
  • 功能:将自然语言描述转换为模型可理解的语义向量
  • 优化:使用CLIP文本编码器的变种,增强多模态对齐能力
3.2 图像编码阶段
# DiT图像编码器处理流程
image_features = dit_model.encode(img)  # 提取图像特征
image_features = image_projection(image_features)  # 投影到模型内部维度
image_features = spatial_pooling(image_features)  # 空间池化,获取全局特征
  • 功能:提取输入图像的视觉特征,作为视频生成的基础
  • 优化:使用预训练的DiT-XL/2模型,增强图像理解能力
3.3 多模态融合阶段
# 文本和图像特征融合
conditioning = cross_attention(text_embeddings, image_features)  # 交叉注意力机制
conditioning = time_embedding(conditioning, timestep)  # 结合时间步嵌入
  • 技术:使用Transformer架构的交叉注意力机制
  • 作用:将文本语义和图像特征融合为统一的条件表示
3.4 视频生成阶段
# 时空UNet生成过程
noise = torch.randn(batch_size, channels, frames, height, width).to(device)  # 初始噪声

# 扩散过程(反向去噪)
for t in reversed(range(num_timesteps)):
    timestep_emb = get_timestep_embedding(t)  # 当前时间步嵌入
    
    # 预测噪声
    noise_pred = unet(
        x=noise,
        timestep=timestep_emb,
        encoder_hidden_states=conditioning
    )
    
    # 应用噪声预测更新样本
    noise = p_sample(noise, noise_pred, t)  # 基于预测噪声更新样本
    
video = noise  # 最终去噪结果即为生成的视频
  • 核心技术
    • 时空UNet架构:同时处理空间和时间维度
    • 扩散模型:通过逐步去噪生成高质量视频
    • 帧间注意力机制:确保视频帧之间的连贯性
3.5 超分辨率阶段
# 视频超分辨率过程
low_res_video = video  # 从UNet输出的低分辨率视频

# 渐进式上采样
for i in range(num_upscale_steps):
    low_res_video = upsampler_module[i](low_res_video)  # 逐级上采样
    
high_res_video = detail_enhancer(low_res_video)  # 细节增强

技术

  • 级联上采样模块:逐步提升视频分辨率
  • 残差连接:保留细节信息
  • 对抗训练:增强视觉真实性

级联上采样模块中的残差连接与对抗训练

一、级联上采样模块的核心作用

级联上采样模块是视频超分辨率(Video Super-Resolution)的关键组件,其设计目标是将低分辨率视频(如256×256)逐步提升至高分辨率(如1024×1024),同时保持时间维度的连贯性。
核心逻辑:通过多个上采样层的级联(如4级联),每次将分辨率翻倍(×2),最终达到目标尺寸。

二、残差连接(Residual Connection)
1. 技术原理

残差连接是深度学习中的一种架构设计,允许输入直接跳过若干层到达输出,数学表达为:
输出 = 输入 + 非线性变换(输入)
这种设计解决了深层网络的“梯度消失”问题,并能保留原始输入的细节信息。

2. 在视频超分辨率中的作用
  • 细节保留机制
    低分辨率视频中包含的高频细节(如边缘、纹理)在传统上采样中容易丢失,残差连接通过直接传递原始特征,让网络专注于学习“残差信息”(即低分辨率到高分辨率的差异),从而保留原始细节。

  • 网络优化
    级联上采样模块通常包含多层卷积,残差连接使梯度能更有效地反向传播,支持更深的网络结构,提升超分辨率质量。

3. 典型结构示例
低分辨率特征 ──┐
              ▼
      卷积层1 ──┐
              ▼    ┌──────────┐
      卷积层2 ───→─┤ 加法操作 │─→ 高分辨率特征
              ▲    └──────────┘
              ┘
低分辨率特征 ──┘
三、对抗训练(Adversarial Training)
1. 技术原理

对抗训练源于生成对抗网络(GAN),通过生成器(Generator)和判别器(Discriminator)的博弈提升生成质量:

  • 生成器:尝试生成逼真的高分辨率视频
  • 判别器:区分生成视频与真实视频
    两者相互对抗,最终生成器的输出趋近真实。
2. 在视频超分辨率中的作用
  • 视觉真实性增强
    传统上采样方法(如双三次插值)生成的视频可能模糊或出现伪影,对抗训练通过判别器的监督,迫使生成器学习真实视频的纹理、色彩分布和动态特性,提升视觉真实性。

  • 感知质量优化
    判别器通常基于预训练的视觉模型(如VGG)设计,关注人类感知层面的质量(如语义一致性、结构合理性),而非单纯的像素级匹配。

3. 损失函数设计
  • 对抗损失
    L_GAN = -E[log(D(G(z)))](生成器试图让判别器误判)
  • 内容损失
    L_content = ||VGG(G(z)) - VGG(x_real)||_2(特征空间匹配真实视频)
  • 总损失
    L_total = L_content + λ*L_GAN(λ为权衡系数)
四、级联上采样+残差连接+对抗训练的协同效应
低分辨率视频 ──→ 级联上采样模块(含残差连接) ──→ 高分辨率视频 ──→ 对抗训练优化
               (逐步提升分辨率并保留细节)        (增强视觉真实性)
  1. 分辨率提升路径
    64×64 → 128×128 → 256×256 → 512×512 → 1024×1024(4级联)

  2. 细节保留机制
    每级上采样中的残差连接确保前一级的细节不丢失,例如在从256×256到512×512的过程中,残差连接传递256×256的边缘信息,避免放大后模糊。

  3. 真实性增强
    对抗训练使最终的1024×1024视频在纹理(如发丝、布料)和动态(如水流、烟雾)上更接近真实世界视频。


网站公告

今日签到

点亮在社区的每一天
去签到