tensor的数据结构。tensor分为头信息区(Tensor)和存储区(Storage),信息区主要保存着tensor的形状(size)、步长(stride)、数据类型(type)等信息,而真正的数据则保存成连续数组。由于数据动辄成千上万,因此信息区元素占用内存较少,主要内存占用则取决于tensor中元素的数目,也即存储区的大小。
from __future__ import print_function
import torch as t
a = t.arange(0, 6)
a.storage()
'''
输出的结果
0
1
2
3
4
5
'''
b = a.view(2, 3)
b,b.storage()
'''
输出的结果
tensor([[0, 1, 2],
[3, 4, 5]]),
0
1
2
3
4
5
'''
# 一个对象的id值可以看作它在内存中的地址
# storage的内存地址一样,即是同一个storage
id(b.storage()) == id(a.storage())
'''
输出的结果
True
'''
# a改变,b也随之改变,因为他们共享storage
a[1] = 100
b
'''
输出的结果
tensor([[ 0, 100, 2],
[ 3, 4, 5]])
'''
c = a[2:]
c,c.storage()
'''
输出的结果
tensor([2, 3, 4, 5]),
0
100
2
3
4
5
'''
c.data_ptr(), a.data_ptr() # data_ptr返回tensor首元素的内存地址
# 可以看出相差8,这是因为2*4=8--相差两个元素,每个元素占4个字节(float)
'''
输出的结果
(61277776, 61277760)
'''
c[0] = -100 # c[0]的内存地址对应a[2]的内存地址
a
'''
输出的结果
tensor([6666, 100, -100, 3, 4, 5])
'''
d = t.LongTensor(c.storage())
d
d[0] = 6666
b
'''
输出的结果
tensor([6666, 100, -100, 3, 4, 5])
'''
# 下面4个tensor共享storage
id(a.storage()) == id(b.storage()) == id(c.storage()) == id(d.storage())
'''
输出的结果
True
'''
a.storage_offset(), c.storage_offset(), d.storage_offset()
'''
输出的结果
(0, 2, 0
'''
e = b[::2, ::2] # 隔2行/列取一个元素
id(e.storage()) == id(a.storage())
'''
输出的结果
True
'''
b.stride(), e.stride()
'''
输出的结果
((3, 1), (6, 2))
'''
e.is_contiguous()
'''
输出的结果
False
'''
普通索引可以通过只修改tensor的offset,stride和size,而不修改storage来实现