python cnn 实例_python实现简单的卷积神经网络CNN案例1:定义CNN网络结构_weixin_39883374的博客-程序员秘密

技术标签: python cnn 实例  

本案例中定义的CNN网络模型如下:

cnn.py文件中的__init__()函数主要作用是对卷积神经网络的参数w1,w2,w3等进行初始化,下面是该函数的代码:

def __init__(self, input_dim=(3, 32, 32), num_filters=32, filter_size=7,

hidden_dim=100, num_classes=10, weight_scale=1e-3, reg=0.0,

dtype=np.float32):

self.params = {}

self.reg = reg

self.dtype = dtype

# Initialize weights and biases

C, H, W = input_dim

self.params['W1'] = weight_scale * np.random.randn(num_filters, C, filter_size, filter_size)

self.params['b1'] = np.zeros(num_filters)

self.params['W2'] = weight_scale * np.random.randn(int(num_filters*H*W/4), hidden_dim)

self.params['b2'] = np.zeros(hidden_dim)

self.params['W3'] = weight_scale * np.random.randn(hidden_dim, num_classes)

self.params['b3'] = np.zeros(num_classes)

for k, v in self.params.items():

self.params[k] = v.astype(dtype)

无论是caffe还是tensorflow类型都会把数据类型转为float32类型,所以__init__()函数最后一个参数定义为dtype=np.float32。

cnn.py文件中还有另外一个函数 loss()函数,这两个函数起到了定义CNN卷积网络结构的作用。下面是loss()函数的代码:

def loss(self, X, y=None):

W1, b1 = self.params['W1'], self.params['b1']

W2, b2 = self.params['W2'], self.params['b2']

W3, b3 = self.params['W3'], self.params['b3']

# pass conv_param to the forward pass for the convolutional layer

filter_size = W1.shape[2]

conv_param = {'stride': 1, 'pad': (int)((filter_size - 1) / 2)}

# pass pool_param to the forward pass for the max-pooling layer

pool_param = {'pool_height': 2, 'pool_width': 2, 'stride': 2}

# compute the forward pass

a1, cache1 = conv_relu_pool_forward(X, W1, b1, conv_param, pool_param)

a2, cache2 = affine_relu_forward(a1, W2, b2)

scores, cache3 = affine_forward(a2, W3, b3)

if y is None:

return scores

# compute the backward pass

data_loss, dscores = softmax_loss(scores, y)

da2, dW3, db3 = affine_backward(dscores, cache3)

da1, dW2, db2 = affine_relu_backward(da2, cache2)

dX, dW1, db1 = conv_relu_pool_backward(da1, cache1)

# Add regularization

dW1 += self.reg * W1

dW2 += self.reg * W2

dW3 += self.reg * W3

reg_loss = 0.5 * self.reg * sum(np.sum(W * W) for W in [W1, W2, W3])

loss = data_loss + reg_loss

grads = {'W1': dW1, 'b1': db1, 'W2': dW2, 'b2': db2, 'W3': dW3, 'b3': db3}

return loss, grads

1.参数设置

从loss()函数我们可以看到卷积层的参数为: conv_param = {'stride': 1, 'pad': (int)((filter_size - 1) / 2)}

filter_size在_init_()函数中已经定义为7,所以pad值等于3,根据公式1:

conv卷积层的输出高度和宽度为(32-7+2*3)/1+1=32,所以conv卷积层的输出为32*32*32,32分别代表过滤器个数、高度和宽度。

loss()函数中pool池化层的参数定义为: pool_param = {'pool_height': 2, 'pool_width': 2, 'stride': 2}

同理根据公式1计算池化层输出的高度和宽度为:(32-2+2*0)/2+1=16,所以pool池化层的输出为:32*16*16,也就是说经过池化层之后高度和宽度分别减半了。

所以_init()函数中初始化池化层与FC全连接层间的参数w2时,定义为: self.params['W2'] = weight_scale * np.random.randn(int(num_filters*H*W/4), hidden_dim),num_filters*H*W/4就是pool池化层的输出为:32*16*16。

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_39883374/article/details/110885335

智能推荐

足球——2011-2012西甲球队队标_weixin_34050427的博客-程序员秘密

皇家马德里 巴塞罗那 瓦伦西亚 马拉加 莱万特 毕尔巴鄂竞技 马德里竞技 奥萨苏纳 塞维利亚 西班牙人 赫塔菲 马洛卡 皇家贝蒂斯 皇家社会 巴列卡诺 格拉纳达CF 比利亚雷亚尔 希洪竞技 萨拉戈萨 桑坦德竞技 转载于:https://www.cnblog...

ws2812 程序设计与应用(1)DMA 控制 PWM 占空比原理及实现(STM32)_stm32 dma 中断 pwm_Yonas-Luo的博客-程序员秘密

本文开发环境:MCU型号:STM32F10389T6IDE环境: MDK 5.27代码生成工具:STM32CubeMx 5.6.1HAL库版本:STM32Cube_FW_F1_V1.8.0本文内容:STM32 使用 DMA+PWM 方式驱动 ws2812(x4)附件:MDK5 示例工程WS2812 中英文数据手册文章目录一、WS2812 简介时序传输二、ws2812 驱动的几种方式三、DMA+PWM+TIM 驱动 ws2812四、STM32CubeMx 配置基础.

Python判断文件和字符串编码类型_python 判别字符串解码类型库_浅醉樱花雨的博客-程序员秘密

python判断文件和字符串编码类型可以用chardet和cchardet工具包,可以识别大多数的编码类型。使用示例:import chardetwith open("test.txt", "rb") as f: msg = f.read() result = chardet.detect(msg) print(result['encoding'])charde...

使用dbms_profiler 调试PL/SQL 性能_pl/sql 怎么用dbms_profiler_文档搬运工的博客-程序员秘密

目的: 调查PL/SQL(8.1级以上)的性能参考文档:Performance of New PL/SQL Features (Doc ID 104377.1)Using DBMS_PROFILER (Doc ID 97270.1)https://docs.oracle.com/cd/E11882_01/appdev.112/e40758/d_profil.htm#ARPLS67481-- 安装在sys账号下执行$ORACLE_HOME/rdbms/admin/profload.sql...

多个haproxy 之间跳转_weixin_30527323的博客-程序员秘密

C:\>ping wechatTest.winfae.com正在 Ping wechatTest.winfae.com [120.55.118.6] 具有 32 字节的数据:来自 120.55.118.6 的回复: 字节=32 时间=5ms TTL=54来自 120.55.118.6 的回复: 字节=32 时间=4ms TTL=54指向120.55.118.6的hapr...

Python---对象_ennuoo的博客-程序员秘密

Python  对象一切事物皆对象  对象创建基于类以上面这种方式创建太复杂Python以这种方式dir(list)看list里面的功能/方法有哪些;help(list) 详细;help(list.append)

随便推点

JavaScript去掉双引号_zhuanghw全栈工程师的博客-程序员秘密

var reg = new RegExp('"',"g");str = str.replace(reg, "");   运行前:   字符串str内容如下:      "小明","20"     "小东","30" 运行后:   字符串str内容如下:       小明,20       小东,30...

【干货分享】前端面试知识点锦集01(HTML篇)——附答案_weixin_30760895的博客-程序员秘密

一、HTML部分1、浏览器页面有哪三层构成,分别是什么,作用是什么?构成:结构层、表示层、行为层分别是:HTML、CSS、JavaScript作用:HTML实现页面结构,CSS完成页面的表现与风格,JavaScript实现一些客户端的功能与业务。2、HTML5的优点与缺点?优点:a、网络标准统一、HTML5本身是由W3C推荐出来的。b、多设备、跨平台c、即时更新。d、提高可用性和改...

详细介绍Linux2.6设备的驱动模型[转]_batoom的博客-程序员秘密

<br />linux 2.6 设备驱动模型 1.背景 随着设备拓扑结构越来越复杂,需要为内核建立一个统一的设备模型,对系统结构做一般性的抽象描述。 有了该抽象结构,可支持多种不同的任务: a) 电源管理  完成电源管理工作需要对系统结构的理解,且有严格的顺序,如:一个USB宿主适配器,在处理完所有与其相连接的设备面前是不能关闭的; b) 与用户空间通信  由/sysfs虚拟文件系统展示设备的属性 c) 热插拔设备 d) 对象生命周期 2.sysfs 虚拟文件系统 sysfs 是一个特殊的文件系统,类似于/

超适合3D建模初学者的学习游戏建模小技巧!想成为次世代大神必须要知道的这些小技巧_Grape_3DModeler的博客-程序员秘密

今天来为大家分享适合建模初学者学习的几个小技巧。想要学习一项技能,方向和方法对了才可以事半功倍。制定合理地目标想要快速学习建模,一定要给自己制定一个合理地目标,不然你的状况可能就是“今天没空先不学了”、“明天有个约会,要不先休息一下吧”。那么如何制定一个合理地目标呢?1制定短期地目标每个人可能都有长期的大目标,但是短期地目标也是至关重要的。只有完成一个个短期目标,才可以向着长期目标迈进。2不要过于着急很多人做事三分钟热度,刚开始可能有比较强大的驱动力,但是物极必反,太

程序员面试金典 - 面试题 05.08. 绘制直线(位运算)_Michael阿明的博客-程序员秘密

1. 题目绘制直线。有个单色屏幕存储在一个一维数组中,使得32个连续像素可以存放在一个 int 里。屏幕宽度为w,且w可被32整除(即一个 int 不会分布在两行上),屏幕高度可由数组长度及屏幕宽度推算得出。请实现一个函数,绘制从点(x1, y)到点(x2, y)的水平线。给出数组的长度 length,宽度 w(以比特为单位)、直线开始位置 x1(比特为单位)、直线结束位置 x2(比特为单...

【19调剂】南华大学特聘教授团队招收数学物理计算机方向的调剂生_计算机与软件考研的博客-程序员秘密

点击文末的阅读原文或者公众号界面左下角的调剂信息或者公众号回复“调剂”是计算机/软件等专业的所有调剂信息集合,会一直更新的。要求:人品好(一诺千金)、热爱科研。数学竞赛获奖者、或喜欢编程...

推荐文章

热门文章

相关标签