scheduler_class = globals()[scheduler_class_name]
Python 中一种 动态获取类对象 的常用技巧,属于 反射(reflection) 编程的范畴
globals() |
Python 内置函数,返回一个 字典(dict),包含当前模块(文件)中所有全局变量、函数、类、导入的模块等的名称和对象。 |
[scheduler_class_name] |
从 globals() 返回的字典中,以字符串 scheduler_class_name 为键,查找对应的值。 |
scheduler_class_name |
一个字符串变量,比如 "DDPMScheduler" 或 "DPMSolverMultistepScheduler" 。 |
用法示例
from diffusers import DDPMScheduler, DDIMScheduler, DPMSolverMultistepScheduler
# 假设这是从配置文件读取的类名
scheduler_class_name = "DPMSolverMultistepScheduler"
# 动态获取类
scheduler_class = globals()[scheduler_class_name]
# 现在 scheduler_class 就是 DPMSolverMultistepScheduler 这个类
print(scheduler_class) # <class 'diffusers.DPMSolverMultistepScheduler'>
# 可以用来创建实例
scheduler = scheduler_class.from_config(config)
如果不用global():
if scheduler_class_name == "DDPMScheduler":
scheduler_class = DDPMScheduler
elif scheduler_class_name == "DDIMScheduler":
scheduler_class = DDIMScheduler
elif scheduler_class_name == "DPMSolverMultistepScheduler":
scheduler_class = DPMSolverMultistepScheduler
else:
raise ValueError(f"Unknown scheduler: {scheduler_class_name}")
这样写冗长、难维护,每新增一个调度器就要改代码