由于网络中用到了torch.pairwise_distance(),发现在不同版本的pytorch下,计算结果不一致的情况。下面就来解决这个问题。
一、torch.nn. PairwiseDistance ( p = 2.0 , eps = 1e-06 , keepdim = False )用法
首先,了解一下pytorch里 torch.pairwise_distance()的用法。
其中,各项参数如下:
p ( real , optional ) – 范数度。可以为负。默认值:2
eps ( float , optional ) – 小值以避免被零除。默认值:1e-6
keepdim ( bool , optional ) – 确定是否保持向量维度。默认值:False
按照其上意思,对于x1(B,C,H,W)和x2(B,C,H,W)进行欧式距离计算,如下:
import torch.nn.functional as F
dist = F.pairwise_distance(x1, x2, keepdim=True)
则可获得dist(B,1,H,W)。
在pytorch1.8.0中进行如上计算,结果完全正确。
然而,在pytorch1.13.0中进行如上计算,则发现结果的size乱码了,得到dist(B,C,H,1)。蜜汁困惑啊!!!
后来,在这篇blog里发现,因为torch.pairwise_distance函数,会对最后一维进行展开,所以应该先把张量维度重构为(Batch,sizeA(B),sizeB(A),Channel),再进行计算即可。
因此,解决办法是,先将x1和x2进行维度变换,转换成(B,H,W,C)再进行欧式距离计算,即可。
x1=x1.contiguous().permute(0,2,3,1)
x2=x2.contiguous().permute(0,2,3,1)
文章评论