PyTorch torch.topk

发布于:2024-12-09 ⋅ 阅读:(113) ⋅ 点赞:(0)

torch
https://pytorch.org/docs/stable/torch.html

  • torch.topk (Python function, in torch.topk)
  • torch.Tensor.topk (Python method, in torch.Tensor.topk)

1. torch.topk

https://pytorch.org/docs/stable/generated/torch.topk.html

torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)

Returns the k largest elements of the given input tensor along a given dimension.
返回给定 input 张量沿给定维度的前 k 个最大元素。

If dim is not given, the last dimension of the input is chosen.
如果未提供 dim,则选择 input 的最后一个维度。

If largest is False then the k smallest elements are returned.
如果 largestFalse,则返回前 k 个最小元素。

A namedtuple of (values, indices) is returned with the values and indices of the largest k elements of each row of the input tensor in the given dimension dim.
返回一个元组 (values, indices),其中包含 input 张量在给定维度 dim 上每行的前 k 个最大元素的 valuesindices

The boolean option sorted if True, will make sure that the returned k elements are themselves sorted.
如果布尔选项 sortedTrue,则确保返回的 k 个元素本身已排序。

  • Parameters

input (Tensor) - the input tensor.

k (int) - the k in “top-k”

dim (int, optional) - the dimension to sort along
要排序的维度

largest (bool, optional) - controls whether to return largest or smallest elements
控制是否返回最大或最小元素

sorted (bool, optional) - controls whether to return the elements in sorted order
控制是否按排序顺序返回元素

  • Keyword Arguments

out (tuple, optional) - the output tuple of (Tensor, LongTensor) that can be optionally given to be used as output buffers

2. Example

(base) yongqiang@yongqiang:~$ python
Python 3.11.4 (main, Jul  5 2023, 13:45:01) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>>
>>> input = torch.arange(1., 9.)
>>> input
tensor([1., 2., 3., 4., 5., 6., 7., 8.])
>>>
>>> torch.topk(input, 3)
torch.return_types.topk(
values=tensor([8., 7., 6.]),
indices=tensor([7, 6, 5]))
>>>
>>> values, indices = torch.topk(input, 4)
>>> values
tensor([8., 7., 6., 5.])
>>> indices
tensor([7, 6, 5, 4])
>>> exit()
(base) yongqiang@yongqiang:~$

3. Example

https://github.com/karpathy/llama2.c/blob/master/model.py

import torch

logits = torch.arange(1., 11.)
print("logits.shape:", logits.shape)
print("logits:\n", logits)

logits = logits.view(1, 10)
print("\nlogits.shape:", logits.shape)
print("logits:\n", logits)

values, indices = torch.topk(logits, k=1, dim=-1)
print("\nvalues:\n", values)
print("indices:\n", indices)

top_k = 5
print("\nlogits.size(-1):", logits.size(-1))
values, indices = torch.topk(logits, min(top_k, logits.size(-1)))
print("values:\n", values)
print("indices:\n", indices)

/home/yongqiang/miniconda3/bin/python /home/yongqiang/llm_work/llama2.c/yongqiang.py 
logits.shape: torch.Size([10])
logits:
 tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.])

logits.shape: torch.Size([1, 10])
logits:
 tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.]])

values:
 tensor([[10.]])
indices:
 tensor([[9]])

logits.size(-1): 10
values:
 tensor([[10.,  9.,  8.,  7.,  6.]])
indices:
 tensor([[9, 8, 7, 6, 5]])

Process finished with exit code 0

References

[1] Yongqiang Cheng, https://yongqiang.blog.csdn.net/


网站公告

今日签到

点亮在社区的每一天
去签到