unsqueeze() 方法与squeeze() 方法

发布于:2024-05-09 ⋅ 阅读:(34) ⋅ 点赞:(0)

unsqueeze() 方法在 PyTorch 中用于在指定的维度位置插入一个维度大小为 1 的新维度。

tips:

()内指定维度位置,‘0’表示第一个维度位置,以此类推‘1’ ‘2’ ‘3’.......

1.增加一个维度

import torch

# 创建一个形状为 [4] 的一维张量
x = torch.tensor([1, 2, 3, 4])

# 使用 unsqueeze 在第一个维度位置增加一个维度,结果形状变为 [1, 4]
x_unsqueezed = x.unsqueeze(0)

print(x.shape)
print(x)
print('*'*50)
print(x_unsqueezed.shape)
print(x_unsqueezed)

# 输出:
torch.Size([4])
tensor([1, 2, 3, 4])
**************************************************
torch.Size([1, 4])
tensor([[1, 2, 3, 4]])

2.在中间维度插入一个维度

# 创建一个形状为 [3, 4] 的二维张量
x = torch.tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8],
                  [9, 10, 11, 12]])

# 使用 unsqueeze 在第二个维度位置增加一个维度,结果形状变为 [3, 1, 4]
x_unsqueezed = x.unsqueeze(1)

print(x.shape)
print(x)
print('*'*50)
print(x_unsqueezed.shape)
print(x_unsqueezed)

# 输出
torch.Size([3, 4])
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])
**************************************************
torch.Size([3, 1, 4])
tensor([[[ 1,  2,  3,  4]],

        [[ 5,  6,  7,  8]],

        [[ 9, 10, 11, 12]]])

3.在特定位置插入多个维度

# 创建一个形状为 [2, 2] 的二维张量
x = torch.tensor([[1, 2],
                  [3, 4]])
# 使用 unsqueeze 在第一和第三维度位置各增加一个维度,结果形状变为 [1, 2, 1, 2]
x_unsqueezed = x.unsqueeze(0).unsqueeze(2)

print(x.shape)
print(x)
print('*'*50)
print(x_unsqueezed.shape)
print(x_unsqueezed)

# 输出:
torch.Size([2, 2])
tensor([[1, 2],
        [3, 4]])
**************************************************
torch.Size([1, 2, 1, 2])
tensor([[[[1, 2]],

         [[3, 4]]]])

4.使用 unsqueeze 进行广播

# 创建一个形状为 [4] 的一维张量
a = torch.tensor([1, 2, 3, 4])

# 创建一个形状为 [3, 4] 的二维张量
b = torch.tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8],
                  [9, 10, 11, 12]])

# 使用 unsqueeze 使 a 可以广播到 b 的形状
a_unsqueezed = a.unsqueeze(0)  # 形状变为 [1, 4]

# 现在 a_unsqueezed 和 b 可以进行广播操作
result = a_unsqueezed + b


print(a.shape)
print(a)
print(a_unsqueezed)
print('*'*50)
print(b.shape)
print(b)
print('*'*50)
print(result)

# 输出
torch.Size([4])
tensor([1, 2, 3, 4])
tensor([[1, 2, 3, 4]])
**************************************************
torch.Size([3, 4])
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])
**************************************************
tensor([[ 2,  4,  6,  8],
        [ 6,  8, 10, 12],
        [10, 12, 14, 16]])

5.在多维张量中插入多个维度

# 创建一个形状为 [2, 3, 4] 的三维张量
x = torch.randn(2, 3, 4)

# 使用 unsqueeze 在第二和第四维度位置各增加一个维度,结果形状变为 [2, 1, 3, 1, 4]
x_unsqueezed = x.unsqueeze(1).unsqueeze(3)


print(x.shape)
print(x)
print('*'*50)
print(x_unsqueezed.shape)
print(x_unsqueezed)

# 输出
torch.Size([2, 3, 4])
tensor([[[ 0.7232, -1.1270, -0.3702, -0.6435],
         [ 1.2270, -0.6766,  1.0700, -1.4295],
         [-0.6011,  0.0285, -0.2584, -0.9866]],

        [[ 0.8701,  1.1882, -0.6923,  0.9238],
         [ 0.6200,  1.7528,  1.1101,  0.3141],
         [-0.7319,  0.1732, -0.5922,  0.4118]]])
**************************************************
torch.Size([2, 1, 3, 1, 4])
tensor([[[[[ 0.7232, -1.1270, -0.3702, -0.6435]],

          [[ 1.2270, -0.6766,  1.0700, -1.4295]],

          [[-0.6011,  0.0285, -0.2584, -0.9866]]]],



        [[[[ 0.8701,  1.1882, -0.6923,  0.9238]],

          [[ 0.6200,  1.7528,  1.1101,  0.3141]],

          [[-0.7319,  0.1732, -0.5922,  0.4118]]]]])

6.将标量转换为张量

import torch

# 创建一个标量值
scalar = 5

# 将标量转换为一个 PyTorch 张量
scalar_tensor = torch.tensor(scalar)

# 使用 unsqueeze 在第一个维度位置增加一个维度,结果形状变为 [1]
scalar_tensor_unsqueeze = scalar_tensor.unsqueeze(0)


print(type(scalar))
print(scalar_tensor.shape)
print(scalar_tensor)
print('*'*50)
print(scalar_tensor_unsqueeze.shape)
print(scalar_tensor_unsqueeze)

# 输出
<class 'int'>
torch.Size([])
tensor(5)
**************************************************
torch.Size([1])
tensor([5])

squeeze() 函数在 PyTorch 中用于从张量中移除所有长度为 1 的维度。这通常用于减少张量的维度,特别是在某些操作之后,其中某些维度可能只有一个元素,而这些单一元素的维度不再需要。

 1.移除单维度

import torch

# 创建一个形状为 [1, 3, 1] 的张量
x = torch.randn(1, 3, 1)

# 使用 squeeze() 移除单维度
y = x.squeeze()

# 打印 y 的形状
print(y.shape)  # 输出: torch.Size([3])

2.选择性移除单维度 

# 创建一个形状为 [2, 1, 4, 1] 的张量
x = torch.randn(2, 1, 4, 1)

# 使用 squeeze(1) 移除第二个维度
y = x.squeeze(1)

# 打印 y 的形状
print(y.shape)  # 输出: torch.Size([2, 4, 1])

 3.移除所有单维度

# 创建一个形状为 [1, 2, 1, 3, 1] 的张量
x = torch.randn(1, 2, 1, 3, 1)

# 使用 squeeze() 移除所有单维度
y = x.squeeze()

# 打印 y 的形状
print(y.shape)  # 输出: torch.Size([2, 3])