Pytorch基础:Tensor的flatten()方法
相关阅读Pytorch基础https://blog.csdn.net/weixin_45791458/category_12457644.html?spm1001.2014.3001.5482在Pytorch中flatten()是Tensor类的一个重要方法同时它也是一个torch模块中的一个函数它们的语法如下所示。Tensor.flatten(start_dim0, end_dim-1) → Tensor torch.flatten(input, start_dim0, end_dim-1) → Tensor input (Tensor) – the input tensor start_dim (int) – the first dim to flatten end_dim (int) – the last dim to flattenflatten()函数或方法用于将一个张量以特定方法展平 如果传递了参数则会将从start_dim到end_dim之间的维度展开。默认情况下将从第0维展平至最后1维。flatten()函数或方法可能返回原始张量、原始张量的视图(共享底层存储)或原始张量的副本如果没有维度被展平则返回原始张量同一个对象。如果输出张量可以视为等效地使用view()方法展平则返回视图(共享底层存储)。如果输出张量不能视为等效地使用view()方法展平则返回数据副本。张量的视图可能是一个非连续张量关于它的更多细节可以看下面的文章。Pytorch基础Tensor的连续性https://blog.csdn.net/weixin_45791458/article/details/140736700?ops_request_misc%257B%2522request%255Fid%2522%253A%2522eb4c722817c335758581a52404bb2dce%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fblog.%2522%257Drequest_ideb4c722817c335758581a52404bb2dcebiz_id0utm_mediumdistribute.pc_search_result.none-task-blog-2~blog~first_rank_ecpm_v1~rank_v31_ecpm-2-140736700-null-null.nonecaseutm_term%E9%9D%9E%E8%BF%9E%E7%BB%ADspm1018.2226.3001.4450关于view()方法的更多细节可以看下面的文章。Pytorch基础Tensor的连续性https://blog.csdn.net/weixin_45791458/article/details/140736723?sharetypeblogdetailsharerId140736723sharereferPCsharesourceweixin_45791458spm1011.2480.3001.8118下面以三个例子分别说明上述三种情况# 例1 import torch input_tensor torch.tensor([[1, 2], [3, 4]]) flattened_tensor torch.flatten(input_tensor, start_dim0, end_dim0) print(input_tensor) print(flattened_tensor) print(id(flattened_tensor) id(input_tensor)) # 查看是否是同一个张量对象 print(flattened_tensor.storage().data_ptr() input_tensor.storage().data_ptr()) # 查看是否共享底层存储 输出 tensor([[1, 2], [3, 4]]) tensor([[1, 2], [3, 4]]) True True# 例2 import torch input_tensor torch.tensor([[1, 2], [3, 4]]) flattened_tensor torch.flatten(input_tensor, start_dim0, end_dim1) print(input_tensor) print(flattened_tensor) print(id(flattened_tensor) id(input_tensor)) # 查看是否是同一个张量对象 print(flattened_tensor.storage().data_ptr() input_tensor.storage().data_ptr()) # 查看是否共享底层存储 输出 tensor([[1, 2], [3, 4]]) tensor([1, 2, 3, 4]) False True# 例3 import torch input_tensor torch.tensor([[1, 2], [3, 4]]).transpose(0, 1) flattened_tensor torch.flatten(input_tensor, start_dim0, end_dim1) print(input_tensor) print(flattened_tensor) print(id(flattened_tensor) id(input_tensor)) # 查看是否是同一个张量对象 print(flattened_tensor.storage().data_ptr() input_tensor.storage().data_ptr()) # 查看是否共享底层存储 输出 tensor([[1, 3], [2, 4]]) tensor([1, 3, 2, 4]) False False