convLSTM 理解与实现_convlstm实现_Runner_of_nku的博客-程序员秘密

技术标签: 机器学习  

本文主要是有关convLSTM的pytorch实现代码的理解,原理请移步其他博客。

在pytorch中实现LSTM或者GRU等RNN一般需要重写cell,每个cell中包含某一个时序的计算,也就是以下:

在传统LSTM中,LSTM每次要调用t次cell,t就是时序的总长度,如果是n层LSTM就相当于一共调用了n*t次cell

class ConvLSTMCell(nn.Module):

    def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias):
        """
        Initialize ConvLSTM cell.

        Parameters
        ----------
        input_size: (int, int)
            Height and width of input tensor as (height, width).
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """

        super(ConvLSTMCell, self).__init__()

        self.height, self.width = input_size
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias

        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)

    def forward(self, input_tensor, cur_state):
        '''
        :param input_tensor:[batch,dim,inp_height,inp_width]
        :param cur_state: [h,c] h:[batch,dim,H,W]
        :return:
        '''
        h_cur, c_cur = cur_state

        combined = torch.cat([input_tensor, h_cur], dim=1)  

        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size):
        return (Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)),
                Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)))

 cell的实现主要如上,其中可以看出来,参数是通过一次Conv2d的调用就申请好的,这里是这样理解的:

in_channels=self.input_dim + self.hidden_dim

out_channels=4 * self.hidden_dim

由于在计算 i f c o 时我们都引入了卷积运算,且运算的对象都是一样的,要么是对当前时序的输入input要么是对上一时刻的输出h,所以Wxi Wxf Wcf Wxo的卷积核规模或者说参数的规模是相同的,是相同size的tensor,同理Whi Whf Whc Wco也是一样。

在当前时刻,我们的输入有input[batch,input_dim,H,W],h[batch,hidden_dim,H,W](这里的input和h的HW是相同的,因为在设置卷积padding的时候约定了输入的规模)然后我们通过torch.cat将输入合并为[batch,input_dim+hidden_dim,H,W]

此时我们的核规模是[4*hidden_dim,input_dim+hidden_dim,kH,kW]

在卷积过程中,这里参考pytorch文档

 这个公式可以理解为矩阵乘运算,不过矩阵乘运算中进行的是每个元素相乘相加,这里的矩阵乘运算是矩阵中每个核和每个输入的tensor进行卷积操作再相加。

举个栗子

假设我们的输入是28*28的手写体,为了符合时序的要求,我们将手写体划分为16个时序的7*7图片,核大小设置为(3,3),此时我们的输入输出的dim都是1,所以以上的input=[batch,1,7,7] h=[batch,1,7,7] core=[4,2,3,3]

不考虑bias结果就如图所示了,所以再经过一次split切割,就能拿到很多中间步骤的结果啦。 

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/Runner_of_nku/article/details/98876653

智能推荐

C#AE练习 (6)创建与编辑数据_菜鸟学飞ing去看世界的博客-程序员秘密

创建与编辑数据//在CreateFields函数里创建字段集public IFieldsEdit CreateFields(){ //使用编辑接口IFieldEdit(IFieldsEdit)创建新的Field对象或新的Fields集 //创建字段集合. 使用IFieldsEdit接口,要将Field对象加入到Fields集中 //IFields pFields = new Field...

sqlserver 字符串拆分和取某分隔符之前的字符串_weixin_33770878的博客-程序员秘密

ALTER FUNCTION [dbo].[f_splitSTR]( @s varchar(8000), --待分拆的字符串 @split varchar(10) --数据分隔符 )RETURNS @re TABLE( col varchar(max)) AS BEGIN DECLARE ...

MATLAB 条形图或饼状图 图案填充_aizhao3648的博客-程序员秘密

function [im_hatch,colorlist] = applyhatch_pluscolor(h,patterns,CvBW,Hinvert,colorlist, ... dpi,hatchsc,lw)%APPLYHATCH_PLUSCOLOR Apply hatched...

Thinkphp3.2 中的where条件复杂条件下的条件组合_DREAM-追梦的博客-程序员秘密

组合查询的主体还是采用数组方式查询,只是加入了一些特殊的查询支持,包括字符串模式查询(_string)、复合查询(_complex)、请求字符串查询(_query),混合查询中的特殊查询每次查询只能定义一个,由于采用数组的索引方式,索引相同的特殊查询会被覆盖。一、字符串模式查询数组条件可以和字符串条件(采用_string 作为查询条件)混合使用,例如:$User = M("Us

前端发送http请求的几种方式_LV0720的博客-程序员秘密

前端发送http的几种方式1. XMLHttpRequest2. ajax3. axios4. fetch1. XMLHttpRequest所有现代浏览器均内建了XMLHttpRequest对象,IE5、IE6使用ActiveX对象。 var xmlHttp; if(window.XMLHttpRequest){ xmlHttp = new XMLHttpRequest(); }else{ xmlHttp = new ActiveXObject(); } xmlHttp.open(m

简单易学Matlab深度学习教程--数据可视化_SUNNY小飞的博客-程序员秘密

数据可视化要使用plot函数来绘制图形,需要执行以下步骤:通过指定要绘制函数的变量x的值的范围来定义x。定义函数,y = f(x)调用plot命令,如下:plot(x,y)以下示例将演示该概念。下面绘制x的值范围是从0到100,使用简单函数y = x,增量值为5。创建脚本文件并键入以下代码 -x = [0:5:100];y = x;plot(x, y)在图上添加标题,标签,...

随便推点

ROS的安装、卸载以及Turtlebot包的安装_ros-hydro-turtlebot-simulator下载安装_张京林要加油的博客-程序员秘密

一、前言我的运行环境: 操作系统:Ubuntu Kylin 14.04 ROS版本:Indigo 背景说明:本文记录了ROS的安装和卸载过程与Turtlebot包的安装过程以及其间遇到的问题二、ROS的安装1、ROS与Ubuntu版本的考虑ROS作为机器人操作系统并不是一个像Windows、Linux那样可以独立运行的操作系统,他需要运行在Linux环境下。所以运行ROS的前提是电脑要装有Li

套接字的select、WsaAsyncSelect、WsaEventSelect模型_kakaluote3223的博客-程序员秘密

套接字的select、WsaAsyncSelect、WsaEventSelect模型的区别及实现

Mellanox Infiniband 架构设计快速实战指南 - A_ShawnTheLearner的博客-程序员秘密

Nvidia GTC 2020正在进行中,Nvidia(英伟达)最终完成了对以色列高速计算网络厂商Mellanox的收购,再一次被黄老板高度评价。Mellanox Infiniband至今依然是高速计算互联网络的主流厂商和解决方案,在HPC领域应用广泛。本文基于Mellanox官方提供的材料,介绍如何快速上手Infiniband高速互联网络的设计规划。

SPP框架的基本使用_anmengdai0123的博客-程序员秘密

入职两天Day1、Day2:学习SPP框架SPP是什么?SPP提供了一系列的基础功能,是一个通用的网络服务器运行框架。主要由proxy,worker,controller三个模块组成。它提供API给开发人员,因此只需要关心业务逻辑的处理,进行插件开发。直接调用其暴露出来的接口开发可以节省开发时间,提高效率。SPP的运作流程?流程总结:Client发送请求会由proxy进行...

win10家庭版升级教育版,专业版和企业版最新密钥和方法分享_教育版升级专业版_ysd948006909的博客-程序员秘密

很多同学购买电脑时自带的系统都是windows家庭版的,在使用时会发现windows10家庭版系统主题太少,还有其他功能也有很多限制,今天小编就教你如何升级自己的windows10系统.一.升级系统最重要的是有可用的密钥,今天小编就把各个版本的升级密钥给大家分享一下,希望对大家有用,如果发现密钥已经过期或者升级激活遇到麻烦,可以联系小编Vx: xiaotudou1803508  获取最新密钥或者提...

elementui基础学习_暗影杀神的博客-程序员秘密

一、前后端分离1.什么是前后端分离?前端和后端分离开前端: 将浏览器中为用户进行页面展示的部分称之为前端后端: 为前端提供业务逻辑和数据准备的所有代码统称为后端前后端分工:​ 前后端开发工作的分工就是前后端分离。 前后端分工;(错误的认识)真的前后端分离:​ 不仅仅是前端和后端的分工开发,而是架构的模式2.交互形式在前后端分离架构中,后端只需要负责按照约定的数据格...

推荐文章

热门文章

相关标签