基于pytorch与opencv的手写汉字识别系统_opencv 手写文字识别-程序员宅基地

技术标签: pytorch  人工智能  opencv  

手写汉字识别

b站地址:https://www.bilibili.com/video/BV1384y1P76m/?vd_source=65a01bd1c4223f2aede873e40c0cdb3e

前言

本次实验的任务是汉字识别。使用pytorch深度学习框架和opencv在HWDB手写汉字数据集进行实验。由于数据集过于庞大,这里只选取了前1311个类作为实验。
文末附有源码下载地址。

效果预览

数据集介绍

HWDB是一个手写汉字数据集,该数据集来自于中科院自动化研究所,一共有三个版本,分别为HWDB1.0、HWDB1.1和HWDB1.2。
本文使用的数据集共有1311种汉字,大概共有几十万张图片,其中20%的图片用于验证,80%的图片用于训练。图片的格式为png,下图为部分数据集图片。
在这里插入图片描述

模型介绍(ResNet18)

resnet18的结构图如下所示:
在这里插入图片描述
pytorch内部自带resnet18模型,不过原始的模型最后的分类数为1000,而本文的汉字类别数为1311,所以需要修改模型的最后一层全连接层,代码如下所示:

device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#加载resnet18模型
net=models.resnet18(pretrained=False)
net.conv1=nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
#修改模型最后一层
net.fc=nn.Linear(in_features=512, out_features=1311, bias=True)
net=net.to(device)

读取数据

使用dataset读取数据代码如下:

from torch.utils.data import DataLoader,Dataset
import cv2
import numpy as np
import torch
import imgaug.augmenters as iaa
import random
#读取训练图片类
class Mydataset(Dataset):
    def __init__(self,lines,train=True):
        super(Mydataset, self).__init__()
        #储存图像所有路径
        self.lines=lines
        self.train=train


    def __getitem__(self, item):
        """读取图像,并转换成rgb格式"""
        #图片路径
        img_path=self.lines[item].strip().split()[0]
        #图片标签
        img_lab=self.lines[item].strip().split()[1]


        img=cv2.imread(img_path)[...,::-1]
        # 图像标签转换成整数
        img_lab = int(img_lab)
        #数据增强
        if self.train:
            img=self.get_random_data(img)
        else:img=cv2.resize(img,(64,64))
        #灰度化
        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        #进行二值化
        _,img=cv2.threshold(img,0,255,cv2.THRESH_OTSU)
        img=255-img

        #显示图像
        # cv2.imshow('img',img)
        # cv2.waitKey(0)
        """数据归一化,并在添加一个维度"""
        img=np.expand_dims(img,axis=0)/255
        img=img.astype('float32')


        return img,img_lab


    def __len__(self):
        #返回训练图片数量
        return len(self.lines)

    def get_random_data(self,img):
        """随机增强图像"""
        seq = iaa.Sequential([
            iaa.Multiply((0.8, 1.5)),  # change brightness, doesn't affect BBs(bounding boxes)
            iaa.GaussianBlur(sigma=(0, 1.0)),  # 标准差为0到3之间的值
            iaa.Crop(percent=(0, 0.06)),
            iaa.Grayscale(alpha=(0, 1)),
            iaa.Affine(
                scale=(0.9, 1.),  # 尺度变换
                rotate=(-20, 20),
                cval=(250),
                mode='constant'),
            iaa.Resize(64)
        ])
        img=seq.augment(image=img)
        return img

if __name__ == '__main__':
    lines=open('data.txt','r').readlines()
    mydata=Mydataset(lines=lines)
    myloader=DataLoader(mydata,batch_size=3,shuffle=True)
    for i,j in myloader:
        print(i.shape,j)



训练模型代码

import torch.nn as nn
import torchvision.models as models
import torch
import random
import torch.optim as optim
from dataset import Mydataset
from torch.utils.data import DataLoader
from tqdm import tqdm


#获取学习率函数
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']
#计算准确率函数
def metric_func(pred,lab):
    _,index=torch.max(pred,dim=-1)
    acc=torch.where(index==lab,1.,0.).mean()
    return acc


device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#加载resnet18模型
net=models.resnet18(pretrained=False)
net.conv1=nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
#修改模型最后一层
net.fc=nn.Linear(in_features=512, out_features=1311, bias=True)
net=net.to(device)

#划分训练和验证比例
rate=0.2
"""读取所有训练图像路径,并划分成训练集和验证集"""
lines=open('data.txt','r').readlines()[:2618]
val_lines=random.sample(lines,k=int(len(lines)*rate))
train_lines=list(set(lines)-set(val_lines))


#学习率
lr          = 2e-3
#设置batchsize
batch_size  = 40

num_train   = len(train_lines)
epoch_step  = num_train // batch_size

#设置损失函数
loss_fun     = nn.CrossEntropyLoss()
#设置优化器
optimizer  = optim.Adam(net.parameters(), lr=lr, betas=(0.5, 0.999))
#学习率衰减
lr_scheduler  = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)

"""迭代读取训练数据"""
train_data=Mydataset(train_lines,train=True)
val_data=Mydataset(val_lines,train=False)
train_loader = DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
val_loader   = DataLoader(dataset=val_data,batch_size=batch_size,shuffle=False)

if __name__ == '__main__':

    #设置迭代次数200次
    Epoch=50
    epoch_step = num_train // batch_size
    for epoch in range(1, Epoch + 1):
        net.train()

        total_loss = 0
        loss_sum = 0.0
        with tqdm(total=epoch_step, desc=f'Epoch {
      epoch}/{
      Epoch}', postfix=dict, mininterval=0.3) as pbar:
            for step, (features, labels) in enumerate(train_loader, 1):
                features = features.to(device)
                labels = labels.to(device)
                batch_size = labels.size()[0]

                optimizer.zero_grad()
                out = net(features)
                loss = loss_fun(out, labels)
                loss.backward()
                optimizer.step()
                total_loss += loss
                pbar.set_postfix(**{
    'loss': total_loss.item() / (step),
                                        'lr': get_lr(optimizer)})
                pbar.update(1)


        # 验证
        net.eval()
        acc_sum = 0
        for val_step, (features, labels) in enumerate(val_loader, 1):
            with torch.no_grad():
                features = features.to(device)
                labels = labels.to(device)
                predictions = net(features)
                val_metric = metric_func(predictions, labels)
            acc_sum += val_metric.item()
        print('val_acc=%.4f' % (acc_sum / val_step))

        #保存模型
        if (epoch) % 1 == 0:
            torch.save(net.state_dict(), 'logs/Epoch%d-Loss%.4f_.pth' % (
                epoch, total_loss / (epoch_step + 1)))


        lr_scheduler.step()



训练日志如下所示,验证集准确率可以达到0.95以上
在这里插入图片描述

源代码下载

项目目录如下所示:
在这里插入图片描述
源码地址:下载地址列表2

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/2302_82079084/article/details/135126816

智能推荐

使用nginx解决浏览器跨域问题_nginx不停的xhr-程序员宅基地

文章浏览阅读1k次。通过使用ajax方法跨域请求是浏览器所不允许的,浏览器出于安全考虑是禁止的。警告信息如下:不过jQuery对跨域问题也有解决方案,使用jsonp的方式解决,方法如下:$.ajax({ async:false, url: 'http://www.mysite.com/demo.do', // 跨域URL ty..._nginx不停的xhr

在 Oracle 中配置 extproc 以访问 ST_Geometry-程序员宅基地

文章浏览阅读2k次。关于在 Oracle 中配置 extproc 以访问 ST_Geometry,也就是我们所说的 使用空间SQL 的方法,官方文档链接如下。http://desktop.arcgis.com/zh-cn/arcmap/latest/manage-data/gdbs-in-oracle/configure-oracle-extproc.htm其实简单总结一下,主要就分为以下几个步骤。..._extproc

Linux C++ gbk转为utf-8_linux c++ gbk->utf8-程序员宅基地

文章浏览阅读1.5w次。linux下没有上面的两个函数,需要使用函数 mbstowcs和wcstombsmbstowcs将多字节编码转换为宽字节编码wcstombs将宽字节编码转换为多字节编码这两个函数,转换过程中受到系统编码类型的影响,需要通过设置来设定转换前和转换后的编码类型。通过函数setlocale进行系统编码的设置。linux下输入命名locale -a查看系统支持的编码_linux c++ gbk->utf8

IMP-00009: 导出文件异常结束-程序员宅基地

文章浏览阅读750次。今天准备从生产库向测试库进行数据导入,结果在imp导入的时候遇到“ IMP-00009:导出文件异常结束” 错误,google一下,发现可能有如下原因导致imp的数据太大,没有写buffer和commit两个数据库字符集不同从低版本exp的dmp文件,向高版本imp导出的dmp文件出错传输dmp文件时,文件损坏解决办法:imp时指定..._imp-00009导出文件异常结束

python程序员需要深入掌握的技能_Python用数据说明程序员需要掌握的技能-程序员宅基地

文章浏览阅读143次。当下是一个大数据的时代,各个行业都离不开数据的支持。因此,网络爬虫就应运而生。网络爬虫当下最为火热的是Python,Python开发爬虫相对简单,而且功能库相当完善,力压众多开发语言。本次教程我们爬取前程无忧的招聘信息来分析Python程序员需要掌握那些编程技术。首先在谷歌浏览器打开前程无忧的首页,按F12打开浏览器的开发者工具。浏览器开发者工具是用于捕捉网站的请求信息,通过分析请求信息可以了解请..._初级python程序员能力要求

Spring @Service生成bean名称的规则(当类的名字是以两个或以上的大写字母开头的话,bean的名字会与类名保持一致)_@service beanname-程序员宅基地

文章浏览阅读7.6k次,点赞2次,收藏6次。@Service标注的bean,类名:ABDemoService查看源码后发现,原来是经过一个特殊处理:当类的名字是以两个或以上的大写字母开头的话,bean的名字会与类名保持一致public class AnnotationBeanNameGenerator implements BeanNameGenerator { private static final String C..._@service beanname

随便推点

二叉树的各种创建方法_二叉树的建立-程序员宅基地

文章浏览阅读6.9w次,点赞73次,收藏463次。1.前序创建#include<stdio.h>#include<string.h>#include<stdlib.h>#include<malloc.h>#include<iostream>#include<stack>#include<queue>using namespace std;typed_二叉树的建立

解决asp.net导出excel时中文文件名乱码_asp.net utf8 导出中文字符乱码-程序员宅基地

文章浏览阅读7.1k次。在Asp.net上使用Excel导出功能,如果文件名出现中文,便会以乱码视之。 解决方法: fileName = HttpUtility.UrlEncode(fileName, System.Text.Encoding.UTF8);_asp.net utf8 导出中文字符乱码

笔记-编译原理-实验一-词法分析器设计_对pl/0作以下修改扩充。增加单词-程序员宅基地

文章浏览阅读2.1k次,点赞4次,收藏23次。第一次实验 词法分析实验报告设计思想词法分析的主要任务是根据文法的词汇表以及对应约定的编码进行一定的识别,找出文件中所有的合法的单词,并给出一定的信息作为最后的结果,用于后续语法分析程序的使用;本实验针对 PL/0 语言 的文法、词汇表编写一个词法分析程序,对于每个单词根据词汇表输出: (单词种类, 单词的值) 二元对。词汇表:种别编码单词符号助记符0beginb..._对pl/0作以下修改扩充。增加单词

android adb shell 权限,android adb shell权限被拒绝-程序员宅基地

文章浏览阅读773次。我在使用adb.exe时遇到了麻烦.我想使用与bash相同的adb.exe shell提示符,所以我决定更改默认的bash二进制文件(当然二进制文件是交叉编译的,一切都很完美)更改bash二进制文件遵循以下顺序> adb remount> adb push bash / system / bin /> adb shell> cd / system / bin> chm..._adb shell mv 权限

投影仪-相机标定_相机-投影仪标定-程序员宅基地

文章浏览阅读6.8k次,点赞12次,收藏125次。1. 单目相机标定引言相机标定已经研究多年,标定的算法可以分为基于摄影测量的标定和自标定。其中,应用最为广泛的还是张正友标定法。这是一种简单灵活、高鲁棒性、低成本的相机标定算法。仅需要一台相机和一块平面标定板构建相机标定系统,在标定过程中,相机拍摄多个角度下(至少两个角度,推荐10~20个角度)的标定板图像(相机和标定板都可以移动),即可对相机的内外参数进行标定。下面介绍张氏标定法(以下也这么称呼)的原理。原理相机模型和单应矩阵相机标定,就是对相机的内外参数进行计算的过程,从而得到物体到图像的投影_相机-投影仪标定

Wayland架构、渲染、硬件支持-程序员宅基地

文章浏览阅读2.2k次。文章目录Wayland 架构Wayland 渲染Wayland的 硬件支持简 述: 翻译一篇关于和 wayland 有关的技术文章, 其英文标题为Wayland Architecture .Wayland 架构若是想要更好的理解 Wayland 架构及其与 X (X11 or X Window System) 结构;一种很好的方法是将事件从输入设备就开始跟踪, 查看期间所有的屏幕上出现的变化。这就是我们现在对 X 的理解。 内核是从一个输入设备中获取一个事件,并通过 evdev 输入_wayland

推荐文章

热门文章

相关标签