【机器学习实战】利用KNN和其他分类器对手写数字进行识别_使用knn实现对sklearn自带的手写数据集digits的分类任务并计算其准确率,之后,使用-程序员宅基地

技术标签: 分类器  手写数字识别  机器学习  

一、在sklearn中创建KNN分类器

KNeighborsClassifier(n_neighbors=5, weights='uniform', algorithm='auto', leaf_size=30)

看一下这几个参数

1. n_neighbors:即 KNN 中的 K 值,代表的是邻居的数量。如果K 值比较小,会造成过拟合;如果 K 值比较大,无法将未知物体分类出来。一般我们使用默认值 5。

2. weights:是用来确定邻居的权重,有两种方式:

  1. weights=‘uniform’,代表所有邻居的权重相同;
  2. weights=‘distance’,代表权重是距离的倒数,即与距离成反比。

3. algorithm:用来规定计算邻居的方法,它有四种方式:

  1. algorithm=‘auto’,根据数据的情况自动选择适合的算法,默认情况选择 auto;
  2. algorithm=‘kd_tree’,也叫作 KD 树,是多维空间的数据结构,方便对关键数据进行检索,不过 KD 树适用于维度少的情况,一般维数不超过 20,如果维数大于 20 之后,效率反而会下降;
  3. algorithm=‘ball_tree’,也叫作球树,它和 KD 树一样都是多维空间的数据结果,不同于 KD 树,球树更适用于维度大的情况;
  4. algorithm=‘brute’,也叫作暴力搜索,它和 KD 树不同的地方是在于采用的是线性扫描,而不是通过构造树结构进行快速检索。当训练集大的时候,效率很低。

4.leaf_size:代表构造 KD 树或球树时的叶子数,默认是 30,调整 leaf_size 会影响到树的构造和搜索速度。

创建完 KNN 分类器之后,我们就可以输入训练集对它进行训练,这里我们使用 fit() 函数,传入训练集中的样本特征矩阵和分类标识,会自动得到训练好的 KNN 分类器。然后可以使用 predict() 函数来对结果进行预测,这里传入测试集的特征矩阵,可以得到测试集的预测分类结果。

二、工作流程

我们用 sklearn 自带的手写数字数据集做 KNN 分类,你可以把这个数据集理解成一个简版的 MNIST 数据集,它只包括了 1797 幅数字图像,每幅图像大小是 8*8 像素。

先划分一下流程:

整个训练过程基本上都会包括三个阶段:

  1. 数据加载:直接从 sklearn 中加载自带的手写数字数据集;
  2. 准备阶段:在这个阶段中,我们需要对数据集有个初步的了解,比如样本的个数、图像长什么样、识别结果是怎样的。你可以通过可视化的方式来查看图像的呈现。通过数据规范化可以让数据都在同一个数量级的维度。另外,因为训练集是图像,每幅图像是个 8*8 的矩阵,我们不需要对它进行特征选择,将全部的图像数据作为特征值矩阵即可
  3. 分类阶段:通过训练可以得到分类器,然后用测试集进行准确率的计算。

三、实战环节

1.导包

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.datasets import load_digits
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.model_selection import train_test_split

from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
from sklearn.naive_bayes import MultinomialNB
from sklearn.ensemble import AdaBoostClassifier

from sklearn.metrics import accuracy_score

2.加载数据并探索

# 加载数据
digits = load_digits()
data = digits.data

# 数据探索
print(data.shape)
# 查看第一幅图像
print(digits.images[0])
# 第一幅图像代表的数字含义
print(digits.target[0])

# 将第一幅图像显示出来
plt.imshow(digits.images[0])
plt.show()

输出:

(1797, 64)
[[ 0.  0.  5. 13.  9.  1.  0.  0.]
 [ 0.  0. 13. 15. 10. 15.  5.  0.]
 [ 0.  3. 15.  2.  0. 11.  8.  0.]
 [ 0.  4. 12.  0.  0.  8.  8.  0.]
 [ 0.  5.  8.  0.  0.  9.  8.  0.]
 [ 0.  4. 11.  0.  1. 12.  7.  0.]
 [ 0.  2. 14.  5. 10. 12.  0.  0.]
 [ 0.  0.  6. 13. 10.  0.  0.  0.]]
0

我们对原始数据集中的第一幅进行数据可视化,可以看到图像是个 8*8 的像素矩阵,上面这幅图像是一个“0”,从训练集的分类标注中我们也可以看到分类标注为“0”。

3.分割数据集并规范化

sklearn 自带的手写数字数据集一共包括了 1797 个样本,每幅图像都是 8*8 像素的矩阵。因为并没有专门的测试集,所以我们需要对数据集做划分,划分成训练集和测试集。因为 KNN 算法和距离定义相关,我们需要对数据进行规范化处理,采用 Z-Score 规范化,代码如下:

# 数据及目标
data1 = digits.data
target1 = digits.target

# 分割数据,将25%的数据作为测试集,其余作为训练集(你也可以指定其他比例的数据作为训练集)
train_x, test_x, train_y, test_y = train_test_split(data1, target1, test_size=0.25)

# 采用z-score规范化
ss = StandardScaler()
train_ss_scaled = ss.fit_transform(train_x)
test_ss_scaled = ss.transform(test_x)

# 采用0-1归一化
mm = MinMaxScaler()
train_mm_scaled = mm.fit_transform(train_x)
test_mm_scaled = mm.transform(test_x)

这里之所以用了0-1归一化,是因为多项式朴素贝叶斯分类这个模型,传入的数据不能有负数。因为 Z-Score 会将数值规范化为一个标准的正态分布,即均值为 0,方差为 1,数值会包含负数。因此我们需要采用 Min-Max 规范化,将数据规范化到[0,1]范围内。

4.建立模型,并进行比较

这里构造五个分类器, 分别是K近邻,SVM, 多项式朴素贝叶斯, 决策树模型, AdaBoost模型。并分别看看他们的效果。

models = {}
models['knn'] = KNeighborsClassifier()
models['svm'] = SVC()
models['bayes'] = MultinomialNB()
models['tree'] = DecisionTreeClassifier()
models['ada'] = AdaBoostClassifier(base_estimator=models['tree'], learning_rate=0.1)

for model_key in models.keys():
    if model_key == 'knn' or model_key == 'svm' or model_key == 'ada':
        model = models[model_key]
        model.fit(train_ss_scaled, train_y)
        predict = model.predict(test_ss_scaled)
        print(model_key, "准确率:", accuracy_score(test_y, predict))
    else:
        model = models[model_key]
        model.fit(train_mm_scaled, train_y)
        predict = model.predict(test_mm_scaled)
        print(model_key, "准确率: ", accuracy_score(test_y, predict))

输出:

knn 准确率: 0.9777777777777777
svm 准确率: 0.9866666666666667
bayes 准确率:  0.8888888888888888
tree 准确率:  0.8444444444444444
ada 准确率: 0.8355555555555556

你能看出来 KNN 的准确率还是不错的,和 SVM 不相上下。并且竟然比AdaBoost效果都要好,而让我纳闷的是决策树和AdaBoost怎么效果这么差,不可思议。后来我发现了,原来是样本数量的问题,我们最多数据集才1000多照片,数量太少了,AdaBoost的作用发挥不出来,所以我对数据进行扩增,复制了三遍原来的数据:

data2 = np.vstack((data1, data1, data1))
target2 = np.hstack((target1, target1, target1))

变成了5000多张数据,然后再进行测试,结果就是AdaBoost和tree的效果提升了,甚至可以和SVM效果媲美了。

输出:

knn 准确率: 0.9821958456973294
svm 准确率: 0.9970326409495549
bayes 准确率:  0.9013353115727003
tree 准确率:  0.9955489614243324
ada 准确率: 0.9933234421364985

 

 

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_40431584/article/details/105123066

智能推荐

Cookie跨域以及Cookie共享问题_跨域cookie共享-程序员宅基地

文章浏览阅读5.3k次。解决跨域以后,如何允许跨域请求携带cookie,例如访问B的接口,默认情况下是不允许带cookie的,此时需要设置axios的withcredentials的属性为true,告诉浏览器在访问B网站时,将B网站的cookie带上,此时光前端设置还不行,还需要后端在响应头中添加 allow-withcredentials = true,这样就可以保证跨域请求也可以携带cookie。在站点A下面访问B域名的接口,那么这是一个跨域请求,如果不做处理,此时这个请求就跨域了,浏览器在接收到响应以后会直接报错。_跨域cookie共享

安全架构与企业风险管理:实现全面的安全保障-程序员宅基地

文章浏览阅读456次,点赞14次,收藏20次。1.背景介绍在当今的数字时代,数据安全和企业风险管理已经成为企业最关键的问题之一。随着互联网和人工智能技术的发展,企业数据的安全性和隐私保护成为了越来越重要的问题。因此,安全架构和企业风险管理已经成为企业最关键的问题之一。在这篇文章中,我们将讨论安全架构与企业风险管理的关系,以及如何实现全面的安全保障。我们将从以下几个方面进行讨论:背景介绍核心概念与联系核心算法原理和具体操作步骤...

P问题、NP问题、NPC问题、NP hard问题-程序员宅基地

文章浏览阅读4.1w次,点赞50次,收藏268次。图论算法摘要1. 图的概念图一个图(graph) G=(V,E)G=(V,E)G=(V,E) 由顶点(vertex)集 VVV 和边(edge)集 EEE 组成。每一条边就是一个点对 (a,b),a,b∈V(a,b),a,b∈V(a,b),a,b∈V。有时候也把边叫做弧(arc)。有向图如果点对(a,b),a,b∈V(a,b),a,b∈V(a,b),a,b∈V是有序的,那么图就是有向的..._npc问题

【UE4 C++】大规模人群绕行避让的最优解DetourCrowdAIController如何开启_detour crowd-程序员宅基地

文章浏览阅读7.6k次,点赞8次,收藏24次。目录问题阐述与解决效果RVO Avoidance与Detour Crowd AI Controller的区别如何使用Detour Crowd AI Controller蓝图C++弃用的写法新版写法(也很简便)问题阐述与解决效果在项目存在大规模寻路人群时,很容易出现两个角色的寻路路径相冲突,就会造成这种互斥现象。使用AI ControllerUE4为此..._detour crowd

《信息系统安全》课后习题答案(陈萍)_信息系统安全第二版课后答案-程序员宅基地

文章浏览阅读1.1w次,点赞22次,收藏146次。《信息系统安全》教材(作者:陈萍,张涛,赵敏)的课后习题答案_信息系统安全第二版课后答案

Ajax跨域问题_ajax请求跨域-程序员宅基地

文章浏览阅读3.2k次,点赞3次,收藏13次。ajax 是不能跨域。那么怎么解决前端发送请求的跨域问题呢。超详细,1、设置响应头、2、通过jsonp 3、通过调用jQuery封装的jsonp 4、httpclient 5、nginx_ajax请求跨域

随便推点

学做 方玲玉 网络营销_网络营销实务(方玲玉)课件及习题参考答案-程序员宅基地

文章浏览阅读1.2k次。内容简介:网络营销实务(方玲玉)课件及参考答案教学内容第01讲 网络正在改写传统商业规则第02讲 网络营销:传统营销的继承与超越(1)实训1 传统企业经营现状及网络平台建设情况调研第03讲 网络营销:传统营销的继承与超越(2)第04讲 创新创意:网络营销的核心竞争力实训2 成功网络卖家网络营销创新创意分析第05讲 目标市场及竞争对手分析第06讲 网民消费模式分析实训3 网络目标用户、竞争对手及消费..._网络营销实务课后题答案

JAVA java学习(16)——————javaweb主流框架介绍(小结)_javaweb框架-程序员宅基地

文章浏览阅读993次。Java Web开发的用到的框架之多简直令人发指,而且因为版本的更新换代导致的问题也是层出不穷。然而这也是Web技术不断演化的结果,要么选择接受,要么引领节奏。原来常用的Javaweb框架是SSH(Struts + Spring + Hibernate)后来随着Spring的强大以及Struts漏洞上的等等问题,演变成为了Spring + SpringMVC + Hibernate/Mybatis。互联网这块比较常见的是Mybatis。再后来也慢慢演变为了Springboot + Mybatis。1. _javaweb框架

和我一起写lua - 确认操作系统-程序员宅基地

文章浏览阅读294次。最近写的lua脚本需要运行在多个平台,因而一些平台相关的属性必须区别设置。如路径分隔符。在lua中,没有找到相关判断操作系统的函数。因此相关设置一直手工设置,增加了环境配置的时间。 在luarocks模块中,有一个luarocks.site_config模块(一个lua文件),其安装时便设定了操作系统类型。因此我们可以从这个模块获取操作系统:示例:require "..._lua判断操作系统

广度优先搜索算法及其MATLAB实现_广度优先算法可行路径matlab-程序员宅基地

文章浏览阅读6.2k次,点赞6次,收藏39次。摘要广度优先搜索算法(又称宽度优先搜索)是最简便的图的搜索算法之一,这一算法也是很多重要的图的算法的原型。Dijkstra单源最短路径算法和Prim最小生成树算法都采用了和宽度优先搜索类似的思想。其别名又叫BFS,属于一种盲目搜寻法,目的是系统地展开并检查图中的所有节点,以找寻结果。换句话说,它并不考虑结果的可能位置,彻底地搜索整张图,直到找到结果为止。(来自百度百科)算法思想1.对图中的任..._广度优先算法可行路径matlab

微信和支付宝相关支付业务场景介绍_支付宝的应用场景-程序员宅基地

文章浏览阅读1.1w次,点赞5次,收藏38次。支付宝 当面付 条码支付 应用场景:商家使用扫码设备,扫描用户支付宝钱包上的条码/二维码,完成收款。支付流程:API列表: 接口名称 描述 API地址 alipay.trade.pay 统一收单交易支付接口 https://docs.op..._支付宝的应用场景

iphone隐藏底条_iPhone12隐藏底部横条方法 iPhone12怎么隐藏底部小白条-程序员宅基地

文章浏览阅读7.7k次。iPhone12怎么隐藏底部小白条?很多iPhone 12用户反馈在看手机或者玩游戏的时候,屏幕底部的小白横条非常碍眼,但是又不知道怎么隐藏掉,所以小编今天整理了下iPhone12隐藏底部横条方法,帮大家一键隐藏底部横条,一起来看看吧!iPhone12隐藏底部横条方法:利用“引导式访问“功能。打开 iPhone “设置”-“辅助功能”,下拉找到“引导式访问”并开启: 在使用该功能之前,建议仔细阅..._iphone玩王者荣耀怎么把下面那个横条去掉

推荐文章

热门文章

相关标签