错误代码示例
torch::Tensor multi_dim_identity = torch::zeros({ 2, 2, 2, 2 }, torch::kComplexDouble);
for (int i = 0; i < 2; ++i) {
multi_dim_identity.index_put_({ torch::indexing::Slice(), torch::indexing::Slice(), i, i }, 1);
}
torch::Tensor all_Kx = multi_dim_identity;
torch::Tensor all_Ky = multi_dim_identity;
for (int i = 0; i < 2; ++i) {
torch::Tensor a = torch::zeros({ 2, 2, 2 }, torch::kComplexDouble);
torch::Tensor b = torch::rand({ 2, 2, 2 }, torch::kComplexDouble);
for (int j = 0; j < dim_x * dim_y; ++j) {
all_Kx.index_put_({ i, torch::indexing::Slice(), j, j }, a.index({torch::indexing::Slice(), j, j}));
all_Ky.index_put_({ i, torch::indexing::Slice(), j, j }, b.index({torch::indexing::Slice(), j, j}));
}
}
结果 all_Kx和all_Ky一样,在每个第1维度上都是一样的随机b,因为all_Kx和all_Ky都是multi_dim_identity的浅拷贝,all_Kx先赋值,其实是赋值给了multi_dim_identity,然后all_Ky再赋值,其实是赋值给了multi_dim_identity,导致all_Kx也跟着变,所以和all_Ky一样
正确代码如下
torch::Tensor multi_dim_identity = torch::zeros({ 2, 2, 2, 2 }, torch::kComplexDouble);
for (int i = 0; i < 2; ++i) {
multi_dim_identity.index_put_({ torch::indexing::Slice(), torch::indexing::Slice(), i, i }, 1);
}
torch::Tensor all_Kx = multi_dim_identity.clone();
torch::Tensor all_Ky = multi_dim_identity.clone();
for (int i = 0; i < 2; ++i) {
torch::Tensor a = torch::zeros({ 2, 2, 2 }, torch::kComplexDouble);
torch::Tensor b = torch::rand({ 2, 2, 2 }, torch::kComplexDouble);
for (int j = 0; j < dim_x * dim_y; ++j) {
all_Kx.index_put_({ i, torch::indexing::Slice(), j, j }, a.index({torch::indexing::Slice(), j, j}));
all_Ky.index_put_({ i, torch::indexing::Slice(), j, j }, b.index({torch::indexing::Slice(), j, j}));
}
}
用clone方法可以深拷贝,这样all_Kx和all_Ky就不一样