textcnn模型实践_textcnn epoch一般多少_会发paper的学渣的博客-程序员秘密

技术标签: tensorflow  NLP  tensorflow2.x  深度学习  keras  

对应的tensorflow版本:2.5.0+

textcnn模型如下:

import tensorflow as tf


class ConvMaxPooling1d(tf.keras.layers.Layer):
    def __init__(self, filters, kernel):
        super(ConvMaxPooling1d, self).__init__()
        self.kernel_size = kernel
        #(batch_size, step, embedding_size)->(batch_size,step-kernel_size+1,filter_size)
        self.conv = tf.keras.layers.Conv1D(filters=filters, kernel_size=kernel, activation='relu')
        # (batch_size,step-kernel_size+1,filter_size)->(batch_size,filter_size)
        self.pool = tf.keras.layers.GlobalMaxPool1D()
        tf.random.uniform()

    def call(self, inputs, masks=None):
        conv_out = self.conv(inputs)
        pool_out = self.pool(conv_out)
        return pool_out


class TextCNN(tf.keras.models.Model):
    def __init__(self, vocab, embedding_size, hidden_size, filters_list=[50 ,60, 70, 80], kernels=[2,3, 4, 5],
                 dropout=0.5, sentence_length=20):
        super(TextCNN, self).__init__()
        ind = tf.feature_column.categorical_column_with_vocabulary_file("sentence_vocab", vocabulary_file=vocab,
                                                                        default_value=0)
        self.embedding_size = embedding_size
        self.sentence_length = sentence_length
        self.dense_feature_layer = tf.keras.layers.DenseFeatures(
            [tf.feature_column.embedding_column(ind, dimension=embedding_size)])

        self.conv_maxs = [ConvMaxPooling1d(f, k) for f, k in zip(filters_list, kernels)]
        self.dropout = tf.keras.layers.Dropout(dropout)
        self.dense = tf.keras.layers.Dense(hidden_size, activation='relu')
        self.classifier = tf.keras.layers.Dense(1, activation='sigmoid')

    # @tf.function(input_signature=(tf.TensorSpec(shape=(None, None), dtype=tf.dtypes.string),))
    def call(self, inputs):
        # ***************word token embedding begin***************
        inputs = tf.convert_to_tensor(inputs)
        inputs_tensor = tf.reshape(inputs, (-1, 1))
        embed_word_vectors1 = self.dense_feature_layer({"sentence_vocab": inputs_tensor})
        embeddings = tf.reshape(embed_word_vectors1, (-1, self.sentence_length, self.embedding_size))
        # ***************word token embedding end***************
        #对于每一个layer来说,输入是:(batch_size,step,embedding_size)->(batch_size,step-kernel_size+1,filter_size)
        conv_outs = [layer(embeddings, None) for layer in self.conv_maxs]
        # 对于每一个layer来说,输入是:[(batch_size,step-kernel_size+1,filter_size)]->(batch_size,step-kernel_size+1,sum(filter_size))
        concat_out = tf.concat(conv_outs, axis=-1)
        dense_out = self.dense(concat_out)
        drop_out = self.dropout(dense_out)
        logits = self.classifier(drop_out)

        return logits

模型训练代码如下:

import tensorflow as tf

from model.TextCNN import TextCNN
from utils.path_utils import get_full_path
from utils.read_batch_data import get_data


class SingleRNNModelTest:
    def __init__(self,
                 epoch=3,
                 batch_size=100,
                 embedding_size=256,
                 learning_rate=0.001,
                 model_path="model_version",
                 sentence_vocb_length=20,
                 fill_vocab='TTTTTT',
                 vocab_file_path="data/vocab_clean.txt"
                 ):
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
        self.loss = tf.keras.losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)
        self.epoch = epoch
        self.batch_size = batch_size
        self.model_path = model_path
        self.sentence_vocb_length = sentence_vocb_length
        self.fill_vocab = fill_vocab

        self.model = TextCNN(vocab=get_full_path(vocab_file_path), embedding_size=embedding_size, hidden_size=20,
                             sentence_length=sentence_vocb_length)
        self.summary_writer = tf.summary.create_file_writer('./tensorboard/news_label_model/{}'.format(model_path))


    def train(self):
        # ========================== Create dataset =======================

        train_x,train_y = get_data("data/train_data/prepare/train_data_v2.txt", self.sentence_vocb_length, self.fill_vocab)
        self.model(train_x)
        board = tf.keras.callbacks.TensorBoard(log_dir=get_full_path("data/fit_log/graph"), write_graph=True)
        model_save = tf.keras.callbacks.ModelCheckpoint(get_full_path("data/fit_log/fit_model3"), monitor="val_loss",
                                                        mode="min")
        self.model.compile(optimizer=self.optimizer, loss=self.loss,
                           metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0.5),
                                    tf.keras.metrics.AUC(curve='PR', name='p-r'),
                                    tf.keras.metrics.AUC(curve='ROC', name='ROC'),
                                    ])
        self.model.fit(x=train_x,y=train_y,batch_size=self.batch_size, epochs=self.epoch,shuffle=True,callbacks=[board, model_save])

if __name__ == '__main__':
    # =============================== GPU ==============================
    gpu = tf.config.experimental.list_physical_devices(device_type='GPU')
    print("gpu message:{}".format(gpu))
    # If you have GPU, and the value is GPU serial number.
    import os

    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

    epoch = 20
    batch_size = 1000
    sentence_vocb_length = 25
    embedding_size = 216
    learning_rate = 0.001

    train_instance = SingleRNNModelTest(epoch=epoch, batch_size=batch_size, sentence_vocb_length=sentence_vocb_length,
                                        embedding_size=embedding_size, learning_rate=learning_rate)

    train_instance.train()

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/sslfk/article/details/122681949

智能推荐

用sysbench测试数据库吞吐量指标_configure.ac:49: error: possibly undefined macro: _正义之兔的博客-程序员秘密

1.  在Github上下载最新版本的SysBench,# wget -O sysbench-1.0.14.tar.gz https://github.com/akopytov/sysbench/archive/1.0.14.tar.gz,文件下载为sysbench-1.0.14.tar.gz2. tar -vzxf sysbench-0.4.12.14.tar.gz 解压缩,生成新目录sysbe...

【UOJ #131】【NOI 2015】品酒大会_as2886089的博客-程序员秘密

http://uoj.ac/problem/131求出后缀数组和height数组后,从大到小扫相似度进行合并,每次相当于合并两个紧挨着的区间。合并区间可以用并查集来实现,每个区间的信息都记录在这个区间的并查集的根上,合并并查集时用一个根的信息更新另一个根的信息同时计算两个答案。时间复杂度\(O(n\log n)\)。#include<cstdio>#include&...

7-48 银行排队问题之单窗口“夹塞”版 (30分)--map,vector_Robinxbw的博客-程序员秘密

1 #include <iostream>2 #include<iomanip>3 #include <map>4 #include <string>5 #include <cstring>6 #include <queue>7 #include <vector>8 using names...

02、Spring AOP_排骨玉米汤的博客-程序员秘密

02、Spring AOP1 转账案例1.1 基础功能1.2 传统事务1 转账案例需求使用spring框架整合DBUtils技术,实现用户转账功能1.1 基础功能步骤分析1. 创建java项目,导入坐标 2. 编写Account实体类 3. 编写AccountDao接口和实现类 4. 编写AccountService接口和实现类 5. 编写spring核心配置文件 6. 编写测试代码1)创建java项目,导入坐标<dependencies> <depende

随便推点

android图形框架之surfaceflinger分析(一)_android.hardware.graphics.allocator_welljrj的博客-程序员秘密

1. 概念 surfaceflinger作用是接受多个来源的图形显示数据,将他们合成,然后发送到显示设备。比如打开应用,常见的有三层显示,顶部的statusbar底部或者侧面的导航栏以及应用的界面,每个层是单独更新和渲染,这些界面都是有surfaceflinger合成一个刷新到硬件显示。在显示过程中使用到了bufferqueue,surfaceflinger作为consumer方,...

第GPS定位与高德地图的使用_物联网gps设备上报怎么从高德上接受数据_MX_XXS的博客-程序员秘密

一.GPS定位:android 的三种定位方式二.GPS常用的类:二.GPS定位代码:三.使用高德地图获取定位数据:官网:https://lbs.amap.com四.使用高德地图:官网:https://lbs.amap.com一.GPS定位:android 的三种定位方式1.GPS定位:需要GPS硬件支持,直接和卫星交互来获取当前经纬度。  优点:速度快、精度高、可在无网络情况下使用...

5.2 变长参数表(函数的实参个数可变)编程示例_sjmp的博客-程序员秘密

/* algo5-2.c 变长参数表(函数的实参个数可变)编程示例 */ #include"c1.h" #include /* 实现变长参数表要包括的头文件 */ typedef int ElemType; ElemType Max(int num,...) /* ...表示变长参数表,位于形参表的最后,前面必须有至少一个固定参数 */ { /* 函数功能:返回num个数中的最

找两个字符串的最长公共子串的最大长度_lyl194458的博客-程序员秘密

题目:找两个字符串的最长公共子串的最大长度分析:从较短的那个子串的长度串依次判断是不是长串的子串,即从将可能的最长公共子串开始比较,那么只要匹配到,一定是最长的子串。int max_common_str(const string& s1, const string& s2){ string longstr; string shortstr; if (s1.s...

字节跳动再扩招1万人,太猛了。。。_程序员小乐的博客-程序员秘密

去年由于这该死的疫情,我被工作6年的旅游集团裁员了,不过当时我并不在意,想着是自己有8年的工作经验,找工作肯定轻轻松松的。索性就直接给自己放了一个月假,天天吃喝玩乐,直到媳妇催我的时候,我...

Linq 基本语法_linq 语法_锤子馆长的博客-程序员秘密

以下都是转载内容1.简单的linq语法 //1 var ss = from r in db.Am_recProScheme select r; //2 var ss1 = db.Am_recProScheme; //3