torch.split()-程序员宅基地

技术标签: python  深度学习  pytorch  

torch.split()

官网链接:https://pytorch.org/docs/stable/torch.html
官网解释:Splits the tensor into chunks.——PyTorch中用于分割张量的函数。
作用:将一个多维张量分割成多个张量。

If split_size_or_sections is an integer type, then tensor will be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by split_size.
If split_size_or_sections is a list, then tensor will be split into len(split_size_or_sections) chunks with sizes in dim according to split_size_or_sections.

定义是:

torch.split(tensor, split_size_or_sections, dim=0)
参数解释:
- tensor:要分割的输入张量
- split_size_or_sections:
    - 如果是一个整数,则表示分割成每个张量里包含split_size_or_sections个张量,而不是分成split_size_or_sections个
    - 如果是一个列表,则表示对dim维度进行分割,分割为指定大小的张量
- dim:沿着哪个维度进行分割,默认是dim=0,第一维

例1,有这样一个3D张量:

# 生成大小为(2, 4, 8)的随机张量
random_tensor = torch.rand(2, 4, 8)
tensor([[[0.8644, 0.0177, 0.7970, 0.7016, 0.4632, 0.9147, 0.8053, 0.4261],
         [0.0450, 0.9565, 0.8375, 0.9347, 0.8196, 0.8751, 0.4523, 0.9660],
         [0.8350, 0.1566, 0.8367, 0.0345, 0.6804, 0.7308, 0.8989, 0.9943],
         [0.2294, 0.2361, 0.1537, 0.9923, 0.7680, 0.0824, 0.3566, 0.6546]],

        [[0.2106, 0.2736, 0.8687, 0.4333, 0.4102, 0.4820, 0.7104, 0.7776],
         [0.6558, 0.1098, 0.4384, 0.4891, 0.3681, 0.7371, 0.2555, 0.2687],
         [0.4181, 0.6644, 0.1816, 0.2111, 0.8317, 0.4180, 0.7011, 0.7221],
         [0.1922, 0.4405, 0.6633, 0.5787, 0.9912, 0.0370, 0.9894, 0.8748]]])

我们可以这样分割:

  1. torch.split(random_tensor, 2, dim=1) :分割第二维(dim=1)
split_2 = torch.split(random_tensor, 2, dim=1)  # 返回一个元组 tuple
split_2
# (tensor([[[0.8644, 0.0177, 0.7970, 0.7016, 0.4632, 0.9147, 0.8053, 0.4261],
#           [0.0450, 0.9565, 0.8375, 0.9347, 0.8196, 0.8751, 0.4523, 0.9660]],
 
#          [[0.2106, 0.2736, 0.8687, 0.4333, 0.4102, 0.4820, 0.7104, 0.7776],
#           [0.6558, 0.1098, 0.4384, 0.4891, 0.3681, 0.7371, 0.2555, 0.2687]]]),
#  tensor([[[0.8350, 0.1566, 0.8367, 0.0345, 0.6804, 0.7308, 0.8989, 0.9943],
#           [0.2294, 0.2361, 0.1537, 0.9923, 0.7680, 0.0824, 0.3566, 0.6546]],
 
#          [[0.4181, 0.6644, 0.1816, 0.2111, 0.8317, 0.4180, 0.7011, 0.7221],
#           [0.1922, 0.4405, 0.6633, 0.5787, 0.9912, 0.0370, 0.9894, 0.8748]]]))
len(split_2)   # 2   
split_2[0]
# tensor([[[0.8644, 0.0177, 0.7970, 0.7016, 0.4632, 0.9147, 0.8053, 0.4261],
#          [0.0450, 0.9565, 0.8375, 0.9347, 0.8196, 0.8751, 0.4523, 0.9660]],

#         [[0.2106, 0.2736, 0.8687, 0.4333, 0.4102, 0.4820, 0.7104, 0.7776],
#          [0.6558, 0.1098, 0.4384, 0.4891, 0.3681, 0.7371, 0.2555, 0.2687]]])
split_2[0].size()  # torch.Size([2, 2, 8])
split_2[1]
# tensor([[[0.8350, 0.1566, 0.8367, 0.0345, 0.6804, 0.7308, 0.8989, 0.9943],
#          [0.2294, 0.2361, 0.1537, 0.9923, 0.7680, 0.0824, 0.3566, 0.6546]],

#         [[0.4181, 0.6644, 0.1816, 0.2111, 0.8317, 0.4180, 0.7011, 0.7221],
#          [0.1922, 0.4405, 0.6633, 0.5787, 0.9912, 0.0370, 0.9894, 0.8748]]])
split_2[1].size()  # torch.Size([2, 2, 8])
  1. torch.split(random_tensor, 3, dim=1) 与上例对比
split_3 = torch.split(random_tensor, 3, dim=1)
split_3
# (tensor([[[0.8644, 0.0177, 0.7970, 0.7016, 0.4632, 0.9147, 0.8053, 0.4261],
#           [0.0450, 0.9565, 0.8375, 0.9347, 0.8196, 0.8751, 0.4523, 0.9660],
#           [0.8350, 0.1566, 0.8367, 0.0345, 0.6804, 0.7308, 0.8989, 0.9943]],
 
#          [[0.2106, 0.2736, 0.8687, 0.4333, 0.4102, 0.4820, 0.7104, 0.7776],
#           [0.6558, 0.1098, 0.4384, 0.4891, 0.3681, 0.7371, 0.2555, 0.2687],
#           [0.4181, 0.6644, 0.1816, 0.2111, 0.8317, 0.4180, 0.7011, 0.7221]]]),
#  tensor([[[0.2294, 0.2361, 0.1537, 0.9923, 0.7680, 0.0824, 0.3566, 0.6546]],
 
#          [[0.1922, 0.4405, 0.6633, 0.5787, 0.9912, 0.0370, 0.9894, 0.8748]]]))
len(split_3)  # 2   1维长度为4,第一次取3,第二次也应取3,但是剩余长度不够,所以取1
split_3[0]    # torch.Size([2, 3, 8])
# tensor([[[0.8644, 0.0177, 0.7970, 0.7016, 0.4632, 0.9147, 0.8053, 0.4261],
#          [0.0450, 0.9565, 0.8375, 0.9347, 0.8196, 0.8751, 0.4523, 0.9660],
#          [0.8350, 0.1566, 0.8367, 0.0345, 0.6804, 0.7308, 0.8989, 0.9943]],

#         [[0.2106, 0.2736, 0.8687, 0.4333, 0.4102, 0.4820, 0.7104, 0.7776],
#          [0.6558, 0.1098, 0.4384, 0.4891, 0.3681, 0.7371, 0.2555, 0.2687],
#          [0.4181, 0.6644, 0.1816, 0.2111, 0.8317, 0.4180, 0.7011, 0.7221]]])
split_3[1]   # torch.Size([2, 1, 8])
# tensor([[[0.2294, 0.2361, 0.1537, 0.9923, 0.7680, 0.0824, 0.3566, 0.6546]],

#         [[0.1922, 0.4405, 0.6633, 0.5787, 0.9912, 0.0370, 0.9894, 0.8748]]])
  1. torch.split(random_tensor, [1, 3], dim=1)
split_1_3 = torch.split(random_tensor, [1, 3], dim=1) # 列表中数值总和必须与原维度数值相等
split_1_3
# (tensor([[[0.8644, 0.0177, 0.7970, 0.7016, 0.4632, 0.9147, 0.8053, 0.4261]],
 
#          [[0.2106, 0.2736, 0.8687, 0.4333, 0.4102, 0.4820, 0.7104, 0.7776]]]),
#  tensor([[[0.0450, 0.9565, 0.8375, 0.9347, 0.8196, 0.8751, 0.4523, 0.9660],
#           [0.8350, 0.1566, 0.8367, 0.0345, 0.6804, 0.7308, 0.8989, 0.9943],
#           [0.2294, 0.2361, 0.1537, 0.9923, 0.7680, 0.0824, 0.3566, 0.6546]],
 
#          [[0.6558, 0.1098, 0.4384, 0.4891, 0.3681, 0.7371, 0.2555, 0.2687],
#           [0.4181, 0.6644, 0.1816, 0.2111, 0.8317, 0.4180, 0.7011, 0.7221],
#           [0.1922, 0.4405, 0.6633, 0.5787, 0.9912, 0.0370, 0.9894, 0.8748]]]))
len(split_1_3)  #2
split_1_3[0]  # torch.Size([2, 1, 8])
# tensor([[[0.8644, 0.0177, 0.7970, 0.7016, 0.4632, 0.9147, 0.8053, 0.4261]],

#         [[0.2106, 0.2736, 0.8687, 0.4333, 0.4102, 0.4820, 0.7104, 0.7776]]])
split_1_3[1]   # torch.Size([2, 3, 8])
# tensor([[[0.0450, 0.9565, 0.8375, 0.9347, 0.8196, 0.8751, 0.4523, 0.9660],
#          [0.8350, 0.1566, 0.8367, 0.0345, 0.6804, 0.7308, 0.8989, 0.9943],
#          [0.2294, 0.2361, 0.1537, 0.9923, 0.7680, 0.0824, 0.3566, 0.6546]],

#         [[0.6558, 0.1098, 0.4384, 0.4891, 0.3681, 0.7371, 0.2555, 0.2687],
#          [0.4181, 0.6644, 0.1816, 0.2111, 0.8317, 0.4180, 0.7011, 0.7221],
#          [0.1922, 0.4405, 0.6633, 0.5787, 0.9912, 0.0370, 0.9894, 0.8748]]])

例2,有这样一个3D张量:

random_tensor = torch.rand(2, 2, 3)
tensor([[[0.0445, 0.0481, 0.1199],
         [0.2850, 0.1215, 0.0584]],

        [[0.1323, 0.4458, 0.0899],
         [0.3338, 0.3624, 0.7511]]])
  1. torch.split(random_tensor, 2, dim=1):分割第二维(dim=1),第一次取两个张量,数据取完。这里本身就是两个张量,所以还是返回自身
split_2 = torch.split(random_tensor, 2, dim=1)   # 返回元组
split_2
# (tensor([[[0.0445, 0.0481, 0.1199],
#           [0.2850, 0.1215, 0.0584]],
 
#          [[0.1323, 0.4458, 0.0899],
#           [0.3338, 0.3624, 0.7511]]]),)
len(split_2)   # 1
split_2[0]     # torch.Size([2, 2, 3])
# tensor([[[0.0445, 0.0481, 0.1199],
#          [0.2850, 0.1215, 0.0584]],

#         [[0.1323, 0.4458, 0.0899],
#          [0.3338, 0.3624, 0.7511]]])
split_2[1]   # 报错
  1. torch.split(random_tensor, [1, 2], dim=2):沿第三维(dim=2)分割
split_1_2 = torch.split(random_tensor, [1, 2], dim=2) # 返回元组
split_1_2
# (tensor([[[0.0445],
#           [0.2850]],
 
#          [[0.1323],
#           [0.3338]]]),
#  tensor([[[0.0481, 0.1199],
#           [0.1215, 0.0584]],
 
#          [[0.4458, 0.0899],
#           [0.3624, 0.7511]]]))
len(split_1_2)  # 2
split_1_2[0]    # torch.Size([2, 2, 1])
# tensor([[[0.0445],
#          [0.2850]],

#         [[0.1323],
#          [0.3338]]])
split_1_2[1]   # torch.Size([2, 2, 2])
# tensor([[[0.0481, 0.1199],
#          [0.1215, 0.0584]],

#         [[0.4458, 0.0899],
#          [0.3624, 0.7511]]])

所以,torch.split()是一个很有用的函数,可以轻松地将张量分割成任意形状和大小的张量列表,以用于后续处理。

Tips:
感谢@qq_42798074指正
感谢@qq_41720271指正

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_39398917/article/details/131075177

智能推荐

前端开发之vue-grid-layout的使用和实例-程序员宅基地

文章浏览阅读1.1w次,点赞7次,收藏34次。vue-grid-layout的使用、实例、遇到的问题和解决方案_vue-grid-layout

Power Apps-上传附件控件_powerapps点击按钮上传附件-程序员宅基地

文章浏览阅读218次。然后连接一个数据源,就会在下面自动产生一个添加附件的组件。把这个控件复制粘贴到页面里,就可以单独使用来上传了。插入一个“编辑”窗体。_powerapps点击按钮上传附件

C++ 面向对象(Object-Oriented)的特征 & 构造函数& 析构函数_"object(cnofd[\"ofdrender\"])十条"-程序员宅基地

文章浏览阅读264次。(1) Abstraction (抽象)(2) Polymorphism (多态)(3) Inheritance (继承)(4) Encapsulation (封装)_"object(cnofd[\"ofdrender\"])十条"

修改node_modules源码,并保存,使用patch-package打补丁,git提交代码后,所有人可以用到修改后的_修改 node_modules-程序员宅基地

文章浏览阅读133次。删除node_modules,重新npm install看是否成功。在 package.json 文件中的 scripts 中加入。修改你的第三方库的bug等。然后目录会多出一个目录文件。_修改 node_modules

【】kali--password:su的 Authentication failure问题,&sudo passwd root输入密码时Sorry, try again._password: su: authentication failure-程序员宅基地

文章浏览阅读883次。【代码】【】kali--password:su的 Authentication failure问题,&sudo passwd root输入密码时Sorry, try again._password: su: authentication failure

整理5个优秀的微信小程序开源项目_微信小程序开源模板-程序员宅基地

文章浏览阅读1w次,点赞13次,收藏97次。整理5个优秀的微信小程序开源项目。收集了微信小程序开发过程中会使用到的资料、问题以及第三方组件库。_微信小程序开源模板

随便推点

Centos7最简搭建NFS服务器_centos7 搭建nfs server-程序员宅基地

文章浏览阅读128次。Centos7最简搭建NFS服务器_centos7 搭建nfs server

Springboot整合Mybatis-Plus使用总结(mybatis 坑补充)_mybaitis-plus ruledataobjectattributemapper' and '-程序员宅基地

文章浏览阅读1.2k次,点赞2次,收藏3次。前言mybatis在持久层框架中还是比较火的,一般项目都是基于ssm。虽然mybatis可以直接在xml中通过SQL语句操作数据库,很是灵活。但正其操作都要通过SQL语句进行,就必须写大量的xml文件,很是麻烦。mybatis-plus就很好的解决了这个问题。..._mybaitis-plus ruledataobjectattributemapper' and 'com.picc.rule.management.d

EECE 1080C / Programming for ECESummer 2022 Laboratory 4: Global Functions Practice_eece1080c-程序员宅基地

文章浏览阅读325次。EECE 1080C / Programming for ECESummer 2022Laboratory 4: Global Functions PracticePlagiarism will not be tolerated:Topics covered:function creation and call statements (emphasis on global functions)Objective:To practice program development b_eece1080c

洛谷p4777 【模板】扩展中国剩余定理-程序员宅基地

文章浏览阅读53次。被同机房早就1年前就学过的东西我现在才学,wtcl。设要求的数为\(x\)。设当前处理到第\(k\)个同余式,设\(M = LCM ^ {k - 1} _ {i - 1}\) ,前\(k - 1\)个的通解就是\(x + i * M\)。那么其实第\(k\)个来说,其实就是求一个\(y\)使得\(x + y * M ≡ a_k(mod b_k)\)转化一下就是\(y * M ...

android 退出应用没有走ondestory方法,[Android基础论]为何Activity退出之后,系统没有调用onDestroy方法?...-程序员宅基地

文章浏览阅读1.3k次。首先,问题是如何出现的?晚上复查代码,发现一个activity没有调用自己的ondestroy方法我表示非常的费解,于是我检查了下代码。发现再finish代码之后接了如下代码finish();System.exit(0);//这就是罪魁祸首为什么这样写会出现问题System.exit(0);////看一下函数的原型public static void exit (int code)//Added ..._android 手动杀死app,activity不执行ondestroy

SylixOS快问快答_select函数 导致堆栈溢出 sylixos-程序员宅基地

文章浏览阅读894次。Q: SylixOS 版权是什么形式, 是否分为<开发版税>和<运行时版税>.A: SylixOS 是开源并免费的操作系统, 支持 BSD/GPL 协议(GPL 版本暂未确定). 没有任何的运行时版税. 您可以用她来做任何 您喜欢做的项目. 也可以修改 SylixOS 的源代码, 不需要支付任何费用. 当然笔者希望您可以将使用 SylixOS 开发的项目 (不需要开源)或对 SylixOS 源码的修改及时告知笔者.需要指出: SylixOS 本身仅是笔者用来提升自己水平而开发的_select函数 导致堆栈溢出 sylixos

推荐文章

热门文章

相关标签