神经网络学习3【计算流程公式推导+Python代码框架】_神经网络计算流程-程序员宅基地

技术标签: python  神经网络  

上一篇文章 神经网络学习2【分类器+升华至神经元 搭建神经网络】中由浅入深的理解了从分类器开始如何理解神经网络的内部计算。
接下来我们借用数学工具去逐步推导其计算公式,并同时建立出Python代码框架。

1.神经网络计算流程公式推导

1.1两层神经网络计算

尝试使用只有两层、每层两个神经元的较小的神经网络,来演示神经网络如何工作:
在这里插入图片描述

两个输入值分别为1.0和0.5,每个节点使用激活函数。使用一些随机权重:
在这里插入图片描述
第一层仅作输入层,不需要使用激活函数。
第二层中需要做一些计算,以及使用激活函数,当然,在这个简单的网络里,它也做输出层。计算好后就输出。鉴于它处在这张网里,它会得到的是组合的输入。
在这里插入图片描述
与上张图对照,a、b、c即为输入x,权重w。
我们观察到有两个指入的箭头在层2的第一个节点,故在计算其输入的时候,其总和输入值x =(1.0 * 0.9)+(0.5 * 0.3)= 1.05。
计算该节点的输出。答案为y = 1 /(1 + 0.3499)= 1/ 1.3499。因此,y = 0.7408。
权重是神经网络进行学习的内容,这些权重持续进行优化,得到越来越好的结果。
同理,层2的第二个节点的总和输入值x =(1.0 * 0.2)+(0.5 * 0.8)= 0.6。
使用S激活函数y = 1/(1 + 0.5488) = 1/1.5488计算节点输出,得到y = 0.6457。
至此,对于节点的组合输入我们可以总结出一个较为一般的公式:
x = ( 第一个节点的输出 * 链接权重 ) + (第二个节点的输出 * 链接权重 )
在这里插入图片描述

1.2 引入矩阵乘法计算多层神经网络

如果说我们使用3层甚至更多层的神经网络,那么哪怕是使用计算机,这样步骤多的计算也看起来非常的繁琐。但是如果再仔细观察,是否会发现如果我们把它写成不同的形式,这样的计算有点熟悉的感觉?
运用矩阵的点乘:
在这里插入图片描述
第一个矩阵包含两层节点之间的权重。第二个矩阵包含第一层输入层的信号。通过两个矩阵相乘,我们得到的答案是输入到第二层节点组合调节后的信号。
在这里插入图片描述
这不就是我们刚才进行过的计算?却只需要一个式子就可以轻松表达出来。如果说现在它的方便之处感觉还不太明显的话,动手试着搭建一个类似的三层神经网络,它的便利之处就可以大大体现出来了!
X = W •I
W 是权重矩阵,I 是输入矩阵,X 是组合调节后的信号。
更便利的是,如果使用矩阵乘法来帮助我们计算,在程序中,我们只需要用import命令告诉Python,去借助numpy模块进行计算。numpy模块包含了例如数组这样有用的工具以及使用这些工具的进行计算的能力。
最后再使用激活函数即可,并不需要矩阵乘法。
我们所需做的,是对矩阵X 的每个单独元素应用S函数y = 1 /(1 + e^(-x ))。
激活函数只是简单地应用阈值,使反应变得更像是在生物神经元中观察到的行为。因此,来自第二层的最终输出是:O = sigmoid ( X )

1.通过神经网络向前馈送信号所需的大量运算可以表示为矩阵乘法。
2.不管神经网络的规模多大,将输入输出表达为矩阵乘法,使得我们可以更简洁地进行书写。

1.3 使用矩阵乘法的三层神经网络示例

具有3层、每层具有3个节点的神经网络示例:
在这里插入图片描述
第一层为输入层,最后一层为输出层,中间层我们称之为隐藏层。
输入矩阵I 为:
在这里插入图片描述
接下来是中间的隐藏层。需要计算出输入到中间层每个节点的组合(调节)信号。隐藏层输入的链接权重矩阵:
在这里插入图片描述
第二个矩阵W hidden_output:
在这里插入图片描述
继续算出输入到隐藏层的组合调节输入值。
X hidden = W input_hidden • I
在这里插入图片描述
至此,得到:
在这里插入图片描述
对这些节点应用了S激活函数:O hidden = sigmoid( X hidden)
在这里插入图片描述
获得隐藏层的输出:
在这里插入图片描述
继续计算最终层的组合调节输入X = W •I

这一层的输入信号是第二层的输出信号,也就是我们刚刚解出的O hidden 。所使用的权重就是第二层和第三层之间的链接权重W hidden_output
X output = W hidden_output • O hidden
在这里插入图片描述
得到了输出层的输入,更新:
在这里插入图片描述
应用S激活函数获得最终的输出:
在这里插入图片描述
完整的:
在这里插入图片描述
第一步的计算我们已经完成,但别忘记,初始的权重矩阵是随机的,神经网络的学习过程就是去更新权重,获得可靠的权重矩阵。
那么下一步,将神经网络的输出值与训练样本中的输出值进行比较,计算出误差。我们需要使用这个误差值来调整神经网络本身,进而改进神经网络的输出值。

到这里我知道该继续理解误差反馈如何更新权重矩阵。但是由于我更想把误差反馈的公式推导和其具体的代码一块列出来,所以就先去建立python代码框架。

2. Python代码框架建立

使用Python 类和对象,点击学习。

通过前面的学习,能够想象到神经网络至少应该有3个函数:

  • 初始化函数 —— 设定输入层节点、隐藏层节点和输出层节点的数量。
  • 训练 —— 学习给定训练集样本后,优化权重。
  • 查询 —— 给定输入,从输出节点给出答案。

初步建立出代码框架即可:

#neural network class definition
class NeuralNetwork:

    #initialise the neural netmork
    def __init__(self):
        pass

    #train the neural network
    def train(self):
        pass

    #query the neural network
    def query(self):
        pass

接下来我们只需要一步步的,充实框架里的内容即可!

2.1 init()函数

在__init__()函数里,我们要做就是初始化神经网络,在这里设置好输入层的节点、隐藏层的节点和输出层节点的数量。
这些节点数量定义了神经网络的形状和尺寸。建立时要注意,写代码时我们更希望的是去写一个应用性广泛的函数,即普适性比较强的函数。就是说我希望我建立好的这个网络,是可以随意更改大小的,在面对不同情况时,而并不是在遇到不同的环境就需要更改一次代码。

  1. 对于__init__()做出解释,该神经网络是用类来建立的。
    当第一次创建对象时,Python会调用这个名为__init__()的函数,创建和初始化只属于这个对象的变量名。
    如此设置,可以根据所需要的神经网络的不同的尺寸去新建不同的对象,只需更改开头的参数即可。
  2. 不要忘记设置learning rate
  3. 输入函数中的重点 权重——网络的核心
    在这一步中我们将确立网络的节点和链接是网络中最重要的部分,是改进网络时优化的目标。
    注意,权重是神经网络的固有部分,不是一个临时数据集,不会随着调用结束而消失。这意味着,权重必须是初始化的一部分,并且可以使用其他函数来访问。

在输入层与隐藏层之间的链接权重矩阵Winput_hidden ,大小为hidden_nodes 乘以 input_nodes。
在隐藏层和输出层之间的链接权重矩阵Whidden_output ,大小为output_nodes 乘以 hidden_nodes。

注意: 书《Python神经网络编程》中,上述隐藏层和输出层之间的链接权重矩阵Whidden_output 的大小是写反了的。这里我进行了一些细致的推导:
在这里插入图片描述

  1. numpy.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes))
    有些人更喜欢稍微复杂的方法来创建初始随机权重。在这里使用的就是正态分布采样权重,其中平均值是0,标准方差为节点传入链接数目的开方,即 (传入链接数目)^(-0.5)
    由于我们需要的是随机矩阵,而不是单个数字,因此采用分布中心值、标准方差和numpy数组的大小作为参数。
 # initialise the neural network
 def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):
     # set number of nodes in each inputs,hidden,output layer
     self.inodes = inputnodes
     self.hnodes = hiddennodes
     self.onodes = outputnodes

     self.lr = learningrate

     # link weight matrices, wih and who
     self.wih = numpy.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes))
     self.who = numpy.random.normal(0.0, pow(self.onodes, -0.5), (self.onodes, self.hnodes))

     self.activation_function = lambda x: scipy.special.expit(x)
     pass

2.2 query()函数

查询网络。在query()函数接受神经网络的输入,返回神经网络的输出。

  1. 要注意,当信号馈送至给定的隐藏层节点或输出节点时,我们使用链接权重调节信号,还应用S激活函数来一致来自这些节点的信号。
    S函数:self.activation_function = lambda x: scipy.special.expit(x)
    这里使用了Python中,lambda的用法
    定义在初始化函数里,用lambda来创建函数,方便又快捷,这个函数接受了x,返回scipy.special.expit(x)。使用lambda创建的函数是没有名字的,经验丰富的程序员喜欢称他们为匿名函数,但是这里分配给它一个名字self.activation_function()。所以后面需要调用激活函数的时候调用self.activation_function()即可。
    这一段是需要放在上面的__init__()函数里的。
  2. query()函数的输入只需要inputs_list。
  3. 输入是一个列表,写在方括号内。
# query the neural network
  def query(self, inputs_list):
  # convert inputs list to 2d array
     inputs = numpy.array(inputs_list, ndmin=2).T

  # calculate signals into hidden layer
     hidden_inputs = numpy.dot(self.wih, inputs)

     hidden_outputs = self.activation_function(hidden_inputs)

     final_inputs = numpy.dot(self.who, hidden_outputs)

     final_outputs = self.activation_function(final_inputs)
     return final_outputs

需要在代码头部导入:

import numpy
import scipy.special

好啦~还剩下关于误差传递和权重矩阵更新的部分,这部分的计算也应该是放在train()函数里的。

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/HUSHLALA/article/details/119702996

智能推荐

ORACLE ERP 月结与年结处理流程_oracle erp期间损益结转在哪-程序员宅基地

文章浏览阅读6.6k次。月结与年结处理流程 月结与年结处理,是企业财务比较特殊而重要的业务操作。在实施与推广OracleERP系统过程中,如何结合现行的会计制度与惯例,充分利用软件功能,做好相应的关账、开账工作,是困扰许多企业财务人员乃至实施顾问的一个热点问题。为此,笔者提出自己一些解决思路,供同仁参考。由于时间紧迫,错漏_oracle erp期间损益结转在哪

编译树莓派4b的openwrt镜像_openwrt 树莓派4b 编译-程序员宅基地

文章浏览阅读585次。openwrt编译(树莓派4b)编译环境搭建(ubuntu18.04)1.安装基础软件sudo apt-get updatesudo apt-get upgradesudo apt-get -y install build-essential asciidoc binutils bzip2 gawk gettext git libncurses5-dev libz-dev patch python3.5 python2.7 unzip zlib1g-dev lib32gcc1 libc6-dev-_openwrt 树莓派4b 编译

SecureCRT乱码,SecureFX中文乱码,SecureFX显示中文乱码,SecureCRT中文乱码,SecureFX乱码_电子保险柜显示secure-程序员宅基地

文章浏览阅读366次。SecureCRT和SecureFx设置中文乱码SecureCRT和SecureFx连接服务器时中文显示乱码,找了好多资料好久都没整出来,后来整出来了,因此把个人的解决办法提供出来,已变帮助更多的人,同时也方便以后自己配置时不至于到处找资料。Linux系统环境:Red Hat Enterprise Linux6 (64位)问题一:SecureCRT和SecureFx会话中创建文件或者v..._电子保险柜显示secure

流利说 Level 4 全文_英语流利说 level4文本-程序员宅基地

文章浏览阅读5.8w次,点赞50次,收藏89次。Level 4Unit 11/4ListeningLesson 1 Visiting a Friend 1-2Lesson2VocabularyLesson 3 Pains and SicknessLesson 4 LandformsDialogueLesson 5 Eating Out2/4ListeningLesson 1 A Trip to..._英语流利说 level4文本

Scrapy - bilibili视频信息爬取,使用scrapy-redis分布式,b站抓取速度约为16核服务器2500万条/天_scrapy b站视频下载-程序员宅基地

文章浏览阅读2.5k次。bilibili_video_stathttps://github.com/Wangler2333/bilibili_video_stat爬取b站视频信息,供大数据分析用户喜好。使用scrapy-redis分布式,在16核服务器上实现抓取2500万条/天。可长期部署抓取,实现视频趋势分析1.提供代理ip池2.提供user agent池3.使用scrapy-redis分布式4.使..._scrapy b站视频下载

软件工程是什么_什么是软件工程?-程序员宅基地

文章浏览阅读1k次。软件工程是什么Software engineers and computer programmers both develop software applications needed by working computers. The difference between the two positions lies in the responsibilities and the approa...

随便推点

超强记忆笔记二_star如何记忆-程序员宅基地

文章浏览阅读303次。人体桩是最实用最方便的桩子,我们之所以在人体上找12个桩子,是因为很多知识点都是12个,大家可以自己尝试用人体桩记忆12个月的英文单词,凡是数量在12个之内的知识点都可以用人体桩来记忆。“谐音法”:将一些抽象的词通过谐音转换为具体的词。对于一些很难具体的词可以采用“潜意识出图法”。一般情况下,我们一是能用谐音的尽可能用谐音法;二是在进行潜意识出图的时候,如果上类似的情况,我们尽可能把图像的_star如何记忆

完美解决java.lang.classNotFoundException:org.apache.jsp.xxx.jsp_org.apache.jsp.metadata.repositorymanage_jsp-程序员宅基地

文章浏览阅读4.9k次,点赞12次,收藏14次。在初学JSTL库中常遇见的错误,我花了一下午时间翻各种网站,最后得到解决一般运行会报这种错误,不要慌,把JSTL库的jar包,放在tomact的lib目录下即可解决然后重启服务器,就好了。_org.apache.jsp.metadata.repositorymanage_jsp

NX/UG二次开发—其他—UG工具调用其他开发工具_ug通过dll工具启动另一个c# dll工具-程序员宅基地

文章浏览阅读1.8k次,点赞3次,收藏7次。 dllPath: 被调用的dll路径entryFunctionName: 工具的入口函数void Function::CallOtherDll(char *dllPath, char *entryFunctionName){ typedef void(*load_ufusr_f_p_t)(char *param, int *retcod,..._ug通过dll工具启动另一个c# dll工具

c3p0死锁-程序员宅基地

文章浏览阅读101次。1.APPARENT DEADLOCK!!! Creating emergency threads for unassigned pending tasks!抛出以下异常信息:com.mchange.v2.async.ThreadPoolAsynchronousRunner$DeadlockDetector@13067b2 -- APPARENT DEADLOCK!!! Creat..._c3p0 apparent deadlock

uva 10020- Minimal coverage (贪心思想 简单区间覆盖)-程序员宅基地

文章浏览阅读1k次。题目大意:给出一个范围M,然后给出若干的区间,以0 0 终止, 要求用最少的区间将0 ~M 覆盖,输出最少个数以及方案。解题思路:典型的区间覆盖问题,算法竞赛入门经典P154上有讲。/*author: charkj_z *//*time: 0.108s *//*rank: 674 *//*为什么不把没用的地方去掉? 因为去掉了我觉得不像我能写出来的*//*Ac code :_minimal coverage

nand读_cat.qen.dad.nand.ten.leg.red.btg.pig.six.milk.怎么读-程序员宅基地

文章浏览阅读478次。串口通讯分同步通讯和异步通讯,通常使用的都是异步串口,通讯时双方约好波特率、数据位、停止位、奇偶校验位等常用的波特率38400、115200起始位:空闲时,电平为高,检测到下降沿,则视为起始位,然后接收一帧数据通常使用RS232的9针串口,其中最为重要的是2、3、5脚2 :RXD接收数据3 :TXD发送数据5 :GND接地——————/2440引脚配置——设置数据格_cat.qen.dad.nand.ten.leg.red.btg.pig.six.milk.怎么读