【Pytorch】topk函数

发布于:2024-08-11 ⋅ 阅读:(132) ⋅ 点赞:(0)

topk 是 PyTorch 中的一个函数,用于从张量中选取最大(或最小)的 k 个元素及其对应的索引。其定义如下:

values, indices = torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)

参数说明

  • input (Tensor): 输入张量。
  • k (int): 要选取的最大(或最小)元素的数量。
  • dim (int, 可选): 指定沿着哪个维度进行操作。默认为 None,此时沿着最后一个维度进行操作。
  • largest (bool, 可选): 如果为True,则选取最大的 k 个元素;如果为 False,则选取最小的 k 个元素。默认为 True。
  • sorted (bool, 可选): 如果为 True,则返回的值是排序过的(即最大的值排在前面)。如果为 False,则返回的值是按照它们在原张量中的顺序排列。默认为 True。
  • out (tuple, 可选): 可以指定一个元组来存储输出结果。元组应该包含两个张量,分别用于存储值和索引。默认为None。

代码片段赏析:

    def get_embedding_indices(self, points):
        r"""Compute the indices of pair-wise distance embedding and triplet-wise angular embedding.

        Args:
            points: torch.Tensor (B, N, 3), input point cloud

        Returns:
            d_indices: torch.FloatTensor (B, N, N), distance embedding indices
            a_indices: torch.FloatTensor (B, N, N, k), angular embedding indices
        """
        batch_size, num_point, _ = points.shape

        dist_map = torch.sqrt(pairwise_distance(points, points))  # (B, N, N)
        d_indices = dist_map / self.sigma_d

        k = self.angle_k
        ##! largest=False的含义是选择K个距离最小的点, dim=2代表从dist_map的第二个维度来选择
        knn_indices = dist_map.topk(k=k + 1, dim=2, largest=False)[1][:, :, 1:]  # 这里将 dist_map.topk会返回values和indices, 用索引1来取出indices后, 再从所有的knn中去掉自身, 所以取[:, :, 1:] 

网站公告

今日签到

点亮在社区的每一天
去签到