pytorch 入门(1): pytorch库基本用法举例_知者智者的博客-程序员宅基地

技术标签: 机器学习-深度学习  深度学习  pytorch  Python  

这一系列文章是对pytorch 入门教程的翻译和学习总结。英文原版可以从以下网址获得:

https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html

 

目标:

  • 理解Pytorch Tensor 库和神经网络
  • 训练一个小的神经网络来分类图片。

本手册假定你对numpy库有一个基本的了解。

 

注意:

确保你的测试环境已经安装了torch 和 torchvision 包

 

Pytorch是什么?

Pytoch是一个基于Python的科学计算的包,目标使用者有两类:

  • 替换NumPy以达到使用GPU算力的目的
  • 一个深度学习的研究平台,能够提供最大的灵活性和速度

1 Tensors(张量)

Tensors 类似于 NumPy库中的ndarrays, 而且,Tensor还能够使用GPU做计算加速。

使用torch,首先import torch包,如下

 

$ python3

from __future__ import print_function

import torch

 

1.1 创建一个未初始化的 5x3 矩阵:

x = torch.empty(5, 3)

print(x)

输出:

tensor([[1.9160e-11, 0.0000e+00, 8.5305e+02],
        [1.0894e+27, 8.9683e-44, 0.0000e+00],
        [4.4842e-44, 0.0000e+00, 1.9160e-11],
        [0.0000e+00, 1.7857e+05, 3.9586e+12],
        [1.3452e-43, 0.0000e+00, 4.4842e-44]])

注意:

创建一个未初始化的矩阵后,矩阵中的值是不确定的, 使用分配内存时,内存中的值作为初始值。

 

1.2 创建一个随机初始化的矩阵

x = torch.rand(5, 3)

print(x)

输出:

tensor([[1.9160e-11, 0.0000e+00, 8.5305e+02],
        [1.0894e+27, 8.9683e-44, 0.0000e+00],
        [4.4842e-44, 0.0000e+00, 1.9160e-11],
        [0.0000e+00, 1.7857e+05, 3.9586e+12],
        [1.3452e-43, 0.0000e+00, 4.4842e-44]])

创建的矩阵是经过初始化的,初始化为随机值。

 

1.3 创建一个矩阵,初始化为0并且数据类型为long

x = torch.zeros(5, 3, dtype=torch.long)

print(x)

输出:

tensor([[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]])

 

创建一个张量并赋值

x = torch.tensor([[2.2, 1.2], [2.1, 1.5], [4.5, 3.1]])

print(x)

输出:

tensor([[2.2000, 1.2000], [2.1000, 1.5000], [4.5000, 3.1000]])

 

1.4 基于已有的张量创建新张量

使用new_ones时,新的tensor会继承原来tensor的属性(例如tdype,device等),除非用户指定了新属性。

x = x.new_ones(5, 3, dtype=torch.double) # new_* methods take in sizes

print(x)

x = torch.randn_like(x, dtype=torch.float) # override dtype!

rint(x)

输出:

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], dtype=torch.float64)

tensor([[-0.4050, -0.8701, -0.3150],
        [ 0.0908, -0.3928,  0.0118],
        [ 1.0196,  0.0511,  1.0521],
        [ 0.3389, -0.1678, -0.3757],
        [-0.2501,  0.4394,  0.2627]])

 

获取张量的大小:

print(x.size())

输出:

torch.Size([5, 3])

torch.Size()是一个元组,支持所有的元组操作。

 

1.5 Operations 运算

torch中有很多运算语法,这里看一下加法运算

x = torch.tensor([1.2, 2.0])

y = torch.tensor([2.1, 3.0])

print(x + y)

print(torch.add(x,y))

 

输出:

>>> print(x + y) tensor([3.3000, 5.0000])

>>> print(torch.add(x,y)) tensor([3.3000, 5.0000])

 

result = torch.empty(1, 2)

torch.add(x, y, out=result)

print(result)

输出:

tensor([3.3000, 5.0000])

 

在变量上加:

y.add(x)

print(y)

y.add_(x)

print(y)

输出:

>>> y.add(x) tensor([3.3000, 5.0000])

>>> print(y) tensor([2.1000, 3.0000])

>>> y.add_(x) tensor([3.3000, 5.0000])

>>> print(y) tensor([3.3000, 5.0000])

注意: 加上“_”的操作会改变变量自己的值,不加"_"的操作,不改变变量的值。

例如x.copy_(y), x.t_(), 将会改变 x的值

 

1.6 改变tensor的大小, view()

x = torch.randn(4, 4)

y = x.view(16)

z = x.view(-1, 8) # the size -1 is inferred from other dimensions

print(x.size(), y.size(), z.size())

输出:

>>> x = torch.randn(4,4)

>>> print(x) tensor([[-1.8123, 0.4619, -1.0568, -0.3072], [-0.4922, -0.7467, -0.6142, 0.7062], [-0.3275, 0.3135, 0.1623, 0.2957], [ 0.3594, 0.6116, 0.7314, -0.2364]])

>>> y = x.view(16)

>>> print(y)

tensor([-1.8123, 0.4619, -1.0568, -0.3072, -0.4922, -0.7467, -0.6142, 0.7062, -0.3275, 0.3135, 0.1623, 0.2957, 0.3594, 0.6116, 0.7314, -0.2364])

>>> z = x.view(-1, 8)

>>> print(z)

tensor([[-1.8123, 0.4619, -1.0568, -0.3072, -0.4922, -0.7467, -0.6142, 0.7062], [-0.3275, 0.3135, 0.1623, 0.2957, 0.3594, 0.6116, 0.7314, -0.2364]])

>>> print(x.size(), y.size(), z.size())

torch.Size([4, 4]) torch.Size([16]) torch.Size([2, 8])

 

1.7 获取tensor元素的值 .item()

x = torch.randn(1)

print(x)

print(x.item())

x = torch.randn(2,2)

print(x[1,1].item())

输出:

>>> x = torch.randn(1)

>>> print(x)

tensor([0.0017])

>>> print(x.item())

0.0016834917478263378

>>> x = torch.randn(2,2)

>>> print(x[1,1].item())

0.9299264550209045


2 NumPy Bridge

可以将一个Torch Tensor转为NumPy array, 也可以将一个NumPy array转为Torch Tensor, 非常方便。

转换之后,Torch Tensor 和 NumPy array底层存储仍然共用相同的内存地址(Tensor在CPU上的情况下), 所以,改变一个值,另一个也会相应改变。

 

2.1 Torch Tensor 转为NumPy Array

a = torch.ones(5)

print(a)

b = a.numpy()

print(b)

输出:

>>> a = torch.ones(5)

>>> print(a)

tensor([1., 1., 1., 1., 1.])

>>> b = a.numpy()

>>> print(b)

[1. 1. 1. 1. 1.]

 

a的值改变后,b的值也会变:

a.add_(1)

print(a)

print(b)

输出:

>>> a.add_(1)

tensor([2., 2., 2., 2., 2.])

>>> print(b)

[2. 2. 2. 2. 2.]

 

2.2 NumPy Array 转为 Torch Tensor

import numpy as np

a = np.ones(5)

b = torch.from_numpy(a)

np.add(a, 1, out=a)

print(a)

print(b)

 

输出:

>>> import numpy as np

>>> a = np.ones(5)

>>> b = torch.from_numpy(a)

>>> np.add(a,1,out=a)

array([2., 2., 2., 2., 2.])

>>> print(b)

tensor([2., 2., 2., 2., 2.], dtype=torch.float64)

 

除了CharTensor之外,CPU上的所有tensors都支持与numpy间的相互转换。

 

3 CUDA Tensors

Tensors 可以使用.to() 方法移动到任何device上。

# let us run this cell only if CUDA is available
# We will use ``torch.device`` objects to move tensors in and out of GPU
if torch.cuda.is_available():
    device = torch.device("cuda")          # a CUDA device object
    y = torch.ones_like(x, device=device)  # directly create a tensor on GPU
    x = x.to(device)                       # or just use strings ``.to("cuda")``
    z = x + y
    print(z)
    print(z.to("cpu", torch.double))       # ``.to`` can also change dtype together!

 

输出:

>>> torch.cuda.is_available()

True

>>> device = torch.device("cuda")

>>> y = torch.ones_like(x, device=device)

>>> print(x)

tensor([[ 0.2147, -1.1994], [-0.6267, 0.9299]])

>>> print(y)

tensor([[1., 1.], [1., 1.]], device='cuda:0')

>>> x = x.to(device)

>>> print(x)

tensor([[ 0.2147, -1.1994], [-0.6267, 0.9299]], device='cuda:0')

>>> z = x+y

>>> print(z) tensor([[ 1.2147, -0.1994], [ 0.3733, 1.9299]], device='cuda:0')

>>> print(z.to("cpu", torch.double))

tensor([[ 1.2147, -0.1994], [ 0.3733, 1.9299]], dtype=torch.float64)

完整的用法,见:

https://pytorch.org/docs/stable/torch.html

 

 

 

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/lclfans1983/article/details/108372105

智能推荐

java中Keytool的使用总结-程序员宅基地

以前用过几次这个东东,但每次都重新查询一次。本文原始出处是这里 。-----------------------------------------------------------Keytool 是一个Java 数据证书的管理工具 ,Keytool 将密钥(key)和证书(certificates)存在一个称为keystore的文件中在keystore里,包含两种数据: 密钥实体(Ke

uniapp自定义组件父子组件props传递对象数据时,当对象中包含函数,子组件无法引用到对象中的函数的解决办法_@ther的博客-程序员宅基地

uniapp自定义组件父子组件props传递对象数据时,当对象中包含函数,子组件无法引用到对象中的函数的解决办法出现这种情况 是因为uniapp 在传递数据的时候使用的是JSON.parse(JSON.stringify(obj1))这样传递的 无法传递函数。具体参考[https://blog.csdn.net/py_boy/article/details/107089150]解决办法是重写挂载在Vue原型对象上的__patch__方法 如下import {myPatch} from "./ext

关于在maven中导入依赖失败的问题-程序员宅基地

关于在maven中导入依赖失败的问题方法1:找到maven本地仓库,删掉引入失败的jar包,再重新导入。如果失败进入方法2方法2:打开idea右侧的maven>>打开项目>>Lifecycle>>分别点击clean、install如果这都不行的话,进入方法3.方法3:先执行方法1 ,然后打开https://maven.aliyun.com/mvn/search直接下载你需要的依赖,然后存放到你的本地仓库4、如果这都不行的话,建议你删除该项目,重新建

PyQt5/PySide2 ‘module‘ object has no attribute ‘QStringListModel‘-程序员宅基地

问题描述:在某些版本的PyQt5/PySide2中使用 QtGui.QStringListModel 会出现模块不存在的错误。原因分析:较近版本的 PyQt5/PySide2 把 QStringListModel 放到了 QtCore下方。使用QtCore.QStringListModel就可以解决。事实上在Qt中 QStringListModel 一直是在QtCore的下方,因为同属于模型,逻辑比较一致。Git的讨论...

NoClassDefFoundError Could not initialize class com.fasterxml.jackson.databind.ObjectMapper-程序员宅基地

应用启动异常:Factory method 'jacksonObjectMapper' threw exception; nested exception is java.lang.NoClassDefFoundError: Could not initialize class com.fasterxml.jackson.databind.ObjectMapperll is not a number, throw NumberFormatException.Stopping available co

客户端/服务器模式下,pvpython操作完数据后,paraview客户端未响应_paraview闪退-程序员宅基地

Paraview client crashes while loading dataParaview客户端在加载数据时崩溃paraview|high performance computingRunning ParaView in Client-Server ModeThe SM Tools are a set of tools developed to create and edit streaming movie files (sm). SM工具是一组用于创建和编辑流媒体电影文件(SM)的工_paraview闪退

随便推点

每天一篇论文302/365 A General and Adaptive Robust Loss Function-程序员宅基地

A General and Adaptive Robust Loss Function摘要给出了Cauchy/Lorentzian,Geman-mccluer,Welsch/Leclerc,广义Charbonnier,Charbonnier/pseudo-Huber/L1-L2和L2损失函数的一个推广。通过引入鲁棒性作为一个连续参数,我们的损失函数允许基于鲁棒损失最小化的算法被推广,从而提高了..._a general and adaptive robust loss function

java.lang.IllegalArgumentException: The document is really a OOXML file (解决)_Mr_going的博客-程序员宅基地

异常原因:java.lang.IllegalArgumentException: The document is really a OOXML file因为项目需要用到在线预览word功能,所以在网上找了代码,运行后却报错了。如下:以上实现代码摘自CSDN 点我,载你过去。。。错误原因这是由于上述代码中的 HWPFDocument 对象只能读取2003之前版本的word, 也就是说仅支持 .doc格式。如果你想读取2007之后版本的word (.docx格式),则需要使用 XWPFD

php 管理员界面源码,ThinkPHP通用后台管理系统TP-Admin-程序员宅基地

ThinkPHP通用后台管理系统TP-Admin以Thinkphp为底层框架,融合PHPCMS思想进行开发建立的一个大型的系统后台;以模块化开发方式做为功能开发形式;框架易于功能扩展,代码维护,优秀的二次开发能力,可满足所有网站的应用需求。主要功能介绍全新框架采用全球认可的最为先进的开放理念——OOP(面向对象),进行全新框架设计。框架结构更为清晰,代码更易于维护。模块化做为功能的开发形式,让扩 ..._thinkphp后台管理系统

LG·烦恼的高考志愿【二分】_现有 所学校,每所学校预计分数线是 。有 位学生,估 分分别为 。-程序员宅基地

luogu P1678 烦恼的高考志愿Description--Input--Output--Sample Input--Sample Output--说明--解题思路--代码--Description–背景计算机竞赛小组的神牛V神终于结束了万恶的高考,然而作为班长的他还不能闲下来,班主任老t给了他一个艰巨的任务:帮同学找出最合理的大学填报方案。可是v神太忙了,身后还有一群小姑娘等着和他约会..._现有 所学校,每所学校预计分数线是 。有 位学生,估 分分别为 。

True Impostor技术原理总结与实践-程序员宅基地

True Impostor技术基本介绍Impostor,原意“伪装者”,是一种使用极简单的mesh来模拟真实mesh模型的一种优化技术,可以高效的在场景中绘制大量同类的模型而不需要绘制大量的多边形。Impostor技术是介于Billboard和mesh之间的一种模型,在节省顶点的同时实现模型全角度细节的展示。Impostor将2维纹理映射到一张矩形mesh上,模拟高精度模型的假象,实现方法可以..._impostor技术

数据结构——二叉排序树_二叉排序树的结构体-程序员宅基地

(1)二叉排序树,又称二叉查找树。一棵二叉树或是空二叉树,或是具有如下性质的二叉树:①左子树上所有结点的关键字均小于根结点的关键字;②右子树上所有结点的关键字均大于根结点的关键字;③左子树和右子树又各是一棵二叉排序树。(2)由于左子树结点值_二叉排序树的结构体

推荐文章

热门文章

相关标签