tensorflow实现AlexNet_漂洋过海的油条的博客-程序员秘密

一、AlexNet网络介绍


AlexNet包含了6亿3000万个连接,6000万个参数和65万个神经元,拥有5个卷积层,其中3个卷积层后面连接了最大池化层,最后有3个全连接层。

二、AlexNet优点

1、成功使用RELU作为CNN的激活函数,并验证其效果在较深的网络超过了sigmoid,成功解决了sigmoid在网络较深时的梯度弥散问题。

2、训练时使用Dropout随机忽略一部分神经元,以避免模型过拟合。

3、在CNN中使用重叠的最大池化。

4、提出了LRN层,对局部神经元的活动创建竞争机制,使得其中响应比较大的值变得相对更大,并抑制其他反馈较小的神经元,增强模型的泛化能力。

5、数据增强,增大了数据集的量。

三、实现

我们主要建立一个完整的AlexNet卷积神经网络,然后对它的每个batch的forward和backward的速度进行测试。

定义一个用来显示网络每一层结构的函数print_actications,这个函数接受一个tensor作为输入,并显示其名称和尺寸。

from datetime import datetime
import  math
import  time
import tensorflow as tf

batch_size = 32
num_batches = 100

def print_activations(t):
    print(t.op.name, '', t.get_shape().as_list())

接下来设计AlexNet的网络结构。这个inference函数将会很大,包括多个卷积和池化层。

def inference(images):
    parameters = []
    
    with tf.name_scope('conv1') as scope: #scope可以规范化变了名称
        kernel = tf.Variable(tf.truncated_normal([11, 11, 3, 64],
                    dtype=tf.float32, stddev=1e-1), name='weights')
        conv = tf.nn.conv2d(images, kernel,[1, 4, 4,1], padding='SAME')
        biases = tf.Variable(tf.constant(0.0, shape=[64], dtype=tf.float32),
                             trainable=True, name='biases')
        bias = tf.nn.bias_add(conv, biases)
        conv1 = tf.nn.relu(bias, name=scope)
        print_activations(conv1) #将这一层最后输出的tensor conv1的结构打印出来,并将这一层可训练的参数kernel、biases添加到parameters中
        parameters += [kernel, biases]
        
lrn1 = tf.nn.lrn(conv1, 4, bias=1.0, alpah=0.001/9, beta=0.75, name='lrn1')#depth_radius 设为4,
pool1 = tf.nn.max_pool(lrn1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
                               padding='VALID', name='pool1')#VALID为取样时不能超过边框,不像SAME模式那样可以填充边界外的点。
print_activations(pool1)#将输出结果pool1的结构打印出来

#第二层卷积层
    with tf.name_scope('conv2') as scope:  # scope可以规范化变了名称
         kernel = tf.Variable(tf.truncated_normal([5, 5, 64, 192],
                                             dtype=tf.float32, stddev=1e-1), name='weights')
         conv = tf.nn.conv2d(pool1, kernel, [1, 1, 1, 1], padding='SAME')
         biases = tf.Variable(tf.constant(0.0, shape=[192],
                        dtype=tf.float32),trainable=True, name='biases')
         bias = tf.nn.bias_add(conv, biases)
         conv2 = tf.nn.relu(bias, name=scope)
         parameters += [kernel, biases]
    print_activations(conv2)
 
lrn2 = tf.nn.lrn(conv2, 4, bias=1.0, alpah=0.001/9, beta=0.75, name='lrn2')#depth_radius 设为4,
pool2 = tf.nn.max_pool(lrn2, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
                               padding='VALID', name='pool2')#VALID为取样时不能超过边框,不像SAME模式那样可以填充边界外的点。
print_activations(pool2)#将输出结果pool2的结构打印出来

#第三层
    with tf.name_scope('conv3') as scope:  # scope可以规范化变了名称
        kernel = tf.Variable(tf.truncated_normal([3, 3, 192, 384],
                                             dtype=tf.float32, stddev=1e-1), name='weights')
        conv = tf.nn.conv2d(pool2, kernel, [1, 1, 1, 1], padding='SAME')
        biases = tf.Variable(tf.constant(0.0, shape=[384],
                                     dtype=tf.float32), trainable=True, name='biases')
        bias = tf.nn.bias_add(conv, biases)
        conv3 = tf.nn.relu(bias, name=scope)
        parameters += [kernel, biases]
    print_activations(conv3)

#第四层
    with tf.name_scope('conv4') as scope:  # scope可以规范化变了名称
        kernel = tf.Variable(tf.truncated_normal([3, 3, 384, 256],
                                             dtype=tf.float32, stddev=1e-1), name='weights')
        conv = tf.nn.conv2d(conv3, kernel, [1, 1, 1, 1], padding='SAME')
        biases = tf.Variable(tf.constant(0.0, shape=[256],
                     dtype=tf.float32), trainable=True, name='biases')
        bias = tf.nn.bias_add(conv, biases)
        conv4 = tf.nn.relu(bias, name=scope)
        parameters += [kernel, biases]
    print_activations(conv4)
    
#第五层
        with tf.name_scope('conv5') as scope:  # scope可以规范化变了名称
            kernel = tf.Variable(tf.truncated_normal([3, 3, 256, 256],
                         dtype=tf.float32, stddev=1e-1), name='weights')
        conv = tf.nn.conv2d(conv4, kernel, [1, 1, 1, 1], padding='SAME')
        biases = tf.Variable(tf.constant(0.0, shape=[256],
                     dtype=tf.float32), trainable=True, name='biases')
        bias = tf.nn.bias_add(conv, biases)
        conv3 = tf.nn.relu(bias, name=scope)
        parameters += [kernel, biases]
        print_activations(conv5)
        
    pool5 = tf.nn.max_pool(conv5, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
                            padding='VALID', name='pool5')
    print_activations(pool5)
    return pool5, parameters

接下来实现一个评估AlexNet每轮计算时间的函数time_tensorflow_run。这个函数的第一个输入时tensorflow的session,第二个变了是需要评测的运算算子,第三个是测试的名称。

先定义预热轮数num_steps_burn_in=10,作用时给程序热身,头几轮迭代有显存加载等问题可以跳过。所以我们只需考量10轮迭代之后的计算时间。同时,也记录总时间total_duration和平方和total_duration_squared用来计算方差。

#time_tensorflow_run函数时评估AlexNet每轮计算时间,用来计算某个算子的运行时间
def time_tensorflow_run(session, target, info_string):
    num_steps_burn_in = 10 #轮数
    total_duration = 0.0 #总时间
    total_duration_squared = 0.0 #平方和

我们进行num_batches + num_steps_burn_in次迭代计算,使用time.time()记录时间,每次迭代通过session.run(target)执行。在初始热身的num_steps_burn_in次迭代后,每10轮迭代显示当前迭代所需要的时间。同时每轮将total_duration和total_duration_squared累加以便后面计算每轮耗时的均值和标准差。

在循环结束后,计算每轮迭代的平均耗时mn和标准差sd,最后将结果显示出来。这样就完成了计算每轮迭代耗时的评测函数time_tensorflow_run。

    for i in range(num_batches + num_steps_burn_in):
        start_time = time.time()
        _ = session.run(target)
        duration = time.time() - start_time
        if i >= num_steps_burn_in:
            if not i % 10:
                print('%s: step %d, duration = %.3f' %
                        (datetime.now(), i - num_steps_burn_in, duration))
            total_duration += duration
            total_duration_squared += duration * duration
#循环完成后,计算每轮迭代的平均耗时mn和标准差sd,最后将结果显示出来。
    mn = total_duration / num_batches
    vr = total_duration_squared / num_batches - mn * mn
    sd = math.sqrt(vr)
    print('%s: %s across %d steps, %.3f +/- %.3f sec / batch' %
          (datetime.now(), info_string, num_batches, mn, sd))

接下来时主函数run_benchmark。首先使用with tf.Graph().as_default()定义默认的Graph方便后面使用。

使用tf.random_normal函数构造正态分布(标准差为0.1)的随机tensor,第一个维度是batch_size,第二个和第三个维度时图片的尺寸 image_size=224,第四个维度是图片的颜色通道数。

用inference函数构建整个AlexNet网络,得到最后一个池化层的输出pool5和网络中需要训练的参数的集合parameters.

使用tf.Session创建新的Session,并通过tf.global_variables_initializer()初始化所有参数。

接着进行AlexNet的forward计算的评测,这里直接使用time_tensorflow_run统计运行时间,传入的target就是pool5,即卷积网络最后一个池化层的输出。然后进行backward,这里需要给最后的输出pool5设置一个优化目标loss。我们使用tf.nn.l2_loss计算pool5的loss,在使用tf.gradients求相对于loss的所有模型参数的过程。

最后使用time_tensorflow_run统计backward的运算时间,这里的target就是求整个网络梯度gard的操作。

def run_benchmark():
    with tf.Graph().as_default():
        image_size = 224
        images = tf.Variable(tf.random_normal([batch_size,
                                               image_size,
                                               image_size, 3],
                                              dtype=tf.float32,
                                              stddev=1e-1))
        
        pool5, parameters = inference(images)
        
        init = tf.global_variables_initializer()
        sess = tf.Session()
        sess.run(init)
        
        time_tensorflow_run(sess, pool5, "Forward")
        object = tf.nn.l2_loss(pool5)
        grad = tf.gradients(objective, parameters)
        time_tensorflow_run(sess, grad, "Forward-backward")

最后执行主函数

run_benchmark()

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_40533355/article/details/80423986

智能推荐

安装nginx及fastdfs-nginx-module_阿Q咚咚咚的博客-程序员秘密

先了解背景:FastDFS为什么要结合Nginx以及FastDFS原理,请参考文章:https://baijiahao.baidu.com/s?id=1628343949188630389&wfr=spider&for=pc准备工作:安装安装Nginx所需的环境,参考文献:https://www.cnblogs.com/yanyh/p/9801466.htmlapt in...

富文本保存的base64字符串转换成图片并保存到OSS_小程序如何将base64格式转换成png_谦客的博客-程序员秘密

富文本保存的base64字符串转换成图片并保存到OSS1. 问题最近在项目中由于疏忽,遇到了一件刚上线就比较棘手的事: 富文本保存图片,前端直接保存图片base64数据到服务器,导致产生大量的数据,直接导致数据库打不开,上线之后页面刷新超慢,平均查询一个商品详情要10s-20s,简直忍无可忍。2. 思路富文本base64格式数据:<p><img src="*****Z"></p>将数据中的base64格式数据

Pythom(6.25)异常与日志_pythom异常_蛇群中的一只羊的博客-程序员秘密

一、异常        Python程序的语法是正确的,在运行它的时候,也有可能发生错误。运行期检测到的错误被称为异常。大多数的异常都不会被程序处理,都以错误信息的形式展现在这里:        异常产生的时机:系统产生        如果产生异常,程序中止:程序不健壮        给程序添加异常,使程序变得健壮        try:             语句体            1/...

thinkphp5 使用phpoffice/phpspreadsheet导入和导出excel_枯灯一枚的博客-程序员秘密

PHPExcel已经被废弃在PHP7.2中已经无法获取更新,官方重新开了一个新包phpspreadsheetcomposer安装:composer require phpoffice/phpspreadsheet一,导出,1,view中:<a href="#" class="label label-primary set" style="margin-right: 8...

Layui 下拉框多选 —老司机首选(测试效果已ok)_林夕_影的博客-程序员秘密

Layui 下拉框多选 —老司机首选https://blog.csdn.net/YBaog/article/details/79933223

atoi函数和itoa函数详解(整数与字符串之间的互相转换函数)_Lwhere~的博客-程序员秘密

记忆技巧: int/之类的转string itoa(其实就是 i to a) string 转 int atoi函数(其实就是a to i)1.int/float to string/array:C语言提供了几个标准库函数,可以将任意类型(整型、长整型、浮点型等)的数字转换为字符串,下面列举了各函数的方法及其说明。● itoa():将整型值转换为字符串。●...

随便推点

axios 拦截器与取消 pending 状态请求_axios pending__let的博客-程序员秘密

axios 拦截器与取消 pending 状态请求/** * axios 拦截器配置 */import axios from 'axios'import { Notification } from 'element-ui'import router from '../router/index.js'// 跳转到登录页面const gotoLoginPage = function...

C++:究竟还有没有未来?_阿言教编程的博客-程序员秘密

很多人说C++现在已经过时了,快要被淘汰了,真的是这样吗?权威部门统计,我国目前C/C++软件开发人才缺口每年为10万人左右,未来随着信息化、数据化不断提速,这一数字还将成倍增长。从事编程领域工作多年,最先接触的是C#,但是后续由于其跨平台性的限制,逐渐转向C++。其实最开始我是十分抵触C++的,因为写C#习惯了,用起C++来真的十分不习惯。不仅仅是难,系统库的查看方面也不如C#的简洁清晰...

Caffe添加自定义的层_caffe添加自定义层_aworkholic的博客-程序员秘密

介绍在使用Caffe时,可能已有的层不满足需求,需要实现自己的层,最好的方式是修改caffe.proto文件,增加对应cpp、h、cu的声明和实现,编译caffe库即可。

侧面菜单与滚动条随动 element ui_element ui 侧边栏菜单设置滚动条_阿门阿钱的博客-程序员秘密

技术:javascript vue实现功能: 侧面菜单与滚动条随动实现原理:通过监听滚动条的数值,并更改左侧菜单栏的状态。代码如下:<template> <div> <!-- 添加布局--> <el-container style="display: flex;"> <!-- 左侧导航栏--> <el-aside width="130px" style=...

每天5个运维常用指令——wc,iconv,dos2unix,diff,vimdiff_dos wc命令_卤香狗蛋的博客-程序员秘密

【wc】统计文件的行数、单词数或字节数。利用wc指令我们可以计算文件的Byte数、字数、或是列数,若不指定文件名称、或是所给予的文件名为"-",则wc指令会从标准输入设备读取数据。语法wc [-clw][--help][--version][文件...]参数:-c或--bytes或--chars 只显示Bytes数。 -l或--lines 只显示行数。 -w或--word...

WPF中MVVM模式原理分析与实践(转)_weixin_34050519的博客-程序员秘密

1, 前提   可以说MVVM是专为WPF打造的模式, 也可以说MVVM仅仅是MVC的一个变种, 但无论如何, 就实践而言, 如果你或你的团队没有使用"Binding"的习惯, 那么研究MVVM就没有多大意义.  另外,个人觉得, 使用Command以及打造一种合理的简化的方式去使用Command也与使用Binding一样重要.  2, 诞生  为了解决现实世界中的问题,我们需要将...