从tflearn Example中学习CNN(1)-程序员宅基地

技术标签: tflearn  cnn  tensorflow  深度学习  

     本人水平有限,难免会有错误,如果发现希望可以及时指出来,利人利己。哈哈哈~

     这是博客写的第一篇文章,主要想从tflearn的例子代码一步步理解CNN模型。这里插一句话,tflearn是tensorflow接口的更高层次的封装,与keras的区别时debug时可以看到源码,并且tflearn代码写的非常工整,适合我这样的菜鸟学习。

    现在深度学习异常火热,如果不会点深度学习,出门都不好意思和人家打招呼。

    这篇博客主要讲解下tflearn例子里的examples/images/convnet_mnist.py对于每个函数中涉及的参数我会一一的给出中文说明,下一篇主要讲解每个参数在CNN中的含义,以及系统讲解下CNN的构建过程。废话不多说,先看代码。


  

from __future__ import division, print_function, absolute_import

import tflearn
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.normalization import local_response_normalization
from tflearn.layers.estimator import regression
#加载大名顶顶的mnist数据集(http://yann.lecun.com/exdb/mnist/)
import tflearn.datasets.mnist as mnist
X, Y, testX, testY = mnist.load_data(one_hot=True)
X = X.reshape([-1, 28, 28, 1])
testX = testX.reshape([-1, 28, 28, 1])

network = input_data(shape=[None, 28, 28, 1], name='input')
# CNN中的卷积操作,下面会有详细解释
network = conv_2d(network, 32, 3, activation='relu', regularizer="L2")
# 最大池化操作
network = max_pool_2d(network, 2)
# 局部响应归一化操作
network = local_response_normalization(network)
network = conv_2d(network, 64, 3, activation='relu', regularizer="L2")
network = max_pool_2d(network, 2)
network = local_response_normalization(network)
# 全连接操作
network = fully_connected(network, 128, activation='tanh')
# dropout操作
network = dropout(network, 0.8)
network = fully_connected(network, 256, activation='tanh')
network = dropout(network, 0.8)
network = fully_connected(network, 10, activation='softmax')
# 回归操作
network = regression(network, optimizer='adam', learning_rate=0.01,
                     loss='categorical_crossentropy', name='target')

# Training
# DNN操作,构建深度神经网络
model = tflearn.DNN(network, tensorboard_verbose=0)
model.fit({'input': X}, {'target': Y}, n_epoch=20,
           validation_set=({'input': testX}, {'target': testY}),
           snapshot_step=100, show_metric=True, run_id='convnet_mnist')


关于conv_2d函数,在源码里是可以看到总共有14个参数,分别如下:

1.incoming: 输入的张量,形式是[batch, height, width, in_channels]
2.nb_filter: filter的个数
3.filter_size: filter的尺寸,是int类型
4.strides: 卷积操作的步长,默认是[1,1,1,1]
5.padding: padding操作时标志位,"same"或者"valid",默认是“same”
6.activation: 激活函数(ps:这里需要了解的知识很多,会单独讲)
7.bias: bool量,如果True,就是使用bias
8.weights_init: 权重的初始化
9.bias_init: bias的初始化,默认是0,比如众所周知的线性函数y=wx+b,其中的w就相当于weights,b就是bias
10.regularizer: 正则项(这里需要讲解的东西非常多,会单独讲)
11.weight_decay: 权重下降的学习率
12.trainable: bool量,是否可以被训练
13.restore: bool量,训练的模型是否被保存
14.name: 卷积层的名称,默认是"Conv2D"

关于max_pool_2d函数,在源码里有5个参数,分别如下:
1.incoming ,类似于conv_2d里的incoming
2.kernel_size:池化时核的大小,相当于conv_2d时的filter的尺寸
3.strides:类似于conv_2d里的strides
4.padding:同上
5.name:同上

看了这么多参数,好像有些迷糊,我先用一张图解释下每个参数的意义。


其中的filter就是
[1 0 1 
 0 1 0
 1 0 1],size=3,由于每次移动filter都是一个格子,所以strides=1.

关于最大池化可以看看下面这张图,这里面 strides=1,kernel_size =2(就是每个颜色块的大小),图中示意的最大池化(可以提取出显著信息,比如在进行文本分析时可以提取一句话里的关键字,以及图像处理中显著颜色,纹理等),关于池化这里多说一句,有时需要平均池化,有时需要最小池化。

下面说说其中的padding操作,做图像处理的人对于这个操作应该不会陌生,说白了,就是填充。比如你对图像做卷积操作,比如你用的3×3的卷积核,在进行边上操作时,会发现卷积核已经超过原图像,这时需要把原图像进行扩大,扩大出来的就是填充,基本都填充0。
一下关于padding的操作转自:http://www.jianshu.com/p/05c4f1621c7e

1.输入W×W的矩阵,(这里讨论长宽相等情况,不相等的话,推导方法有区别),现在想象一下脑子里有一副W*W的图像

2.假定filter的大小是F×F,卷积核

3.步长stride为S

4.输出的宽高为new_w,new_h

上面已经提到padding总共有两种方式,same,valid

当取valid

            new_weight=new_height=(W-F+1)/S(结果向上取整),

此时输出的矩阵大小比输入时小(这里不讨论F=1时的情况,说到1*1的卷积核,大家可以看看GoogLeNet模型,其中用到1*1卷积核,这个用来降维的,tflearn代码里有GoogLeNet的复现)

当取same时,

            new_height = new_weight= W/S(结果向上取整)

在高度上需要pad的像素数为

            pad_needed_height = (new_height – 1)  × S + F - W

根据上式,输入矩阵上方添加的像素数为

            pad_top = pad_needed_height / 2  (结果取整)

下方添加的像素数为

pad_down = pad_needed_height - pad_top

以此类推,在宽度上需要pad的像素数和左右分别添加的像素数为

pad_needed_width = (new_width – 1)  × S + F - W

pad_left = pad_needed_width  / 2 (结果取整)

pad_right = pad_needed_width – pad_left

下面看图示以及计算过程:



下一篇将详细介绍下激活函数以及正则项这两个大部头。

参考文献:
1.https://ujjwalkarn.me/2016/08/11/intuitive-explanation-convnets/
2.https://github.com/tflearn/tflearn/blob/master/examples/images/convnet_mnist.py
3.http://www.jianshu.com/p/05c4f1621c7e



版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/BugCreater/article/details/53293075

智能推荐

Android Apk反编译 dex2j遇到如下问题 com.googlecode.d2j.DexException: not support version.-程序员宅基地

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。 ...

记一次导入4G大小的SQL文件到MySQL数据库_sql 4g文件-程序员宅基地

Win10操作系统 MySQL版本一个为8.0另一个为5.7通过命令行直接进入MySQL8.0会报错:可能是因为系统分辨不出要进入哪一个MySQL(警告为在命令行里直接输入MySQL密码不安全,使用mysql -uroot -p敲回车再输入密码更安全。)正确打开的方法为进入安装MySQL文件夹下的bin目录,再敲上述代码:先cd一下MySQL8.0的安装目录再使用进入MySQL的命令,最好加上端口号(此处MySQL57版本的端口号是13306,而MySQL8版本的端口是3306):检查数_sql 4g文件

html5 拖拽上传文件时,屏蔽浏览器默认打开文件-程序员宅基地

  我们在使用html5的拖拽上传时,做法往往是监听一个控件范围内的drop事件。但是用户在操作的时候往往会出现文件没有进入到控件范围内就释放的情况,这种情况在以下浏览器中会出现不同的情况,下面是实验结果:    chrome: 如果该文件是浏览器可浏览文件(图片等),浏览器会在当前窗口打开文件的预览;如果是不可浏览文件,则会触发浏览器的下载    fireFox:如果该文件是..._拖拽上传为什么要阻止默认事件

在jsp页面使用富文本编辑器_jsp页面文本编辑器-程序员宅基地

#在jsp页面使用富文本编辑器工具:MyEclipse(1)下载ueditor下载地址:http://ueditor.baidu.com/website/download.html(2)解压缩文件,并把文件夹名改为ueditorueditor\jsp\lib路径下有commons-codec-1.9.jar、commons-fileupload-1.3.1.jar、commons-i..._jsp页面文本编辑器

密码学专题 非对称加密算法指令概述 DSA算法指令_dsa加密算法-程序员宅基地

DSA算法和DSA指令概述DSA算法是美国国家标准的数字签名算法,只具备数字签名的功能不具备密钥交换的功能 生成DSA参数然后生成DSA密钥,DSA参数决定了DSA密钥的长度 三个指令 首先是dsaparam指令,该指令主要用来生成DSA密钥参数,并提供了一些格式转换、C代码生成等其他类似于dhparam指令的功能。一组DSA参数可以用来生成多个不同的DSA密钥,而不是仅仅对应于一个DSA密钥。 gendsa指令用来从现有的DSA参数中生成DSA密钥,使用相同的DSA参数可以生成不同的DSA密钥_dsa加密算法

unity动态加载Resources并且实例化_unity resouce.load 后需要实例化吗-程序员宅基地

直接附上代码,在resources文件家中存在预制物体就可以GameObject Prefab = (GameObject)Resources.Load("Prefabs/task1"); Prefab = Instantiate(Prefab); Prefab.transform.parent = parent; Prefab.transform.position = Vector3.zero; Prefab.transform.localScale = Vector3.one;将预制物体放在需要位置,_unity resouce.load 后需要实例化吗

随便推点

Android 10 (Android Q)中的屏幕刷新率(display refresh rate)切换方法和策略_peak_refresh_rate-程序员宅基地

本文禁止转载,如有需求,请联系作者。什么是屏幕刷新率,什么是应用显示帧率。如何修改LCD的刷新率。Android的显示刷新率切换策略。_peak_refresh_rate

Egret项目中使用protobuf_egret protobuf-程序员宅基地

如何安装npm install [email protected] -gnpm install @egret/protobuf -g如何使用# 假设用户有个名为 egret-project 的白鹭项目cd egret-project# 将代码和项目结构拷贝至白鹭项目中pb-egret add# 将 protofile 文件放在 egret-project/protobuf/protofile 文件夹中pb-egret generate# 文件将会生成到 protobuf/bundles 文_egret protobuf

TCP server Socket编程 VC++6.0-程序员宅基地

研二 wifi嗅探项目 第一阶段 数据提取与分析#include #include #pragma comment(lib,"ws2_32.lib") int main(int argc, char* argv[]) { //一、WSAStartup函数初始化Winsock WORD sockVersion = MAKE

修改Jtable字体颜色-程序员宅基地

今天写一篇关于JTable有关的文章,包括:为JTable单元格设置字体颜色、为JTable单元格设置背景色、让JTable某一列设置为不可能编辑。代码很简单,请朋友们参考如下: package Java; import javax.swing.JFrame; import javax.swing.JTable; import javax.swi

video2frames - 把视频切割成一帧帧的图片-程序员宅基地

def video2frames(video_path, newdir): cap = cv2.VideoCapture(video) count = 0 while (cap.isOpened()): ret, frame = cap.read() if True: if ret == True: # cv2.imshow('video2frames', frame) ...

关于linux内核中结构体初始化的新写法-程序员宅基地

查看linux源代码,经常会被其中一些程序的写法所迷惑,此种初始化写法并不是什么特殊的代码风格,而是所谓的C语言标记化结构初始化语法(designated initializer),而且还是一个ISO标准,C99注意:适用于GCC编译器,GCC能完美支持C99,VC2005支持C89,还不支持C99,只有能完美支持C99的编译器才能编译通过。GCC有扩展标记化结构初始化语法,写法是