文章目录
torch.hub
PyTorch Hub 是一个预训练模型仓库,旨在促进研究可复现性。
发布模型
PyTorch Hub 支持通过添加简单的 hubconf.py
文件,将预训练模型(模型定义和预训练权重)发布到 GitHub 仓库。
hubconf.py
可以包含多个入口点。每个入口点都定义为 Python 函数(例如:您想发布的预训练模型)。
def entrypoint_name(args, *kwargs):
# args & kwargs are optional, for models which take positional/keyword arguments.
...
如何实现入口点?
以下代码片段展示了如果我们扩展 pytorch/vision/hubconf.py
中的实现,如何为 resnet18
模型指定入口点。在大多数情况下,只需在 hubconf.py
中导入正确的函数就足够了。这里我们使用扩展版本作为示例来说明其工作原理。
完整脚本可查看 pytorch/vision 代码库
dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18
# resnet18 is the name of entrypoint
def resnet18(pretrained=False, *kwargs):
""" # This docstring shows up in hub.help()
Resnet18 model
pretrained (bool): kwargs, load pretrained weights into the model
"""
# Call the model, load pretrained weights
model = _resnet18(pretrained=pretrained, *kwargs)
return model
dependencies
变量是一个列表,包含加载模型所需的包名。注意这里可能与训练模型所需的依赖项略有不同。args
和kwargs
会传递给实际的可调用函数。- 函数的文档字符串(docstring)将作为帮助信息。它需要说明模型的功能以及允许的位置参数/关键字参数。强烈建议在此处添加几个示例。
- 入口函数可以返回一个模型(nn.Module),也可以返回辅助工具(如分词器)来优化用户工作流程。
- 以下划线开头的可调用对象被视为辅助函数,不会出现在
torch.hub.list()
的返回结果中。 - 预训练权重可以存储在GitHub仓库本地,也可以通过
torch.hub.load_state_dict_from_url()
加载。如果小于2GB,建议将其附加到项目发布版并使用发布版的URL。
在上面的示例中,torchvision.models.resnet.resnet18
处理了 pretrained
参数,你也可以将以下逻辑放在入口函数定义中。
# For checkpoint saved in local GitHub repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
dirname = os.path.dirname(__file__)
checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
state_dict = torch.load(checkpoint)
model.load_state_dict(state_dict)
# For checkpoint saved elsewhere
checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
重要通知
- 发布的模型至少应位于分支/标签中,不能是随机提交。
从Hub加载模型
PyTorch Hub 提供了一系列便捷的API,帮助开发者探索Hub中所有可用模型:
- 通过
torch.hub.list()
查看所有模型 - 使用
torch.hub.help()
显示文档说明和示例 - 调用
torch.hub.load()
加载预训练模型
torch.hub.list(github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True)
列出由 github
指定的代码仓库中所有可调用的入口点。
参数
github (str)
– 格式为“repo_owner/repo_name[:ref]”的字符串,其中 ref(标签或分支)为可选。如果未指定ref
,则默认分支为main
(如果存在),否则为master
。
示例:‘pytorch/vision:0.10’force_reload ([bool], 可选)
– 是否丢弃现有缓存并强制重新下载。默认为False
。skip_validation ([bool], 可选)
– 如果为False
,torchhub 会检查github
参数指定的分支或提交是否确实属于该仓库所有者。此操作会向 GitHub API 发起请求;可通过设置GITHUB_TOKEN
环境变量指定非默认的 GitHub 令牌。默认为False
。trust_repo ([bool],* str 或 *None)
–
"check"
、True
、False
或None
。
此参数在 v1.12 版本引入,用于确保用户仅运行信任的仓库代码。
- 如果为
False
,会提示用户确认是否信任该仓库。 - 如果为
True
,仓库将被添加到信任列表并直接加载,无需明确确认。 - 如果为
"check"
,会检查该仓库是否在缓存的信任列表中。若不存在,则回退到trust_repo=False
的行为。 - 如果为
None
:会发出警告,提示用户将trust_repo
设为False
、True
或"check"
。此选项仅为向后兼容保留,将在 v2.0 版本移除。默认为None
,未来 v2.0 版本将改为默认"check"
。
verbose ([bool], 可选)
– 如果为False
,则屏蔽关于命中本地缓存的消息。注意首次下载的消息无法屏蔽。默认为True
。
返回
可用的可调用入口点列表
返回类型
list
示例
>>> entrypoints = torch.hub.list("pytorch/vision", force_reload=True)
torch.hub.help(github, model, force_reload=False, skip_validation=False, trust_repo=None)
显示入口点 model
的文档字符串。
参数
github (str)
– 格式为 <repo_owner/repo_name[:ref]> 的字符串,其中 ref(标签或分支)是可选的。如果未指定ref
,则默认分支为main
(如果存在),否则为master
。
示例:‘pytorch/vision:0.10’
model (str)
– 仓库hubconf.py
中定义的入口点名称字符串force_reload ([bool], 可选)
– 是否丢弃现有缓存并强制重新下载。默认为False
。skip_validation ([bool], 可选)
– 如果为False
,torchhub 将检查github
参数指定的 ref 是否确实属于该仓库所有者。这将向 GitHub API 发出请求;您可以通过设置GITHUB_TOKEN
环境变量来指定非默认的 GitHub 令牌。默认为False
。trust_repo ([bool], *str 或 *None)
–
"check"
、True
、False
或 None
。
此参数在 v1.12 版本引入,用于确保用户仅运行来自受信任仓库的代码。
如果为
False
,将提示用户确认是否信任该仓库。如果为
True
,该仓库将被添加到受信任列表并直接加载,无需明确确认。如果为
"check"
,将检查该仓库是否在缓存的受信任仓库列表中。如果不在列表中,行为将回退到trust_repo=False
选项。如果为
None
:将发出警告,提示用户将trust_repo
设置为False
、True
或"check"
。此选项仅用于向后兼容,将在 v2.0 版本移除。默认为None
,最终将在 v2.0 版本更改为"check"
。
示例
>>> print(torch.hub.help("pytorch/vision", "resnet18", force_reload=True))
torch.hub.load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True, skip_validation=False, **kwargs)
从 GitHub 仓库或本地目录加载模型。
注意:加载模型是典型用例,但此功能也可用于加载其他对象,如分词器、损失函数等。
如果 source
为 ‘github’,则 repo_or_dir
应为 repo_owner/repo_name[:ref]
格式,其中 ref(标签或分支)为可选项。
如果 source
为 ‘local’,则 repo_or_dir
应为本地目录路径。
参数
repo_or_dir (str)
– 如果source
为 ‘github’,则应为 GitHub 仓库,格式为repo_owner/repo_name[:ref]
(ref 为可选的标签或分支),例如 ‘pytorch/vision:0.10’。如果未指定ref
,则默认分支为main
(如果存在),否则为master
。
如果 source
为 ‘local’,则应为本地目录路径。
model (str)
– 仓库/目录中hubconf.py
文件定义的可调用对象(入口点)名称。*args (可选)
– 可调用对象model
的对应参数。source (str, 可选)
– ‘github’ 或 ‘local’。指定如何解释repo_or_dir
。默认为 ‘github’。trust_repo ([bool],* str 或 *None)
–
"check"
、True
、False
或 None
。
此参数在 v1.12 中引入,用于确保用户仅运行信任仓库中的代码。
如果为
False
,将提示用户确认是否信任该仓库。如果为
True
,该仓库将被添加到信任列表并直接加载,无需明确确认。如果为
"check"
,将检查该仓库是否在缓存信任列表中。如果不在,则回退到trust_repo=False
的行为。如果为
None
:将发出警告,提示用户将trust_repo
设置为False
、True
或"check"
。此选项仅用于向后兼容,将在 v2.0 中移除。默认为None
,最终将在 v2.0 中改为"check"
。
force_reload ([bool], 可选)
– 是否无条件强制重新下载 GitHub 仓库。如果source = 'local'
则无效。默认为False
。verbose ([bool], 可选)
– 如果为False
,则屏蔽有关命中本地缓存的消息。注意首次下载的消息无法屏蔽。如果source = 'local'
则无效。默认为True
。skip_validation ([bool], 可选)
– 如果为False
,torchhub 将检查github
参数指定的分支或提交是否属于该仓库所有者。这将向 GitHub API 发出请求;您可通过设置GITHUB_TOKEN
环境变量指定非默认的 GitHub 令牌。默认为False
。**kwargs (可选)
– 可调用对象model
的对应关键字参数。
返回
调用 model
可调用对象时,传入给定 *args
和 **kwargs
的输出。
示例
>>> # from a github repo
>>> repo = "pytorch/vision"
>>> model = torch.hub.load(
... repo, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1"
... )
>>> # from a local directory
>>> path = "/some/local/path/pytorch/vision"
>>> model = torch.hub.load(path, "resnet50", weights="ResNet50_Weights.DEFAULT")
torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)
将给定URL的对象下载到本地路径。
参数
url (str)
- 要下载对象的URLdst (str)
- 对象将被保存的完整路径,例如/tmp/temporary_file
hash_prefix (str, 可选)
- 如果不为None,下载文件的SHA256哈希值应以hash_prefix
开头。默认值:Noneprogress ([bool], 可选)
- 是否在标准错误输出中显示进度条。默认值:True
示例
>>> torch.hub.download_url_to_file(
... "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth",
... "/tmp/temporary_file",
... )
torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None, weights_only=False)
从给定的URL加载Torch序列化对象。
如果下载的文件是zip压缩包,系统会自动解压。
如果对象已存在于model_dir目录中,则直接反序列化并返回。
model_dir
的默认值为<hub_dir>/checkpoints
,其中hub_dir
是由get_dir()
返回的目录路径。
参数说明
url (str)
- 需要下载对象的URL地址model_dir (str, 可选)
- 保存对象的目录路径map_location (可选)
- 指定存储位置重映射的函数或字典(参见torch.load)progress ([bool], 可选)
- 是否在标准错误输出中显示进度条。默认值:Truecheck_hash ([bool], 可选)
- 若为True,则URL中的文件名部分需遵循命名规范:
filename-<sha256>.ext
,其中<sha256>
是文件内容SHA256哈希值的前8位或更多位数字。该哈希值用于确保唯一文件名并验证文件内容。默认值:Falsefile_name (str, 可选)
- 下载文件的名称。若未设置,则使用URL中的文件名weights_only ([bool], 可选)
- 若为True,则仅加载权重而不加载复杂的pickle对象。建议用于不可信来源。详见load()
说明
返回类型:dict[str, Any]
使用示例:
>>> state_dict = torch.hub.load_state_dict_from_url(
... "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth"
... )
运行加载的模型:
请注意,torch.hub.load()
中的 *args
和 **kwargs
用于实例化模型。加载模型后,如何了解该模型的功能?
建议的工作流程如下:
- 使用
dir(model)
查看模型所有可用的方法 - 通过
help(model.foo)
查看model.foo
运行所需的参数
为了帮助用户无需反复查阅文档即可探索功能,我们强烈建议仓库维护者确保函数帮助信息清晰简洁。同时,提供一个最小可运行示例也非常有帮助。
下载的模型保存在哪里?
模型保存路径按以下顺序确定:
1、调用 hub.set_dir(<PATH_TO_HUB_DIR>)
设置的路径
2、若设置了环境变量 TORCH_HOME
,则使用 $TORCH_HOME/hub
3、若设置了环境变量 XDG_CACHE_HOME
,则使用 $XDG_CACHE_HOME/torch/hub
4、默认路径为 ~/.cache/torch/hub
可通过 torch.hub.get_dir()
获取当前保存路径
***
Get the Torch Hub cache directory used for storing downloaded models \& weights.
If [`set_dir()`](https://pytorch.org/docs/stable/data.html#torch.hub.set_dir "torch.hub.set_dir") is not called, default path is `$TORCH_HOME/hub` where
environment variable `$TORCH_HOME` defaults to `$XDG_CACHE_HOME/torch`.
`$XDG_CACHE_HOME` follows the X Design Group specification of the Linux
filesystem layout, with a default value `~/.cache` if the environment
variable is not set.
Return type
str
torch.hub.set_dir(d)
可选设置用于保存下载模型和权重的 Torch Hub 目录。
参数
d (str)
– 用于保存下载模型和权重的本地文件夹路径。
缓存逻辑
默认情况下,我们在加载文件后不会进行清理。如果文件已存在于 get_dir()
返回的目录中,Hub 会默认使用缓存。
用户可以通过调用 hub.load(..., force_reload=True)
强制重新加载。这将删除现有的 GitHub 文件夹和下载的权重文件,并重新初始化全新下载。当同一分支发布更新时,这个功能非常有用,用户可以及时获取最新版本。
已知限制:
Torch hub 的工作原理是将包当作已安装的包进行导入。Python 导入机制会带来一些副作用,例如你可能会在 Python 缓存 sys.modules
和 sys.path_importer_cache
中看到新增条目,这是 Python 的正常行为。这也意味着,如果不同代码仓库包含同名的子包(通常是名为 model
的子包),在从不同仓库导入不同模型时可能会遇到导入错误。针对这类导入错误的解决方案是从 sys.modules
字典中移除冲突的子包,更多细节可参考 这个 GitHub issue。
需要特别说明的一个已知限制:用户无法在同一个 Python 进程中加载同一代码仓库的两个不同分支。这就像在 Python 中安装两个同名的包一样,是不合理的做法。如果强行尝试,缓存机制可能会介入并带来意外结果。当然,在独立的进程中分别加载它们是完全可行的。
TorchScript
TorchScript 是一种将 PyTorch 代码转换为可序列化和可优化模型的方法。任何 TorchScript 程序都可以从 Python 进程中保存,并在不依赖 Python 的环境中加载运行。
我们提供了一系列工具,帮助开发者逐步将纯 Python 模型转换为可独立于 Python 运行的 TorchScript 程序,例如在独立的 C++ 程序中运行。这使得开发者能够继续使用熟悉的 Python 工具训练 PyTorch 模型,然后通过 TorchScript 将模型导出到生产环境。在生产环境中,由于性能和多线程方面的考虑,使用 Python 程序可能并不合适。
如需了解 TorchScript 的入门指南,请参阅 TorchScript 简介教程。
若想查看将 PyTorch 模型转换为 TorchScript 并在 C++ 中运行的完整示例,请参考 在 C++ 中加载 PyTorch 模型 教程。
创建 TorchScript 代码
script |
将函数转换为脚本 |
---|---|
trace |
追踪函数并返回一个可执行对象或ScriptFunction ,该对象将通过即时编译进行优化 |
script_if_tracing |
在追踪过程中首次调用时编译fn |
trace_module |
追踪模块并返回一个可执行的ScriptModule ,该模块将通过即时编译进行优化 |
fork |
创建异步任务执行函数,并返回对该执行结果的引用 |
wait |
强制完成torch.jit.Future[T]异步任务,返回任务结果 |
ScriptModule |
C++ torch::jit::Module的包装器,包含方法、属性和参数 |
ScriptFunction |
功能上等同于ScriptModule ,但表示单个函数且不包含任何属性或参数 |
freeze |
冻结ScriptModule,将子模块和属性内联为常量 |
optimize_for_inference |
执行一系列优化步骤,为推理目的优化模型 |
enable_onednn_fusion |
根据参数enabled启用或禁用onednn JIT融合 |
onednn_fusion_enabled |
返回onednn JIT融合是否启用 |
set_fusion_strategy |
设置融合过程中可能发生的特化类型和数量 |
strict_fusion |
如果推理中未融合所有节点或训练中未符号微分,则报错 |
save |
保存此模块的离线版本以供其他进程使用 |
load |
加载先前用torch.jit.save 保存的ScriptModule 或ScriptFunction |
ignore |
此装饰器向编译器表明应忽略函数或方法,保留为Python函数 |
unused |
此装饰器向编译器表明应忽略函数或方法,并替换为抛出异常 |
interface |
用于注解不同类型类或模块的装饰器 |
isinstance |
在TorchScript中提供容器类型细化 |
Attribute |
此方法是一个返回值的直通函数,主要用于向TorchScript编译器表明左侧表达式是具有type类型的类实例属性 |
annotate |
用于在TorchScript编译器中指定the_value的类型 |
混合使用追踪与脚本化
在多数情况下,追踪(tracing)或脚本化(scripting)都是将模型转换为 TorchScript 的更简便方式。根据模型不同部分的具体需求,可以组合使用这两种方法。
脚本化函数能够调用追踪生成的函数。当需要在简单前馈模型周围添加控制流逻辑时,这种方式特别有用。例如,序列到序列模型中的束搜索(beam search)通常会用脚本编写,但可以调用通过追踪生成的编码器模块。
示例(在脚本中调用追踪函数):
import torch
def foo(x, y):
return 2 * x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
@torch.jit.script
def bar(x):
return traced_foo(x, x)
被追踪的函数可以调用脚本函数。当模型的大部分只是前馈网络,而其中一小部分需要控制流时,这非常有用。在被追踪函数调用的脚本函数内部,控制流会被正确保留。
示例(在被追踪函数中调用脚本函数):
import torch
@torch.jit.script
def foo(x, y):
if x.max() y.max():
r = x
else:
r = y
return r
def bar(x, y, z):
return foo(x, y) + z
traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))
该组合方式同样适用于 nn.Module
,它可以通过追踪生成一个子模块,该子模块可从脚本模块的方法中调用。
示例(使用追踪模块):
import torch
import torchvision
class MyScriptModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
.resize_(1, 3, 1, 1))
self.resnet = torch.jit.trace(torchvision.models.resnet18(), torch.rand(1, 3, 224, 224))
def forward(self, input):
return self.resnet(input - self.means)
my_script_module = torch.jit.script(MyScriptModule())
TorchScript 语言
TorchScript 是 Python 的一个静态类型子集,因此许多 Python 特性可以直接应用于 TorchScript。详情请参阅完整的 TorchScript 语言参考。
内置函数与模块
TorchScript 支持使用大多数 PyTorch 函数和许多 Python 内置功能。完整支持函数列表请参阅 TorchScript 内置函数。
PyTorch 函数与模块
TorchScript 支持 PyTorch 提供的张量和神经网络函数子集。Tensor 上的大多数方法、torch
命名空间中的所有函数、torch.nn.functional
中的全部函数以及 torch.nn
中的大多数模块均可被 TorchScript 支持。
不支持的 PyTorch 函数和模块列表请参阅 TorchScript 不支持的 PyTorch 结构。
Python 函数与模块
TorchScript 支持许多 Python 的内置函数。
math
模块同样受支持(详见数学模块),但其他 Python 模块(无论是内置还是第三方)均不支持。
Python 语言参考对比
如需查看支持的 Python 功能完整列表,请参阅 Python 语言参考覆盖范围。
调试
禁用 JIT 进行调试
PYTORCH_JIT
设置环境变量 PYTORCH_JIT=0
将禁用所有脚本和追踪注解。当您的 TorchScript 模型中出现难以调试的错误时,可以通过此标志强制所有代码以原生 Python 方式运行。由于该标志会禁用 TorchScript(脚本化和追踪),您可以使用诸如 pdb
之类的工具来调试模型代码。例如:
@torch.jit.script
def scripted_fn(x : torch.Tensor):
for i in range(12):
x = x + x
return x
def fn(x):
x = torch.neg(x)
import pdb; pdb.set_trace()
return scripted_fn(x)
traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))
traced_fn(torch.rand(3, 4))
使用 pdb
调试此脚本时一切正常,但当我们调用 @torch.jit.script
函数时会失效。我们可以全局禁用 JIT 功能,这样就能将 @torch.jit.script
作为普通 Python 函数调用而不进行编译。如果上述脚本名为 disable_jit_example.py
,可以通过以下方式调用:
$ PYTORCH_JIT=0 python disable_jit_example.py
这样我们就能像普通 Python 函数一样单步调试 @torch.jit.script
装饰的函数。如需禁用特定函数的 TorchScript 编译器,请参阅 @torch.jit.ignore
。
代码检查
TorchScript 为所有 ScriptModule
实例提供了代码美化打印器。该美化打印器能够将脚本方法的代码以有效的 Python 语法形式呈现。例如:
@torch.jit.script
def foo(len):
# type: (int) -torch.Tensor
rv = torch.zeros(3, 4)
for i in range(len):
if i < 10:
rv = rv - 1.0
else:
rv = rv + 1.0
return rv
print(foo.code)
一个包含单个 forward
方法的 ScriptModule
会有一个 code
属性,你可以通过该属性来检查 ScriptModule
的代码。
如果 ScriptModule
包含多个方法,你需要访问方法本身的 .code
属性,而不是模块的。我们可以通过访问 .foo.code
来检查 ScriptModule
上名为 foo
的方法代码。
上面的示例会产生以下输出:
def foo(len: int) -Tensor:
rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
rv0 = rv
for i in range(len):
if torch.lt(i, 10):
rv1 = torch.sub(rv0, 1., 1)
else:
rv1 = torch.add(rv0, 1., 1)
rv0 = rv1
return rv0
这是 TorchScript 对 forward
方法代码的编译结果。
您可以通过它来验证 TorchScript(无论是通过追踪还是脚本化方式)是否正确捕获了您的模型代码。
解读图结构
TorchScript 在代码美化打印器之下还有一个更低层次的表示形式,即 IR(中间表示)图。
TorchScript 采用静态单赋值(SSA)中间表示(IR)来描述计算过程。这种格式的指令由 ATen(PyTorch 的 C++ 后端)运算符和其他基础运算符组成,包括用于循环和条件判断的控制流运算符。例如:
@torch.jit.script
def foo(len):
# type: (int) -torch.Tensor
rv = torch.zeros(3, 4)
for i in range(len):
if i < 10:
rv = rv - 1.0
else:
rv = rv + 1.0
return rv
print(foo.graph)
graph
遵循与代码检查章节中描述的相同规则,涉及 forward
方法查找。
上面的示例脚本生成如下图表:
graph(%len.1 : int):
%24 : int = prim::Constant[value=1]()
%17 : bool = prim::Constant[value=1]() # test.py:10:5
%12 : bool? = prim::Constant()
%10 : Device? = prim::Constant()
%6 : int? = prim::Constant()
%1 : int = prim::Constant[value=3]() # test.py:9:22
%2 : int = prim::Constant[value=4]() # test.py:9:25
%20 : int = prim::Constant[value=10]() # test.py:11:16
%23 : float = prim::Constant[value=1]() # test.py:12:23
%4 : int[] = prim::ListConstruct(%1, %2)
%rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10
%rv : Tensor = prim::Loop(%len.1, %17, %rv.1) # test.py:10:5
block0(%i.1 : int, %rv.14 : Tensor):
%21 : bool = aten::lt(%i.1, %20) # test.py:11:12
%rv.13 : Tensor = prim::If(%21) # test.py:11:9
block0():
%rv.3 : Tensor = aten::sub(%rv.14, %23, %24) # test.py:12:18
-(%rv.3)
block1():
%rv.6 : Tensor = aten::add(%rv.14, %23, %24) # test.py:14:18
-(%rv.6)
-(%17, %rv.13)
return (%rv)
以指令%rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10
为例:
%rv.1 : Tensor
表示我们将输出赋值给一个名为rv.1
的(唯一)值,该值的类型为Tensor
,且我们不知道其具体形状。aten::zeros
是运算符(相当于torch.zeros
),输入列表(%4, %6, %6, %10, %12)
指定了应传入哪些作用域中的值作为输入。像aten::zeros
这样的内置函数的模式可以在内置函数中找到。# test.py:9:10
是生成此指令的原始源文件中的位置。在本例中,它位于名为test.py的文件中,第9行,第10个字符。
注意,运算符也可以关联blocks
,即prim::Loop
和prim::If
运算符。在图形打印输出中,这些运算符的格式会反映其等效的源代码形式,以便于调试。
可以按照所示方式检查图形,以确认由ScriptModule
描述的计算是否正确,无论是自动还是手动方式,如下所述。
追踪器
追踪边界情况
在某些特殊情况下,对给定Python函数/模块的追踪可能无法准确反映底层代码的真实行为。这些情况包括:
- 依赖于输入的控制流追踪(例如张量形状)
- 张量视图就地操作的追踪(例如赋值语句左侧的索引操作)
请注意,这些情况未来实际上可能会变得可追踪。
自动追踪检查
自动捕获追踪中多种错误的一种方法是使用 torch.jit.trace()
API 上的 check_inputs
参数。check_inputs
接收一个由输入元组组成的列表,这些输入将用于重新追踪计算并验证结果。例如:
def loop_in_traced_fn(x):
result = x[0]
for i in range(x.size(0)):
result = result * x[i]
return result
inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]
traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)
提供以下诊断信息:
ERROR: Graphs differed across invocations!
Graph diff:
graph(%x : Tensor) {
%1 : int = prim::Constant[value=0]()
%2 : int = prim::Constant[value=0]()
%result.1 : Tensor = aten::select(%x, %1, %2)
%4 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=0]()
%6 : Tensor = aten::select(%x, %4, %5)
%result.2 : Tensor = aten::mul(%result.1, %6)
%8 : int = prim::Constant[value=0]()
%9 : int = prim::Constant[value=1]()
%10 : Tensor = aten::select(%x, %8, %9)
- %result : Tensor = aten::mul(%result.2, %10)
+ %result.3 : Tensor = aten::mul(%result.2, %10)
? ++
%12 : int = prim::Constant[value=0]()
%13 : int = prim::Constant[value=2]()
%14 : Tensor = aten::select(%x, %12, %13)
+ %result : Tensor = aten::mul(%result.3, %14)
+ %16 : int = prim::Constant[value=0]()
+ %17 : int = prim::Constant[value=3]()
+ %18 : Tensor = aten::select(%x, %16, %17)
- %15 : Tensor = aten::mul(%result, %14)
? ^ ^
+ %19 : Tensor = aten::mul(%result, %18)
? ^ ^
- return (%15);
? ^
+ return (%19);
? ^
}
这条消息表明,计算过程在我们首次追踪时和使用 check_inputs
进行追踪时出现了差异。实际上,loop_in_traced_fn
函数体中的循环依赖于输入 x
的形状,因此当我们尝试使用不同形状的另一个 x
时,追踪结果就会发生变化。
对于这种情况,可以使用 torch.jit.script()
来捕获此类数据依赖的控制流:
def fn(x):
result = x[0]
for i in range(x.size(0)):
result = result * x[i]
return result
inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]
scripted_fn = torch.jit.script(fn)
print(scripted_fn.graph)
#print(str(scripted_fn.graph).strip())
for input_tuple in [inputs] + check_inputs:
torch.testing.assert_close(fn(input_tuple), scripted_fn(input_tuple))
输出结果为:
graph(%x : Tensor) {
%5 : bool = prim::Constant[value=1]()
%1 : int = prim::Constant[value=0]()
%result.1 : Tensor = aten::select(%x, %1, %1)
%4 : int = aten::size(%x, %1)
%result : Tensor = prim::Loop(%4, %5, %result.1)
block0(%i : int, %7 : Tensor) {
%10 : Tensor = aten::select(%x, %1, %i)
%result.2 : Tensor = aten::mul(%7, %10)
-(%5, %result.2)
}
return (%result);
}
追踪器警告
追踪器会对追踪计算中的几种问题模式产生警告。例如,假设对一个包含张量切片(视图)进行原地赋值的函数进行追踪:
def fill_row_zero(x):
x[0] = torch.rand(x.shape[1:2])
return x
traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)
生成多个警告信息和一个直接返回输入数据的图表
fill_row_zero.py:4: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
x[0] = torch.rand(x.shape[1:2])
fill_row_zero.py:6: TracerWarning: Output nr 1、of the traced function does not match the corresponding output of the Python function. Detailed error:
Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 1] (0.09115803241729736 vs. 0.6782537698745728) and 3 other locations (33.00%)
traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
graph(%0 : Float(3, 4)) {
return (%0);
}
我们可以通过修改代码来解决这个问题,不再使用原地更新,而是使用torch.cat
来非原地构建结果张量:
def fill_row_zero(x):
x = torch.cat((torch.rand(1, x.shape[1:2]), x[1:2]), dim=0)
return x
traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)
常见问题解答
Q: 我想在GPU上训练模型,然后在CPU上进行推理。有哪些最佳实践?
首先将模型从GPU转换到CPU,然后保存它,如下所示:
cpu_model = gpu_model.cpu()
sample_input_cpu = sample_input_gpu.cpu()
traced_cpu = torch.jit.trace(cpu_model, sample_input_cpu)
torch.jit.save(traced_cpu, "cpu.pt")
traced_gpu = torch.jit.trace(gpu_model, sample_input_gpu)
torch.jit.save(traced_gpu, "gpu.pt")
# ... later, when using the model:
if use_gpu:
model = torch.jit.load("gpu.pt")
else:
model = torch.jit.load("cpu.pt")
model(input)
推荐采用此方式,因为追踪器可能会观测到张量在特定设备上创建的过程,直接转换已加载的模型可能产生意外效果。在保存模型之前进行类型转换,可确保追踪器获取正确的设备信息。
问:如何在ScriptModule
上存储属性?
假设我们有如下模型:
import torch
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.x = 2
def forward(self):
return self.x
m = torch.jit.script(Model())
如果直接实例化 Model
会导致编译错误,因为编译器无法识别 x
。有四种方法可以让编译器识别 ScriptModule
上的属性:
1、nn.Parameter
- 用 nn.Parameter
包装的值会像在 nn.Module
中一样正常工作。
2、register_buffer
- 用 register_buffer
包装的值会像在 nn.Module
中一样正常工作。这相当于一个类型为 Tensor
的属性(见第4点)。
3、常量 - 将类成员标注为 Final
(或在类定义级别将其添加到名为 __constants__
的列表中)会将包含的名称标记为常量。常量会直接保存在模型的代码中。详情请参阅内置常量。
4、属性 - 支持类型的值可以作为可变属性添加。大多数类型可以自动推断,但有些可能需要明确指定,详情请参阅模块属性。
问题:我想追踪模块的方法,但一直遇到这个错误:
RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
这个错误通常意味着你正在追踪的方法使用了模块的参数,而你传递的是模块的方法而不是模块实例(例如 my_module_instance.forward
与 my_module_instance
)。
使用模块的方法调用
trace
会将模块参数(可能需要梯度)捕获为常量。另一方面,使用模块实例(例如
my_module
)调用trace
会创建一个新模块,并正确地将参数复制到新模块中,因此如果需要,它们可以累积梯度。
要追踪模块上的特定方法,请参阅 torch.jit.trace_module
。
已知问题
当你在 TorchScript 中使用 Sequential
时,某些 Sequential
子模块的输入可能会被错误推断为 Tensor
,即使它们被标注为其他类型。标准解决方案是继承 nn.Sequential
并重新声明 forward
方法,确保输入类型正确。
附录
迁移至 PyTorch 1.2 递归脚本化 API
本节详细说明 PyTorch 1.2 中 TorchScript 的变化。如果你是 TorchScript 的新用户,可以跳过这部分内容。PyTorch 1.2 对 TorchScript API 主要做了两处改动:
1、torch.jit.script
现在会尝试递归编译遇到的函数、方法和类。一旦调用 torch.jit.script
,编译过程将采用"默认启用"而非"手动启用"机制。
2、torch.jit.script(nn_module_instance)
现已成为创建 ScriptModule
的推荐方式,取代了原先继承 torch.jit.ScriptModule
的做法。
这些改动共同提供了一个更简单易用的 API,用于将你的 nn.Module
转换为可优化并在非 Python 环境中执行的 ScriptModule
。
新的使用方式如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
my_model = Model()
my_scripted_model = torch.jit.script(my_model)
- 模块的
forward
方法默认会被编译。从forward
中调用的方法会按它们在forward
中的使用顺序进行惰性编译。 - 若要编译未被
forward
调用的其他方法,需添加@torch.jit.export
装饰器。 - 如需阻止编译器编译某个方法,可添加
@torch.jit.ignore
或@torch.jit.unused
。@ignore
会保留 Python 方法调用,而@unused
会将其替换为异常。@ignored
方法不可导出;@unused
方法可以导出。 - 大多数属性类型可自动推断,因此无需使用
torch.jit.Attribute
。对于空容器类型,建议使用 PEP 526 风格 的类注解来声明类型。 - 常量可通过
Final
类注解标记,无需将成员名加入__constants__
列表。 - 可用 Python 3 类型提示替代
torch.jit.annotate
函数。
基于这些变更,以下内容已被弃用,新代码中不应继续使用:
@torch.jit.script_method
装饰器- 继承自
torch.jit.ScriptModule
的类 torch.jit.Attribute
包装类__constants__
数组torch.jit.annotate
函数
模块
警告:在 PyTorch 1.2 中,@torch.jit.ignore
注解的行为发生了变化。在 PyTorch 1.2 之前,@ignore
装饰器用于使函数或方法可以从导出的代码中调用。要恢复此功能,请使用 @torch.jit.unused()
。现在 @torch.jit.ignore
等同于 @torch.jit.ignore(drop=False)
。详情请参阅 @torch.jit.ignore
和 @torch.jit.unused
。
当传递给 torch.jit.script
函数时,torch.nn.Module
的数据会被复制到 ScriptModule
中,并由 TorchScript 编译器编译该模块。默认情况下,模块的 forward
方法会被编译。从 forward
调用的方法会按照它们在 forward
中的使用顺序延迟编译,同时也会编译任何带有 @torch.jit.export
注解的方法。
torch.jit.export(fn)
这个装饰器用于标记nn.Module
中的某个方法作为ScriptModule
的入口点,该方法将被编译。
forward
方法默认被视为入口点,因此不需要此装饰器。
从forward
调用的函数和方法会在编译器处理时自动编译,所以它们也不需要这个装饰器。
示例(在方法上使用@torch.jit.export
装饰器):
import torch
import torch.nn as nn
class MyModule(nn.Module):
def implicitly_compiled_method(self, x):
return x + 99
# `forward` is implicitly decorated with `@torch.jit.export`, # so adding it here would have no effect
def forward(self, x):
return x + 10
@torch.jit.export
def another_forward(self, x):
# When the compiler sees this call, it will compile
# `implicitly_compiled_method`
return self.implicitly_compiled_method(x)
def unused_method(self, x):
return x - 20
# `m` will contain compiled methods:
# `forward`
# `another_forward`
# `implicitly_compiled_method`
# `unused_method` will not be compiled since it was not called from # any compiled methods and wasn't decorated with `@torch.jit.export`
m = torch.jit.script(MyModule())
函数
函数基本保持不变,必要时可以使用 @torch.jit.ignore
或 torch.jit.unused
装饰器进行修饰。
# Same behavior as pre-PyTorch 1.2
@torch.jit.script
def some_fn():
return 2
# Marks a function as ignored, if nothing
# ever calls it then this has no effect
@torch.jit.ignore
def some_fn2():
return 2
# As with ignore, if nothing calls it then it has no effect.
# If it is called in script it is replaced with an exception.
@torch.jit.unused
def some_fn3():
import pdb; pdb.set_trace()
return 4
# Doesn't do anything, this function is already
# the main entry point
@torch.jit.export
def some_fn4():
return 2
TorchScript 类
警告:TorchScript 类的支持目前处于实验阶段。当前最适合用于简单的记录式类型(可理解为附加了方法的NamedTuple
)。
用户定义的 TorchScript 类中所有内容默认会被导出,如有需要可以使用 @torch.jit.ignore
装饰器来忽略特定函数。
属性
TorchScript 编译器需要知道模块属性的类型。大多数类型可以通过成员的值推断出来。空列表和字典无法推断其类型,必须使用 PEP 526 风格 的类注解显式标注类型。如果某个类型既无法推断又未显式标注,则不会将其作为属性添加到最终的 ScriptModule
中。
旧版 API:
from typing import Dict
import torch
class MyModule(torch.jit.ScriptModule):
def __init__(self):
super().__init__()
self.my_dict = torch.jit.Attribute({}, Dict[str, int])
self.my_int = torch.jit.Attribute(20, int)
m = MyModule()
新API:
from typing import Dict
class MyModule(torch.nn.Module):
my_dict: Dict[str, int]
def __init__(self):
super().__init__()
# This type cannot be inferred and must be specified
self.my_dict = {}
# The attribute type here is inferred to be `int`
self.my_int = 20
def forward(self):
pass
m = torch.jit.script(MyModule())
常量
Final
类型构造器可用于将成员标记为常量。如果成员未被标记为常量,它们将被复制到生成的 ScriptModule
中作为属性。使用 Final
可以在已知值固定的情况下开启优化机会,并提供额外的类型安全性。
旧版 API:
class MyModule(torch.jit.ScriptModule):
__constants__ = ['my_constant']
def __init__(self):
super().__init__()
self.my_constant = 2
def forward(self):
pass
m = MyModule()
新 API:
from typing import Final
class MyModule(torch.nn.Module):
my_constant: Final[int]
def __init__(self):
super().__init__()
self.my_constant = 2
def forward(self):
pass
m = torch.jit.script(MyModule())
变量
容器默认具有 Tensor
类型且不可为空(更多信息请参阅默认类型章节)。之前使用 torch.jit.annotate
来告知 TorchScript 编译器类型信息,现在已支持 Python 3 风格的类型提示。
import torch
from typing import Dict, Optional
@torch.jit.script
def make_dict(flag: bool):
x: Dict[str, int] = {}
x['hi'] = 2
b: Optional[int] = None
if flag:
b = 2
return x, b
融合后端
TorchScript 执行优化提供了几种融合后端选择。CPU 上的默认融合器是 NNC,它支持 CPU 和 GPU 的融合操作。而 GPU 上的默认融合器是 NVFuser,它支持更广泛的运算符,并已证明能生成具有更高吞吐量的内核。有关使用和调试的更多细节,请参阅 NVFuser 文档。
参考资料
torch.linalg
常用线性代数运算。
有关常见数值边界情况的说明,请参阅线性代数 (torch.linalg)。
矩阵属性
norm |
计算向量或矩阵范数 |
---|---|
vector_norm |
计算向量范数 |
matrix_norm |
计算矩阵范数 |
diagonal |
torch.diagonal() 的别名,默认参数为dim1 = -2, dim2 = -1 |
det |
计算方阵的行列式 |
slogdet |
计算方阵行列式绝对值的符号和自然对数 |
cond |
计算矩阵关于某个矩阵范数的条件数 |
matrix_rank |
计算矩阵的数值秩 |
矩阵分解
cholesky |
计算复数厄米特矩阵或实数对称正定矩阵的Cholesky分解 |
---|---|
qr |
计算矩阵的QR分解 |
lu |
计算带部分主元消去的矩阵LU分解 |
lu_factor |
计算带部分主元消去的矩阵LU分解的紧凑表示形式 |
eig |
计算方阵的特征值分解(如果存在) |
eigvals |
计算方阵的特征值 |
eigh |
计算复数厄米特矩阵或实数对称矩阵的特征值分解 |
eigvalsh |
计算复数厄米特矩阵或实数对称矩阵的特征值 |
svd |
计算矩阵的奇异值分解(SVD) |
svdvals |
计算矩阵的奇异值 |
求解器
solve |
计算具有唯一解的线性方程组的解。 |
---|---|
solve_triangular |
计算具有唯一解的三角线性方程组的解。 |
lu_solve |
在给定LU分解的情况下,计算具有唯一解的线性方程组的解。 |
lstsq |
计算线性方程组的最小二乘解。 |
逆矩阵
inv |
计算方阵的逆矩阵(如果存在)。 |
---|---|
pinv |
计算矩阵的伪逆(Moore-Penrose 逆)。 |
矩阵函数
matrix_exp |
计算方阵的矩阵指数 |
---|---|
matrix_power |
计算方阵的整数n次幂 |
矩阵运算
cross |
计算两个三维向量的叉积 |
---|---|
matmul |
torch.matmul() 的别名 |
vecdot |
沿指定维度计算两批向量的点积 |
multi_dot |
通过优化乘法顺序来高效计算两个及以上矩阵的连乘,实现最少算术运算 |
householder_product |
计算Householder矩阵乘积的前n列 |
张量运算
tensorinv |
计算 torch.tensordot() 的乘法逆元。 |
---|---|
tensorsolve |
计算方程组 torch.tensordot(A, X) = B 的解 X。 |
杂项函数
vander |
生成范德蒙矩阵 |
---|
实验性函数
cholesky_ex |
计算复厄米特矩阵或实对称正定矩阵的Cholesky分解。 |
---|---|
inv_ex |
计算可逆方阵的逆矩阵。 |
solve_ex |
solve() 的一个变体,除非 check_errors = True,否则不执行错误检查。 |
lu_factor_ex |
这是 lu_factor() 的一个变体,除非 check_errors = True,否则不执行错误检查。 |
ldl_factor |
计算厄米特矩阵或对称矩阵(可能不定)的LDL分解的紧凑表示。 |
ldl_factor_ex |
这是 ldl_factor() 的一个变体,除非 check_errors = True,否则不执行错误检查。 |
ldl_solve |
使用LDL分解计算线性方程组的解。 |
torch.monitor
警告:本模块为原型版本,其接口和功能可能在未来的PyTorch版本中未经通知即发生变更。
torch.monitor
提供了从PyTorch记录事件和计数器的接口。
统计接口设计用于追踪高层次指标,这些指标会定期记录以监控系统性能。由于统计数据会按特定窗口大小进行聚合,您可以在关键循环中记录它们而对性能影响极小。
对于不频繁发生的事件或数值(如损失值、准确率、使用情况追踪),可以直接使用事件接口。
可以注册事件处理器来处理事件,并将其传递至外部事件接收器。
API 参考
class torch.monitor.Aggregation
以下是可用的统计聚合类型:
成员说明:
VALUE :VALUE 返回最后添加的值。
MEAN :MEAN 计算所有添加值的算术平均值。
COUNT :COUNT 返回已添加值的总数量。
SUM :SUM 返回所有添加值的总和。
MAX :MAX 返回添加值中的最大值。
MIN :MIN 返回添加值中的最小值。
property name
class torch.monitor.Stat
Stat 用于在固定时间间隔内高效计算汇总统计量。Stat 会每隔 window_size
时长将统计结果记录为一个 Event 事件。当时间窗口关闭时,统计结果会通过事件处理器以 torch.monitor.Stat
事件的形式记录。
建议将 window_size
设置为较高的值(例如 60 秒),以避免记录过多事件。Stat 使用毫秒级精度。
如果设置了 max_samples
参数,Stat 会通过丢弃超出限制的 add
调用来限制每个窗口的最大样本数。未设置该参数时,窗口期内所有的 add
调用都会被纳入统计。这个可选字段主要用于在样本量可能波动的情况下,使不同窗口期的聚合数据更具可比性。
当 Stat 对象被销毁时,即使当前时间窗口尚未结束,它也会记录所有剩余数据。
__init__(self: torch._C._monitor.Stat, name: str , aggregations: list [torch._C._monitor.Aggregation], window_size: [datetime.timedelta, max_samples: int = 9223372036854775807) → None
构造 Stat
对象。
add(self: torch._C._monitor.Stat, v: float) → None
Adds a value to the stat to be aggregated according to the configured stat type and aggregations.
property count
Number of data points that have currently been collected. Resets
once the event has been logged.
get(self: torch._C._monitor.Stat) → dict[torch._C._monitor.Aggregation, float]
Returns the current value of the stat, primarily for testing purposes. If the stat has logged and no additional values have been added this will be zero.
property name
The name of the stat that was set during creation.
class torch.monitor.data_value_t
data_value_t is one of str
, float
, int
, bool
.
class torch.monitor.Event
Event represents a specific typed event to be logged. This can represent high-level data points such as loss or accuracy per epoch or more low-level aggregations such as through the Stats provided through this library.
All Events of the same type should have the same name so downstream
handlers can correctly process them.
__init__(self: torch._C._monitor.Event, name: str, timestamp: datetime.datetime, data: dict[str, data_value_t]) → None
Constructs the Event
.
property data
The structured data contained within the Event
.
property name
The name of the Event
.
property timestamp
The timestamp when the Event
happened.
class torch.monitor.EventHandlerHandle
EventHandlerHandle is a wrapper type returned by register_event_handler
used to unregister the handler via unregister_event_handler
. This cannot be directly initialized.
torch.monitor.log_event(event: torch._C._monitor.Event) → None
log_event logs the specified event to all of the registered event handlers. It’s up to the event handlers to log the event out to the corresponding event sink.
If there are no event handlers registered this method is a no-op.
torch.monitor.register_event_handler(callback: Callable[[torch._C._monitor.Event], None ]) → torch._C._monitor.EventHandlerHandl)
register_event_handler registers a callback to be called whenever an event is logged via log_event
. These handlers should avoid blocking the main thread since that may interfere with training as they run during the log_event
call.
torch.monitor.unregister_event_handler(handler: torch._C._monitor.EventHandlerHandl)) → None
unregister_event_handler unregisters the EventHandlerHandle
returned after calling register_event_handler
. After this returns the event handler will no longer receive events.
class torch.monitor.TensorboardEventHandler(writer)
TensorboardEventHandler is an event handler that will write known events to the provided SummaryWriter.
This currently only supports torch.monitor.Stat
events which are logged as scalars.
Example :
>>> from torch.utils.tensorboard import SummaryWriter
>>> from torch.monitor import TensorboardEventHandler, register_event_handler
>>> writer = SummaryWriter("log_dir")
>>> register_event_handler(TensorboardEventHandler(writer))
__init__(writer)
构建 TensorboardEventHandler
。
torch.signal 模块
torch.signal 模块的设计灵感来源于 SciPy 的 signal 模块。
torch.signal.windows 窗口函数
bartlett |
计算巴特利特(Bartlett)窗口 |
---|---|
blackman |
计算布莱克曼(Blackman)窗口 |
cosine |
计算余弦窗口,实现方式与SciPy保持一致 |
exponential |
计算指数窗口 |
gaussian |
计算高斯窗口 |
general_cosine |
计算广义余弦窗口 |
general_hamming |
计算广义汉明(Hamming)窗口 |
hamming |
计算汉明(Hamming)窗口 |
hann |
计算汉宁(Hann)窗口 |
kaiser |
计算凯撒(Kaiser)窗口 |
nuttall |
根据Nuttall方法计算最小4项布莱克曼-哈里斯(Blackman-Harris)窗口 |
torch.special
torch.special 模块的设计灵感来源于 SciPy 的 special 模块。
函数
torch.special.airy_ai(input, *, out=None) → Tensor
Airy 函数 A i ( i n p u t ) Ai(input) Ai(input)。
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , 可选)
– 输出张量。
torch.special.bessel_j0(input, *, out=None) → Tensor
第一类零阶贝塞尔函数。
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , 可选 )
– 输出张量。
torch.special.bessel_j1(input, *, out=None) → Tensor
第一类111阶贝塞尔函数。
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , optional)
– 输出张量。
torch.special.digamma(input, *, out=None) → Tensor
计算输入张量的伽玛函数的对数导数。
ϝ ( x ) = d d x ln ( Γ ( x ) ) = Γ ′ ( x ) Γ ( x ) \digamma(x) = \frac{d}{dx} \ln\left(\Gamma\left(x\right)\right) = \frac{\Gamma'(x)}{\Gamma(x)} ϝ(x)=dxdln(Γ(x))=Γ(x)Γ′(x)
参数
input ( Tensor )
– 用于计算digamma函数的输入张量
关键字参数
out ( Tensor , optional)
– 输出张量
注意:此函数与SciPy的scipy.special.digamma功能相似。
注意:从PyTorch 1.8开始,digamma函数在输入为0时会返回-Inf,而此前版本会返回NaN。
示例:
>>> a = torch.tensor([1, 0.5])
>>> torch.special.digamma(a)
tensor([-0.5772, -1.9635])
torch.special.entr(input, *, out=None) → Tensor
计算输入张量 input
中各元素的熵(定义如下)。
$$
\begin{align}
\text{entr(x)} = \begin{cases}
-x * \ln(x) & x 0 \
0 & x = 0.0 \
-\infty & x < 0
\end{cases}
\end{align}
$$
参数说明
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , optional)
– 输出张量。
示例:
>>> a = torch.arange(-0.5, 1, 0.5)
>>> a tensor([-0.5000, 0.0000, 0.5000])
>>> torch.special.entr(a)
tensor([-inf, 0.0000, 0.3466])
torch.special.erf(input, *, out=None) → Tensor
计算输入张量的误差函数。误差函数定义如下:
$$\mathrm{erf}(x) = \frac{2}{\sqrt{\pi}} \int_{0}^{x} e{-t2} dt
$$
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , optional)
– 输出张量。
示例:
>>> torch.special.erf(torch.tensor([0, -1., 10.]))
tensor([0.0000, -0.8427, 1.0000])
torch.special.erfc(input, *, out=None) → Tensor
计算输入张量的互补误差函数。
互补误差函数的定义如下:
$$
\mathrm{erfc}(x) = 1 - \frac{2}{\sqrt{\pi}} \int_{0}^{x} e{-t2} dt
$$
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , optional)
– 输出张量。
示例:
>>> torch.special.erfc(torch.tensor([0, -1., 10.]))
tensor([1.0000, 1.8427, 0.0000])
torch.special.erfcx(input, *, out=None) → Tensor
计算input
中每个元素的缩放互补误差函数。
缩放互补误差函数的定义如下:
e r f c x ( x ) = e x 2 e r f c ( x ) \mathrm{erfcx}(x) = e^{x^2} erfc(x) erfcx(x)=ex2erfc(x)
erfcx(x)=ex2erfc(x)
参数
input ( Tensor )
- 输入张量。
关键字参数
out ( Tensor , 可选)
- 输出张量。
示例
>>> torch.special.erfcx(torch.tensor([0, -1., 10.]))
tensor([1.0000, 5.0090, 0.0561])
torch.special.erfinv(input, *, out=None) → Tensor
计算输入张量的反误差函数值。
反误差函数在区间 (−1,1) 内定义为:
e r f i n v ( e r f ( x ) ) = x \mathrm{erfinv(erf(x))}= x erfinv(erf(x))=x
参数说明
input ( Tensor )
- 输入张量
关键字参数
out ( Tensor , 可选)
- 输出张量
使用示例:
>>> torch.special.erfinv(torch.tensor([0, 0.5, -1.]))
tensor([0.0000, 0.4769, -inf])
torch.special.exp2(input, *, out=None) → Tensor
计算 input
的以 2 为底的指数函数。
y i = 2 x i y_{i} = 2^{x_{i}} yi=2xi
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , 可选)
– 输出张量。
示例:
>>> torch.special.exp2(torch.tensor([0, math.log2(2.), 3, 4]))
tensor([1., 2., 8., 16.])
torch.special.expit(input, *, out=None) → Tensor
计算输入张量 input
各元素的 expit 值(也称为 logistic sigmoid 函数)。
o u t i = 1 1 + e i − i n p u t {out}_{i} = \frac{1}{1 + e^{-input}_{i}} outi=1+ei−input1
outi=1+e−inputi1
参数
input ( Tensor )
- 输入张量。
关键字参数
out ( Tensor , 可选)
- 输出张量。
示例:
>>> t = torch.randn(4)
>>> t
tensor([0.9213, 1.0887, -0.8858, -1.7683])
>>> torch.special.expit(t)
tensor([0.7153, 0.7481, 0.2920, 0.1458])
torch.special.expm1(input, *, out=None) → Tensor
计算输入张量input
各元素的指数值减1。
y i = e x i − 1 y_{i} = e^{x_{i}} - 1 yi=exi−1
注意:对于较小的x值,该函数比直接计算exp(x) - 1能提供更高的精度。
参数说明
input ( Tensor )
- 输入张量。
关键字参数
out ( Tensor , 可选)
- 输出张量。
使用示例:
>>> torch.special.expm1(torch.tensor([0, math.log(2.)]))
tensor([0., 1.])
torch.special.gammainc(input, other, *, out=None) → Tensor
计算正则化的下不完全伽马函数:
$$
\text{out}_{i} = \frac{1}{\Gamma(\text{input}_i)} \int_0^{\text{other}_i} t^{\text{input}_i-1} e^{-t} dt
$$
其中 i n p u t i input_i inputi和 o t h e r i other_i otheri均为弱正数且至少有一个严格为正数。若两者均为零或任一为负数,则 out i = nan \text{out}_i=\text{nan} outi=nan。上述公式中的 Γ \Gamma Γ表示伽马函数:
$$
\Gamma(\text{input}_i) = \int_0^\infty t^{(\text{input}_i-1)} e^{-t} dt.
$$
相关函数请参阅torch.special.gammaincc()
和torch.special.gammaln()
。
支持广播至通用形状和浮点输入。
注意:目前不支持对input
的反向传播。如需此功能,请在PyTorch的Github上提交issue。
参数
input ( Tensor )
– 第一个非负输入张量other ( Tensor )
– 第二个非负输入张量
关键字参数
out ( Tensor , optional)
– 输出张量
示例:
>>> a1 = torch.tensor([4.0])
>>> a2 = torch.tensor([3.0, 4.0, 5.0])
>>> a = torch.special.gammaincc(a1, a2)
tensor([0.3528, 0.5665, 0.7350])
tensor([0.3528, 0.5665, 0.7350])
>>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2)
tensor([1., 1., 1.])
torch.special.gammaincc(input, other, *, out=None) → Tensor
计算正则化上不完全伽马函数:
$$\text{out}_{i} = \frac{1}{\Gamma(\text{input}i)} \int{\text{other}_i}^{\infty} t^{\text{input}_i-1} e^{-t} dt
$$
其中 i n p u t i input_i inputi 和 o t h e r i other_i otheri 均为弱正数,且至少有一个严格为正数。若两者均为零或任一为负数,则 o u t i = n a n out_i=nan outi=nan。上式中的 Γ ( ⋅ ) \Gamma(\cdot) Γ(⋅) 表示伽马函数,
$$\Gamma(\text{input}_i) = \int_0^\infty t^{(\text{input}_i-1)} e^{-t} dt.
$$
相关函数请参阅 torch.special.gammainc()
和 torch.special.gammaln()
。
支持广播至相同形状及浮点输入。
注意:目前不支持对 input
的反向传播。如需此功能,请在 PyTorch 的 Github 上提交 issue。
参数
input ( Tensor )
– 第一个非负输入张量other ( Tensor )
– 第二个非负输入张量
关键字参数
out ( Tensor , optional)
– 输出张量。
示例:
>>> a1 = torch.tensor([4.0])
>>> a2 = torch.tensor([3.0, 4.0, 5.0])
>>> a = torch.special.gammaincc(a1, a2)
tensor([0.6472, 0.4335, 0.2650])
>>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2)
tensor([1., 1., 1.])
torch.special.gammaln(input, *, out=None) → Tensor
计算输入张量绝对值的伽玛函数的自然对数。
outi=lnΓ(∣inputi∣)\text{out}{i} = \ln \Gamma(|\text{input}{i}|)
outi=lnΓ(∣inputi∣)
参数
input ( Tensor )
- 输入张量。
关键字参数
out ( Tensor , optional)
- 输出张量。
示例
>>> a = torch.arange(0.5, 2, 0.5)
>>> torch.special.gammaln(a)
tensor([0.5724, 0.0000, -0.1208])
torch.special.i0(input, *, out=None) → Tensor
计算input
中每个元素的第一类零阶修正贝塞尔函数。
$$\text{out}{i} = I_0(\text{input}{i}) = \sum_{k=0}^{\infty} \frac{(\text{input}_{i}2/4)k}{(k!)^2}
$$
参数
input ( Tensor )
- 输入张量
关键字参数
out ( Tensor , 可选)
- 输出张量
示例
>>> torch.i0(torch.arange(5, dtype=torch.float32))
tensor([1.0000, 1.2661, 2.2796, 4.8808, 11.3019])
torch.special.i0e(input, *, out=None) → Tensor
为 input
的每个元素计算指数缩放的第一类零阶修正贝塞尔函数(定义如下)。
$$\text{out}{i} = \exp(-|x|) * i0(x) = \exp(-|x|) * \sum{k=0}^{\infty} \frac{(\text{input}_{i}2/4)k}{(k!)^2}
$$
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , optional)
– 输出张量。
示例:
>>> torch.special.i0e(torch.arange(5, dtype=torch.float32))
tensor([1.0000, 0.4658, 0.3085, 0.2430, 0.2070])
torch.special.i1(input, *, out=None) → Tensor
计算input
中每个元素的一阶第一类修正贝塞尔函数(定义如下)。
$$\text{out}_{i} = \exp(-|x|) * i1(x) =
\exp(-|x|) * \frac{(\text{input}{i})}{2} * \sum{k=0}^{\infty} \frac{(\text{input}_{i}2/4)k}{(k!) * (k+1)!}
$$
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , optional)
– 输出张量。
示例
>>> torch.special.i1(torch.arange(5, dtype=torch.float32))
tensor([0.0000, 0.5652, 1.5906, 3.9534, 9.7595])
torch.special.i1e(input, *, out=None) → Tensor
计算input
中每个元素的指数缩放一阶第一类修正贝塞尔函数(定义如下):
outi=exp(−∣x∣)∗i1(x)=exp(−∣x∣)∗(inputi)2∗∑k=0∞(inputi2/4)k(k!)∗(k+1)!\text{out}_{i} = \exp(-|x|) * i1(x) =
\exp(-|x|) * \frac{(\text{input}{i})}{2} * \sum{k=0}^{\infty} \frac{(\text{input}_{i}2/4)k}{(k!) * (k+1)!}
outi=exp(−∣x∣)∗i1(x)=exp(−∣x∣)∗2(inputi)∗k=0∑∞(k!)∗(k+1)!(inputi2/4)k
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , optional)
– 输出张量。
示例:
>>> torch.special.i1e(torch.arange(5, dtype=torch.float32))
tensor([0.0000, 0.2079, 0.2153, 0.1968, 0.1788])
torch.special.log1p(input, *, out=None) → Tensor
torch.log1p()
的别名。
torch.special.log_ndtr(input, *, out=None) → Tensor
计算标准高斯概率密度函数从负无穷到input
的逐元素积分对数。
$$\text{log_softmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
$$
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , 可选)
– 输出张量。
示例
>>> torch.special.log_ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3]))
tensor([-6.6077 -3.7832 -1.841 -0.6931 -0.1728 -0.023 -0.0014])
torch.special.log_softmax(input, dim, *, dtype=None) → Tensor
计算经过对数处理的softmax结果。
虽然在数学上等价于log(softmax(x)),但分开执行这两个操作会更慢且数值不稳定。该函数的计算方式如下:
log_softmax(xi)=log(exp(xi)∑jexp(xj))\text{log_softmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
log_softmax(xi)=log(∑jexp(xj)exp(xi))
参数说明
input ( Tensor )
– 输入张量dim ( int )
– 指定计算log_softmax的维度dtype (
torch.dtype, optional)
– 返回张量的期望数据类型
如果指定该参数,在执行操作前会将输入张量转换为dtype
类型。这有助于防止数据类型溢出。默认值:None。
使用示例:
>>> t = torch.ones(2, 2)
>>> torch.special.log_softmax(t, 0)
tensor([[-0.6931, -0.6931], [-0.6931, -0.6931]])
torch.special.logit(input, eps=None, *, out=None) → Tensor
返回一个包含input
元素logit值的新张量。
当eps不为None时,input
会被截断到[eps, 1 - eps]区间。
当eps为None且input
< 0或input
> 1时,函数将返回NaN。
y i = ln ( z i 1 − z i ) z i = { x i if eps is None eps if x i < eps x i if eps ≤ x i ≤ 1 − eps 1 − eps if x i > 1 − eps \begin{align} y_{i} &= \ln(\frac{z_{i}}{1 - z_{i}}) \\ z_{i} &= \begin{cases} x_{i} & \text{if eps is None} \\ \text{eps} & \text{if } x_{i} < \text{eps} \\ x_{i} & \text{if } \text{eps} \leq x_{i} \leq 1 - \text{eps} \\ 1 - \text{eps} & \text{if } x_{i} > 1 - \text{eps} \end{cases} \end{align} yizi=ln(1−zizi)=⎩ ⎨ ⎧xiepsxi1−epsif eps is Noneif xi<epsif eps≤xi≤1−epsif xi>1−eps
参数
input ( Tensor )
- 输入张量。eps (float, 可选)
- 用于输入截断的epsilon值。默认值:None
关键字参数
out ( Tensor , 可选)
- 输出张量。
示例:
>>> a = torch.rand(5)
>>> a tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516])
>>> torch.special.logit(a, eps=1e-6)
tensor([-0.9466, 2.6352, 0.6131, -1.7169, 0.6261])
torch.special.logsumexp(input, dim, keepdim=False, *, out=None)
torch.logsumexp()
的别名。
torch.special.multigammaln(input, p, *, out=None) → Tensor
计算给定维度 p 的多元对数伽玛函数,按元素逐个计算,公式如下:
$$\log(\Gamma_{p}(a)) = C + \displaystyle \sum_{i=1}^{p} \log\left(\Gamma\left(a - \frac{i - 1}{2}\right)\right)
$$
其中 C = log ( π ) ⋅ p ( p − 1 ) 4 C = \log(\pi) \cdot \frac{p (p - 1)}{4} C=log(π)⋅4p(p−1) 为伽玛函数。
所有元素必须大于 p − 1 2 \frac{p - 1}{2} 2p−1,否则行为未定义。
参数
input ( Tensor )
- 用于计算多元对数伽玛函数的张量p ( int )
- 维度数量
关键字参数
out ( Tensor , optional)
- 输出张量
示例
>>> a = torch.empty(2, 3).uniform_(1, 2)
>>> a tensor([[1.6835, 1.8474, 1.1929], [1.0475, 1.7162, 1.4180]])
>>> torch.special.multigammaln(a, 2)
tensor([[0.3928, 0.4007, 0.7586], [1.0311, 0.3901, 0.5049]])
torch.special.ndtr(input, *, out=None) → Tensor
计算标准高斯概率密度函数从负无穷到输入值input
的逐元素积分面积。
ndtr ( x ) = 1 2 π ∫ − ∞ x e − 1 2 t 2 d t \text{ndtr}(x) = \frac{1}{\sqrt{2 \pi}}\int_{-\infty}^{x} e^{-\frac{1}{2}t^2} dt ndtr(x)=2π1∫−∞xe−21t2dt
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , optional)
– 输出张量。
示例
>>> torch.special.ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3]))
tensor([0.0013, 0.0228, 0.1587, 0.5000, 0.8413, 0.9772, 0.9987])
torch.special.ndtri(input, *, out=None) → Tensor
计算高斯概率密度函数下(从负无穷积分到x)面积等于input
各元素值的对应参数x。
ndtri ( p ) = 2 erf − 1 ( 2 p − 1 ) \text{ndtri}(p) = \sqrt{2}\text{erf}^{-1}(2p - 1) ndtri(p)=2erf−1(2p−1)
注意:也称为正态分布的分位数函数。
参数说明:
input ( Tensor )
- 输入张量
关键字参数:
out ( Tensor , optional)
- 输出张量
示例:
>>> torch.special.ndtri(torch.tensor([0, 0.25, 0.5, 0.75, 1]))
tensor([ -inf, -0.6745, 0.0000, 0.6745, inf])
torch.special.polygamma(n, input, *, out=None) → Tensor
计算输入张量 input
的 digamma 函数的 n t h n^{th} nth 阶导数。
其中 n ≥ 0 n \geq 0 n≥0 称为多伽马函数的阶数。
$$\psi^{(n)}(x) = \frac{d{(n)}}{dx{(n)}} \psi(x)
$$
注意:此函数仅针对非负整数 n≥0 实现。
参数说明
n ( int )
– 多伽马函数的阶数input ( Tensor )
– 输入张量
关键字参数
out ( Tensor , optional)
– 输出张量
使用示例:
>>> a = torch.tensor([1, 0.5])
>>> torch.special.polygamma(1, a)
tensor([1.64493, 4.9348])
>>> torch.special.polygamma(2, a)
tensor([-2.4041, -16.8288])
>>> torch.special.polygamma(3, a)
tensor([6.4939, 97.4091])
>>> torch.special.polygamma(4, a)
tensor([-24.8863, -771.4742])
torch.special.psi(input, *, out=None) → Tensor
torch.special.round(input, *, out=None) → Tensor
torch.round()
的别名。
torch.special.scaled_modified_bessel_k0(input, *, out=None) → Tensor
二阶修正贝塞尔函数(阶数为0)。
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , 可选)
– 输出张量。
torch.special.scaled_modified_bessel_k1(input, *, out=None) → Tensor
第二类111阶缩放修正贝塞尔函数。
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , 可选 )
– 输出张量。
torch.special.sinc(input, *, out=None) → Tensor
计算 input
的归一化 sinc 函数值。
$$\text{out}_{i} =
\begin{cases}
1, & \text{if}\ \text{input}_{i}=0 \
\sin(\pi \text{input}{i}) / (\pi \text{input}{i}), & \text{otherwise}
\end{cases}
$$
参数
input ( Tensor )
- 输入张量。
关键字参数
out ( Tensor , optional)
- 输出张量。
示例
>>> t = torch.randn(4)
>>> t
tensor([0.2252, -0.2948, 1.0267, -1.1566])
>>> torch.special.sinc(t)
tensor([0.9186, 0.8631, -0.0259, -0.1300])
torch.special.softmax(input, dim, *, dtype=None) → Tensor
计算softmax函数。
Softmax的定义如下:
Softmax ( x i ) = exp ( x i ) ∑ j exp ( x j ) \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} Softmax(xi)=∑jexp(xj)exp(xi)
该函数会沿着指定维度dim对所有切片进行计算,并将结果重新缩放,使元素值落在[0, 1]区间且总和为1。
参数
input ( Tensor )
– 输入张量dim ( int )
– 指定计算softmax的维度dtype (
torch.dtype, 可选)
– 返回张量的期望数据类型
若指定该参数,在执行操作前会将输入张量转换为dtype
类型。这有助于防止数据类型溢出。默认值:None。
示例::
>>> t = torch.ones(2, 2)
>>> torch.special.softmax(t, 0)
tensor([[0.5000, 0.5000], [0.5000, 0.5000]])
torch.special.spherical_bessel_j0(input, *, out=None) → Tensor
一阶球面贝塞尔函数(阶数为000)。
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , optional)
– 输出张量。
torch.special.xlog1py(input, other, *, out=None) → Tensor
计算 input * log1p(other)
,具体分为以下几种情况:
out i = { NaN if other i = NaN 0 if input i = 0.0 and other i ! = NaN input i ∗ log1p ( other i ) otherwise \text{out}_{i} = \begin{cases} \text{NaN} & \text{if } \text{other}_{i} = \text{NaN} \\ 0 & \text{if } \text{input}_{i} = 0.0 \text{ and } \text{other}_{i} != \text{NaN} \\ \text{input}_{i} * \text{log1p}(\text{other}_{i}) & \text{otherwise} \end{cases} outi=⎩ ⎨ ⎧NaN0inputi∗log1p(otheri)if otheri=NaNif inputi=0.0 and otheri!=NaNotherwise
与 SciPy 的 scipy.special.xlog1py
功能类似。
参数
input (Number* 或 Tensor)
– 乘数other (Number* 或 Tensor)
– 参数
注意:input
和 other
中至少有一个必须是张量。
关键字参数
out (Tensor, 可选)
– 输出张量。
示例:
>>> x = torch.zeros(5,)
>>> y = torch.tensor([-1, 0, 1, float('inf'), float('nan')])
>>> torch.special.xlog1py(x, y)
tensor([0., 0., 0., 0., nan])
>>> x = torch.tensor([1, 2, 3])
>>> y = torch.tensor([3, 2, 1])
>>> torch.special.xlog1py(x, y)
tensor([1.3863, 2.1972, 2.0794])
>>> torch.special.xlog1py(x, 4)
tensor([1.6094, 3.2189, 4.8283])
>>> torch.special.xlog1py(2, y)
tensor([2.7726, 2.1972, 1.3863])
torch.special.xlogy(input, other, *, out=None) → Tensor
计算 input * log(other)
,具体有以下几种情况:
outi={NaN若 otheri=NaN0若 inputi=0.0inputi∗log(otheri)其他情况\text{out}_{i} = \begin{cases}
\text{NaN} & \text{if } \text{other}_{i} = \text{NaN} \\
0 & \text{if } \text{input}_{i} = 0.0 \\
\text{input}{i} * \log{(\text{other}{i})} & \text{otherwise}
\end{cases}
outi=⎩⎨⎧NaN0inputi∗log(otheri)若 otheri=NaN若 inputi=0.0其他情况类似于 SciPy 的 scipy.special.xlogy 函数。
参数
input (Number* 或 Tensor)
– 乘数other (Number* 或 Tensor)
– 参数
注意:input
和 other
中至少有一个必须是张量。
关键字参数
out (Tensor, 可选)
– 输出张量。
示例:
>>> x = torch.zeros(5,)
>>> y = torch.tensor([-1, 0, 1, float('inf'), float('nan')])
>>> torch.special.xlogy(x, y)
tensor([0., 0., 0., 0., nan])
>>> x = torch.tensor([1, 2, 3])
>>> y = torch.tensor([3, 2, 1])
>>> torch.special.xlogy(x, y)
tensor([1.0986, 1.3863, 0.0000])
>>> torch.special.xlogy(x, 4)
tensor([1.3863, 2.7726, 4.1589])
>>> torch.special.xlogy(2, y)
tensor([2.1972, 1.3863, 0.0000])
torch.special.zeta(input, other, *, out=None) → Tensor
逐元素计算 Hurwitz zeta 函数。
$$\zeta(x, q) = \sum_{k=0}^{\infty} \frac{1}{(k + q)^x}
$$
参数
input ( Tensor )
– 对应 x 的输入张量。other ( Tensor )
– 对应 q 的输入张量。
注意:当 q = 1 时即为黎曼 zeta 函数。
关键字参数
out ( Tensor , optional)
– 输出张量。
示例:
>>> x = torch.tensor([2., 4.])
>>> torch.special.zeta(x, 1)
tensor([1.6449, 1.0823])
>>> torch.special.zeta(x, torch.tensor([1., 2.]))
tensor([1.6449, 0.0823])
>>> torch.special.zeta(2, torch.tensor([1., 2.]))
tensor([1.6449, 0.6449])
torch.overrides
该模块提供了多种辅助函数,用于支持__torch_function__
协议。有关__torch_function__
协议的更多详细信息,请参阅扩展torch Python API。
函数
torch.overrides.get_ignored_functions()
返回无法被__torch_function__
覆盖的公共函数。
返回值:一个包含torch API中公开但无法通过__torch_function__
覆盖的函数的元组。这主要是因为这些函数的参数都不是张量或类张量对象。
返回类型:set[Callable]
示例
>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions()
True
>>> torch.add in torch.overrides.get_ignored_functions()
False
torch.overrides.get_overridable_functions()
可通过 __torch_function__
重写的函数列表
返回值:一个字典,将包含可重写函数的命名空间映射到该命名空间中可被重写的函数。
返回类型:Dict[Any, List[Callable]]
torch.overrides.resolve_name(f)
获取传递给__torch_function__
的函数的人类可读字符串名称
参数
f (Callable)
– 需要解析名称的函数。
返回值:该函数的名称;如果对其进行求值,应能返回输入函数。
返回类型:str
torch.overrides.get_testing_overrides()
返回一个包含所有可覆盖函数的虚拟重载的字典
返回值:一个字典,将 PyTorch API 中的可覆盖函数映射到具有相同签名的 lambda 函数,这些 lambda 函数无条件返回 -1。这些 lambda 函数对于测试定义了 __torch_function__
类型的 API 覆盖率非常有用。
返回类型:Dict[Callable, Callable]
示例
>>> import inspect
>>> my_add = torch.overrides.get_testing_overrides()[torch.add]
>>> inspect.signature(my_add)
<Signature (input, other, out=None)>
torch.overrides.handle_torch_function(public_api, relevant_args, *args, **kwargs)
实现一个检查__torch_function__
重载的函数。
在C++实现中,与此函数等效的是torch::autograd::handle_torch_function。
参数
public_api (function)
- 最初以public_api(args, *kwargs)
形式调用的公开torch API函数,现在正在检查其参数。relevant_args (iterable)
- 需要检查__torch_function__
方法的参数迭代器。args (tuple)
- 最初传入public_api
的任意位置参数。kwargs (tuple)
- 最初传入public_api
的任意关键字参数。
返回
根据情况返回调用implementation
或__torch_function__
方法的结果。
返回类型
object
:raises TypeError: 如果找不到实现。
示例
>>> def func(a):
... if has_torch_function_unary(a):
... return handle_torch_function(func, (a,), a)
... return a + 0
torch.overrides.has_torch_function()
检查可迭代对象中的元素是否实现了__torch_function__
,或者是否启用了__torch_function__
模式。注意:精确的Tensor
和Parameter
被视为不可调度类型。此方法用于保护对handle_torch_function()
的调用,不要用它来检测对象是否类似Tensor——请改用is_tensor_like()
。
:param relevant_args: 需要检查__torch_function__
方法的可迭代对象或参数
:type relevant_args: iterable
返回值
如果relevant_args中任何元素实现了__torch_function__
则返回True,否则返回False。
返回类型 : bool
另请参阅
torch.is_tensor_like
检测对象是否为Tensor-like(包括精确的Tensor
)
torch.overrides.is_tensor_like(inp)
如果传入的输入是类张量(Tensor-like)对象,则返回True
。
当前实现中,只要输入对象的类型具有__torch_function__
属性即视为类张量。
示例:
张量的子类通常属于类张量对象。
>>> class SubTensor(torch.Tensor): ...
>>> is_tensor_like(SubTensor([0]))
True
内置类型或用户自定义类型通常不具备 Tensor 的特性。
>>> is_tensor_like(6)
False
>>> is_tensor_like(None)
False
>>> class NotATensor: ...
>>> is_tensor_like(NotATensor())
False
但是,可以通过实现 __torch_function__
使它们具备类似张量的特性。
>>> class TensorLike:
... @classmethod
... def __torch_function__(cls, func, types, args, kwargs):
... return -1
>>> is_tensor_like(TensorLike())
True
torch.overrides.is_tensor_method_or_property(func)
如果传入的函数是 torch.Tensor
方法或属性的处理程序(如传入 __torch_function__
时),则返回 True。
注意:对于属性,必须传入其 __get__
方法。
这在以下情况下尤其需要:
1、方法/属性有时不包含 module 槽位
2、它们要求第一个传入参数必须是 torch.Tensor
的实例
示例
>>> is_tensor_method_or_property(torch.Tensor.add)
True
>>> is_tensor_method_or_property(torch.add)
False
返回类型:bool
torch.overrides.wrap_torch_function(dispatcher)
用__torch_function__
相关功能包装给定的函数。
参数
dispatcher (Callable)
– 一个可调用对象,返回传入函数中的类Tensor对象的可迭代集合。
注意:此装饰器可能会降低代码性能。通常,将代码表达为一系列自身支持__torch_function__
的函数就足够了。如果您遇到罕见情况(例如在封装底层库时也需要使其支持类Tensor对象),则可以使用此函数。
示例
>>> def dispatcher(a): # Must have the same signature as func
... return (a,)
>>> @torch.overrides.wrap_torch_function(dispatcher)
>>> def func(a): # This will make func dispatchable by __torch_function__
... return a + 0
torch.package
torch.package
提供了创建包含工件和任意 PyTorch 代码的包的支持。这些包可以被保存、共享,用于在之后的时间或不同的机器上加载和执行模型,甚至可以使用 torch::deploy
部署到生产环境。
本文档包含教程、操作指南、说明和 API 参考,将帮助您了解更多关于 torch.package
的信息以及如何使用它。
警告:此模块依赖于不安全的 pickle
模块。仅解包您信任的数据。
恶意构造的 pickle 数据可能会在解包过程中执行任意代码。切勿解包可能来自不受信任来源或可能被篡改的数据。
更多信息,请查阅 pickle
模块的文档。
教程
打包你的第一个模型
我们提供了一个教程,引导你完成打包和解包一个简单模型的流程,该教程可在 Colab 上查看。完成这个练习后,你将熟悉创建和使用 Torch 包的基本 API。
如何实现…
查看包内包含哪些内容?
将包视为ZIP归档文件处理
torch.package
的容器格式采用ZIP标准,因此任何适用于标准ZIP文件的工具都能用于查看其内容。以下是操作ZIP文件的常用方法:
- 执行
unzip my_package.pt
命令可将torch.package
归档解压到磁盘,便于自由检查其内容。
$ unzip my_package.pt && tree my_package
my_package
├── .data
│ ├── 94304870911616.storage
│ ├── 94304900784016.storage
│ ├── extern_modules
│ └── version
├── models
│ └── model_1.pkl
└── torchvision
└── models
├── resnet.py
└── utils.py
~ cd my_package && cat torchvision/models/resnet.py
...
- Python的
zipfile
模块提供了读写ZIP归档文件内容的标准方法。
from zipfile import ZipFile with ZipFile("my_package.pt") as myzip:
file_bytes = myzip.read("torchvision/models/resnet.py")
# edit file_bytes in some way
myzip.writestr("torchvision/models/resnet.py", new_file_bytes)
- Vim 原生支持读取 ZIP 压缩包。你甚至可以直接编辑文件并通过
:write
命令将修改写回压缩包!
# add this to your .vimrc to treat `*.pt` files as zip files
au BufReadCmd *.pt call zip#Browse(expand("<amatch>"))
~ vi my_package.pt
使用 file_structure()
API
PackageImporter
提供了一个 file_structure()
方法,该方法会返回一个可打印且可查询的 Directory
对象。Directory
对象是一个简单的目录结构,可用于查看 torch.package
的当前内容。
Directory
对象本身可以直接打印,并会输出文件树的表示形式。如需过滤返回的内容,可使用 glob 风格的 include
和 exclude
过滤参数。
with PackageExporter('my_package.pt') as pe:
pe.save_pickle('models', 'model_1.pkl', mod)
importer = PackageImporter('my_package.pt')
# can limit printed items with include/exclude args
print(importer.file_structure(include=["**/utils.py", "**/*.pkl"], exclude="**/*.storage"))
print(importer.file_structure()) # will print out all files
Output:
# filtered with glob pattern:
# include=["**/utils.py", "**/*.pkl"], exclude="**/*.storage"
─── my_package.pt
├── models
│ └── model_1.pkl
└── torchvision
└── models
└── utils.py
# all files
─── my_package.pt
├── .data
│ ├── 94304870911616.storage
│ ├── 94304900784016.storage
│ ├── extern_modules
│ └── version
├── models
│ └── model_1.pkl
└── torchvision
└── models
├── resnet.py
└── utils.py
你也可以使用 has_file()
方法查询 Directory
对象。
importer_file_structure = importer.file_structure()
found: bool = importer_file_structure.has_file("package_a/subpackage.py")
查看某个模块为何被列为依赖项?
假设有一个模块 foo
,你想知道为什么 PackageExporter
会将其作为依赖项引入。
PackageExporter.get_rdeps()
方法会返回所有直接依赖 foo
的模块。
如果想查看特定模块 src
如何依赖 foo
,可以使用 PackageExporter.all_paths()
方法,该方法会返回一个 DOT 格式的图表,展示 src
和 foo
之间的所有依赖路径。
如果只想查看 PackageExporter
的完整依赖关系图,可以使用 PackageExporter.dependency_graph_string()
方法。
如何在打包时包含任意资源并后续访问?
PackageExporter
提供了三个方法:save_pickle
、save_text
和 save_binary
,允许你将 Python 对象、文本和二进制数据保存到包中。
with torch.PackageExporter("package.pt") as exporter:
# Pickles the object and saves to `my_resources/tensor.pkl` in the archive.
exporter.save_pickle("my_resources", "tensor.pkl", torch.randn(4))
exporter.save_text("config_stuff", "words.txt", "a sample string")
exporter.save_binary("raw_data", "binary", my_bytes)
PackageImporter
提供了三个互补方法:load_pickle
、load_text
和 load_binary
,用于从包中加载 Python 对象、文本数据和二进制数据。
importer = torch.PackageImporter("package.pt")
my_tensor = importer.load_pickle("my_resources", "tensor.pkl")
text = importer.load_text("config_stuff", "words.txt")
binary = importer.load_binary("raw_data", "binary")
自定义类的打包方式
torch.package
允许自定义类的打包方式。这一行为通过以下两种方式实现:在类上定义方法 __reduce_package__
,并定义对应的解包函数。这与为 Python 常规的 pickle 过程定义 __reduce__
类似。
操作步骤:
1、在目标类上定义方法 __reduce_package__(self, exporter: PackageExporter)
。该方法负责将类实例保存到包中,并应返回一个元组,包含对应的解包函数及调用该函数所需的参数。当 PackageExporter
遇到目标类的实例时,会调用此方法。
2、为类定义一个解包函数。该解包函数负责重建并返回类的实例。其函数签名的第一个参数应为 PackageImporter
实例,其余参数由用户自定义。
# foo.py [Example of customizing how class Foo is packaged]
from torch.package import PackageExporter, PackageImporter
import time
class Foo:
def __init__(self, my_string: str):
super().__init__()
self.my_string = my_string
self.time_imported = 0
self.time_exported = 0
def __reduce_package__(self, exporter: PackageExporter):
"""
Called by ``torch.package.PackageExporter``'s Pickler's ``persistent_id`` when
saving an instance of this object. This method should do the work to save this object inside of the ``torch.package`` archive.
Returns function w/ arguments to load the object from a ``torch.package.PackageImporter``'s Pickler's ``persistent_load`` function.
"""
# use this pattern to ensure no naming conflicts with normal dependencies, # anything saved under this module name shouldn't conflict with other
# items in the package
generated_module_name = f"foo-generated._{exporter.get_unique_id()}"
exporter.save_text(
generated_module_name, "foo.txt", self.my_string + ", with exporter modification!", )
time_exported = time.clock_gettime(1)
# returns de-packaging function w/ arguments to invoke with return (unpackage_foo, (generated_module_name, time_exported,))
def unpackage_foo(
importer: PackageImporter, generated_module_name: str, time_exported: float
) -Foo:
"""
Called by ``torch.package.PackageImporter``'s Pickler's ``persistent_load`` function
when depickling a Foo object.
Performs work of loading and returning a Foo instance from a ``torch.package`` archive.
"""
time_imported = time.clock_gettime(1)
foo = Foo(importer.load_text(generated_module_name, "foo.txt"))
foo.time_imported = time_imported
foo.time_exported = time_exported
return foo
# example of saving instances of class Foo
import torch
from torch.package import PackageImporter, PackageExporter
import foo
foo_1 = foo.Foo("foo_1 initial string")
foo_2 = foo.Foo("foo_2 initial string") with PackageExporter('foo_package.pt') as pe:
# save as normal, no extra work necessary
pe.save_pickle('foo_collection', 'foo1.pkl', foo_1)
pe.save_pickle('foo_collection', 'foo2.pkl', foo_2)
pi = PackageImporter('foo_package.pt')
print(pi.file_structure())
imported_foo = pi.load_pickle('foo_collection', 'foo1.pkl')
print(f"foo_1 string: '{imported_foo.my_string}'")
print(f"foo_1 export time: {imported_foo.time_exported}")
print(f"foo_1 import time: {imported_foo.time_imported}")
# output of running above script
─── foo_package
├── foo-generated
│ ├── _0
│ │ └── foo.txt
│ └── _1
│ └── foo.txt
├── foo_collection
│ ├── foo1.pkl
│ └── foo2.pkl
└── foo.py
foo_1 string: 'foo_1 initial string, with reduction modification!'
foo_1 export time: 9857706.650140837
foo_1 import time: 9857706.652698385
如何在源码中检测当前是否运行在包环境中?
PackageImporter
会在初始化每个模块时为其添加 __torch_package__
属性。你的代码可以通过检查该属性是否存在,来判断当前是否处于打包后的运行环境中。
# In foo/bar.py:
if "__torch_package__" in dir(): # true if the code is being loaded from a package
def is_in_package():
return True
UserException = Exception
else:
def is_in_package():
return False
UserException = UnpackageableException
现在,代码的行为会根据它是通过Python环境正常导入还是从torch.package
导入而有所不同。
from foo.bar import is_in_package
print(is_in_package()) # False
loaded_module = PackageImporter(my_package).import_module("foo.bar")
loaded_module.is_in_package() # True
警告:通常情况下,让代码在打包前后表现不一致是一种不良实践。这会导致难以调试的问题,且问题表现会因代码导入方式的不同而敏感变化。如果你的包预计会被频繁使用,建议重构代码,确保无论以何种方式加载,其行为都保持一致。
如何将代码补丁打入包中?
PackageExporter
提供了 save_source_string()
方法,允许你将任意 Python 源代码保存到指定的模块中。
with PackageExporter(f) as exporter:
# Save the my_module.foo available in your current Python environment.
exporter.save_module("my_module.foo")
# This saves the provided string to my_module/foo.py in the package archive.
# It will override the my_module.foo that was previously saved.
exporter.save_source_string("my_module.foo", textwrap.dedent(
"""\
def my_function():
print('hello world')
"""
))
# If you want to treat my_module.bar as a package
# (e.g. save to `my_module/bar/__init__.py` instead of `my_module/bar.py)
# pass is_package=True, exporter.save_source_string("my_module.bar", "def foo(): print('hello')\n", is_package=True)
importer = PackageImporter(f)
importer.import_module("my_module.foo").my_function() # prints 'hello world'
如何从打包代码中访问包内容?
PackageImporter
实现了 importlib.resources API,用于从包内部访问资源。
with PackageExporter(f) as exporter:
# saves text to my_resource/a.txt in the archive
exporter.save_text("my_resource", "a.txt", "hello world!")
# saves the tensor to my_pickle/obj.pkl
exporter.save_pickle("my_pickle", "obj.pkl", torch.ones(2, 2))
# see below for module contents
exporter.save_module("foo")
exporter.save_module("bar")
importlib.resources
API 允许从打包代码中访问资源。
# foo.py:
import importlib.resources
import my_resource
# returns "hello world!"
def get_my_resource():
return importlib.resources.read_text(my_resource, "a.txt")
推荐使用 importlib.resources
来访问打包代码中的包内容,因为它符合 Python 标准。不过,也可以直接从打包代码中访问父级 PackageImporter
实例本身。
# bar.py:
import torch_package_importer # this is the PackageImporter that imported this module.
# Prints "hello world!", equivalent to importlib.resources.read_text
def get_my_resource():
return torch_package_importer.load_text("my_resource", "a.txt")
# You also do things that the importlib.resources API does not support, like loading
# a pickled object from the package.
def get_my_pickle():
return torch_package_importer.load_pickle("my_pickle", "obj.pkl")
区分打包代码与非打包代码
要判断一个对象的代码是否来自 torch.package
,可使用 torch.package.is_from_package()
函数。
注意:若对象来自某个包但其定义源自标记为 extern
的模块或来自 stdlib
,此检查将返回 False
。
importer = PackageImporter(f)
mod = importer.import_module('foo')
obj = importer.load_pickle('model', 'model.pkl')
txt = importer.load_text('text', 'my_test.txt')
assert is_from_package(mod)
assert is_from_package(obj)
assert not is_from_package(txt) # str is from stdlib, so this will return False
如何重新导出已导入的对象?
要通过新的 PackageExporter
重新导出之前由 PackageImporter
导入的对象,必须让导出器知晓原始导入器的存在,这样才能正确找到对象依赖项的源代码。
importer = PackageImporter(f)
obj = importer.load_pickle("model", "model.pkl")
# re-export obj in a new package with PackageExporter(f2, importer=(importer, sys_importer)) as exporter:
exporter.save_pickle("model", "model.pkl", obj)
如何打包 TorchScript 模块?
要打包 TorchScript 模型,可以使用与其他对象相同的 save_pickle
和 load_pickle
API。TorchScript 对象作为属性或子模块时也支持直接保存,无需额外操作。
# save TorchScript just like any other object with PackageExporter(file_name) as e:
e.save_pickle("res", "script_model.pkl", scripted_model)
e.save_pickle("res", "mixed_model.pkl", python_model_with_scripted_submodule)
# load as normal
importer = PackageImporter(file_name)
loaded_script = importer.load_pickle("res", "script_model.pkl")
loaded_mixed = importer.load_pickle("res", "mixed_model.pkl"
说明
torch.package
格式概述
torch.package
文件是一个 ZIP 归档文件,通常使用 .pt
扩展名。ZIP 归档内包含两类文件:
- 框架文件:存放在
.data/
目录下 - 用户文件:其余所有文件
例如,以下是一个完整打包的 torchvision
ResNet 模型的结构示例:
resnet
├── .data # All framework-specific data is stored here.
│ │ # It's named to avoid conflicts with user-serialized code.
│ ├── 94286146172688.storage # tensor data
│ ├── 94286146172784.storage
│ ├── extern_modules # text file with names of extern modules (e.g. 'torch')
│ ├── version # version metadata
│ ├── ...
├── model # the pickled model
│ └── model.pkl
└── torchvision # all code dependencies are captured as source files
└── models
├── resnet.py
└── utils.py
框架文件
.data/
目录由 torch.package 所有,其内容被视为私有实现细节。torch.package
格式不保证 .data/
目录内容的具体结构,但所有改动都将保持向后兼容性(即新版本的 PyTorch 始终能够加载旧版 torch.package
)。
目前,.data/
目录包含以下内容:
version
:序列化格式的版本号,用于让torch.package
的导入基础设施知道如何加载该包。extern_modules
:被视为extern
的模块列表。extern
模块将使用加载环境的系统导入器进行导入。*.storage
:序列化的张量数据。
.data
├── 94286146172688.storage
├── 94286146172784.storage
├── extern_modules
├── version
├── ...
用户文件
归档中的所有其他文件都是由用户放置的。其目录结构与 Python 的常规包完全一致。若想深入了解 Python 的打包机制,请参阅这篇文章(内容稍有过时,具体实现细节请以 Python 参考文档为准)。
<package root>
├── model # the pickled model
│ └── model.pkl
├── another_package
│ ├── __init__.py
│ ├── foo.txt # a resource file , see importlib.resources
│ └── ...
└── torchvision
└── models
├── resnet.py # torchvision.models.resnet
└── utils.py # torchvision.models.utils
torch.package
如何查找代码依赖项
分析对象的依赖关系
当你调用 save_pickle(obj, ...)
时,PackageExporter
会正常地对对象进行 pickle 序列化。随后,它会使用标准库模块 pickletools
来解析 pickle 字节码。
在 pickle 序列化过程中,对象会与一个 GLOBAL
操作码一起保存,该操作码描述了如何找到对象类型的实现位置,例如:
GLOBAL 'torchvision.models.resnet Resnet`
依赖解析器会收集所有GLOBAL
操作,并将它们标记为待序列化对象的依赖项。
有关序列化及pickle格式的更多信息,请参阅Python官方文档。
分析模块依赖关系
当识别出某个Python模块作为依赖项时,torch.package
会遍历该模块的Python抽象语法树(AST)表示,并查找其中的导入语句。它完全支持标准导入形式:from x import y
、import z
、from w import v as u
等。当遇到这些导入语句时,torch.package
会将导入的模块注册为依赖项,随后以同样的AST遍历方式解析这些依赖模块。
注意:AST解析对__import__(...)
语法的支持有限,且不支持importlib.import_module
调用。通常不应期望torch.package
能检测到动态导入。
依赖管理
torch.package
会自动发现你的代码和对象所依赖的 Python 模块。这一过程称为依赖解析。
对于依赖解析器找到的每个模块,你必须指定一个操作来处理它。
允许的操作包括:
intern
:将该模块打包到包中。extern
:声明该模块为包的外部依赖项。mock
:将该模块替换为存根。deny
:依赖此模块会在包导出时引发错误。
此外,还有一个重要的操作虽然技术上不属于 torch.package
的一部分:
- 重构:移除或更改代码中的依赖项。
注意,操作仅针对整个 Python 模块定义,无法仅打包模块中的某个函数或类而忽略其余部分。
这是有意为之的设计。Python 并未提供模块内对象之间的清晰边界,模块是依赖组织的唯一明确定义单元,因此 torch.package
也采用这一标准。
操作通过模式应用于模块。模式可以是模块名称(如 "foo.bar"
)或通配符(如 "foo.**"
)。你可以使用 PackageExporter
上的方法将模式与操作关联,例如:
my_exporter.intern("torchvision.**")
my_exporter.extern("numpy")
如果模块匹配某个模式,就会对其应用相应的操作。对于给定的模块,系统会按照模式定义的顺序依次检查,并执行第一个匹配的操作。
intern
如果一个模块被标记为intern
,它将被放入包中。
此操作适用于你的模型代码,或任何你想打包的相关代码。例如,如果你要打包torchvision
中的ResNet模型,就需要将模块torchvision.models.resnet
标记为intern
。
当导入包时,如果打包的代码尝试导入一个被标记为intern
的模块,PackageImporter
会在你的包内查找该模块。如果找不到该模块,则会抛出错误。这确保了每个PackageImporter
与加载环境隔离——即使my_interned_module
同时在你的包和加载环境中可用,PackageImporter
也只会使用包内的版本。
注意:只有Python源码模块可以被标记为intern
。其他类型的模块(如C扩展模块和字节码模块)如果尝试标记为intern
会抛出错误。这类模块需要被标记为mock
或extern
。
extern
如果一个模块被声明为extern
,它将不会被包含在包中。相反,该模块会被添加到当前包的外部依赖列表中。你可以在package_exporter.extern_modules
中找到这个列表。
当导入包时,如果打包后的代码尝试导入一个被声明为extern
的模块,PackageImporter
会使用默认的Python导入器来查找该模块,就像执行了importlib.import_module("my_externed_module")
一样。如果找不到该模块,则会抛出错误。
通过这种方式,你可以在包中依赖第三方库(如numpy
和scipy
),而无需将它们一并打包。
警告:如果任何外部库发生了不向后兼容的更改,你的包可能无法加载。如果需要长期保证包的复现性,请尽量减少使用extern
。
mock
当一个模块被mock
时,它不会被打包。取而代之的是一个存根模块会被打包到该位置。这个存根模块允许你从中获取对象(因此from my_mocked_module import foo
不会报错),但任何使用该对象的操作都会引发NotImplementedError
。
mock
应该用于那些你"确定"在加载的包中不需要,但仍希望在非打包内容中可用的代码。例如初始化/配置代码,或仅用于调试/训练的代码。
警告:一般来说,mock
应该作为最后手段使用。它会导致打包代码和非打包代码之间的行为差异,可能引发后续混淆。建议优先通过重构代码来移除不需要的依赖项。
代码重构
管理依赖关系的最佳方式就是彻底消除依赖!通过重构代码,我们往往可以移除不必要的依赖。以下是编写低依赖代码的指导原则(这些原则本身也是优秀的编码实践):
只引入真正用到的内容。不要在代码中保留未使用的导入项。依赖解析器无法智能识别这些未使用的导入,仍会尝试处理它们。
精确限定导入范围。例如,与其写import foo
然后在代码中使用foo.bar.baz
,不如直接写from foo.bar import baz
。这种方式能更精准地声明实际依赖(foo.bar
),让依赖解析器明确你不需要引入整个foo
包。
将包含无关功能的大文件拆分为小模块。如果你的utils
模块混杂了大量无关功能,任何依赖该模块的代码都不得不引入许多无关依赖——即使你只需要其中一小部分功能。更好的做法是创建功能单一的小模块,这些模块可以彼此独立地打包。
模式
模式允许您通过简洁的语法来指定模块组。其语法和行为遵循 Bazel/Buck 的 glob() 规范。
我们将尝试与模式匹配的模块称为候选对象。候选对象由通过分隔符字符串分隔的多个段组成,例如 foo.bar.baz
。
一个模式包含一个或多个段。段可以是以下类型:
- 字面字符串(如
foo
),表示精确匹配 - 包含通配符的字符串(如
torch
或foo*baz*
)。通配符可以匹配任意字符串,包括空字符串 - 双通配符(
**
)。这将匹配零个或多个完整段
示例:
torch.**
:匹配torch
及其所有子模块,例如torch.nn
和torch.nn.functional
torch.*
:匹配torch.nn
或torch.functional
,但不匹配torch.nn.functional
或torch
torch*.**
:匹配torch
、torchvision
及其所有子模块
在指定操作时,您可以传入多个模式,例如
exporter.intern(["torchvision.models.**", "torchvision.utils.**"])
模块将匹配此操作,只要符合其中任一模式。
您还可以指定要排除的模式,例如
exporter.mock("**", exclude=["torchvision.**"])
如果模块匹配任何排除模式,则该模块不会与此操作匹配。在本示例中,我们模拟了除 torchvision
及其子模块之外的所有模块。
当一个模块可能匹配多个操作时,将优先采用第一个定义的操作。
torch.package
的注意事项
避免在模块中使用全局状态
Python 可以非常方便地在模块作用域内绑定对象和运行代码。这通常没有问题——毕竟函数和类也是通过这种方式绑定到名称的。但当你定义一个模块作用域内的可变对象时,就会引入可变的全局状态,情况就变得复杂了。
可变全局状态非常有用——它可以减少样板代码、允许开放注册到表中等等。但除非非常谨慎地使用,否则在与 torch.package
一起使用时可能会导致问题。
每个 PackageImporter
都会为其内容创建一个独立的环境。这很好,因为它意味着我们可以加载多个包并确保它们彼此隔离。但当模块的编写方式假设存在共享的可变全局状态时,这种行为可能会导致难以调试的错误。
类型在包与加载环境之间不共享
通过 PackageImporter
导入的任何类,都将是该导入器特有的类版本。例如:
from foo import MyClass
my_class_instance = MyClass()
with PackageExporter(f) as exporter:
exporter.save_module("foo")
importer = PackageImporter(f)
imported_MyClass = importer.import_module("foo").MyClass
assert isinstance(my_class_instance, MyClass) # works
assert isinstance(my_class_instance, imported_MyClass) # ERROR!
在这个示例中,MyClass
和 imported_MyClass
不是同一类型。虽然在这个特定例子中,MyClass
和 imported_MyClass
的实现完全相同,你可能会认为它们可以视为同一个类。但设想一下,如果 imported_MyClass
来自一个旧版本的包,其中 MyClass
的实现完全不同——这种情况下,将它们视为同一个类是不安全的。
实际上,每个导入器都有一个唯一标识类的前缀:
print(MyClass.__name__) # prints "foo.MyClass"
print(imported_MyClass.__name__) # prints <torch_package_0>.foo.MyClass
这意味着当其中一个参数来自某个包而另一个不是时,你不应期望isinstance
检查能正常工作。如果需要此功能,请考虑以下选项:
- 采用鸭子类型(直接使用类而非显式检查其是否属于给定类型)。
- 将类型关系明确作为类契约的一部分。例如,可以添加属性标签
self.handler = "handle_me_this_way"
,并让客户端代码检查handler
的值而非直接检查类型。
torch.package
如何实现包之间的隔离
每个 PackageImporter
实例都会为其模块和对象创建一个独立的隔离环境。包中的模块只能导入其他已打包的模块或被标记为 extern
的模块。如果使用多个 PackageImporter
实例加载同一个包,将会得到多个互不干扰的独立环境。
这一机制是通过扩展 Python 的导入系统实现的,PackageImporter
是一个自定义导入器。它提供了与 importlib
导入器相同的核心 API,即实现了 import_module
和 __import__
方法。
当调用 PackageImporter.import_module()
时,PackageImporter
会像系统导入器一样构造并返回一个新模块。不同之处在于,它会修改返回的模块,使其在后续导入请求时使用当前 PackageImporter
实例(即从包内查找资源),而不是从用户的 Python 环境中搜索。
名称修饰(Mangling)
为了避免混淆(“这个 foo.bar
对象是来自我的包,还是来自 Python 环境?”),PackageImporter
会通过添加修饰前缀来修改所有导入模块的 __name__
和 __file__
属性。
对于 __name__
,像 torchvision.models.resnet18
这样的名称会被修饰为 <torch_package_0>.torchvision.models.resnet18
。
对于 __file__
,像 torchvision/models/resnet18.py
这样的路径会被修饰为 <torch_package_0>.torchvision/modules/resnet18.py
。
名称修饰有助于避免不同包之间模块名的意外冲突,并通过使堆栈跟踪和打印语句更清晰地显示它们是否引用打包代码来辅助调试。有关名称修饰的开发者详细信息,请参阅 torch/package/
目录下的 mangling.md
文件。
API 参考
class torch.package.PackagingError(dependency_graph, debug=False)
当导出包出现问题时,会引发此异常。
PackageExporter
会尝试收集所有错误并一次性展示给您。
class torch.package.EmptyMatchError
当模拟对象(mock)或外部依赖(extern)被标记为allow_empty=False
,但在打包过程中未匹配到任何模块时,会抛出此异常。
class torch.package.PackageExporter(f, importer=<torch.package.importer._SysImporter object>, debug=False)
导出器(Exporters)允许你将代码包、序列化的Python数据以及任意二进制和文本资源打包成自包含的独立包。
导入器(Imports)能够以封闭方式加载这些代码,确保代码从包内加载而非通过常规Python导入系统。这种机制使得PyTorch模型代码和数据可以被打包,以便在服务器上运行或用于未来的迁移学习。
包中的代码在创建时会从原始源文件逐份复制,其文件格式采用特殊组织的zip压缩包。包的使用者可以解压该包并编辑代码,以实现自定义修改。
包的导入器会确保模块代码只能从包内加载,除非模块通过extern()
明确声明为外部依赖。zip压缩包中的extern_modules
文件列出了该包所有外部依赖的模块。
这种机制避免了"隐式"依赖问题——即包在本地运行时能正常导入本地安装的依赖,但当包被复制到其他机器时就会运行失败。
当源代码被添加到包中时,导出器可选择性地扫描代码以发现更多依赖关系(通过设置dependencies=True
)。它会查找import语句,将相对引用解析为完整模块名,并执行用户指定的操作(参见:extern()
、mock()
和intern()
)。
__init__(f, importer=<torch.package.importer._SysImporter object>, debug=False)
创建一个导出器。
参数
f ( Union [str,* PathLike[str],* [IO](https://docs.python.org/3/library/typing.html#typing.IO "(in Python v3.13)")[bytes ]])
- 导出目标位置。可以是包含文件名的string
/Path
对象,也可以是二进制I/O对象。importer ( Union [Importer*,* Sequence [Importer]])
- 如果传入单个Importer,则使用它来搜索模块。如果传入一个importer序列,则会基于它们构建一个OrderedImporter
。debug ([bool])
- 如果设为True,会将损坏模块的路径添加到PackagingErrors中。
add_dependency(module_name, dependencies=True)
根据用户指定的模式,将给定模块添加到依赖关系图中。
all_paths(src, dst)
返回从源节点到目标节点所有路径的子图点表示形式。
返回值:包含从源节点到目标节点所有路径的点表示字符串。
参考文档:Graphviz语言规范
返回类型:str
close()
将包写入文件系统。调用close()
方法后,所有后续操作都将无效。
建议改用资源守卫语法:
with PackageExporter("file.zip") as e:
...
denied_modules()
返回当前被拒绝的所有模块。
返回值:一个包含该包中被拒绝模块名称的列表。
返回类型:list [str]
deny(include, *, exclude=())
根据给定的glob模式,从包可导入的模块列表中屏蔽匹配名称的模块。
如果发现任何匹配包的依赖项,将抛出 PackagingError
错误。
参数
include (Union[List[str],* str])
- 可以是字符串(例如"my_package.my_subpackage"
)或模块名称列表,用于指定需要外部化的模块。也支持glob风格的模式,详见mock()
。exclude (Union[List[str],* str])
- 可选参数,用于排除与include字符串匹配的部分模式。
dependency_graph_string()
返回包中依赖关系的有向图字符串表示形式。
返回值:包中依赖关系的字符串表示形式。
返回类型:str
extern(include, *, exclude=(), allow_empty=True)
将 module
包含在包可导入的外部模块列表中。
这将阻止依赖项发现机制将其保存到包中。导入器会直接从标准导入系统加载外部模块。
外部模块的代码也必须存在于加载包的进程中。
参数
include (Union[List[str],* str])
– 字符串(例如"my_package.my_subpackage"
)或待外部化的模块名称字符串列表。也可使用通配符模式,如mock()
所述。exclude (Union[List[str],* str])
– 可选参数,用于排除与 include 字符串匹配的部分模式。allow_empty ([bool])
– 可选标志,指定本次调用extern
方法所定义的外部模块是否必须在打包时匹配到某些模块。若以allow_empty=False
添加外部模块通配符模式,且在匹配到任何模块前调用close()
(显式调用或通过__exit__
调用),则会抛出异常。若allow_empty=True
则不会抛出此类异常。
(注:严格保留所有代码块、术语标记及链接格式,技术术语如extern
、mock()
等未作翻译,被动语态已转换为主动表达,长句进行了合理拆分)
externed_modules()
返回当前被外部化的所有模块。
返回值:一个包含该包中将被外部化的模块名称的列表。
返回类型:list [str]
get_rdeps(module_name)
返回所有依赖于模块 module_name
的模块列表。
返回值:一个包含依赖 module_name
的模块名称列表。
返回类型:list [str]
get_unique_id()
获取一个ID。该ID保证在此包中只会被分配一次。
返回类型:str
intern(include, *, exclude=(), allow_empty=True)
指定需要打包的模块。模块必须匹配某些intern
模式才能被包含在包中,并递归处理其依赖项。
参数
include (Union[List[str],* str])
– 可以是字符串(例如"my_package.my_subpackage")或模块名称列表,用于指定需要外部化的模块。该参数也支持glob风格的模式匹配,具体说明见mock()
文档。exclude (Union[List[str],* str])
– 可选参数,用于排除与include模式匹配的特定模式。allow_empty ([bool])
– 可选标志,指定通过intern
方法设置的内部模块在打包时是否必须匹配到某些模块。如果添加intern
模块glob模式时设置allow_empty=False
,且在调用close()
(显式调用或通过__exit__
)时没有任何模块匹配该模式,则会抛出异常。若设置allow_empty=True
,则不会抛出此类异常。
interned_modules()
返回当前被内部化的所有模块。
返回值:一个包含将被此包内部化的模块名称的列表。
返回类型:list [str]
mock(include, *, exclude=(), allow_empty=True)
用模拟实现替换某些必需的模块。被模拟的模块将返回一个虚假对象,用于处理任何从其访问的属性。由于我们采用逐文件复制的方式,依赖解析有时会找到被模型文件导入但实际功能从未使用的文件(例如自定义序列化代码或训练辅助工具)。
使用此函数可以模拟这些功能,而无需修改原始代码。
参数
include (Union[List[str],* str])
–
一个字符串(例如 "my_package.my_subpackage"
)或字符串列表,表示需要被模拟的模块名称。字符串也可以使用通配符模式,可能匹配多个模块。任何符合此模式字符串的必需依赖项都将被自动模拟。
示例:
'torch.**'
– 匹配 torch
及其所有子模块,例如 'torch.nn'
和 'torch.nn.functional'
'torch.*'
– 匹配 'torch.nn'
或 'torch.functional'
,但不匹配 'torch.nn.functional'
exclude (Union[List[str],* str])
– 一个可选模式,用于排除某些匹配include
字符串的模式。
例如,include='torch.**', exclude='torch.foo'
将模拟除'torch.foo'
之外的所有 torch 包。默认值为[]
。allow_empty ([bool])
– 一个可选标志,指定通过调用mock()
方法指定的模拟实现是否必须在打包过程中匹配到某些模块。如果以allow_empty=False
添加模拟,并且在调用close()
(显式调用或通过__exit__
)时该模拟未匹配到导出包所使用的模块,则会抛出异常。
如果 allow_empty=True
,则不会抛出此类异常。
mocked_modules()
返回当前被模拟的所有模块。
返回值:包含本包中将被模拟的模块名称列表。
返回类型:list [str]
register_extern_hook(hook)
在导出器上注册一个外部钩子。
每当有模块匹配 extern()
模式时,就会调用该钩子。
钩子函数需要遵循以下签名:
hook(exporter: PackageExporter, module_name: str) -None
钩子将按照注册顺序被调用。
返回值:一个句柄,可用于通过调用 handle.remove()
来移除已添加的钩子。
返回类型:torch.utils.hooks.RemovableHandle
register_intern_hook(hook)
在导出器上注册一个内部钩子。
每当模块匹配 intern()
模式时,都会调用该钩子。
它应具有以下签名:
hook(exporter: PackageExporter, module_name: str) -None
钩子将按照注册顺序被调用。
返回值:一个句柄,可用于通过调用 handle.remove()
来移除已添加的钩子。
返回类型:torch.utils.hooks.RemovableHandle
register_mock_hook(hook)
在导出器上注册一个模拟钩子。
每当模块匹配 mock()
模式时,该钩子就会被调用。
它应具有以下签名:
hook(exporter: PackageExporter, module_name: str) -None
钩子函数将按照注册顺序依次调用。
返回值:返回一个句柄,可通过调用 handle.remove()
来移除已添加的钩子。
返回类型:torch.utils.hooks.RemovableHandle
save_binary(package, resource, binary)
将原始字节保存到包中。
参数
package (str)
– 该资源所属的模块包名称(例如"my_package.my_subpackage"
)。resource (str)
– 资源的唯一名称,用于加载时识别。binary (str)
– 要保存的数据。
save_module(module_name, dependencies=True)
将 module
的代码保存到包中。模块代码的解析过程是:先通过 importers
路径查找模块对象,然后利用其 __file__
属性定位源代码。
参数
module_name (str)
– 例如my_package.my_subpackage
,代码将被保存以提供该包的实现代码。dependencies ([bool], 可选)
– 若设为True
,则会扫描源代码中的依赖项。
save_pickle(package, resource, obj, dependencies=True, pickle_protocol=3)
使用pickle将Python对象保存到归档文件中。功能等同于torch.save()
,但会保存到归档而非独立文件。标准pickle不会保存代码,仅保存对象。
如果dependencies
参数为True,此方法还会扫描被pickle的对象,识别重建它们所需的模块,并保存相关代码。
要保存type(obj).__name__
为my_module.MyObject
的对象时,my_module.MyObject
必须能根据importer
顺序解析为对象的类。当保存先前已打包的对象时,importer
列表中必须包含import_module
方法才能正常工作。
参数
package (str)
- 该资源所属的模块包名称(例如"my_package.my_subpackage"
)。resource (str)
- 资源的唯一名称,用于加载时识别。obj (Any)
- 要保存的对象,必须可被pickle序列化。dependencies ([bool], 可选)
- 若为True
,则会扫描源代码中的依赖项。
save_source_file(module_name, file_or_directory, dependencies=True)
将本地文件系统中的 file_or_directory
添加到源码包中,为 module_name
提供代码。
参数
module_name (str)
– 例如"my_package.my_subpackage"
,代码将被保存以提供该包的代码。file_or_directory (str)
– 代码文件或目录的路径。如果是目录,将递归复制该目录中的所有 Python 文件,使用save_source_file()
。如果文件名为"/__init__.py"
,则该代码被视为一个包。dependencies ([bool], 可选)
– 如果为True
,则会扫描源码中的依赖项。
save_source_string(module_name, src, is_package=False, dependencies=True)
将src
作为导出包中module_name
的源代码添加。
参数
module_name (str)
– 例如my_package.my_subpackage
,代码将被保存以提供该包的源代码。src (str)
– 要为此包保存的Python源代码。is_package ([bool], 可选)
– 如果为True
,则将此模块视为包。包允许包含子模块(例如my_package.my_subpackage.my_subsubpackage
),并且可以在其中保存资源。默认为False
。dependencies ([bool], 可选)
– 如果为True
,则会扫描源代码中的依赖项。
save_text(package, resource, text)
将文本数据保存到包中。
参数
package (str)
– 该资源所属模块包的名称(例如"my_package.my_subpackage"
)。resource (str)
– 资源的唯一名称,用于加载时识别。text (str)
– 要保存的内容。
class torch.package.PackageImporter(file_or_buffer, module_allowed=<function PackageImporter.<lambda>>)
导入器(Importers)允许您加载由PackageExporter
写入包的代码。
代码以密封方式加载,使用包中的文件而非常规Python导入系统。这使得PyTorch模型代码和数据可以被打包,从而能在服务器上运行或用于未来的迁移学习。
包导入器确保模块中的代码只能从包内部加载,除非是导出时明确列为外部的模块。
zip存档中的extern_modules
文件列出了包外部依赖的所有模块。
这避免了"隐式"依赖问题——即包在本地运行时能正常工作(因为导入了本地安装的包),但当包被复制到其他机器时就会失败。
__init__(file_or_buffer, module_allowed=<function PackageImporter.<lambda>>)
打开 file_or_buffer
以进行导入操作。此操作会检查导入的包是否仅依赖 module_allowed
所允许的模块。
参数
file_or_buffer ( Union [str,* PathLike[str],* [IO](https://docs.python.org/3/library/typing.html#typing.IO "(in Python v3.13)")[bytes ], PyTorchFileReader])
– 类文件对象(需实现read()
、readline()
、tell()
和seek()
方法)、字符串或包含文件名的os.PathLike
对象。module_allowed (Callable[[str],* [bool]], optional)
– 用于判断是否应允许外部提供模块的方法。可用于确保加载的包不依赖服务器不支持的模块。默认允许所有模块。
抛出异常
ImportError` – 如果包尝试使用被禁止的模块。
file_structure(*, include='**', exclude=())
返回包 zipfile 的文件结构表示。
参数
include (Union[List[str],* str])
– 可选字符串(如"my_package.my_subpackage"
)或可选字符串列表,用于指定要包含在 zipfile 表示中的文件名。也可以是 glob 风格的模式,如PackageExporter.mock()
中所述。exclude (Union[List[str],* str])
– 可选模式,用于排除名称匹配该模式的文件。
返回
返回类型:Directory
id()
返回 torch.package 用于区分 PackageImporter
实例的内部标识符。
格式类似:
<torch_package_0>
import_module(name, package=None)
如果模块尚未加载,则从包中加载该模块并返回。模块会被加载到导入者的本地命名空间,并出现在 self.modules
中而非 sys.modules
里。
参数
name (str)
– 要加载模块的完整限定名。package ([type], 可选)
– 未使用,但为了与importlib.import_module
的函数签名保持一致而保留。默认为None
。
返回值
(可能已加载的)模块对象。
返回类型
load_binary(package, resource)
加载原始字节数据。
参数
package (str)
– 模块包的名称(例如"my_package.my_subpackage"
)。resource (str)
– 资源的唯一名称。
返回值:已加载的数据。
返回类型:bytes
load_pickle(package, resource, map_location=None)
从包中反序列化资源,加载构造对象所需的所有模块
使用 import_module()
。
参数
package (str)
– 模块包的名称(例如"my_package.my_subpackage"
)。resource (str)
– 资源的唯一名称。map_location
– 传递给 torch.load 以确定张量如何映射到设备。默认为None
。
返回
反序列化后的对象。
返回类型:任意
load_text(package, resource, encoding='utf-8', errors='strict')
加载字符串。
参数
package (str)
– 模块包的名称(例如"my_package.my_subpackage"
)。resource (str)
– 资源的唯一名称。encoding (str, 可选)
– 传递给decode
。默认为'utf-8'
。errors (str, 可选)
– 传递给decode
。默认为'strict'
。
返回值:加载的文本。
返回类型:str
python_version()
返回用于创建此软件包的 Python 版本。
注意:此功能为实验性质,不具备向前兼容性。计划后续将其迁移至锁文件中。
返回值:Optional[str]
返回 Python 版本号(例如 3.8.9),如果该软件包未存储版本信息则返回 None
class torch.package.Directory(name, is_dir)
一种文件结构表示形式。它以目录节点形式组织,每个节点包含其子目录列表。通过调用 PackageImporter.file_structure()
可为包创建目录结构。
has_file(filename)
检查文件是否存在于 Directory
中。
参数
filename (str)
- 要搜索的文件路径。
返回值
如果 Directory
包含指定文件则返回 True。
返回类型:bool
torch.profiler
概述
PyTorch Profiler 是一款用于在训练和推理过程中收集性能指标的工具。通过其上下文管理器 API,开发者可以深入分析模型中最耗时的算子、检查输入张量形状和调用堆栈、研究设备内核活动,并可视化执行轨迹。
注意: torch.autograd
模块中的旧版 API 已被视为遗留接口,未来将被弃用。
API 参考
class torch.profiler._KinetoProfile(*, activities=None, record_shapes=False, profile_memory=False, with_stack=False, with_flops=False, with_modules=False, experimental_config=None, execution_trace_observer=None, acc_events=False, custom_trace_id_callback=None)
底层分析器封装了自动梯度分析功能
参数
activities (iterable)
– 用于分析的活跃组列表(CPU、CUDA),支持以下值:
torch.profiler.ProfilerActivity.CPU
, torch.profiler.ProfilerActivity.CUDA
, torch.profiler.ProfilerActivity.XPU
.
默认值:ProfilerActivity.CPU 和(可用时)ProfilerActivity.CUDA 或(可用时)ProfilerActivity.XPU。
record_shapes ([bool])
– 保存算子输入形状信息。profile_memory ([bool])
– 跟踪张量内存分配/释放(详见export_memory_timeline
)。with_stack ([bool])
– 记录算子的源码信息(文件和行号)。with_flops ([bool])
– 使用公式估算特定算子(矩阵乘法和2D卷积)的浮点运算次数。with_modules ([bool])
– 记录与算子调用栈对应的模块层次结构(包括函数名)。
例如:如果模块A的forward调用模块B的forward(其中包含aten::add算子),那么aten::add的模块层次结构就是A.B
注意:目前该功能仅支持TorchScript模型,不支持eager模式模型。
experimental_config (_ExperimentalConfig)
– 一组实验性选项,供Kineto等分析器库使用。注意不保证向后兼容性。execution_trace_observer (ExecutionTraceObserver)
– PyTorch执行轨迹观察器对象。
PyTorch执行轨迹提供了基于图的AI/ML工作负载表示,支持回放基准测试、模拟器和仿真器。
当包含此参数时,观察器的start()和stop()方法将与PyTorch分析器在同一时间窗口被调用。
acc_events ([bool])
– 启用跨多个分析周期的FunctionEvents累积功能
注意:此API为实验性质,未来可能变更。
启用形状和调用栈跟踪会导致额外开销。
当指定record_shapes=True时,分析器会临时持有张量的引用,这可能会阻止某些依赖引用计数的优化,并引入额外的张量拷贝。
add_metadata(key, value)
向跟踪文件中添加用户定义的元数据,包含字符串键和字符串值
add_metadata_json(key, value)
向跟踪文件中添加用户自定义的元数据,包含字符串键和有效的JSON值
events()
返回未聚合的性能分析事件列表,可用于跟踪回调或在性能分析结束后使用
export_chrome_trace(path)
以 Chrome JSON 格式导出收集的跟踪数据。如果启用了 kineto,则仅导出调度中的最后一个周期。
export_memory_timeline(path, device=None)
从分析器收集的内存事件信息树中导出指定设备的数据,并生成时间线图表。使用export_memory_timeline
可导出三种文件格式,通过path
参数的后缀控制:
- 如需生成HTML兼容的图表,使用
.html
后缀。内存时间线图表将以PNG格式嵌入HTML文件中。 - 如需导出由
[时间戳, [按类别划分的内存大小]]
组成的数据点(其中times
是时间戳,sizes
是每个类别的内存使用量),根据后缀选择保存为JSON(.json
)或gzip压缩的JSON(.json.gz
)。 - 如需原始内存事件数据,使用
.raw.json.gz
后缀。每个原始内存事件将包含(时间戳, 操作类型, 字节数, 类别)
,其中:action
是[PREEXISTING, CREATE, INCREMENT_VERSION, DESTROY]
之一category
来自torch.profiler._memory_profiler.Category
枚举
输出:内存时间线数据将以gzip压缩JSON、JSON或HTML格式写入。
export_stacks(path, metric='self_cpu_time_total')
将堆栈跟踪保存到文件
参数
path (str)
- 将堆栈文件保存到此路径;metric (str)
- 使用的指标:“self_cpu_time_total” 或 “self_cuda_time_total”
key_averages(group_by_input_shape=False, group_by_stack_n=0, group_by_overload_name=False)
Averages events, grouping them by operator name and (optionally) input shapes, stack and overload name.
Note: To use shape/stack functionality make sure to set record_shapes/with_stack
when creating profiler context manager.
preset_metadata_json(key, value)
在性能分析器未启动时预设用户自定义元数据,该元数据后续会被添加到跟踪文件中。
元数据格式为字符串键与有效JSON值的组合
toggle_collection_dynamic(enable, activities)
功能说明
可在收集过程中的任意时间点开启/关闭活动收集功能。当前支持切换 Torch 算子(CPU)以及 Kineto 中支持的 CUDA 活动。
参数说明
activities (iterable)
– 用于性能分析的活动组列表,支持以下取值:torch.profiler.ProfilerActivity.CPU
torch.profiler.ProfilerActivity.CUDA
使用示例
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]
) as p:
code_to_profile_0()
// turn off collection of all CUDA activity
p.toggle_collection_dynamic(False, [torch.profiler.ProfilerActivity.CUDA])
code_to_profile_1()
// turn on collection of all CUDA activity
p.toggle_collection_dynamic(True, [torch.profiler.ProfilerActivity.CUDA])
code_to_profile_2()
print(p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
class torch.profiler.profile(*, activities=None, schedule=None, on_trace_ready=None, record_shapes=False, profile_memory=False, with_stack=False, with_flops=False, with_modules=False, experimental_config=None, execution_trace_observer=None, acc_events=False, use_cuda=None, custom_trace_id_callback=None)
分析器上下文管理器。
参数
activities (iterable)
– 用于分析的活跃组列表(CPU、CUDA),支持以下值:
torch.profiler.ProfilerActivity.CPU
、torch.profiler.ProfilerActivity.CUDA
、torch.profiler.ProfilerActivity.XPU
。
默认值:ProfilerActivity.CPU 和(如果可用)ProfilerActivity.CUDA 或(如果可用)ProfilerActivity.XPU。
schedule (Callable)
– 可调用对象,接收步骤(int)作为单一参数,并返回
ProfilerAction
值,指定在每一步执行的分析器操作。
on_trace_ready (Callable)
– 当schedule
在分析过程中返回ProfilerAction.RECORD_AND_SAVE
时,每一步调用的可调用对象。record_shapes ([bool])
– 保存操作符输入形状的信息。profile_memory ([bool])
– 跟踪张量内存分配/释放。with_stack ([bool])
– 记录操作符的源信息(文件和行号)。with_flops ([bool])
– 使用公式估算特定操作符(矩阵乘法和2D卷积)的浮点运算次数(FLOPs)。with_modules ([bool])
– 记录与操作符调用堆栈对应的模块层次结构(包括函数名称)。例如,如果模块A的前向调用模块B的前向,其中包含一个aten::add操作符,那么aten::add的模块层次结构是A.B。
注意,目前此功能仅支持TorchScript模型,不支持eager模式模型。
experimental_config (_ExperimentalConfig)
– 一组用于Kineto库功能的实验性选项。注意,不保证向后兼容性。execution_trace_observer (ExecutionTraceObserver)
– 一个PyTorch执行跟踪观察器对象。
PyTorch执行跟踪 提供基于图的AI/ML工作负载表示,并支持重放基准测试、模拟器和仿真器。
当包含此参数时,观察器的start()和stop()将在与PyTorch分析器相同的时间窗口内调用。请参阅下面的示例部分获取代码示例。
acc_events ([bool])
– 启用跨多个分析周期的FunctionEvents累积。use_cuda ([bool])
–
自版本1.8.1起弃用:改用 activities
。
注意:使用 schedule()
生成可调用的计划。
非默认计划在分析长时间训练作业时非常有用,允许用户在训练过程的不同迭代中获取多个跟踪。
默认计划只是在上下文管理器持续时间内连续记录所有事件。
注意:使用 tensorboard_trace_handler()
生成TensorBoard的结果文件:
on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name)
分析完成后,结果文件可以在指定目录中找到。使用命令:
tensorboard --logdir dir_name
在TensorBoard中查看结果。
更多信息,请参阅 PyTorch Profiler TensorBoard Plugin
注意:启用形状和堆栈跟踪会导致额外的开销。
当指定record_shapes=True时,分析器将临时持有对张量的引用;这可能会进一步阻止某些依赖于引用计数的优化,并引入额外的张量副本。
示例:
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ]
) as p:
code_to_profile()
print(p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
使用性能分析器的 schedule
、on_trace_ready
和 step
函数:
# Non-default profiler schedule allows user to turn profiler on and off
# on different iterations of the training loop;
# trace_handler is called every time a new trace becomes available
def trace_handler(prof):
print(prof.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
# prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json")
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ],
# In this example with wait=1, warmup=1, active=2, repeat=1, # profiler will skip the first step/iteration, # start warming up on the second, record
# the third and the forth iterations, # after which the trace will become available
# and on_trace_ready (when set) is called;
# the cycle repeats starting with the next step
schedule=torch.profiler.schedule(
wait=1, warmup=1, active=2, repeat=1), on_trace_ready=trace_handler
# on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')
# used when outputting for tensorboard
) as p:
for iter in range(N):
code_iteration_to_profile(iter)
# send a signal to the profiler that the next iteration has started
p.step()
以下示例展示了如何设置执行跟踪观察器(execution_trace_observer)
with torch.profiler.profile(
...
execution_trace_observer=(
ExecutionTraceObserver().register_callback("./execution_trace.json")
), ) as p:
for iter in range(N):
code_iteration_to_profile(iter)
p.step()
你也可以参考 tests/profiler/test_profiler.py
中的 test_execution_trace_with_kineto()
方法。
注意:任何实现了 _ITraceObserver
接口的对象都可以传入使用。
get_trace_id()
返回当前跟踪ID。
set_custom_trace_id_callback(callback)
设置当生成新跟踪ID时要调用的回调函数。
step()
通知性能分析器下一个分析步骤已开始。
class torch.profiler.ProfilerAction(value)
可在指定时间间隔执行的性能分析器操作
class torch.profiler.ProfilerActivity
Members:
CPU
XPU
MTIA
CUDA
HPU
PrivateUse1
property name
torch.profiler.schedule(*, wait, warmup, active, repeat=0, skip_first=0, skip_first_wait=0)
返回一个可调用对象,可作为分析器的schedule
参数使用。该分析器会跳过前skip_first
个步骤,然后等待wait
个步骤,接着进行warmup
个步骤的热身,随后执行active
个步骤的活跃记录,最后从wait
步骤开始重复这个循环。
通过repeat
参数可指定可选的循环次数,值为零表示循环将持续到分析结束。
skip_first_wait
参数控制是否跳过第一个wait
阶段。
当用户希望在循环之间等待时间超过skip_first
但首次分析不适用时,这个功能很有用。例如,若skip_first
为10且wait
为20,当skip_first_wait
为零时,第一个循环将在热身前等待10+20=30个步骤;若skip_first_wait
非零,则仅等待10个步骤。之后所有循环都会在最后一次活跃记录和热身之间等待20个步骤。
返回类型:Callable
torch.profiler.tensorboard_trace_handler(dir_name, worker_name=None, use_gzip=False)
将跟踪文件输出到 dir_name
目录,该目录可直接作为日志目录传递给 TensorBoard。
在分布式场景中,每个工作节点的 worker_name
应保持唯一,默认会设置为 ‘[hostname]_[pid]’。
Intel 插桩与追踪技术 API
torch.profiler.itt.is_available()
检查 ITT 功能是否可用
torch.profiler.itt.mark(msg)
描述某个时间点发生的瞬时事件。
参数
msg (str)
– 与该事件关联的ASCII消息。
torch.profiler.itt.range_push(msg)
将一段范围压入嵌套范围跨度栈。返回所启动范围的从零开始的深度。
参数
msg (str)
– 与该范围关联的ASCII消息
torch.profiler.itt.range_pop()
从嵌套范围跨度栈中弹出一个范围。返回被结束范围的从零开始的深度。
2025-05-10(六)