PyTorch torch.topk
torch
https://pytorch.org/docs/stable/torch.html
torch.topk
(Python function, intorch.topk
)torch.Tensor.topk
(Python method, intorch.Tensor.topk
)
1. torch.topk
https://pytorch.org/docs/stable/generated/torch.topk.html
torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)
Returns the k
largest elements of the given input
tensor along a given dimension.
返回给定 input
张量沿给定维度的前 k
个最大元素。
If dim
is not given, the last dimension of the input
is chosen.
如果未提供 dim
,则选择 input
的最后一个维度。
If largest
is False
then the k
smallest elements are returned.
如果 largest
为 False
,则返回前 k
个最小元素。
A namedtuple of (values, indices)
is returned with the values
and indices
of the largest k
elements of each row of the input
tensor in the given dimension dim
.
返回一个元组 (values, indices)
,其中包含 input
张量在给定维度 dim
上每行的前 k
个最大元素的 values
和 indices
。
The boolean option sorted
if True
, will make sure that the returned k
elements are themselves sorted.
如果布尔选项 sorted
为 True
,则确保返回的 k
个元素本身已排序。
- Parameters
input (Tensor)
- the input tensor.
k (int)
- the k
in “top-k”
dim (int, optional)
- the dimension to sort along
要排序的维度
largest (bool, optional)
- controls whether to return largest or smallest elements
控制是否返回最大或最小元素
sorted (bool, optional)
- controls whether to return the elements in sorted order
控制是否按排序顺序返回元素
- Keyword Arguments
out (tuple, optional)
- the output tuple of (Tensor, LongTensor) that can be optionally given to be used as output buffers
2. Example
(base) yongqiang@yongqiang:~$ python
Python 3.11.4 (main, Jul 5 2023, 13:45:01) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>>
>>> input = torch.arange(1., 9.)
>>> input
tensor([1., 2., 3., 4., 5., 6., 7., 8.])
>>>
>>> torch.topk(input, 3)
torch.return_types.topk(
values=tensor([8., 7., 6.]),
indices=tensor([7, 6, 5]))
>>>
>>> values, indices = torch.topk(input, 4)
>>> values
tensor([8., 7., 6., 5.])
>>> indices
tensor([7, 6, 5, 4])
>>> exit()
(base) yongqiang@yongqiang:~$
3. Example
https://github.com/karpathy/llama2.c/blob/master/model.py
import torch
logits = torch.arange(1., 11.)
print("logits.shape:", logits.shape)
print("logits:\n", logits)
logits = logits.view(1, 10)
print("\nlogits.shape:", logits.shape)
print("logits:\n", logits)
values, indices = torch.topk(logits, k=1, dim=-1)
print("\nvalues:\n", values)
print("indices:\n", indices)
top_k = 5
print("\nlogits.size(-1):", logits.size(-1))
values, indices = torch.topk(logits, min(top_k, logits.size(-1)))
print("values:\n", values)
print("indices:\n", indices)
/home/yongqiang/miniconda3/bin/python /home/yongqiang/llm_work/llama2.c/yongqiang.py
logits.shape: torch.Size([10])
logits:
tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
logits.shape: torch.Size([1, 10])
logits:
tensor([[ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]])
values:
tensor([[10.]])
indices:
tensor([[9]])
logits.size(-1): 10
values:
tensor([[10., 9., 8., 7., 6.]])
indices:
tensor([[9, 8, 7, 6, 5]])
Process finished with exit code 0
References
[1] Yongqiang Cheng, https://yongqiang.blog.csdn.net/