Skip to content

张量操作

  • PyTorch 的一个语法糖:dim=-1 表示最后一个维度,dim=-2 表示倒数第二个维度,以此类推

创建张量

torch.randn

torch.zeros

torch.ones

换维度:.T

转置

py
x = torch.randn(2, 3, 4)
xt = x.T

换维度:transpose

交换两个维度

Self-Attention、矩阵乘法前对齐维度

py
x = torch.randn(2, 3, 4)
x.transpose(1, 2).shape  # (2, 4, 3)

换维度:permute

任意重排所有维度

复杂换维,如 (B, H, T, D) → (B, T, H, D)

py
x = torch.randn(2, 3, 4)
x.permute(0, 2, 1).shape  # (2, 4, 3)

换维度:reshape/view

改变形状(不能调维顺序)

  • 拉平维度 (B, T, D) → (B*T, D)
  • 恢复维度 (B*T, D) → (B, T, D)
py
x = torch.randn(2, 3, 4)
x.reshape(6, 4).shape  # (6, 4)

广播(broadcasting)

PyTorch 的广播规则是:

如果两个张量的形状不同,从右往左比:

  • 如果某个维度相等,OK;
  • 如果其中一个维度是 1,会在该维度上广播(复制);
  • 否则报错。

数据表示

TODO