c++ libtorch tensor 注意浅拷贝

发布于:2024-10-12 ⋅ 阅读:(89) ⋅ 点赞:(0)

错误代码示例

torch::Tensor multi_dim_identity = torch::zeros({ 2, 2, 2, 2 }, torch::kComplexDouble);
for (int i = 0; i < 2; ++i) {
    multi_dim_identity.index_put_({ torch::indexing::Slice(), torch::indexing::Slice(), i, i }, 1);
}

torch::Tensor all_Kx = multi_dim_identity;
torch::Tensor all_Ky = multi_dim_identity;

for (int i = 0; i < 2; ++i) {
   torch::Tensor a = torch::zeros({ 2, 2, 2 }, torch::kComplexDouble);
   torch::Tensor b = torch::rand({ 2, 2, 2 }, torch::kComplexDouble);
   for (int j = 0; j < dim_x * dim_y; ++j) {
       all_Kx.index_put_({ i, torch::indexing::Slice(), j, j }, a.index({torch::indexing::Slice(), j, j}));
       all_Ky.index_put_({ i, torch::indexing::Slice(), j, j }, b.index({torch::indexing::Slice(), j, j}));
}

}

结果 all_Kx和all_Ky一样,在每个第1维度上都是一样的随机b,因为all_Kx和all_Ky都是multi_dim_identity的浅拷贝,all_Kx先赋值,其实是赋值给了multi_dim_identity,然后all_Ky再赋值,其实是赋值给了multi_dim_identity,导致all_Kx也跟着变,所以和all_Ky一样

正确代码如下

torch::Tensor multi_dim_identity = torch::zeros({ 2, 2, 2, 2 }, torch::kComplexDouble);
for (int i = 0; i < 2; ++i) {
    multi_dim_identity.index_put_({ torch::indexing::Slice(), torch::indexing::Slice(), i, i }, 1);
}

torch::Tensor all_Kx = multi_dim_identity.clone();
torch::Tensor all_Ky = multi_dim_identity.clone();

for (int i = 0; i < 2; ++i) {
   torch::Tensor a = torch::zeros({ 2, 2, 2 }, torch::kComplexDouble);
   torch::Tensor b = torch::rand({ 2, 2, 2 }, torch::kComplexDouble);
   for (int j = 0; j < dim_x * dim_y; ++j) {
       all_Kx.index_put_({ i, torch::indexing::Slice(), j, j }, a.index({torch::indexing::Slice(), j, j}));
       all_Ky.index_put_({ i, torch::indexing::Slice(), j, j }, b.index({torch::indexing::Slice(), j, j}));
}

}

用clone方法可以深拷贝,这样all_Kx和all_Ky就不一样


网站公告

今日签到

点亮在社区的每一天
去签到