参考资料:
https://github.com/pytorch/pytorch/issues/73515
https://www.cnblogs.com/X1OO/articles/18171700
由于业务原因,需要在Pytorch代码中使用分布式通讯来把计算负载平均到多张显卡上。在无数次确认我的业务代码没问题之后,我开始把怀疑的对象转移到分布式通讯的问题上:
单卡推理的中间层输出
多卡推理的中间层输出
如上两图,在打印了中间层输出之后,我发现在After gather之后,多卡推理与单卡推理的中间变量的均值、最大值和最小值是完全一致的。但是紧邻的一个log却显示,它们各自在进入decode方法之后,值就随即发生了变化。这就不符合我的认知了,因为我可以完全保证从gather到decode没有任何对tensor做特殊处理的操作。并且在我的固有观念里,只要两个形状较大的tensor统计值类似,基本就可以保证两个tensor是一模一样的,那么问题到底出现在哪呢?
不信邪的我把After gather之后统计值相同的两个tensor都保存下载进行了分析,一分析我就傻眼了:
只见两个tensor统计值完全相同,甚至通过排序之后发现Tensor中的元素也似乎完全相同,但是这两个Tensor就是不一样的。在此检查了一下代码中没有对维度进行特殊操作之后,我把目光放到了我写的分布式gather函数里:
1 def _conv_gather_avg(input_, dim):
2 cp_world_size = get_context_parallel_world_size()
3 # Bypass the function if context parallel is 1
4 if cp_world_size == 1:
5 return input_.contiguous()
6
7 # input_ = input_.contiguous()
8
9 group = get_context_parallel_group()
10 cp_rank = get_context_parallel_rank()
11 tensor_list = [torch.empty_like(input_) for _ in range(cp_world_size)]
12 tensor_list[cp_rank] = input_
13 torch.distributed.all_gather(tensor_list, input_, group=group)
14 # Note: torch.cat already creates a contiguous tensor.
15 output = torch.cat(tensor_list, dim=dim).contiguous()
16 # print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)
17 return output
注意第七行一开始是没有的,这里的代码我是借鉴了其他人的,我发现很多地方都强调了contiguous这个方法,难道它真的有这么重要?于是抱着试一试的态度,我在第七行上加上了input_ = input_.contiguous(),然后神奇的事情就发生了,gather之后的tensor居然就能够精度对上了。总结一下问题就是如果在all_gather之前不对输入input_运行contiguous的话,会导致gather之后的tensor虽然值都一样,但是排列顺序完全混乱。下面引用参考资料讲一下为什么Pytorch分布式通讯中all_gather要求tensor连续。
工作太忙了没时间讲了... 有时间再补充