keras自定义网络层_(源码解读)_keras实现自定义网络层-程序员宅基地

技术标签: Keras源码  深度学习  Keras定义层  自定义神经网络  

keras是基于Tensorflow等的一个神经网络的上层框架,通过Keras我们可以简单的构造出自己的神经网络,同时Keras针对主流的算法框架、激活函数和优化函数等进行自己的实现,某些方面只需要我们进行简单的调用,Keras的出现大大简化了网络构建的成本。

Keras自定义网络层需要一下步骤:

1、继承一个Layer

keras顶级Layer类定义在engine包的base_layer.py文件,其中的class Layer(object)类定义了基本的关于Layer的方法和变量:这定义的方法变量众多,随着Keras版本的更新越来越满足我们特殊的需求。

class Layer(object):
#抽象类主要方法总结

  #为layer增加权重
  def add_weight()
  
  #对输入数据和本Layer定义的数据进行效验,如输入数据不符合定义规定报错
  def assert_input_compatibility(self, inputs):

  #训练过程中前向和后向传播的主要逻辑实现
  def call()

  #计算本层的输出形状
  def compute_output_shape()

  #计算输出的掩膜
  def compute_mask()

  #构造layer的权重
  def build()

  #检索给定节点上的层的输入\出形状
  def get_input_shape_at()
  def get_output_shape_at()

  #检索给定节点上的层的输入/出张量
  def get_input_at():
  def get_output_at():

  #检索给定节点上的层的输入/出掩模张量。
  def get_input_mask_at(self, node_index):
  def get_output_mask_at(self, node_index):

以Keras里面全连接层为例,Dense继承自Layer类,主要实现了build()、call()、compute_out_shaper()和get_config()方法。

class Dense(Layer):
  def build(self, input_shape):
  def call(self, inputs):
  def compute_output_shape(self, input_shape):
  def get_config(self):
  ...

2、重写其中的方法

keras自定义网络layer,主要根据自己网络的需要,对其父类的部分方法进行重写,当然如果有特殊的需要也可以对其父类集合Tensorflow进行改写。

build(input_shape):

通过build()方法定义自己layer的权重,此方法最后必须实现变量self.build = True。实现过程中我们可以对权重值进行约束和初始化或者正则化,分别调用self.kernel_constraint、self.kernel_initializer和self.kernel_regularizer可以实现。

call(x):

通过call()方法进行功能逻辑实现,是该层的计算逻辑或计算图。显然,这个层的核心应该是一段符号式的输入张量到输出张量的计算过程。

class BinaryConv2D(Conv2D):

    #定义构造方法
    def __init__(self, filters, kernel_lr_multiplier='Glorot', bias_lr_multiplier=None, H=1., **kwargs):
        super(BinaryConv2D, self).__init__(filters, **kwargs)
        self.H = H
        self.kernel_lr_multiplier = kernel_lr_multiplier
        self.bias_lr_multiplier = bias_lr_multiplier
        
    #这是你定义权重的地方。input_shape=(28,28,1)
    def build(self, input_shape):
        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1 
        if input_shape[channel_axis] is None:
                raise ValueError('The channel dimension of the inputs '
                                 'should be defined. Found `None`.')

        input_dim = input_shape[channel_axis] #获取channel
        kernel_shape = self.kernel_size + (input_dim, self.filters) #*(3,3,1,128)
            
        base = self.kernel_size[0] * self.kernel_size[1] #9
        if self.H == 'Glorot':
            nb_input = int(input_dim * base)
            nb_output = int(self.filters * base)
            self.H = np.float32(np.sqrt(1.5 / (nb_input + nb_output)))
            #print('Glorot H: {}'.format(self.H))
            
        if self.kernel_lr_multiplier == 'Glorot':
            nb_input = int(input_dim * base) #9
            nb_output = int(self.filters * base) #1152
            self.kernel_lr_multiplier = np.float32(1. / np.sqrt(1.5/ (nb_input + nb_output)))
            #print('Glorot learning rate multiplier: {}'.format(self.lr_multiplier))

        self.kernel_constraint = Clip(-self.H, self.H) #对主权重矩阵进行约束
        self.kernel_initializer = initializers.RandomUniform(-self.H, self.H)
        self.kernel = self.add_weight(shape=kernel_shape,
                                 initializer=self.kernel_initializer,
                                 name='kernel',
                                 regularizer=self.kernel_regularizer,
                                 constraint=self.kernel_constraint)

        if self.use_bias:
            self.lr_multipliers = [self.kernel_lr_multiplier, self.bias_lr_multiplier]
            self.bias = self.add_weight((self.output_dim,),
                                     initializer=self.bias_initializers,
                                     name='bias',
                                     regularizer=self.bias_regularizer,
                                     constraint=self.bias_constraint)

        else:
            self.lr_multipliers = [self.kernel_lr_multiplier]
            self.bias = None

        # Set input spec.
        self.input_spec = InputSpec(ndim=4, axes={channel_axis: input_dim})
        self.built = True

    #这里是编写层的功能逻辑的地方。
    def call(self, inputs):
        binary_kernel = binarize(self.kernel, H=self.H) 
        outputs = K.conv2d(
            inputs,
            binary_kernel,
            strides=self.strides,
            padding=self.padding,
            data_format=self.data_format,
            dilation_rate=self.dilation_rate)

        if self.use_bias:
            outputs = K.bias_add(
                outputs,
                self.bias,
                data_format=self.data_format)

        if self.activation is not None:
            return self.activation(outputs)
        return outputs

      def get_config(self):
        config = {'H': self.H,
                  'kernel_lr_multiplier': self.kernel_lr_multiplier,
                  'bias_lr_multiplier': self.bias_lr_multiplier}
        base_config = super(BinaryConv2D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
  

 

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_36931982/article/details/90369187

智能推荐

记录Linux安装Oracle19c_debian 安装 oracle19c-程序员宅基地

文章浏览阅读712次。最近单位要求本渣学习服务器脚本编写完成定点市级机构下发的数据库表导入项目服务器数据库,按工作顺序就先打算在自己笔记本电脑上通过虚拟机来模拟生产环境,部署虚拟环境后安装Linux版本Oracle19c数据库。经过数天研究终完成安装,记录如下。安装准备:1、虚拟软件 Oracle VM VirtualBox2、镜像 CentOS-7-x86_64-Minimal-2009.iso3、Xshell 7.04、Xftp 7.05、Xmanager Enterprise 56、LINUX._debian 安装 oracle19c

Halcon分类器之高斯混合模型分类器_训练高斯混合模型分类器-程序员宅基地

文章浏览阅读2k次,点赞2次,收藏20次。Halcon分类器示例自我理解看了很多网上的例子,总有一种纸上得来终觉浅,绝知此事要躬行的感觉。说干就干,将Halcon自带分类器例子classify_metal_parts.hdev按照自己的理解重新写一遍,示例中的分类器是MLP(多层感知机),我将它改变为GMM(高斯混合模型)。希望可以帮助刚入门的同学学习理解,大神请绕路吧,当然也喜欢各位看官帮我找出不足之处,共同进步。谢谢!分类效果如图..._训练高斯混合模型分类器

Office转PDF工具类_"officetopdf.wordtopdf(\"d:\\\\1234.doc\", \"d:\\\-程序员宅基地

文章浏览阅读819次。使用Jacob转换office文件,Jacob.dll文件需要放到jdk\bin目录下Jacob.dll文件下载地址https://download.csdn.net/download/zss0101/10546500package com.zss.util;import java.io.File;import com.jacob.activeX.ActiveXComponent;..._"officetopdf.wordtopdf(\"d:\\\\1234.doc\", \"d:\\\\1234.pdf\");"

redis实现队列_redistemplate convertandsend方法实现队列-程序员宅基地

文章浏览阅读1k次,点赞30次,收藏30次。上面的例子我们已经了一个简易的消息队列。我们继续思考一个现实的场景,假定这些是一些游戏商品,它需要添加"延迟销售"特性,在未来某个时候才可以开始处理这些游戏商品数据。那么要实现这个延迟的特性,我们需要修改现有队列的实现。在消息数据的信息中包含延迟处理消息的执行时间,如果工作进程发现消息的执行时间还没到,那么它将会在短暂的等待之后重新把消息数据推入队列中。(延迟发送消息)_redistemplate convertandsend方法实现队列

java基础-程序员宅基地

文章浏览阅读287次,点赞5次,收藏5次。java基础篇

使用gparted对linux根目录扩容(windows+linux双系统)_双系统linux扩容-程序员宅基地

文章浏览阅读298次。linux扩容根目录与/home_双系统linux扩容

随便推点

Python使用pika调用RabbitMQ_python pika 通过主机名称来访问mq-程序员宅基地

文章浏览阅读388次。定义RabbitMQ类import jsonimport osimport sysimport pikafrom Data import Datafrom MongoDB import MongoDBfrom constants import *class RabbitMQ: def __init__(self, queue_name): """ 初始化队列对象 :param queue_name: 队列名称 "_python pika 通过主机名称来访问mq

Python利用openpyxl处理excel文件_在 python 中可以通過 openpyxl 套件來很好的操作 excel 讀寫-程序员宅基地

文章浏览阅读568次。**openpyxl简介**openpyxl是一个开源项目,openpyxl模块是一个读写Excel 2010文档的Python库,如果要处理更早格式的Excel文档,需要用到其它库(如:xlrd、xlwt等),这是openpyxl比较其他模块的不足之处。openpyxl是一款比较综合的工具,不仅能够同时读取和修改Excel文档,而且可以对Excel文件内单元格进行详细设置,包括单元格样式等内容,甚至还支持图表插入、打印设置等内容,使用openpyxl可以读写xltm, xltx, xlsm, xls_在 python 中可以通過 openpyxl 套件來很好的操作 excel 讀寫

Unity判断某个物体是否在某个规定的区域_unity判断物体在范围内-程序员宅基地

文章浏览阅读1.4w次,点赞7次,收藏56次。Unity自带的两种写法:①物体的位置是否在某个区域内Vector3 pos = someRenderer.transform.position;Bounds bounds = myBoxCollider.bounds;bool rendererIsInsideTheBox = bounds.Contains(pos);②物体的矩形与区域的矩形是否交叉Bounds rendererBo..._unity判断物体在范围内

[深度学习] 使用深度学习开发的循线小车-程序员宅基地

文章浏览阅读295次,点赞6次,收藏4次。报错:Unable to find image 'openexplorer/ai_toolchain_centos_7_xj3:v2.3.3' locally。报错: ./best_line_follower_model_xy.pth cannot be opened。可以看到生成的文件 best_line_follower_model_xy.pth。报错:Module onnx is not installed!安装onox,onnxruntime。这是由于没有文件夹的写权限。

MDB-RS232测试NAYAX的VPOS刷卡器注意事项-程序员宅基地

文章浏览阅读393次,点赞10次,收藏8次。MDB-RS232测试NAYAX的VPOS非现金MDB协议刷卡器注意事项

Pytorch和Tensorflow,谁会笑到最后?-程序员宅基地

文章浏览阅读2.5k次。作者 |土豆变成泥来源 |知秋路(ID:gh_4a538bd95663)【导读】作为谷歌tensorflow某项目的Contributor,已经迅速弃坑转向Pytorch。目前Ten..._pytorch与tensorflow