Appearance
torch.squeeze
- 去批次
- 去通道:针对图片,神经网络需要通道维度,但是显示的时候不需要
torch.unsqueeze
- 新增批次
- 新增通道:灰度图 shape=(1, W, H)
torch.cat 和 torch.view
TODO
torch.reshape 详解
torch.reshape 是 PyTorch 中用于改变张量形状的重要函数,它可以在不改变数据总数的情况下,重新组织张量的维度。让我详细解释这个函数的工作原理。
基本概念
torch.reshape(input, shape) 接受两个参数:
input: 要重塑的张量shape: 目标形状,可以是一个整数或一个整数元组
基本用法示例
python
# 创建一个一维张量
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
print(x.shape) # 输出: torch.Size([12])
# 将一维张量重塑为 3x4 的二维张量
y = torch.reshape(x, (3, 4))
print(y)
# 输出:
# tensor([[ 1, 2, 3, 4],
# [ 5, 6, 7, 8],
# [ 9, 10, 11, 12]])
print(y.shape) # 输出: torch.Size([3, 4])
# 将一维张量重塑为 2x3x2 的三维张量
z = torch.reshape(x, (2, 3, 2))
print(z)
# 输出:
# tensor([[[ 1, 2],
# [ 3, 4],
# [ 5, 6]],
#
# [[ 7, 8],
# [ 9, 10],
# [11, 12]]])
print(z.shape) # 输出: torch.Size([2, 3, 2])特殊用法
1. 使用 -1 自动推断维度
python
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
# 使用 -1 让 PyTorch 自动计算第一个维度的大小
y = torch.reshape(x, (-1, 4))
print(y.shape) # 输出: torch.Size([3, 4])
# 使用 -1 让 PyTorch 自动计算第二个维度的大小
z = torch.reshape(x, (3, -1))
print(z.shape) # 输出: torch.Size([3, 4])
# 在多维张量中使用 -1
a = torch.reshape(x, (2, 3, -1))
print(a.shape) # 输出: torch.Size([2, 3, 2])当使用 -1 时,PyTorch 会根据张量的总元素数和其他指定维度的大小,自动计算该维度的大小。注意,在一个形状中只能使用一个 -1。
2. 使用 0 保留原始维度
python
x = torch.zeros(2, 3, 4)
print(x.shape) # 输出: torch.Size([2, 3, 4])
# 使用 0 保留某些原始维度的大小
y = torch.reshape(x, (0, 6)) # 保留第一个维度为2,计算第二个维度为6
print(y.shape) # 输出: torch.Size([2, 6])
z = torch.reshape(x, (4, 0, 2)) # 计算第一个维度为4,保留第二个维度为3,指定第三个维度为2
print(z.shape) # 输出: torch.Size([4, 3, 2])reshape 与 view 的区别
在 PyTorch 中,reshape 和 view 都可以用来改变张量的形状,但它们有一些重要的区别:
- 内存连续性要求:
view要求张量在内存中是连续的(contiguous)reshape不要求张量在内存中是连续的,它会自动处理非连续的情况
python
x = torch.tensor([1, 2, 3, 4, 5, 6])
y = x.view(2, 3) # 正常工作
# 创建一个非连续的张量
z = torch.tensor([[1, 2, 3], [4, 5, 6]])
w = z.transpose(0, 1) # 转置使张量变得非连续
# view 会报错,因为 w 不是连续的
try:
w.view(3, 2)
except RuntimeError as e:
print(e) # 输出: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
# reshape 可以正常工作
v = w.reshape(3, 2)
print(v)
# 输出:
# tensor([[1, 4],
# [2, 5],
# [3, 6]])- 内存共享:
- 如果可能,
view会返回原始张量的视图(共享内存) reshape可能会返回视图或副本,取决于内存布局
- 如果可能,
python
x = torch.tensor([1, 2, 3, 4, 5, 6])
y = x.view(2, 3)
z = x.reshape(2, 3)
# 修改原始张量
x[0] = 10
# view 和 reshape 的结果都会受到影响(因为它们共享内存)
print(y)
# 输出:
# tensor([[10, 2, 3],
# [ 4, 5, 6]])
print(z)
# 输出:
# tensor([[10, 2, 3],
# [ 4, 5, 6]])实际应用场景
1. 图像数据处理
python
# 假设我们有一个展平的图像数据
flat_image = torch.randn(784) # 28x28 图像展平为 784 个元素
# 将其重塑为图像的原始形状
image = torch.reshape(flat_image, (28, 28))
print(image.shape) # 输出: torch.Size([28, 28])
# 或者重塑为带通道的图像(假设是灰度图)
image_with_channel = torch.reshape(flat_image, (1, 28, 28))
print(image_with_channel.shape) # 输出: torch.Size([1, 28, 28])2. 批处理数据准备
python
# 假设我们有多个样本的特征
features = torch.randn(100, 64) # 100个样本,每个样本64个特征
# 将其重塑为适合RNN输入的形状 (sequence_length, batch_size, input_size)
rnn_input = torch.reshape(features, (10, 10, 64))
print(rnn_input.shape) # 输出: torch.Size([10, 10, 64])3. 卷积神经网络中的特征图处理
python
# 假设我们有一个卷积层的输出
conv_output = torch.randn(16, 32, 8, 8) # batch_size=16, channels=32, height=8, width=8
# 将其重塑为全连接层的输入
fc_input = torch.reshape(conv_output, (16, -1)) # 保持批次大小,展平其他维度
print(fc_input.shape) # 输出: torch.Size([16, 2048]) # 32*8*8=2048注意事项
- 元素总数必须匹配:重塑后的张量必须包含与原始张量相同数量的元素。
python
x = torch.tensor([1, 2, 3, 4, 5, 6])
# 这会报错,因为 2*2=4 ≠ 6
try:
torch.reshape(x, (2, 2))
except RuntimeError as e:
print(e) # 输出: shape '[2, 2]' is invalid for input of size 6- 内存考虑:如果
reshape必须创建副本(因为张量不是连续的),那么修改原始张量不会影响重塑后的张量,反之亦然。
python
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = x.transpose(0, 1) # 非连续张量
z = y.reshape(3, 2) # 可能会创建副本
# 修改原始张量
x[0, 0] = 10
# 检查 z 是否受到影响
print(z)
# 输出可能是:
# tensor([[1, 4],
# [2, 5],
# [3, 6]])
# 注意 z[0,0] 仍然是 1,而不是 10,因为 reshape 可能创建了副本总结
torch.reshape 是一个灵活的工具,用于改变张量的形状,而不改变数据的总数。它比 view 更灵活,因为它可以处理非连续的张量,但这也意味着它有时会创建副本而不是视图。在实际应用中,reshape 常用于:
- 数据预处理,如将展平的图像数据重塑为图像的原始形状
- 准备批处理数据,使其适合特定模型的输入要求
- 在神经网络的不同层之间转换数据形状
理解 reshape 的工作原理对于有效地使用 PyTorch 进行深度学习开发非常重要。
torch.stack 详解
torch.stack 是 PyTorch 中用于沿新维度连接一系列张量的重要函数。与 torch.cat(concatenate)不同,stack 会在一个新的维度上堆叠张量,而不是在现有维度上连接它们。让我详细解释这个函数的工作原理。
基本概念
torch.stack(tensors, dim=0) 接受两个主要参数:
tensors: 要堆叠的张量序列(列表或元组)dim: 沿哪个维度插入并堆叠张量(默认为 0)
基本用法示例
python
# 创建几个形状相同的张量
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
c = torch.tensor([7, 8, 9])
# 沿新维度0堆叠这些张量
stacked_dim0 = torch.stack([a, b, c], dim=0)
print(stacked_dim0)
# 输出:
# tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])
print(stacked_dim0.shape) # 输出: torch.Size([3, 3])
# 沿新维度1堆叠这些张量
stacked_dim1 = torch.stack([a, b, c], dim=1)
print(stacked_dim1)
# 输出:
# tensor([[1, 4, 7],
# [2, 5, 8],
# [3, 6, 9]])
print(stacked_dim1.shape) # 输出: torch.Size([3, 3])stack 与 cat 的区别
理解 stack 和 cat 的区别非常重要:
python
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# 使用 stack
stacked = torch.stack([a, b], dim=0)
print(stacked)
# 输出:
# tensor([[1, 2, 3],
# [4, 5, 6]])
print(stacked.shape) # 输出: torch.Size([2, 3])
# 使用 cat
concatenated = torch.cat([a, b], dim=0)
print(concatenated)
# 输出: tensor([1, 2, 3, 4, 5, 6])
print(concatenated.shape) # 输出: torch.Size([6])stack在新维度上堆叠张量,增加了维度数量cat在现有维度上连接张量,不增加维度数量
不同维度的堆叠示例
1. 一维张量的堆叠
python
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# 沿维度0堆叠
stacked_dim0 = torch.stack([a, b], dim=0)
print(stacked_dim0)
# 输出:
# tensor([[1, 2, 3],
# [4, 5, 6]])
print(stacked_dim0.shape) # 输出: torch.Size([2, 3])
# 沿维度1堆叠
stacked_dim1 = torch.stack([a, b], dim=1)
print(stacked_dim1)
# 输出:
# tensor([[1, 4],
# [2, 5],
# [3, 6]])
print(stacked_dim1.shape) # 输出: torch.Size([3, 2])2. 二维张量的堆叠
python
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])
# 沿维度0堆叠
stacked_dim0 = torch.stack([a, b], dim=0)
print(stacked_dim0)
# 输出:
# tensor([[[1, 2],
# [3, 4]],
#
# [[5, 6],
# [7, 8]]])
print(stacked_dim0.shape) # 输出: torch.Size([2, 2, 2])
# 沿维度1堆叠
stacked_dim1 = torch.stack([a, b], dim=1)
print(stacked_dim1)
# 输出:
# tensor([[[1, 2],
# [5, 6]],
#
# [[3, 4],
# [7, 8]]])
print(stacked_dim1.shape) # 输出: torch.Size([2, 2, 2])
# 沿维度2堆叠
stacked_dim2 = torch.stack([a, b], dim=2)
print(stacked_dim2)
# 输出:
# tensor([[[1, 5],
# [2, 6]],
#
# [[3, 7],
# [4, 8]]])
print(stacked_dim2.shape) # 输出: torch.Size([2, 2, 2])使用负数作为维度参数
与其他 PyTorch 函数一样,stack 也支持负数作为维度参数:
python
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
c = torch.tensor([7, 8, 9])
# 使用 dim=-1(最后一个维度)
stacked_neg1 = torch.stack([a, b, c], dim=-1)
print(stacked_neg1)
# 输出:
# tensor([[1, 4, 7],
# [2, 5, 8],
# [3, 6, 9]])
print(stacked_neg1.shape) # 输出: torch.Size([3, 3])
# 使用 dim=-2(倒数第二个维度)
stacked_neg2 = torch.stack([a, b, c], dim=-2)
print(stacked_neg2)
# 输出:
# tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])
print(stacked_neg2.shape) # 输出: torch.Size([3, 3])实际应用场景
1. 创建批次数据
python
# 假设我们有多个样本,每个样本是一个特征向量
sample1 = torch.tensor([1.0, 2.0, 3.0])
sample2 = torch.tensor([4.0, 5.0, 6.0])
sample3 = torch.tensor([7.0, 8.0, 9.0])
# 使用 stack 创建批次数据
batch = torch.stack([sample1, sample2, sample3], dim=0)
print(batch.shape) # 输出: torch.Size([3, 3]) # batch_size=3, feature_size=32. 序列数据处理
python
# 假设我们有一个序列的多个时间步
time_step1 = torch.tensor([[1, 2], [3, 4]]) # 假设是2个特征在时间步1
time_step2 = torch.tensor([[5, 6], [7, 8]]) # 假设是2个特征在时间步2
time_step3 = torch.tensor([[9, 10], [11, 12]]) # 假设是2个特征在时间步3
# 使用 stack 创建序列数据
sequence = torch.stack([time_step1, time_step2, time_step3], dim=0)
print(sequence.shape) # 输出: torch.Size([3, 2, 2]) # sequence_length=3, batch_size=2, feature_size=23. 多通道图像处理
python
# 假设我们有图像的多个通道(如RGB)
red_channel = torch.tensor([[1, 2], [3, 4]]) # 红色通道
green_channel = torch.tensor([[5, 6], [7, 8]]) # 绿色通道
blue_channel = torch.tensor([[9, 10], [11, 12]]) # 蓝色通道
# 使用 stack 创建多通道图像
rgb_image = torch.stack([red_channel, green_channel, blue_channel], dim=0)
print(rgb_image.shape) # 输出: torch.Size([3, 2, 2]) # channels=3, height=2, width=24. 模型集成
python
# 假设我们有多个模型的预测结果
model1_pred = torch.tensor([0.1, 0.9, 0.2]) # 模型1的预测
model2_pred = torch.tensor([0.3, 0.7, 0.4]) # 模型2的预测
model3_pred = torch.tensor([0.5, 0.5, 0.6]) # 模型3的预测
# 使用 stack 组合预测结果
ensemble_preds = torch.stack([model1_pred, model2_pred, model3_pred], dim=0)
print(ensemble_preds.shape) # 输出: torch.Size([3, 3]) # num_models=3, num_classes=3
# 计算平均预测
avg_pred = torch.mean(ensemble_preds, dim=0)
print(avg_pred) # 输出: tensor([0.3000, 0.7000, 0.4000])注意事项
- 形状要求:所有要堆叠的张量必须具有相同的形状。
python
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5]) # 形状与a不同
try:
torch.stack([a, b], dim=0)
except RuntimeError as e:
print(e) # 输出: stack expects each tensor to be equal size, but got [3] and [2] at entry 0 in the list- 维度范围:
dim参数必须在[-input.dim() - 1, input.dim() + 1)范围内。
python
a = torch.tensor([1, 2, 3])
try:
torch.stack([a], dim=2) # 超出范围
except RuntimeError as e:
print(e) # 输出: dim 2 out of range for tensor of dimension 1- 空列表:不能堆叠空列表。
python
try:
torch.stack([], dim=0)
except RuntimeError as e:
print(e) # 输出: stack expects a non-empty TensorList与其他函数的比较
stack vs cat
python
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# stack 增加维度
stacked = torch.stack([a, b], dim=0)
print(stacked.shape) # 输出: torch.Size([2, 3])
# cat 不增加维度
concatenated = torch.cat([a, b], dim=0)
print(concatenated.shape) # 输出: torch.Size([6])stack vs unsqueeze + cat
实际上,stack 可以被看作是先对每个张量进行 unsqueeze,然后再 cat:
python
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# 使用 stack
stacked = torch.stack([a, b], dim=0)
print(stacked)
# 输出:
# tensor([[1, 2, 3],
# [4, 5, 6]])
# 使用 unsqueeze + cat 实现相同效果
unsqueezed_a = torch.unsqueeze(a, 0) # 在维度0增加一个维度
unsqueezed_b = torch.unsqueeze(b, 0) # 在维度0增加一个维度
manual_stack = torch.cat([unsqueezed_a, unsqueezed_b], dim=0)
print(manual_stack)
# 输出:
# tensor([[1, 2, 3],
# [4, 5, 6]])总结
torch.stack 是一个非常有用的函数,用于沿新维度堆叠一系列形状相同的张量。它的主要特点包括:
- 增加维度:与
cat不同,stack会增加一个新的维度 - 形状要求:所有输入张量必须具有相同的形状
- 灵活性:可以指定沿哪个维度堆叠张量
- 广泛应用:在创建批次数据、处理序列数据、多通道图像处理和模型集成等场景中非常有用
理解 stack 的工作原理对于有效地使用 PyTorch 进行深度学习开发非常重要,特别是在处理需要组合多个张量的场景时。
torch.vstack 和 torch.hstack 详解
torch.vstack 和 torch.hstack 是 PyTorch 中用于垂直和水平堆叠张量的便捷函数。它们实际上是 torch.cat 的特例,专门用于在特定维度上连接张量。让我详细解释这两个函数的工作原理。
torch.vstack (垂直堆叠)
torch.vstack 用于垂直堆叠张量,即在第一个维度(维度 0)上连接张量。它要求所有输入张量的形状除了第一个维度外必须相同。
基本用法
python
# 一维张量的垂直堆叠
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
vstacked = torch.vstack([a, b])
print(vstacked)
# 输出:
# tensor([[1, 2, 3],
# [4, 5, 6]])
print(vstacked.shape) # 输出: torch.Size([2, 3])二维张量的垂直堆叠
python
# 二维张量的垂直堆叠
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([[7, 8, 9], [10, 11, 12]])
vstacked = torch.vstack([a, b])
print(vstacked)
# 输出:
# tensor([[ 1, 2, 3],
# [ 4, 5, 6],
# [ 7, 8, 9],
# [10, 11, 12]])
print(vstacked.shape) # 输出: torch.Size([4, 3])与 torch.cat 的关系
torch.vstack(tensors) 等价于 torch.cat(tensors, dim=0):
python
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# 使用 vstack
vstack_result = torch.vstack([a, b])
# 使用 cat 实现相同效果
cat_result = torch.cat([a, b], dim=0)
print(torch.equal(vstack_result, cat_result)) # 输出: Truetorch.hstack (水平堆叠)
torch.hstack 用于水平堆叠张量,即在最后一个维度上连接张量。对于一维张量,它相当于在维度 0 上连接;对于二维或更高维张量,它相当于在维度 1 上连接。
一维张量的水平堆叠
python
# 一维张量的水平堆叠
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
hstacked = torch.hstack([a, b])
print(hstacked)
# 输出: tensor([1, 2, 3, 4, 5, 6])
print(hstacked.shape) # 输出: torch.Size([6])二维张量的水平堆叠
python
# 二维张量的水平堆叠
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([[7, 8], [9, 10]])
hstacked = torch.hstack([a, b])
print(hstacked)
# 输出:
# tensor([[ 1, 2, 3, 7, 8],
# [ 4, 5, 6, 9, 10]])
print(hstacked.shape) # 输出: torch.Size([2, 5])与 torch.cat 的关系
对于一维张量,torch.hstack(tensors) 等价于 torch.cat(tensors, dim=0):
python
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# 使用 hstack
hstack_result = torch.hstack([a, b])
# 使用 cat 实现相同效果
cat_result = torch.cat([a, b], dim=0)
print(torch.equal(hstack_result, cat_result)) # 输出: True对于二维或更高维张量,torch.hstack(tensors) 等价于 torch.cat(tensors, dim=1):
python
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([[7, 8], [9, 10]])
# 使用 hstack
hstack_result = torch.hstack([a, b])
# 使用 cat 实现相同效果
cat_result = torch.cat([a, b], dim=1)
print(torch.equal(hstack_result, cat_result)) # 输出: True高维张量的堆叠
三维张量的垂直堆叠
python
# 三维张量的垂直堆叠
a = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
b = torch.tensor([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])
vstacked = torch.vstack([a, b])
print(vstacked.shape) # 输出: torch.Size([4, 2, 2])三维张量的水平堆叠
python
# 三维张量的水平堆叠
a = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
b = torch.tensor([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])
hstacked = torch.hstack([a, b])
print(hstacked.shape) # 输出: torch.Size([2, 4, 2])实际应用场景
1. 数据集合并
python
# 假设我们有多个批次的数据
batch1 = torch.randn(32, 100) # 32个样本,每个样本100个特征
batch2 = torch.randn(32, 100) # 32个样本,每个样本100个特征
batch3 = torch.randn(32, 100) # 32个样本,每个样本100个特征
# 使用 vstack 合并批次
combined_batch = torch.vstack([batch1, batch2, batch3])
print(combined_batch.shape) # 输出: torch.Size([96, 100]) # 96个样本,每个样本100个特征2. 特征拼接
python
# 假设我们有两组特征
features1 = torch.randn(100, 50) # 100个样本,每个样本50个特征
features2 = torch.randn(100, 30) # 100个样本,每个样本30个特征
# 使用 hstack 拼接特征
combined_features = torch.hstack([features1, features2])
print(combined_features.shape) # 输出: torch.Size([100, 80]) # 100个样本,每个样本80个特征3. 图像处理
python
# 假设我们有多个图像行
row1 = torch.ones(3, 64) # 3个通道,64个像素宽
row2 = torch.zeros(3, 64) # 3个通道,64个像素宽
row3 = torch.ones(3, 64) * 0.5 # 3个通道,64个像素宽
# 使用 vstack 垂直堆叠图像行
image = torch.vstack([row1, row2, row3])
print(image.shape) # 输出: torch.Size([9, 64]) # 9个通道(3*3),64个像素宽
# 假设我们有多个图像列
col1 = torch.ones(64, 3) # 64个像素高,3个通道
col2 = torch.zeros(64, 3) # 64个像素高,3个通道
col3 = torch.ones(64, 3) * 0.5 # 64个像素高,3个通道
# 使用 hstack 水平堆叠图像列
image = torch.hstack([col1, col2, col3])
print(image.shape) # 输出: torch.Size([64, 9]) # 64个像素高,9个通道(3*3)注意事项
1. 形状兼容性
对于 torch.vstack,除了第一个维度外,所有其他维度的大小必须相同:
python
a = torch.tensor([[1, 2, 3], [4, 5, 6]]) # 形状: [2, 3]
b = torch.tensor([[7, 8], [9, 10]]) # 形状: [2, 2]
try:
torch.vstack([a, b])
except RuntimeError as e:
print(e) # 输出: Sizes of tensors must match except in dimension 0. Expected size 3 but got size 2 for tensor number 1 in the list.对于 torch.hstack,除了最后一个维度外,所有其他维度的大小必须相同:
python
a = torch.tensor([[1, 2, 3], [4, 5, 6]]) # 形状: [2, 3]
b = torch.tensor([[7, 8, 9]]) # 形状: [1, 3]
try:
torch.hstack([a, b])
except RuntimeError as e:
print(e) # 输出: Sizes of tensors must match except in dimension 1. Expected size 2 but got size 1 for tensor number 1 in the list.2. 空列表处理
python
try:
torch.vstack([])
except RuntimeError as e:
print(e) # 输出: vstack expects a non-empty TensorList
try:
torch.hstack([])
except RuntimeError as e:
print(e) # 输出: hstack expects a non-empty TensorListvstack 和 hstack 与 stack 的比较
python
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# vstack 在维度0上连接张量
vstack_result = torch.vstack([a, b])
print(vstack_result)
# 输出:
# tensor([[1, 2, 3],
# [4, 5, 6]])
print(vstack_result.shape) # 输出: torch.Size([2, 3])
# hstack 在最后一个维度上连接张量
hstack_result = torch.hstack([a, b])
print(hstack_result)
# 输出: tensor([1, 2, 3, 4, 5, 6])
print(hstack_result.shape) # 输出: torch.Size([6])
# stack 在新维度上堆叠张量
stack_result = torch.stack([a, b])
print(stack_result)
# 输出:
# tensor([[1, 2, 3],
# [4, 5, 6]])
print(stack_result.shape) # 输出: torch.Size([2, 3])
# 虽然 vstack_result 和 stack_result 看起来相同,但它们的实现方式不同
# vstack 实际上是 cat 的特例,而 stack 会增加一个新的维度总结
torch.vstack 和 torch.hstack 是 PyTorch 中用于垂直和水平堆叠张量的便捷函数:
torch.vstack:- 在第一个维度(维度 0)上连接张量
- 等价于
torch.cat(tensors, dim=0) - 要求所有输入张量的形状除了第一个维度外必须相同
torch.hstack:- 在最后一个维度上连接张量
- 对于一维张量,等价于
torch.cat(tensors, dim=0) - 对于二维或更高维张量,等价于
torch.cat(tensors, dim=1) - 要求所有输入张量的形状除了最后一个维度外必须相同
与
torch.stack的区别:vstack和hstack是cat的特例,不增加维度数量stack会在新维度上堆叠张量,增加维度数量
这些函数在数据合并、特征拼接和图像处理等场景中非常有用,提供了直观且便捷的方式来组合张量。