torch.full_like()
是 PyTorch 中的一个张量创建函数,用于创建一个与输入张量形状相同但所有元素值都填充为指定标量值的新张量。下面详细讲解其用法和特性:
1. 函数签名
torch.full_like(input, fill_value, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format)
2. 参数说明
- input (Tensor)
输入张量,新张量将继承其形状(shape)。
不修改原张量,仅参考其形状。
- fill_value (标量)
- 新张量中所有元素的值(填充值)。
- 关键字参数(可选):
dtype (torch.dtype):新张量的数据类型。默认与 input 相同。
device (torch.device):新张量所在的设备(CPU/GPU)。默认与 input 相同。
requires_grad (bool):是否需要梯度(用于自动求导)。默认 False。
layout (torch.layout):内存布局(如 torch.strided)。默认与 input 相同。
memory_format (torch.memory_format):内存格式。默认 torch.preserve_format(保留输入格式)。
3. 核心特性
形状继承:新张量形状与 input 完全一致。
全同填充:所有元素值均为 fill_value。
灵活性:可通过关键字参数覆盖输入张量的属性(如数据类型、设备等)。
示例代码
import torch
# 示例1:基础用法
x = torch.tensor([[1, 2], [3, 4]]) # 形状 (2, 2)
y = torch.full_like(x, fill_value=5)
print(y)
# 输出:
# tensor([[5, 5],
# [5, 5]])
# 示例2:指定数据类型(覆盖输入类型)
z = torch.full_like(x, fill_value=3.14, dtype=torch.float32)
print(z)
# 输出:
# tensor([[3.1400, 3.1400],
# [3.1400, 3.1400]])
# 示例3:改变设备(如GPU)
if torch.cuda.is_available():
device = torch.device("cuda")
x_gpu = x.to(device)
w = torch.full_like(x_gpu, fill_value=10, dtype=torch.float16)
print(w.device) # 输出: cuda:0
# 示例4:创建需要梯度的张量
v = torch.full_like(x, fill_value=2.0, dtype=torch.float32, requires_grad=True)
print(v.requires_grad) # 输出: True
4. 与相关函数的对比
函数 | 描述 | 区别 |
---|---|---|
torch.full_like(input, fill_value) | 按输入张量形状填充 | 形状来自 input |
torch.full(size, fill_value) | 直接指定形状填充 | 需手动设置 size |
torch.ones_like(input) | 创建全1张量 | 固定填充值 1 |
torch.zeros_like(input) | 创建全0张量 | 固定填充值 0 |
5. 典型应用场景
5.1 初始化固定值张量
快速创建与现有张量形状相同的常量张量(如掩码、偏置)。
mask = torch.full_like(image_tensor, fill_value=0) # 创建全0掩码
5.2 指定数据类型/设备
在保持形状的同时转换数据类型或设备。
# 在GPU上创建与x形状相同的全1张量(float32类型)
ones_gpu = torch.full_like(x, 1.0, dtype=torch.float32, device="cuda")
5.3 梯度计算准备
创建需要梯度跟踪的常量张量。
learnable_bias = torch.full_like(output, 0.1, requires_grad=True)
6. 注意事项
6.1 数据类型一致性:
如果 fill_value 与 dtype 不兼容(如用整数填充浮点类型),PyTorch 会自动转换:
a = torch.full_like(x, 3.14, dtype=torch.int) # 值会被截断为整数3
6.2 内存格式:
默认 memory_format=torch.preserve_format 会继承输入格式,如需特定格式(如通道优先),需显式指定:
nchw_tensor = torch.full_like(input, 0, memory_format=torch.contiguous_format)
6.3 无原地操作:
此函数总是返回新张量,不修改输入张量。
通过 torch.full_like(),您可以高效创建与现有张量形状匹配的常量张量,同时灵活控制数据类型、设备等属性,非常适合深度学习中的张量初始化操作。