SkyReels-V2 视频生成

发布于:2025-05-14 ⋅ 阅读:(9) ⋅ 点赞:(0)

SkyReels-V2 视频生成

flyfish

扩散强制(DF)模型:专为无限长度视频生成设计,提供1.3B-540P和14B-720P等版本
文本到视频(T2V)模型:专注于从文本提示生成高质量视频
图像到视频(I2V)模型:能够从输入图像生成连贯的视频序列

import os
# 设置TOKENIZERS_PARALLELISM为false,避免分词器并行化可能带来的问题
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import argparse
import gc
import os
import random
import time
import json

import imageio
import torch
from diffusers.utils import load_image
from skyreels_v2_infer import DiffusionForcingPipeline
from skyreels_v2_infer.pipelines import PromptEnhancer
from skyreels_v2_infer.pipelines import resizecrop


# 单例模式元类,确保类只有一个实例
class Singleton(type):
    _instances = {}

    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            cls._instances[cls] = super().__call__(*args, **kwargs)
        return cls._instances[cls]


# 配置解析类,用于解析命令行参数
class ConfigParser:
    def __init__(self):
        self.parser = argparse.ArgumentParser()
        self._add_arguments()

    def _add_arguments(self):
        # 输出目录
        self.parser.add_argument("--outdir", type=str, default="diffusion_forcing")
        # 模型ID
        self.parser.add_argument(
            "--model_id",
            type=str,
            default="/media/models/Skywork/SkyReels-V2-DF-1___3B-540P/",
        )
        # 分辨率
        self.parser.add_argument(
            "--resolution", type=str, default="540P", choices=["540P", "720P"]
        )
        # 帧数
        self.parser.add_argument("--num_frames", type=int, default=97)
        # 图像路径
        self.parser.add_argument("--image", type=str, default=None)
        # AR步骤
        self.parser.add_argument("--ar_step", type=int, default=0)
        # 是否使用因果注意力
        self.parser.add_argument("--causal_attention", action="store_true")
        # 因果块大小
        self.parser.add_argument("--causal_block_size", type=int, default=1)
        # 基础帧数
        self.parser.add_argument("--base_num_frames", type=int, default=97)
        # 重叠历史
        self.parser.add_argument("--overlap_history", type=int, default=None)
        # 添加噪声条件
        self.parser.add_argument("--addnoise_condition", type=int, default=0)
        # 引导比例
        self.parser.add_argument("--guidance_scale", type=float, default=6.0)
        # 偏移量
        self.parser.add_argument("--shift", type=float, default=8.0)
        # 推理步骤
        self.parser.add_argument("--inference_steps", type=int, default=30)  # 30
        # 是否使用USP
        self.parser.add_argument("--use_usp", action="store_true")
        # 是否卸载
        self.parser.add_argument("--offload", action="store_true")
        # 帧率
        self.parser.add_argument("--fps", type=int, default=24)
        # 随机种子
        self.parser.add_argument("--seed", type=int, default=None)
        # 提示文件
        self.parser.add_argument("--prompt", type=str, default="prompt.json")
        # 是否使用提示增强器
        self.parser.add_argument("--prompt_enhancer", action="store_true")
        # 是否使用TEA缓存
        self.parser.add_argument("--teacache", action="store_true")
        # TEA缓存阈值
        self.parser.add_argument(
            "--teacache_thresh",
            type=float,
            default=0.2,
            help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup",
        )
        # 是否使用保留步骤
        self.parser.add_argument(
            "--use_ret_steps",
            action="store_true",
            help="Using Retention Steps will result in faster generation speed and better generation quality.",
        )

    def parse(self):
        return self.parser.parse_args()


# 环境设置类,用于设置运行环境
class EnvironmentSetup:
    def __init__(self, args):
        self.args = args
        self._validate_seed()
        self._set_resolution()
        self._validate_num_frames()
        self._validate_addnoise_condition()
        # 负提示词
        self.negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
        self._create_save_dir()
        self._setup_usp()

    def _validate_seed(self):
        # 验证种子是否有效,USP模式需要种子
        assert (self.args.use_usp and self.args.seed is not None) or (
            not self.args.use_usp
        ), "usp mode need seed"
        if self.args.seed is None:
            random.seed(time.time())
            self.args.seed = int(random.randrange(4294967294))

    def _set_resolution(self):
        # 根据分辨率参数设置高度和宽度
        if self.args.resolution == "540P":
            self.height = 544
            self.width = 960
        elif self.args.resolution == "720P":
            self.height = 720
            self.width = 1280
        else:
            raise ValueError(f"Invalid resolution: {self.args.resolution}")

    def _validate_num_frames(self):
        # 验证帧数是否有效,长视频生成需要指定重叠历史
        if self.args.num_frames > self.args.base_num_frames:
            assert (
                self.args.overlap_history is not None
            ), 'You are supposed to specify the "overlap_history" to support the long video generation. 17 and 37 are recommanded to set.'

    def _validate_addnoise_condition(self):
        # 验证添加噪声条件是否有效,值过大可能导致长视频生成不一致
        if self.args.addnoise_condition > 60:
            print(
                f'You have set "addnoise_condition" as {self.args.addnoise_condition}. The value is too large which can cause inconsistency in long video generation. The value is recommanded to set 20.'
            )

    def _create_save_dir(self):
        # 创建保存目录
        self.save_dir = os.path.join("result", self.args.outdir)
        os.makedirs(self.save_dir, exist_ok=True)

    def _setup_usp(self):
        self.local_rank = 0
        if self.args.use_usp:
            # USP模式下不允许使用提示增强器
            assert (
                not self.args.prompt_enhancer
            ), "`--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter."
            from xfuser.core.distributed import (
                initialize_model_parallel,
                init_distributed_environment,
            )
            import torch.distributed as dist

            # 初始化分布式环境
            dist.init_process_group("nccl")
            self.local_rank = dist.get_rank()
            torch.cuda.set_device(dist.get_rank())
            self.device = "cuda"

            init_distributed_environment(
                rank=dist.get_rank(), world_size=dist.get_world_size()
            )

            initialize_model_parallel(
                sequence_parallel_degree=dist.get_world_size(),
                ring_degree=1,
                ulysses_degree=dist.get_world_size(),
            )


# 管道设置类,用于创建和配置DiffusionForcingPipeline
class PipelineSetup(metaclass=Singleton):
    def __init__(self, args):
        self.pipe = DiffusionForcingPipeline(
            args.model_id,
            dit_path=args.model_id,
            device=torch.device("cuda"),
            weight_dtype=torch.bfloat16,
            use_usp=args.use_usp,
            offload=args.offload,
        )

        if args.causal_attention:
            # 设置因果注意力
            self.pipe.transformer.set_ar_attention(args.causal_block_size)

        if args.teacache:
            if args.ar_step > 0:
                # 计算推理步骤数
                num_steps = (
                    args.inference_steps
                    + (((args.base_num_frames - 1) // 4 + 1) // args.causal_block_size - 1)
                    * args.ar_step
                )
                print("num_steps:", num_steps)
            else:
                num_steps = args.inference_steps
            # 初始化TEA缓存
            self.pipe.transformer.initialize_teacache(
                enable_teacache=True,
                num_steps=num_steps,
                teacache_thresh=args.teacache_thresh,
                use_ret_steps=args.use_ret_steps,
                ckpt_dir=args.model_id,
            )


# 提示加载类,用于加载提示信息
class PromptLoader:
    def __init__(self, args):
        self.args = args

    def load(self):
        # 加载提示文件,如果文件存在且为JSON格式,则解析JSON文件,否则返回默认提示
        if os.path.exists(self.args.prompt) and self.args.prompt.endswith(".json"):
            with open(self.args.prompt, "r", encoding="utf-8") as f:
                return json.load(f)
        return [{"prompt": self.args.prompt}]


# 提示增强包装类,用于增强提示信息
class PromptEnhancerWrapper:
    def __init__(self, args):
        self.args = args

    def enhance(self, prompt_input, image):
        if self.args.prompt_enhancer and image is None:
            print(f"init prompt enhancer")
            prompt_enhancer = PromptEnhancer()
            # 增强提示信息
            prompt_input = prompt_enhancer(prompt_input)
            print(f"enhanced prompt: {prompt_input}")
            del prompt_enhancer
            gc.collect()
            torch.cuda.empty_cache()
        return prompt_input


# 视频生成类,用于生成视频帧
class VideoGenerator:
    def __init__(self, pipe):
        self.pipe = pipe

    def generate(self, prompt_input, negative_prompt, image, height, width, num_frames,
                 num_inference_steps, shift, guidance_scale, seed, overlap_history,
                 addnoise_condition, base_num_frames, ar_step, causal_block_size, fps):
        with torch.cuda.amp.autocast(dtype=self.pipe.transformer.dtype), torch.no_grad():
            # 生成视频帧
            return self.pipe(
                prompt=prompt_input,
                negative_prompt=negative_prompt,
                image=image,
                height=height,
                width=width,
                num_frames=num_frames,
                num_inference_steps=num_inference_steps,
                shift=shift,
                guidance_scale=guidance_scale,
                generator=torch.Generator(device="cuda").manual_seed(seed),
                overlap_history=overlap_history,
                addnoise_condition=addnoise_condition,
                base_num_frames=base_num_frames,
                ar_step=ar_step,
                causal_block_size=causal_block_size,
                fps=fps,
            )[0]


# 视频保存类,用于保存生成的视频
class VideoSaver:
    def save(self, video_frames, save_dir, prompt_input, seed, fps):
        # 生成视频文件名
        current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
        video_out_file = f"{prompt_input[:100].replace('/','')}_{seed}_{current_time}.mp4"
        output_path = os.path.join(save_dir, video_out_file)
        # 保存视频
        imageio.mimwrite(
            output_path,
            video_frames,
            fps=fps,
            quality=8,
            output_params=["-loglevel", "error"],
        )


# 视频生成应用类,作为应用程序的入口点,协调各个类的工作


class VideoGenerationApp:
    def __init__(self):
        self.config_parser = ConfigParser()
        self.args = self.config_parser.parse()
        self.env_setup = EnvironmentSetup(self.args)
        self.pipeline_setup = PipelineSetup(self.args)
        self.prompt_loader = PromptLoader(self.args)
        self.prompt_enhancer = PromptEnhancerWrapper(self.args)
        self.video_generator = VideoGenerator(self.pipeline_setup.pipe)
        self.video_saver = VideoSaver()
        
        # 新增统计相关属性
        self.video_count = 0
        self.current_video = 0
        self.time_records = []
        self.total_time = 0.0

    def run(self):
        prompts = self.prompt_loader.load()
        self.video_count = len(prompts)
        self.current_video = 0
        self.time_records.clear()
        self.total_time = 0.0

        for prompt_info in prompts:
            self.current_video += 1
            start_time = time.perf_counter()  # 记录开始时间
            
            prompt_input = prompt_info["prompt"]
            image = None
            if "image_paths" in prompt_info:
                image_path = prompt_info["image_paths"][0]
                image = load_image(image_path)
                image_width, image_height = image.size
                if image_height > image_width:
                    self.env_setup.height, self.env_setup.width = self.env_setup.width, self.env_setup.height
                image = resizecrop(image, self.env_setup.height, self.env_setup.width)
                image = image.convert("RGB")

            prompt_input = self.prompt_enhancer.enhance(prompt_input, image)

            print(f"\n=== Video {self.current_video}/{self.video_count} ===")
            print(f"Prompt: {prompt_input[:100]}...")
            print(f"Guidance Scale: {self.env_setup.args.guidance_scale}")

            video_frames = self.video_generator.generate(
                prompt_input,
                self.env_setup.negative_prompt,
                image,
                self.env_setup.height,
                self.env_setup.width,
                self.env_setup.args.num_frames,
                self.env_setup.args.inference_steps,
                self.env_setup.args.shift,
                self.env_setup.args.guidance_scale,
                self.env_setup.args.seed,
                self.env_setup.args.overlap_history,
                self.env_setup.args.addnoise_condition,
                self.env_setup.args.base_num_frames,
                self.env_setup.args.ar_step,
                self.env_setup.args.causal_block_size,
                self.env_setup.args.fps,
            )

            # 计算本次推理时间
            duration = time.perf_counter() - start_time
            self.time_records.append(duration)
            self.total_time += duration

            print(f"Generation completed in {duration:.2f} seconds")
            
            if self.env_setup.local_rank == 0:
                self.video_saver.save(video_frames, self.env_setup.save_dir, prompt_input, self.env_setup.args.seed,
                                      self.env_setup.args.fps)
                print(f"Video saved to {self.env_setup.save_dir}")

        # 新增统计结果汇总
        if self.env_setup.local_rank == 0 and self.video_count > 0:
            avg_time = self.total_time / self.video_count
            max_time = max(self.time_records) if self.time_records else 0
            
            print("\n=== Generation Statistics ===")
            print(f"Total Videos: {self.video_count}")
            print(f"Total Time: {self.total_time:.2f} seconds")
            print(f"Average Time per Video: {avg_time:.2f} seconds")
            print(f"Max Time per Video: {max_time:.2f} seconds")


if __name__ == "__main__":
    app = VideoGenerationApp()
    app.run()

使用说明

1. 环境准备
  • 模型路径:将--model_id参数指向正确的模型目录(默认路径为示例路径,需根据实际情况修改)。
  • 提示文件:准备提示文件(默认prompt.json),格式为JSON,支持多提示输入:
    [
        {"prompt": "your first prompt", "image_paths": ["image1.jpg"]},
        {"prompt": "your second prompt", "image_paths": ["image2.jpg"]}
    ]
    
2. 关键参数说明
参数名 功能描述
--outdir 输出目录,视频将保存在result/{outdir}下。
--resolution 视频分辨率,支持540P(544x960)和720P(720x1280)。
--num_frames 生成视频的总帧数(长视频需配合--overlap_history参数)。
--prompt 提示文件路径(JSON格式)或直接输入提示词(非JSON时默认使用单提示)。
--seed 随机种子(固定种子可复现结果,--use_usp模式下必须设置)。
--guidance_scale 生成质量控制参数(值越大越贴近提示,建议6.0-8.0)。
--image 初始图像路径(可选,用于图像生成视频)。
--use_usp 启用分布式模式(需多GPU支持,需提前初始化分布式环境)。
3. 运行命令
python script_name.py [参数列表]
  • 示例:生成一个720P、97帧、使用默认提示的视频:
    python video_generation_refactored.py --resolution 720P --num_frames 97
    
    使用默认
  python video_generation_refactored.py  --prompt  prompt.json

思路

1. 模块化设计(单一职责原则)

将复杂功能拆解为独立类,每个类专注于单一职责:

  • ConfigParser:解析命令行参数,统一管理输入配置。
  • EnvironmentSetup:验证参数合法性、设置运行环境(分辨率、保存目录、分布式配置等)。
  • PipelineSetup:初始化模型管道(DiffusionForcingPipeline),配置推理参数(因果注意力、TEA缓存等)。
  • PromptLoader/Enhancer:加载提示文件并按需增强提示词(提升生成效果)。
  • VideoGenerator/Saver:分离视频生成和保存逻辑,解耦核心功能与IO操作。
2. 单例模式(资源优化)
  • PipelineSetup使用单例模式:确保模型管道全局唯一,避免重复加载模型浪费内存,提升效率。
  • 适用场景:模型体积大、初始化耗时,单例模式保证内存中仅存在一个实例。
3. 分布式支持(扩展性)
  • --use_usp参数:支持多GPU分布式推理,通过xfuser库初始化分布式环境,提升大规模生成效率。
  • 约束机制:分布式模式下禁止使用提示增强器(--prompt_enhancer),确保逻辑一致性。
4. 鲁棒性设计(参数验证与异常处理)
  • 强参数校验
    • USP模式强制要求种子(--seed),避免随机初始化导致的分布式不一致。
    • 长视频生成(--num_frames > base_num_frames)强制要求--overlap_history,确保时序连贯性。
    • 分辨率严格限制为540P/720P,避免无效输入。
  • 内存管理:提示增强器使用后手动释放资源(gc.collect()),避免内存泄漏。
5. 流程解耦与协作
  • VideoGenerationApp作为协调者:串联各模块,按“加载配置→初始化环境→生成视频→保存结果”的流程执行。
  • 依赖注入:通过类构造函数传递依赖(如管道对象、配置参数),降低模块间耦合度。
6. 输出与可追溯性
  • 自动生成文件名:包含提示词(前100字)、种子、时间戳,便于区分不同生成任务。
  • 保存目录结构:统一输出到result/{outdir},支持断点续传(exist_ok=True)。

网站公告

今日签到

点亮在社区的每一天
去签到