1.Pytorch计算公式
a,b为两个张量,且a.size=B,N,3),b.size)=B,M,3),计算a中各点到b中各点的距离,返回距离张量c,c.size)=B,N,M)。不考虑Batch时,可以将理解:c的第i行j列的值表示a中第i个点到b中第j个点的距离。
import torch def EuclideanDistancet1,t2): dim=lent1.size)) if dim==2: N,C=t1.size) M,_=t2.size) dist = -2 * torch.matmult1, t2.permute1, 0)) dist += torch.sumt1 ** 2, -1).viewN, 1) dist += torch.sumt2 ** 2, -1).view1, M) dist=torch.sqrtdist) return dist elif dim==3: B,N,_=t1.size) _,M,_=t2.size) dist = -2 * torch.matmult1, t2.permute0, 2, 1)) dist += torch.sumt1 ** 2, -1).viewB, N, 1) dist += torch.sumt2 ** 2, -1).viewB, 1, M) dist=torch.sqrtdist) return dist else: print'error...') printf'dimensional 2.......') a=torch.Tensor[[0,0],[1,1]]) b=torch.Tensor[[1,0],[3,4]]) printf'size of a:{a.size)}\tsize of b:{b.size)}') printf'distance of point a and b is: {EuclideanDistancea,b)}') printf'\ndimensional 3.......') a=torch.unsqueezea,dim=0) b=torch.unsqueezeb,dim=0) printf'size of a:{a.size)}\tsize of b:{b.size)}') printf'distance of point a and b is: {EuclideanDistancea,b)}')
2.代码理解
2.1定义待计算张量
现有张量a,b如下:
2.2距离公式
有距离公式如下:
2.3分步计算
(1)计算:
d1=-2 * torch.matmula, b.permute0, 2, 1))
(1)结果如下:
(2)计算:
d2=torch.suma** 2, -1)
d3=torch.sumb** 2, -1)
(2)结果如下:
当前有:d1.size=B,N,M),d2.size)=B,N,1),d3.size)=B,M,1)
可以看到d1中的i行中保持不变的部分为a中的第i个点,d1中第j列中不变的部分对应b中的j行。因此,只需在d1的行上加上一个d2的对应行,列上加d3的对应行即可。
(3)相加:
d=d1+d2.viewB,N,1)+d3.viewB,1,M)
(3)结果如下:
(4)开根
d=torch.sqrtd)