SkyReels-V2 视频生成
flyfish
扩散强制(DF)模型:专为无限长度视频生成设计,提供1.3B-540P和14B-720P等版本
文本到视频(T2V)模型:专注于从文本提示生成高质量视频
图像到视频(I2V)模型:能够从输入图像生成连贯的视频序列
import os
# 设置TOKENIZERS_PARALLELISM为false,避免分词器并行化可能带来的问题
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import argparse
import gc
import os
import random
import time
import json
import imageio
import torch
from diffusers.utils import load_image
from skyreels_v2_infer import DiffusionForcingPipeline
from skyreels_v2_infer.pipelines import PromptEnhancer
from skyreels_v2_infer.pipelines import resizecrop
# 单例模式元类,确保类只有一个实例
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls]
# 配置解析类,用于解析命令行参数
class ConfigParser:
def __init__(self):
self.parser = argparse.ArgumentParser()
self._add_arguments()
def _add_arguments(self):
# 输出目录
self.parser.add_argument("--outdir", type=str, default="diffusion_forcing")
# 模型ID
self.parser.add_argument(
"--model_id",
type=str,
default="/media/models/Skywork/SkyReels-V2-DF-1___3B-540P/",
)
# 分辨率
self.parser.add_argument(
"--resolution", type=str, default="540P", choices=["540P", "720P"]
)
# 帧数
self.parser.add_argument("--num_frames", type=int, default=97)
# 图像路径
self.parser.add_argument("--image", type=str, default=None)
# AR步骤
self.parser.add_argument("--ar_step", type=int, default=0)
# 是否使用因果注意力
self.parser.add_argument("--causal_attention", action="store_true")
# 因果块大小
self.parser.add_argument("--causal_block_size", type=int, default=1)
# 基础帧数
self.parser.add_argument("--base_num_frames", type=int, default=97)
# 重叠历史
self.parser.add_argument("--overlap_history", type=int, default=None)
# 添加噪声条件
self.parser.add_argument("--addnoise_condition", type=int, default=0)
# 引导比例
self.parser.add_argument("--guidance_scale", type=float, default=6.0)
# 偏移量
self.parser.add_argument("--shift", type=float, default=8.0)
# 推理步骤
self.parser.add_argument("--inference_steps", type=int, default=30) # 30
# 是否使用USP
self.parser.add_argument("--use_usp", action="store_true")
# 是否卸载
self.parser.add_argument("--offload", action="store_true")
# 帧率
self.parser.add_argument("--fps", type=int, default=24)
# 随机种子
self.parser.add_argument("--seed", type=int, default=None)
# 提示文件
self.parser.add_argument("--prompt", type=str, default="prompt.json")
# 是否使用提示增强器
self.parser.add_argument("--prompt_enhancer", action="store_true")
# 是否使用TEA缓存
self.parser.add_argument("--teacache", action="store_true")
# TEA缓存阈值
self.parser.add_argument(
"--teacache_thresh",
type=float,
default=0.2,
help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup",
)
# 是否使用保留步骤
self.parser.add_argument(
"--use_ret_steps",
action="store_true",
help="Using Retention Steps will result in faster generation speed and better generation quality.",
)
def parse(self):
return self.parser.parse_args()
# 环境设置类,用于设置运行环境
class EnvironmentSetup:
def __init__(self, args):
self.args = args
self._validate_seed()
self._set_resolution()
self._validate_num_frames()
self._validate_addnoise_condition()
# 负提示词
self.negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
self._create_save_dir()
self._setup_usp()
def _validate_seed(self):
# 验证种子是否有效,USP模式需要种子
assert (self.args.use_usp and self.args.seed is not None) or (
not self.args.use_usp
), "usp mode need seed"
if self.args.seed is None:
random.seed(time.time())
self.args.seed = int(random.randrange(4294967294))
def _set_resolution(self):
# 根据分辨率参数设置高度和宽度
if self.args.resolution == "540P":
self.height = 544
self.width = 960
elif self.args.resolution == "720P":
self.height = 720
self.width = 1280
else:
raise ValueError(f"Invalid resolution: {self.args.resolution}")
def _validate_num_frames(self):
# 验证帧数是否有效,长视频生成需要指定重叠历史
if self.args.num_frames > self.args.base_num_frames:
assert (
self.args.overlap_history is not None
), 'You are supposed to specify the "overlap_history" to support the long video generation. 17 and 37 are recommanded to set.'
def _validate_addnoise_condition(self):
# 验证添加噪声条件是否有效,值过大可能导致长视频生成不一致
if self.args.addnoise_condition > 60:
print(
f'You have set "addnoise_condition" as {self.args.addnoise_condition}. The value is too large which can cause inconsistency in long video generation. The value is recommanded to set 20.'
)
def _create_save_dir(self):
# 创建保存目录
self.save_dir = os.path.join("result", self.args.outdir)
os.makedirs(self.save_dir, exist_ok=True)
def _setup_usp(self):
self.local_rank = 0
if self.args.use_usp:
# USP模式下不允许使用提示增强器
assert (
not self.args.prompt_enhancer
), "`--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter."
from xfuser.core.distributed import (
initialize_model_parallel,
init_distributed_environment,
)
import torch.distributed as dist
# 初始化分布式环境
dist.init_process_group("nccl")
self.local_rank = dist.get_rank()
torch.cuda.set_device(dist.get_rank())
self.device = "cuda"
init_distributed_environment(
rank=dist.get_rank(), world_size=dist.get_world_size()
)
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=1,
ulysses_degree=dist.get_world_size(),
)
# 管道设置类,用于创建和配置DiffusionForcingPipeline
class PipelineSetup(metaclass=Singleton):
def __init__(self, args):
self.pipe = DiffusionForcingPipeline(
args.model_id,
dit_path=args.model_id,
device=torch.device("cuda"),
weight_dtype=torch.bfloat16,
use_usp=args.use_usp,
offload=args.offload,
)
if args.causal_attention:
# 设置因果注意力
self.pipe.transformer.set_ar_attention(args.causal_block_size)
if args.teacache:
if args.ar_step > 0:
# 计算推理步骤数
num_steps = (
args.inference_steps
+ (((args.base_num_frames - 1) // 4 + 1) // args.causal_block_size - 1)
* args.ar_step
)
print("num_steps:", num_steps)
else:
num_steps = args.inference_steps
# 初始化TEA缓存
self.pipe.transformer.initialize_teacache(
enable_teacache=True,
num_steps=num_steps,
teacache_thresh=args.teacache_thresh,
use_ret_steps=args.use_ret_steps,
ckpt_dir=args.model_id,
)
# 提示加载类,用于加载提示信息
class PromptLoader:
def __init__(self, args):
self.args = args
def load(self):
# 加载提示文件,如果文件存在且为JSON格式,则解析JSON文件,否则返回默认提示
if os.path.exists(self.args.prompt) and self.args.prompt.endswith(".json"):
with open(self.args.prompt, "r", encoding="utf-8") as f:
return json.load(f)
return [{"prompt": self.args.prompt}]
# 提示增强包装类,用于增强提示信息
class PromptEnhancerWrapper:
def __init__(self, args):
self.args = args
def enhance(self, prompt_input, image):
if self.args.prompt_enhancer and image is None:
print(f"init prompt enhancer")
prompt_enhancer = PromptEnhancer()
# 增强提示信息
prompt_input = prompt_enhancer(prompt_input)
print(f"enhanced prompt: {prompt_input}")
del prompt_enhancer
gc.collect()
torch.cuda.empty_cache()
return prompt_input
# 视频生成类,用于生成视频帧
class VideoGenerator:
def __init__(self, pipe):
self.pipe = pipe
def generate(self, prompt_input, negative_prompt, image, height, width, num_frames,
num_inference_steps, shift, guidance_scale, seed, overlap_history,
addnoise_condition, base_num_frames, ar_step, causal_block_size, fps):
with torch.cuda.amp.autocast(dtype=self.pipe.transformer.dtype), torch.no_grad():
# 生成视频帧
return self.pipe(
prompt=prompt_input,
negative_prompt=negative_prompt,
image=image,
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=num_inference_steps,
shift=shift,
guidance_scale=guidance_scale,
generator=torch.Generator(device="cuda").manual_seed(seed),
overlap_history=overlap_history,
addnoise_condition=addnoise_condition,
base_num_frames=base_num_frames,
ar_step=ar_step,
causal_block_size=causal_block_size,
fps=fps,
)[0]
# 视频保存类,用于保存生成的视频
class VideoSaver:
def save(self, video_frames, save_dir, prompt_input, seed, fps):
# 生成视频文件名
current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
video_out_file = f"{prompt_input[:100].replace('/','')}_{seed}_{current_time}.mp4"
output_path = os.path.join(save_dir, video_out_file)
# 保存视频
imageio.mimwrite(
output_path,
video_frames,
fps=fps,
quality=8,
output_params=["-loglevel", "error"],
)
# 视频生成应用类,作为应用程序的入口点,协调各个类的工作
class VideoGenerationApp:
def __init__(self):
self.config_parser = ConfigParser()
self.args = self.config_parser.parse()
self.env_setup = EnvironmentSetup(self.args)
self.pipeline_setup = PipelineSetup(self.args)
self.prompt_loader = PromptLoader(self.args)
self.prompt_enhancer = PromptEnhancerWrapper(self.args)
self.video_generator = VideoGenerator(self.pipeline_setup.pipe)
self.video_saver = VideoSaver()
# 新增统计相关属性
self.video_count = 0
self.current_video = 0
self.time_records = []
self.total_time = 0.0
def run(self):
prompts = self.prompt_loader.load()
self.video_count = len(prompts)
self.current_video = 0
self.time_records.clear()
self.total_time = 0.0
for prompt_info in prompts:
self.current_video += 1
start_time = time.perf_counter() # 记录开始时间
prompt_input = prompt_info["prompt"]
image = None
if "image_paths" in prompt_info:
image_path = prompt_info["image_paths"][0]
image = load_image(image_path)
image_width, image_height = image.size
if image_height > image_width:
self.env_setup.height, self.env_setup.width = self.env_setup.width, self.env_setup.height
image = resizecrop(image, self.env_setup.height, self.env_setup.width)
image = image.convert("RGB")
prompt_input = self.prompt_enhancer.enhance(prompt_input, image)
print(f"\n=== Video {self.current_video}/{self.video_count} ===")
print(f"Prompt: {prompt_input[:100]}...")
print(f"Guidance Scale: {self.env_setup.args.guidance_scale}")
video_frames = self.video_generator.generate(
prompt_input,
self.env_setup.negative_prompt,
image,
self.env_setup.height,
self.env_setup.width,
self.env_setup.args.num_frames,
self.env_setup.args.inference_steps,
self.env_setup.args.shift,
self.env_setup.args.guidance_scale,
self.env_setup.args.seed,
self.env_setup.args.overlap_history,
self.env_setup.args.addnoise_condition,
self.env_setup.args.base_num_frames,
self.env_setup.args.ar_step,
self.env_setup.args.causal_block_size,
self.env_setup.args.fps,
)
# 计算本次推理时间
duration = time.perf_counter() - start_time
self.time_records.append(duration)
self.total_time += duration
print(f"Generation completed in {duration:.2f} seconds")
if self.env_setup.local_rank == 0:
self.video_saver.save(video_frames, self.env_setup.save_dir, prompt_input, self.env_setup.args.seed,
self.env_setup.args.fps)
print(f"Video saved to {self.env_setup.save_dir}")
# 新增统计结果汇总
if self.env_setup.local_rank == 0 and self.video_count > 0:
avg_time = self.total_time / self.video_count
max_time = max(self.time_records) if self.time_records else 0
print("\n=== Generation Statistics ===")
print(f"Total Videos: {self.video_count}")
print(f"Total Time: {self.total_time:.2f} seconds")
print(f"Average Time per Video: {avg_time:.2f} seconds")
print(f"Max Time per Video: {max_time:.2f} seconds")
if __name__ == "__main__":
app = VideoGenerationApp()
app.run()
使用说明
1. 环境准备
- 模型路径:将
--model_id
参数指向正确的模型目录(默认路径为示例路径,需根据实际情况修改)。 - 提示文件:准备提示文件(默认
prompt.json
),格式为JSON,支持多提示输入:[ {"prompt": "your first prompt", "image_paths": ["image1.jpg"]}, {"prompt": "your second prompt", "image_paths": ["image2.jpg"]} ]
2. 关键参数说明
参数名 | 功能描述 |
---|---|
--outdir |
输出目录,视频将保存在result/{outdir} 下。 |
--resolution |
视频分辨率,支持540P (544x960)和720P (720x1280)。 |
--num_frames |
生成视频的总帧数(长视频需配合--overlap_history 参数)。 |
--prompt |
提示文件路径(JSON格式)或直接输入提示词(非JSON时默认使用单提示)。 |
--seed |
随机种子(固定种子可复现结果,--use_usp 模式下必须设置)。 |
--guidance_scale |
生成质量控制参数(值越大越贴近提示,建议6.0-8.0)。 |
--image |
初始图像路径(可选,用于图像生成视频)。 |
--use_usp |
启用分布式模式(需多GPU支持,需提前初始化分布式环境)。 |
3. 运行命令
python script_name.py [参数列表]
- 示例:生成一个720P、97帧、使用默认提示的视频:
使用默认python video_generation_refactored.py --resolution 720P --num_frames 97
python video_generation_refactored.py --prompt prompt.json
思路
1. 模块化设计(单一职责原则)
将复杂功能拆解为独立类,每个类专注于单一职责:
ConfigParser
:解析命令行参数,统一管理输入配置。EnvironmentSetup
:验证参数合法性、设置运行环境(分辨率、保存目录、分布式配置等)。PipelineSetup
:初始化模型管道(DiffusionForcingPipeline
),配置推理参数(因果注意力、TEA缓存等)。PromptLoader/Enhancer
:加载提示文件并按需增强提示词(提升生成效果)。VideoGenerator/Saver
:分离视频生成和保存逻辑,解耦核心功能与IO操作。
2. 单例模式(资源优化)
PipelineSetup
使用单例模式:确保模型管道全局唯一,避免重复加载模型浪费内存,提升效率。- 适用场景:模型体积大、初始化耗时,单例模式保证内存中仅存在一个实例。
3. 分布式支持(扩展性)
--use_usp
参数:支持多GPU分布式推理,通过xfuser
库初始化分布式环境,提升大规模生成效率。- 约束机制:分布式模式下禁止使用提示增强器(
--prompt_enhancer
),确保逻辑一致性。
4. 鲁棒性设计(参数验证与异常处理)
- 强参数校验:
- USP模式强制要求种子(
--seed
),避免随机初始化导致的分布式不一致。 - 长视频生成(
--num_frames > base_num_frames
)强制要求--overlap_history
,确保时序连贯性。 - 分辨率严格限制为
540P
/720P
,避免无效输入。
- USP模式强制要求种子(
- 内存管理:提示增强器使用后手动释放资源(
gc.collect()
),避免内存泄漏。
5. 流程解耦与协作
VideoGenerationApp
作为协调者:串联各模块,按“加载配置→初始化环境→生成视频→保存结果”的流程执行。- 依赖注入:通过类构造函数传递依赖(如管道对象、配置参数),降低模块间耦合度。
6. 输出与可追溯性
- 自动生成文件名:包含提示词(前100字)、种子、时间戳,便于区分不同生成任务。
- 保存目录结构:统一输出到
result/{outdir}
,支持断点续传(exist_ok=True
)。