PyTorch学习(12):PyTorch取极值(max, argmax, min, argmin)

发布于:2024-06-07 ⋅ 阅读:(229) ⋅ 点赞:(0)

PyTorch学习(1):torch.meshgrid的使用-CSDN博客

PyTorch学习(2):torch.device-CSDN博客

PyTorch学习(9):torch.topk-CSDN博客

PyTorch学习(10):torch.where-CSDN博客

PyTorch学习(11):PyTorch的形状变换(view, reshape)与维度变换(transpose, permute)-CSDN博客


 


目录

1. 写在前面

2. max

3. argmax

4. min

5. argmin


1. 写在前面

        PyTorch提供了大量的API接口,帮助我们快速的搭建训练工程,开发新的算法。求极值或者获得极值所处的位置是算法种常见的操作,比如网络的后解算过程中,我们常通过求极值来获得某一个预测点的类别信息。PyTorch提供了max,argmax,min,argmin这四个接口帮助我们快速的获得极值信息。

2. max

        torch.max(input, dim)函数用于在一个张量的指定维度上找到最大值。它返回两个输出:第一个是找到的最大值,第二个是这些最大值的索引。其原型如下。

 torch.max(input, dim, keepdim=False, *, out=None)

其中,

        input是要取极值输入Tensor;

        dim用于指定维度,在dim维度上取极值;

        keepdim控制是否保持原维度信息;

        具体使用可参考如下示例程序。

import torch


# 创建一个3x4的张量

tensor = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])

# 在第0维(行)上找到最大值及其索引

max_values, max_indices = torch.max(tensor, 0)

print("Max values along dimension 0:", max_values)

print("Indices of max values along dimension 0:", max_indices)

# 在第1维(列)上找到最大值及其索引

max_values, max_indices = torch.max(tensor, 1)

print("Max values along dimension 1:", max_values)

print("Indices of max values along dimension 1:", max_indices)

3. argmax

torch.argmax(input, dim)函数返回张量中指定维度上最大值的索引。其原型如下。

torch.argmax(input, dim, keepdim=False)

其中,

input未输入tensor;

dim为指定操作维度;

keepdim控制是否保持原维度信息;

具体使用可参考如下程序。

# 在第0维(行)上找到最大值的索引

max_index = torch.argmax(tensor, 0)

print("Index of max value along dimension 0:", max_index)


# 在第1维(列)上找到最大值的索引

max_index = torch.argmax(tensor, 1)

print("Index of max value along dimension 1:", max_index)

4. min

torch.min(input, dim)函数与max函数类似,但它用于找到张量中指定维度上的最小值。其原型如下。

 torch.min(input, dim, keepdim=False, *, out=None)

其参数信息可参考torch.max。

使用方式可参考如下程序。

# 在第0维(行)上找到最小值及其索引

min_values, min_indices = torch.min(tensor, 0)

print("Min values along dimension 0:", min_values)

print("Indices of min values along dimension 0:", min_indices)


# 在第1维(列)上找到最小值及其索引

min_values, min_indices = torch.min(tensor, 1)

print("Min values along dimension 1:", min_values)

print("Indices of min values along dimension 1:", min_indices)

5. argmin

torch.argmin(input, dim)函数返回张量中指定维度上最小值的索引。原型如下。

torch.argmin(input, dim=None, keepdim=False)

参数可参考torch.argmax。

具体使用可参考如下程序。

# 在第0维(行)上找到最小值的索引

min_index = torch.argmin(tensor, 0)

print("Index of min value along dimension 0:", min_index)


# 在第1维(列)上找到最小值的索引

min_index = torch.argmin(tensor, 1)

print("Index of min value along dimension 1:", min_index)


网站公告

今日签到

点亮在社区的每一天
去签到