PyTorch 高级索引
PyTorch 高级索引
2024年5月15日
摘要
在使用 PyTorch 的过程中,常规的访问 Tensor 的索引方式并不能够很好的适应一些复杂场景。本文将介绍一些在 PyTorch 中访问 Tensor 的高级索引方式。
索引方式
假设存在一个 2\times 4 的 2D Tensor,第一排从 0 到 3,第二排从 4 到 7:
import torch
sample = torch.arrange(8).reshape(2, 4)
# tensor([[0, 1, 2, 3],
# [4, 5, 6, 7]])
基础方式
访问元素
有两种方式可以进行基础的元素访问:
# Divide dimension by comma
element = sample[0, 1]
# tensor(1)
# C-Style
element = sample[0][1]
# tensor(1)
其次,也可以使用负数进行索引,-1
代表最后一个元素,-2
代表倒数第二个,以此类推:
element = sample[-1, -2]
element = sample[-1][-2]
# tensor(6)
访问切片
可以用 a:b
表示选取 [a, b)
的索引:
token = sample[0, 1:3]
token = sample[0][1:3]
# tensor([1, 2])
其中,当 a == b
时,当前维度切片长度为 0:
token = sample[0, 1:1]
token = sample[0][1:1]
# tensor([], dtype=torch.int64)
当 b == a + 1
时,当前维度切片长度为 1,需要注意的是,切片长度为 1 并不代表这个维度不存在了:
token = sample[0:1, 1:2]
token = sample[0:1][1:2]
# tensor([[1]])
element = sample[0, 1]
element = sample[0][1]
# tensor(1)
可以看到,使用 b == a + 1
的形式作为索引,取出的值虽然和直接用 a
做索引一致,但是维度数量能保持不变,直接使用标量 a
做索引会去除当前维度。
切片中一样可以使用负数索引代表倒数第几个元素。
当 [a, b)
超出 Tensor 的范围时,超出部分的索引将被忽略。
a
不写则为 1
,例如 :b
;b
不写则为 -1
,例如 a:
。
高级方式
间隔访问切片
除了连续访问,还能够间隔以等差数列索引来访问,格式为 start:end:step
,其将生成由 start
开始(闭区间),end
结束(开区间),以 step
为步长的等差数列列表,并将这个列表内的值作为索引实现间隔取值,当 Tensor 中不存在索引列表所需的值时,此索引将被忽略。
token = sample[0, 1:100:2]
token = sample[0][1:100:2]
# tensor([1, 3])
其中,start
和 end
可以是负数,step
不行。
其中,start
不写则为 1
;end
不写则为 -1
;step
不写则为 1
。
列表索引访问
除了使用 start:end:step
生成等差数列索引列表,我们还可以直接手动提供索引列表:
token = sample[1, [1, 3]]
token = sample[1][[1, 3]]
# tensor([5, 7])
Bool 访问
除了使用标量列表索引访问,我们还可以使用 Bool 列表来访问:
token = sample[1, [True, False, False, True]]
token = sample[1][[True, False, False, True]]
# tensor([4, 7])
slice
如果一种索引规则比较常用,我们可以创建一个 slice
对象来储存,以后用这个 slice
对象来索引:
index = slice(1, None, 2)
token = sample[1][index]
# tensor([5, 7])
slice
的规则和间隔访问切片
2024年5月15日
摘要
在使用 PyTorch 的过程中,常规的访问 Tensor 的索引方式并不能够很好的适应一些复杂场景。本文将介绍一些在 PyTorch 中访问 Tensor 的高级索引方式。
索引方式
假设存在一个 2\times 4 的 2D Tensor,第一排从 0 到 3,第二排从 4 到 7:
import torch
sample = torch.arrange(8).reshape(2, 4)
# tensor([[0, 1, 2, 3],
# [4, 5, 6, 7]])
基础方式
访问元素
有两种方式可以进行基础的元素访问:
# Divide dimension by comma
element = sample[0, 1]
# tensor(1)
# C-Style
element = sample[0][1]
# tensor(1)
其次,也可以使用负数进行索引,-1
代表最后一个元素,-2
代表倒数第二个,以此类推:
element = sample[-1, -2]
element = sample[-1][-2]
# tensor(6)
访问切片
可以用 a:b
表示选取 [a, b)
的索引:
token = sample[0, 1:3]
token = sample[0][1:3]
# tensor([1, 2])
其中,当 a == b
时,当前维度切片长度为 0:
token = sample[0, 1:1]
token = sample[0][1:1]
# tensor([], dtype=torch.int64)
当 b == a + 1
时,当前维度切片长度为 1,需要注意的是,切片长度为 1 并不代表这个维度不存在了:
token = sample[0:1, 1:2]
token = sample[0:1][1:2]
# tensor([[1]])
element = sample[0, 1]
element = sample[0][1]
# tensor(1)
可以看到,使用 b == a + 1
的形式作为索引,取出的值虽然和直接用 a
做索引一致,但是维度数量能保持不变,直接使用标量 a
做索引会去除当前维度。
切片中一样可以使用负数索引代表倒数第几个元素。
当 [a, b)
超出 Tensor 的范围时,超出部分的索引将被忽略。
a
不写则为 1
,例如 :b
;b
不写则为 -1
,例如 a:
。
高级方式
间隔访问切片
除了连续访问,还能够间隔以等差数列索引来访问,格式为 start:end:step
,其将生成由 start
开始(闭区间),end
结束(开区间),以 step
为步长的等差数列列表,并将这个列表内的值作为索引实现间隔取值,当 Tensor 中不存在索引列表所需的值时,此索引将被忽略。
token = sample[0, 1:100:2]
token = sample[0][1:100:2]
# tensor([1, 3])
其中,start
和 end
可以是负数,step
不行。
其中,start
不写则为 1
;end
不写则为 -1
;step
不写则为 1
。
列表索引访问
除了使用 start:end:step
生成等差数列索引列表,我们还可以直接手动提供索引列表:
token = sample[1, [1, 3]]
token = sample[1][[1, 3]]
# tensor([5, 7])
Bool 访问
除了使用标量列表索引访问,我们还可以使用 Bool 列表来访问:
token = sample[1, [True, False, False, True]]
token = sample[1][[True, False, False, True]]
# tensor([4, 7])
slice
如果一种索引规则比较常用,我们可以创建一个 slice
对象来储存,以后用这个 slice
对象来索引:
index = slice(1, None, 2)
token = sample[1][index]
# tensor([5, 7])
slice
的规则和间隔访问切片一致。