这篇文章主要为大家详细介绍了Python之pytorch中的matmul与mm,bmm区别说明的简单示例,具有一定的参考价值,可以用来参考一下。
感兴趣的小伙伴,下面一起跟随四海网的雯雯来看看吧!
pytorch中matmul和mm和bmm区别 matmulmmbmm结论
先看下官网上对这三个函数的介绍。
顾名思义, 就是两个batch矩阵乘法.
从官方文档可以看出
1、mm只能进行矩阵乘法,也就是输入的两个tensor维度只能是( n × m ) (n\times m)(n×m)和( m × p ) (m\times p)(m×p)
2、bmm是两个三维张量相乘, 两个输入tensor维度是( b × n × m ) (b\times n\times m)(b×n×m)和( b × m × p ) (b\times m\times p)(b×m×p), 第一维b代表batch size,输出为( b × n × p ) (b\times n \times p)(b×n×p)
3、matmul可以进行张量乘法, 输入可以是高维.
点乘都是broadcast的,可以用torch.mul(a, b)实现,也可以直接用*实现。
代码如下:
>>> a = torch.ones(3,4)
>>> a
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
>>> b = torch.Tensor([1,2,3]).reshape((3,1))
>>> b
tensor([[1.],
[2.],
[3.]])
>>> torch.mul(a, b)
tensor([[1., 1., 1., 1.],
[2., 2., 2., 2.],
[3., 3., 3., 3.]])
pytorch中的matmul与mm,bmm区别说明
当a, b维度不一致时,会自动填充到相同维度相点乘。
矩阵相乘有torch.mm和torch.matmul两个函数。其中前一个是针对二维矩阵,后一个是高维。当torch.mm用于大于二维时将报错。
代码如下:
>>> a = torch.ones(3,4)
>>> b = torch.ones(4,2)
>>> torch.mm(a, b)
tensor([[4., 4.],
[4., 4.],
[4., 4.]])
pytorch中的matmul与mm,bmm区别说明
代码如下:
>>> a = torch.ones(3,4)
>>> b = torch.ones(5,4,2)
>>> torch.matmul(a, b).shape
torch.Size([5, 3, 2])
pytorch中的matmul与mm,bmm区别说明
代码如下:
>>> a = torch.ones(5,4,2)
>>> b = torch.ones(5,2,3)
>>> torch.matmul(a, b).shape
torch.Size([5, 4, 3])
pytorch中的matmul与mm,bmm区别说明
代码如下:
>>> a = torch.ones(5,4,2)
>>> b = torch.ones(5,2,3)
>>> torch.matmul(b, a).shape
报错。
pytorch中的matmul与mm,bmm区别说明
以上为个人经验,希望能给大家一个参考,也希望大家多多支持四海网。如有错误或未考虑完全的地方,望不吝赐教。
本文来自:http://www.q1010.com/181/18833-0.html
注:关于Python之pytorch中的matmul与mm,bmm区别说明的简单示例的内容就先介绍到这里,更多相关文章的可以留意四海网的其他信息。
关键词:python
四海网收集整理一些常用的php代码,JS代码,数据库mysql等技术文章。