Wan2.1 图生视频 多卡推理批量生成视频

发布于:2025-05-27 ⋅ 阅读:(105) ⋅ 点赞:(0)

Wan2.1 图生视频 多卡推理批量生成视频

flyfish

视频生成的实践效果展示

Phantom 视频生成的实践
Phantom 视频生成的流程
Phantom 视频生成的命令

Wan2.1 图生视频 支持批量生成
Wan2.1 文生视频 支持批量生成、参数化配置和多语言提示词管理
Wan2.1 加速推理方法
Wan2.1 通过首尾帧生成视频

AnyText2 在图片里玩文字而且还是所想即所得
Python 实现从 MP4 视频文件中平均提取指定数量的帧

config.json

{
    "task": "i2v-14B",
    "size": "832*480",
    "frame_num": null,
    "ckpt_dir": "/media/models/Wan-AI/Wan2___1-I2V-14B-480P/",
    "offload_model": null,
    "ulysses_size": 2,
    "ring_size": 1,
    "t5_fsdp": false,
    "t5_cpu": true,
    "dit_fsdp": true,
    "save_file": null,
    "prompt": null,
    "use_prompt_extend": false,
    "prompt_extend_method": "local_qwen",
    "prompt_extend_model": null,
    "prompt_extend_target_lang": "zh",
    "base_seed": -1,
    "image": null,
    "first_frame": null,
    "last_frame": null,
    "sample_solver": "unipc",
    "sample_steps": null,
    "sample_shift": null,
    "sample_guide_scale": 5.0
}

prompt.json

[
  {
    "prompt": "Dragon Playing with Pearl: A warrior wields a red-tasseled spear, summoning seven dragon-like phantom spear tips amid swirling ink shadows that twist air into a shredding vortex; visuals include ink-black shadows, molten fire-red tassel, and a violent air vortex. ",
    "image_paths": ["images/1.png"]
  },
    {
    "prompt": "Slicing the Sky, Chopping the Moon: The warrior leaps, slashing the spear diagonally like lightning to create a glowing vacuum rift with azure electricity, then traces a lunar arc that solidifies space to trap enemies; visuals feature a billowing black cape, crackling rift, and frozen lunar arc. ",
    "image_paths": ["images/1.png"]
  }
]

流程

WanI2VApp.run()
├─ 主应用启动
├─ 加载配置/验证参数
│  ├─ 设置 frame_num=81 等默认值
│  └─ 校验任务和分辨率合法性
├─ 初始化分布式环境
│  ├─ 多GPU时启动进程组
│  ├─ 同步随机种子
│  └─ 验证分布式参数
├─ 模型单例加载(核心优化点)
│  ├─ 创建 WanI2V 模型实例
│  ├─ 加载 checkpoint 到 GPU
│  └─ 日志:"Creating WanI2V pipeline (first time)."
└─ 批量处理图片循环(N张图片)
   ├─ 读取 prompt 和 image_paths
   ├─ 对每张图片:
   │  ├─ 打开图片并转换格式
   │  ├─ 提示词扩展处理
   │  │  ├─ 调用 DashScope/Qwen 扩展器
   │  │  ├─ 分布式环境广播扩展结果
   │  │  └─ 失败时回退到原始提示词
   │  ├─ 复用模型推理
   │  │  ├─ 调用 model.generate() 方法
   │  │  ├─ 传入分辨率、帧数等参数
   │  │  └─ 日志:"Generating video with existing model."
   │  └─ 保存视频
   │     ├─ 生成默认文件名(含时间戳和提示词)
   │     └─ 调用 cache_video 保存为 MP4
   └─ 模型资源清理(主进程执行)
      ├─ 删除模型实例(del self.model)
      ├─ 清理 GPU 缓存(torch.cuda.empty_cache())
      └─ 日志:"Model resources cleaned up."

模型加载的时序图

┌──────────────────────────────────────────────────────────┐
│                      WanI2VApp.run()                      │
│  ┌─────────────────┐  ┌─────────────────┐  ┌────────────┐ │
│  │ 加载配置/验证参数 │  │ 初始化分布式环境 │  │ 加载模型   │ │
│  └─────────────────┘  └─────────────────┘  └──────┬─────┘ │
│                                                     │     │
│  ┌───────────────────────────────────────────────┐  │     │
│  │  遍历 prompt.json 中的每个 prompt 和 image   │  │     │
│  ├───────────────────────────────────────────────┤  │     │
│  │  ┌────────────┐  ┌────────────┐  ┌──────────┐  │     │
│  │  │ 处理提示词 │  │ 推理生成视频 │  │ 保存视频 │  │     │
│  │  └────────────┘  └────────────┘  └──────────┘  │     │
│  └───────────────────────────────────────────────┘  │     │
│                                                     │     │
│  ┌──────────────────────────┐                       │     │
│  │ 清理模型资源(仅主进程) │                       │     │
│  └──────────────────────────┘                       │     │
└──────────────────────────────────────────────────────────┘
          ↑                  ↑                  ↑
          │                  │                  │
    模型首次加载           复用模型推理         释放模型资源

代码

import argparse
from datetime import datetime
import logging
import os
import sys
import warnings
import json

warnings.filterwarnings('ignore')

import torch, random
import torch.distributed as dist
from PIL import Image

import wan
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video, cache_image, str2bool

class ArgsValidator:
    @staticmethod
    def validate(args):
        # Basic check
        assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
        assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"

        # The default sampling steps are 40 for image-to-video tasks.
        if args.sample_steps is None:
            args.sample_steps = 40

        if args.sample_shift is None:
            args.sample_shift = 3.0 if args.size in ["832*480", "480*832"] else 5.0

        # The default number of frames are 81.
        if args.frame_num is None:
            args.frame_num = 81

        args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
            0, sys.maxsize)
        # Size check
        assert args.size in SUPPORTED_SIZES[
            args.task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
        return args

class ConfigLoader:
    @staticmethod
    def load_config():
        # 从配置文件读取参数
        with open('config.json', 'r') as f:
            config = json.load(f)

        # 创建一个命名空间来存储参数
        class ArgsNamespace:
            def __init__(self, **kwargs):
                self.__dict__.update(kwargs)

        args = ArgsNamespace(**config)
        return args

class LoggerInitializer:
    @staticmethod
    def initialize(rank):
        # logging
        if rank == 0:
            # set format
            logging.basicConfig(
                level=logging.INFO,
                format="[%(asctime)s] %(levelname)s: %(message)s",
                handlers=[logging.StreamHandler(stream=sys.stdout)])
        else:
            logging.basicConfig(level=logging.ERROR)

class DistributedEnv:
    def __init__(self, args):
        self.args = args
        self.rank = int(os.getenv("RANK", 0))
        self.world_size = int(os.getenv("WORLD_SIZE", 1))
        self.local_rank = int(os.getenv("LOCAL_RANK", 0))
        self.device = self.local_rank
        
    def initialize(self):
        if self.args.offload_model is None:
            self.args.offload_model = False if self.world_size > 1 else True
            logging.info(
                f"offload_model is not specified, set to {self.args.offload_model}.")
                
        if self.world_size > 1:
            torch.cuda.set_device(self.local_rank)
            dist.init_process_group(
                backend="nccl",
                init_method="env://",
                rank=self.rank,
                world_size=self.world_size)
        else:
            assert not (
                self.args.t5_fsdp or self.args.dit_fsdp
            ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
            assert not (
                self.args.ulysses_size > 1 or self.args.ring_size > 1
            ), f"context parallel are not supported in non-distributed environments."

        if self.args.ulysses_size > 1 or self.args.ring_size > 1:
            assert self.args.ulysses_size * self.args.ring_size == self.world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
            from xfuser.core.distributed import (initialize_model_parallel,
                                                 init_distributed_environment)
            init_distributed_environment(
                rank=dist.get_rank(), world_size=dist.get_world_size())

            initialize_model_parallel(
                sequence_parallel_degree=dist.get_world_size(),
                ring_degree=self.args.ring_size,
                ulysses_degree=self.args.ulysses_size,
            )
            
        if dist.is_initialized():
            base_seed = [self.args.base_seed] if self.rank == 0 else [None]
            dist.broadcast_object_list(base_seed, src=0)
            self.args.base_seed = base_seed[0]
            
        return self.args, self.rank, self.device

class PromptProcessor:
    def __init__(self, args, rank, device):
        self.args = args
        self.rank = rank
        self.device = device
        
    def process(self, img):
        if not self.args.use_prompt_extend:
            return self.args.prompt
            
        logging.info("Extending prompt ...")
        if self.rank == 0:
            if self.args.prompt_extend_method == "dashscope":
                prompt_expander = DashScopePromptExpander(
                    model_name=self.args.prompt_extend_model, is_vl=True)
            elif self.args.prompt_extend_method == "local_qwen":
                prompt_expander = QwenPromptExpander(
                    model_name=self.args.prompt_extend_model,
                    is_vl=True,
                    device=self.rank)
            else:
                raise NotImplementedError(
                    f"Unsupport prompt_extend_method: {self.args.prompt_extend_method}")
            
            prompt_output = prompt_expander(
                self.args.prompt,
                tar_lang=self.args.prompt_extend_target_lang,
                image=img,
                seed=self.args.base_seed)
                
            if prompt_output.status == False:
                logging.info(
                    f"Extending prompt failed: {prompt_output.message}")
                logging.info("Falling back to original prompt.")
                input_prompt = self.args.prompt
            else:
                input_prompt = prompt_output.prompt
            input_prompt = [input_prompt]
        else:
            input_prompt = [None]
            
        if dist.is_initialized():
            dist.broadcast_object_list(input_prompt, src=0)
            
        self.args.prompt = input_prompt[0]
        logging.info(f"Extended prompt: {self.args.prompt}")
        return self.args.prompt

class VideoGenerator:
    _instance = None  # 单例实例
    
    @classmethod
    def get_instance(cls, args, rank, device):
        # 如果实例不存在,创建新实例
        if cls._instance is None:
            cls._instance = cls(args, rank, device)
        return cls._instance
    
    def __init__(self, args, rank, device):
        # 初始化只执行一次
        self.args = args
        self.rank = rank
        self.device = device
        self.cfg = WAN_CONFIGS[args.task]
        
        # 加载模型
        logging.info("Creating WanI2V pipeline (first time).")
        self.model = wan.WanI2V(
            config=self.cfg,
            checkpoint_dir=self.args.ckpt_dir,
            device_id=self.device,
            rank=self.rank,
            t5_fsdp=self.args.t5_fsdp,
            dit_fsdp=self.args.dit_fsdp,
            use_usp=(self.args.ulysses_size > 1 or self.args.ring_size > 1),
            t5_cpu=self.args.t5_cpu,
        )
    
    def generate(self, prompt, img):
        # 复用已加载的模型进行推理
        logging.info("Generating video with existing model.")
        video = self.model.generate(
            prompt,
            img,
            max_area=MAX_AREA_CONFIGS[self.args.size],
            frame_num=self.args.frame_num,
            shift=self.args.sample_shift,
            sample_solver=self.args.sample_solver,
            sampling_steps=self.args.sample_steps,
            guide_scale=self.args.sample_guide_scale,
            seed=self.args.base_seed,
            offload_model=self.args.offload_model)
            
        return video
    
    @classmethod
    def cleanup(cls):
        # 清理模型资源(如在应用结束时调用)
        if cls._instance and hasattr(cls._instance, 'model'):
            del cls._instance.model
            torch.cuda.empty_cache()
            logging.info("Model resources cleaned up.")
        cls._instance = None

class VideoSaver:
    def __init__(self, args, rank):
        self.args = args
        self.rank = rank
        self.cfg = WAN_CONFIGS[args.task]
        
    def save(self, video):
        if self.rank != 0:
            return
            
        if self.args.save_file is None:
            formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
            formatted_prompt = self.args.prompt.replace(" ", "_").replace("/",
                                                                     "_")[:50]
            suffix = '.mp4'
            self.args.save_file = f"{self.args.task}_{self.args.size.replace('*','x') if sys.platform=='win32' else self.args.size}_{self.args.ulysses_size}_{self.args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix

        logging.info(f"Saving generated video to {self.args.save_file}")
        cache_video(
            tensor=video[None],
            save_file=self.args.save_file,
            fps=self.cfg.sample_fps,
            nrow=1,
            normalize=True,
            value_range=(-1, 1))

class WanI2VApp:
    def __init__(self):
        self.args = None
        self.rank = 0
        self.device = 0
        
    def run(self):
        # 加载配置
        config_loader = ConfigLoader()
        self.args = config_loader.load_config()
        
        # 验证参数
        validator = ArgsValidator()
        self.args = validator.validate(self.args)
        
        # 初始化日志
        LoggerInitializer.initialize(self.rank)
        
        # 初始化分布式环境
        dist_env = DistributedEnv(self.args)
        self.args, self.rank, self.device = dist_env.initialize()
        
        logging.info(f"Generation job args: {self.args}")
        logging.info(f"Generation model config: {WAN_CONFIGS[self.args.task]}")
        
        # 获取单例模型生成器(只加载一次模型)
        generator = VideoGenerator.get_instance(self.args, self.rank, self.device)
        
        # 从prompt.json文件读取prompt和image_paths
        with open('prompt.json', 'r') as f:
            prompt_list = json.load(f)

        for prompt_info in prompt_list:
            self.args.prompt = prompt_info["prompt"]
            image_paths = prompt_info["image_paths"]

            for image_path in image_paths:
                logging.info(f"Input prompt: {self.args.prompt}")
                logging.info(f"Input image: {image_path}")

                img = Image.open(image_path).convert("RGB")
                
                # 处理prompt
                prompt_processor = PromptProcessor(self.args, self.rank, self.device)
                prompt = prompt_processor.process(img)
                
                # 复用已加载的模型生成视频
                video = generator.generate(prompt, img)
                
                # 保存视频
                saver = VideoSaver(self.args, self.rank)
                saver.save(video)
        
        # 清理模型资源(可选,在所有推理完成后调用)
        if self.rank == 0:
            VideoGenerator.cleanup()

        logging.info("Finished.")

if __name__ == "__main__":
    app = WanI2VApp()
    app.run()

执行流程

1. 主应用初始化与配置加载

WanI2VApp 启动:创建主应用实例并调用 run() 方法。
ConfigLoader 加载配置:从 config.json 读取参数(如 taskckpt_dir 等)。
ArgsValidator 验证参数:设置默认值(如 frame_num=81)并校验合法性。

2. 环境与资源初始化

LoggerInitializer 初始化日志:主进程(rank=0)输出INFO,其他进程输出ERROR。
DistributedEnv 初始化分布式环境

  • 多GPU时启动进程组(dist.init_process_group)。
  • 同步随机种子(base_seed)确保结果可复现。
3. 模型单例加载(核心优化点)

VideoGenerator.get_instance() 调用

  • 首次调用时,创建单例实例并加载模型(wan.WanI2V)。
  • 日志提示:Creating WanI2V pipeline (first time).
  • 模型加载完成后,实例保存在 VideoGenerator._instance 中。
4. 批量处理提示词与图片

读取 prompt.json:遍历所有 promptimage_paths
PromptProcessor 扩展提示词

  • 对每张图片,使用 DashScopeQwen 扩展提示词。
  • 扩展失败时回退到原始提示词。
    VideoGenerator.generate() 推理
  • 复用已加载的模型实例(self.model)。
  • 日志提示:Generating video with existing model.
  • 每次推理仅执行计算,不重复加载模型。
5. 结果保存与资源清理

VideoSaver 保存视频:主进程将结果保存为MP4文件。
VideoGenerator.cleanup() 释放资源
- 应用结束时删除模型实例(del self.model)。
- 调用 torch.cuda.empty_cache() 清理GPU缓存。
- 日志提示:Model resources cleaned up.


网站公告

今日签到

点亮在社区的每一天
去签到