使用PyTorch求平方根报错如何解决?针对这个问题,这篇文章详细介绍了相对应的分析和解答,希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法。
问题描述
初步使用PyTorch进行平方根计算,通过range)创建一个张量,然后对其求平方根。
a = torch.tensorlistrange9))) b = torch.sqrta)
报出以下错误:
RuntimeError: sqrt_vml_cpu not implemented for 'Long'
原因
Long类型的数据不支持log对数运算, 为什么Tensor是Long类型? 因为创建List数组时默认使用的是int, 所以从List转成torch.Tensor后, 数据类型变成了Long。
printa.dtype)
torch.int64
解决方法
提前将数据类型指定为浮点型, 重新执行:
b = torch.sqrta.totorch.double)) printb)
tensor[0.0000, 1.0000, 1.4142, 1.7321, 2.0000, 2.2361, 2.4495, 2.6458, 2.8284], dtype=torch.float64)
补充:pytorch20 pytorch常见运算详解
矩阵与标量
这个是矩阵(张量)每一个元素与标量进行操作。
import torch a = torch.tensor[1,2]) printa+1) >>> tensor[2, 3])
哈达玛积
这个就是两个相同尺寸的张量相乘,然后对应元素的相乘就是这个哈达玛积,也成为element wise。
a = torch.tensor[1,2]) b = torch.tensor[2,3]) printa*b) printtorch.mula,b)) >>> tensor[2, 6]) >>> tensor[2, 6])
这个torch.mul)和*是等价的。
当然,除法也是类似的:
a = torch.tensor[1.,2.]) b = torch.tensor[2.,3.]) printa/b) printtorch.diva/b)) >>> tensor[0.5000, 0.6667]) >>> tensor[0.5000, 0.6667])
我们可以发现的torch.div)其实就是/, 类似的:torch.add就是+,torch.sub)就是-,不过符号的运算更简单常用。
矩阵乘法
如果我们想实现线性代数中的矩阵相乘怎么办呢?
这样的操作有三个写法:
torch.mm)
torch.matmul)
@,这个需要记忆,不然遇到这个可能会挺蒙蔽的
a = torch.tensor[[1.],[2.]]) b = torch.tensor[2.,3.]).view1,2) printtorch.mma, b)) printtorch.matmula, b)) printa @ b)
这是对二维矩阵而言的,假如参与运算的是一个多维张量,那么只有torch.matmul)可以使用。等等,多维张量怎么进行矩阵的乘法?在多维张量中,参与矩阵运算的其实只有后两个维度,前面的维度其实就像是索引一样,举个例子:
a = torch.rand1,2,64,32)) b = torch.rand1,2,32,64)) printtorch.matmula, b).shape) >>> torch.Size[1, 2, 64, 64])
a = torch.rand3,2,64,32)) b = torch.rand1,2,32,64)) printtorch.matmula, b).shape) >>> torch.Size[3, 2, 64, 64])
这样也是可以相乘的,因为这里涉及一个自动传播Broadcasting机制,这个在后面会讲,这里就知道,如果这种情况下,会把b的第一维度复制3次 ,然后变成和a一样的尺寸,进行矩阵相乘。
幂与开方
print'幂运算') a = torch.tensor[1.,2.]) b = torch.tensor[2.,3.]) c1 = a ** b c2 = torch.powa, b) printc1,c2) >>> tensor[1., 8.]) tensor[1., 8.])
和上面一样,不多说了。开方运算可以用torch.sqrt),当然也可以用a**0.5)。
对数运算
在上学的时候,我们知道ln是以e为底的,但是在pytorch中,并不是这样。
pytorch中log是以e自然数为底数的,然后log2和log10才是以2和10为底数的运算。
import numpy as np print'对数运算') a = torch.tensor[2,10,np.e]) printtorch.loga)) printtorch.log2a)) printtorch.log10a)) >>> tensor[0.6931, 2.3026, 1.0000]) >>> tensor[1.0000, 3.3219, 1.4427]) >>> tensor[0.3010, 1.0000, 0.4343])
近似值运算
.ceil) 向上取整
.floor)向下取整
.trunc)取整数
.frac)取小数
.round)四舍五入
.ceil) 向上取整.floor)向下取整.trunc)取整数.frac)取小数.round)四舍五入
a = torch.tensor1.2345) printa.ceil)) >>>tensor2.) printa.floor)) >>> tensor1.) printa.trunc)) >>> tensor1.) printa.frac)) >>> tensor0.2345) printa.round)) >>> tensor1.)
剪裁运算
这个是让一个数,限制在你自己设置的一个范围内[min,max],小于min的话就被设置为min,大于max的话就被设置为max。这个操作在一些对抗生成网络中,好像是WGAN-GP,通过强行限制模型的参数的值。
a = torch.rand5) printa) printa.clamp0.3,0.7))
pytorch的优点
1.PyTorch是相当简洁且高效快速的框架;2.设计追求最少的封装;3.设计符合人类思维,它让用户尽可能地专注于实现自己的想法;4.与google的Tensorflow类似,FAIR的支持足以确保PyTorch获得持续的开发更新;5.PyTorch作者亲自维护的论坛 供用户交流和求教问题6.入门简单