sklearn数据集分割方法汇总_sklearn split_会飞的哼哧的博客-程序员宅基地

技术标签: 机器学习  数据集划分  

一、简介

  在现实的机器学习任务中,我们往往是利用搜集到的尽可能多的样本集来输入算法进行训练,以尽可能高的精度为目标,但这里便出现一个问题,一是很多情况下我们不能说搜集到的样本集就能代表真实的全体,其分布也不一定就与真实的全体相同,但是有一点很明确,样本集数量越大则其接近真实全体的可能性也就越大;二是很多算法容易发生过拟合(overfitting),即其过度学习到训练集中一些比较特别的情况,使得其误认为训练集之外的其他集合也适用于这些规则,这使得我们训练好的算法在输入训练数据进行验证时结果非常好,但在训练集之外的新测试样本上精度则剧烈下降,这样训练出的模型可以说没有使用价值;因此怎样对数据集进行合理的抽样-训练-验证就至关重要,下面就对机器学习中常见的抽样技术进行介绍,并通过sklearn进行演示;

 

二、留出法

  留出法(hold-out)在前面的很多篇博客中我都有用到,但当时没有仔细介绍,其基本思想是将数据集D(即我们获得的所有样本数据)划分为两个互斥的集合,将其中一个作为训练集S,另一个作为验证集T,即D=SUT,S∩T=Φ。在S上训练出模型后,再用T来评估其测试误差,作为泛化误差的估计值;

  需要注意的是,训练集/验证集的划分要尽可能保持数据分布的一致性,尽量减少因数据划分过程引入额外的偏差而对最终结果产生的影响,例如在分类任务中,要尽量保持ST内的样本各个类别的比例大抵一致,这可以通过分层抽样(stratified sampling)来实现;

  因为我们希望实现的是通过这个留出法的过程来评估数据集D的性能,但由于留出法需要划分训练集与验证集,这就不可避免的减少了训练素材,若验证集样本数量过于小,导致训练集与原数据集D接近,而与验证集差别过大,进而导致无论训练出的模型效果如何,都无法在验证集上取得真实的评估结果,从而降低了评估效果的保真性(fidelity),因此训练集与验证集间的比例就不能过于随便,通常情况下我们将2/3到4/5的样本划分出来用于训练;

  在sklearn中我们使用sklearn.model_selection中的train_test_split()来分割我们的数据集,其具体参数如下:

X:待分割的样本集中的自变量部分,通常为二维数组或矩阵的形式;

y:待分割的样本集中的因变量部分,通常为一维数组;

test_size:用于指定验证集所占的比例,有以下几种输入类型:

  1.float型,0.0~1.0之间,此时传入的参数即作为验证集的比例;

  2.int型,此时传入的参数的绝对值即作为验证集样本的数量;

  3.None,这时需要另一个参数train_size有输入才生效,此时验证集去为train_size指定的比例或数量的补集;

  4.缺省时为0.25,但要注意只有在train_size和test_size都不输入值时缺省值才会生效;

train_size:基本同test_size,但缺省值为None,其实test_size和train_size输入一个即可;

random_state:int型,控制随机数种子,默认为None,即纯随机(伪随机);

stratify:控制分类问题中的分层抽样,默认为None,即不进行分层抽样,当传入为数组时,则依据该数组进行分层抽样(一般传入因变量所在列);

shuffle:bool型,用来控制是否在分割数据前打乱原数据集的顺序,默认为True,分层抽样时即stratify为None时该参数必须传入False;

返回值:

依次返回训练集自变量、测试集自变量、训练集因变量、测试集因变量,因此使用该函数赋值需在等号右边采取X_train, X_test, y_train, y_test'的形式;

下面以鸢尾花数据(三个class)为例,分别演示简单随机抽样和分层抽样时的不同情况:

未分层时:

from sklearn.model_selection import train_test_split
from sklearn import datasets
import pandas as pd

'''载入数据'''
X,y = datasets.load_iris(return_X_y=True)

'''不采取分层抽样时的数据集分割'''
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.3)

'''打印各个数据集的形状'''
print(X_train.shape,X_test.shape,y_train.shape,y_test.shape)

'''打印训练集中因变量的各类别数目情况'''
print(pd.value_counts(y_train))

'''打印验证集集中因变量的各类别数目情况'''
print(pd.value_counts(y_test))

分层时:

'''采取分层抽样时的数据集分割'''
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.3,stratify=y)

'''打印各个数据集的形状'''
print(X_train.shape,X_test.shape,y_train.shape,y_test.shape)

'''打印训练集中因变量的各类别数目情况'''
print(pd.value_counts(y_train))

'''打印验证集集中因变量的各类别数目情况'''
print(pd.value_counts(y_test))

 

三、交叉验证法

  交叉验证法(cross validation)先将数据集D划分为k个大小相似的互斥子集,即D=D1UD2U...UDk,Di∩Dj=Φ(i≠j),每个子集Di都尽可能保持数据分布的一致性,即从D中通过分层采样得到。然后每次用k-1个子集的并集作为训练集,剩下的那一个子集作为验证集;这样就可获得k组训练+验证集,从而可以进行k次训练与测试,最终返回的是这k个测试结果的均值。显然,交叉验证法的稳定性和保真性在很大程度上取决与k的取值,因此交叉验证法又称作“k折交叉验证”(k-fold cross validation),k最常见的取值为10,即“10折交叉验证”,其他常见的有5,20等;

  假定数据集D中包含m个样本,若令k=m,则得到了交叉验证法的一个特例:留一法(Leave-one-out),显然,留一法不受随机样本划分方式的影响,因为m个样本只有唯一的方式划分m个子集——每个子集包含一个样本,留一法使用的训练集与初始数据集相比只少了一个样本,这就使得在绝大多数情况下,留一法中被实际评估的模型与期望评估的用D训练出的模型很相似,因此,留一法的评估结果往往被认为比较准确,但其也有一个很大的缺陷:当数据集比较大时,训练m个模型的计算成本是难以想象的;

在sklearn.model_selection中集成了众多用于交叉验证的方法,下面对其中常用的进行介绍:

 

cross_val_score():

  这是一个用于直接计算某个已确定参数的模型其交叉验证分数的方法,具体参数如下:

estimator:已经初始化的学习器模型;

X:自变量所在的数组;

y:因变量所在的数组;

scoring:str型,控制函数返回的模型评价指标,默认为准确率;

cv:控制交叉验证中分割样本集的策略,即k折交叉中的k,默认是3,即3折交叉验证,有以下多种输入形式:

  1.int型,则输入的参数即为k;

  2.None,则使用默认的3折;

  3.一个生成器类型的对象,用来控制交叉验证,优点是节省内存,下面的演示中会具体介绍;

  *若estimator是一个分类器,则默认使用分层抽样来产生子集。

n_jobs:int型,用来控制并行运算中使用的核心数,默认为1,即单核;特别的,设置为-1时开启所有核心;

函数返回值:

对应scoring指定的cv个评价指标;

下面以一个简单的小例子进行演示:

from sklearn.model_selection import cross_val_score
from sklearn import datasets
from sklearn.neighbors import KNeighborsClassifier


X,y = datasets.load_breast_cancer(return_X_y=True)

clf = KNeighborsClassifier()


'''打印每次交叉验证的准确率'''
score = cross_val_score(clf,X,y,cv=5,scoring='accuracy')

print('accuracy:'+str(score)+'\n')

'''打印每次交叉验证的f1得分'''
score = cross_val_score(clf,X,y,cv=5,scoring='f1')

print('f1 score:'+str(score)+'\n')

'''打印正确率的95%置信区间'''
print(str(round(score.mean(),3))+'(+/-'+str(round(2*score.std(),3))+')')

 

cross_validate():

  这个方法与cross_val_score()很相似,但有几处新特性:

  1.cross_validate()可以返回多个评价指标,这在需要一次性产生多个不同种类评分时很方便;

  2.cross_validate()不仅返回模型评价指标,还会返回训练花费时长、

 其具体参数如下:

estimator:已经初始化的分类器模型;

X:自变量;

y:因变量;

scoring:字符型或列表形式的多个字符型,控制产出的评价指标,可以通过在列表中写入多个评分类型来实现多指标输出;

cv:控制交叉验证的子集个数;

n_jobs:控制并行运算利用的核心数,同cross_val_score();

return_train_score:bool型,控制是否在得分中计算训练集回带进模型的结果;

函数输出项:字典形式的训练时间、计算得分时间、及各得分情况;

下面以一个简单的小例子进行说明:

from sklearn.model_selection import cross_validate
from sklearn import datasets
from sklearn.neighbors import KNeighborsClassifier


X,y = datasets.load_breast_cancer(return_X_y=True)

clf = KNeighborsClassifier()

'''定义需要输出的评价指标'''
scoring = ['accuracy','f1']

'''打印每次交叉验证的准确率'''
score = cross_validate(clf,X,y,scoring=scoring,cv=5,return_train_score=True)

score

 

 

四、基于生成器的采样方法

  sklearn中除了上述的直接完成整套交叉验证的方法外,还存在着一些基于生成器的方法,这些方法的好处是利用Python中生成器(generator)的方式,以非常节省内存的方式完成每一次的交叉验证,下面一一罗列:

 

KFold():

  以生成器的方式产出每一次交叉验证所需的训练集与验证集,其主要参数如下:

n_splits:int型,控制k折交叉中的k,默认是3;

shuffle:bool型,控制是否在采样前打乱原数据顺序;

random_state:设置随机数种子,默认为None,即不固定随机水平;

下面以一个简单的小例子进行演示:

from sklearn.model_selection import KFold
import numpy as np


X = np.random.randint(1,10,20)

kf = KFold(n_splits=5)

for train,test in kf.split(X):
    print(train,'\n',test)

 

 LeaveOneOut():

  对应先前所介绍的留出法中的特例,留一法,因为其性质很固定,所以无参数需要调节,下面以一个简单的小例子进行演示:

from sklearn.model_selection import LeaveOneOut
import numpy as np


X = np.random.randint(1,10,5)

kf = LeaveOneOut()

for train,test in kf.split(X):
    print(train,'\n',test)

 

 LeavePOut():

  LeaveOneOut()的一个变种,唯一的不同就是每次留出p个而不是1个样本作为验证集,唯一的参数是p,下面是一个简单的小例子:

from sklearn.model_selection import LeavePOut
import numpy as np


X = np.random.randint(1,10,5)

kf = LeavePOut(p=2)

for train,test in kf.split(X):
    print(train,'\n',test)

 

TimeSeriesSplit():

  在机器学习中还存在着一种叫做时间序列的数据类型,这种数据的特点是高度的自相关性,前后相邻时段的数据关联程度非常高,因此在对这种数据进行分割时不可以像其他机器学习任务那样简单随机抽样的方式采样,对时间序列数据的采样不能破坏其时段的连续型,在sklearn.model_selection中我们使用TimeSeriesSplit()来分割时序数据,其主要参数如下:

n_splits:int型,控制产生(训练集+验证集)的数量;

max_train_size:控制最大的时序数据长度;

下面是一个简单的小例子:

from sklearn.model_selection import TimeSeriesSplit
import numpy as np


X = np.random.randint(1,10,20)

kf = TimeSeriesSplit(n_splits=4)

for train,test in kf.split(X):
    print(train,'\n',test)

 

  以上就是sklearn中关于样本抽样的常见功能,如有笔误,望指出。

 

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_38233659/article/details/101553277

智能推荐

数据可视化,Seaborn画图原来这么好看_saeborn功能介绍_web前端学习扣群:244500143的博客-程序员宅基地

matplotlib是python最常见的绘图包,强大之处不言而喻。然而在数据科学领域,可视化库-Seaborn也是重量级的存在。由于matplotlib比较底层,想要绘制漂亮的图非常麻烦,需要写大量的代码。Seaborn是在matplotlib基础上进行了高级API封装,图表装饰更加容易,你可以用更少的代码做出更美观的图。同时,Seaborn高度兼容了numy、pandas、scipy等..._saeborn功能介绍

浅析Mysql索引Innodb及底层存储结构B+树_沉默加速度gaoys的博客-程序员宅基地

前言最近博主在学习mysql索引相关知识,看了很多博客,公开课然后自己总结一下,最近的收获吧。本文从二叉树->平衡二叉树->B树(有同学读B减树,是不正确的) ->B+树 再到B+树在Myisam和Innodb中的体现形式,会提到索引失效的情况,以及创建索引时的一些注意事项及其原理。采用大量的图片+部分文字更加清晰的描述。其中部分图片来源于咕泡学院公开课。准备提到my...

LNMP网站服务器架构搭建-程序员宅基地

LNMP架构搭建1.安装Nginxyum -y install make zlib zlib-devel gcc-c++ libtool openssl openssl-devel #安装编译工具及库文件#安装PCRE让 Nginx支持Rewrite功能wget https://downloads.sourceforge.net/project/pcre/pcre/8.35/pcre-8...

杭电ACM2010:水仙花数_ljjdada的博客-程序员宅基地

#include using namespace std;int main(){ int m, n; int a, b, c,flag=0; int x[1000]; while (cin >> m >> n) { if (m > n){ n = m + n; m = n

初探LinkedList线程安全问题(一)_Scholfield的博客-程序员宅基地

Java中LinkedList是线程不安全的,那么如果在多线程程序中有多个线程访问LinkedList的话会出现什么问题呢? 抛出ConcurrentModificationException JDK代码里,ListItr的add(), next(), previous(), remove(), set()方法都会跑出ConcurrentModificationException。 fina

windows下使用Gitblit搭建局域网内的GIT私有仓库_gitblit创建私有库-程序员宅基地

一、目的使用Gitblit搭建局域网内的GIT私有仓库二、准备工具JDK 1.8.0_65 gitbilt-1.9.1三、搭建步骤1.安装java环境下栽的JDK 1.8.0_65为压缩文件,直接将其解压后,配置相关环境变量即可。需要设置JAVA_HOME、PATH、CLASSPATH三个环境变量。JAVA_HOME指明JDK安装路径,J:\APP-test\jdk1.8.0_65 ..._gitblit创建私有库

随便推点

magento -- paypal 支付方式不能显示在下单页面_xinhaozheng的博客-程序员宅基地

好久没有写关于magento的文章了,不过近期会抽点时间来更新一下。magento新版1.4的速度有了很大提高。尽管还是存在很多的BUG。 今天一个碰到了一个很奇怪的问题,明明在后台中打开了 paypal支付方式,可是在前台下单时就是不出于 paypal支付方式,查了半天也没查出个问题来。跟踪代码,原先怀疑是模块被关闭,事实是没有关闭,再怀疑模板关闭了相应的BLOCK。模板的layo

【荐】Angular官方代码风格指南_野草_前端的博客-程序员宅基地

本文为笔者对Angular官网风格指南的整理版本,删除/增加了部分内容。另外,原文对每个规范都作出了原因的解释,个别还有示例,需要的请点击查看原文。原链接:英文文档 / 中文文档单一职责单一文件一个文件定义一样东西,比如一个组件、一个服务、一个管道、一个指令每个文件最多不要超过400行单一函数定义功能单一的函数一个函数最多不要超过75行命名规范文件名采用feature.type.*

微信小程序 新闻列表及详情页_微信小程序 新闻详情页_JJJenny0607的博客-程序员宅基地

微信小程序 新闻列表及详情页页面效果新闻列表<view class="conatiner"> <view class="news-item" wx:for="{{newsList}}" wx:for-index="index"> <view class="title" bindtap="todetail" data-options="{{item.id}}" >{{item.title}}</view> <vi_微信小程序 新闻详情页

su - oracle 报错 syntax error near unexpected token `then'_Ziv.Ding的博客-程序员宅基地

[root@oel grid]# su - oracle-bash: /etc/profile: line 60: syntax error near unexpected token `then'-bash: /etc/profile: line 60: `if[$USER="oracle"];then'vi /etc/profile去到第 60行,看看语法错了哪里,改过profi...

编译原理——理解LL/LR/SLR/LALR_ll lr 实现区别_邵政道的博客-程序员宅基地

LL(1)文法属于自上而下的分析方法。也就是说,同一个非终结符的多种递推方式中,首字母一定不同。这样就可以只用根据一个首字母就可以判断出是哪一个递推式子。文法名字由来第一个L代表从左边开始扫描;第二个L表示产生最左推导数字1表示每一步推导式只需要向后看一个符号就可以LL(1)文法的明显性质没有公共左因子(如果有,那么无法只读一个字符就判断如何递归)不是二义的(每个读入的字符都..._ll lr 实现区别

在互联网大厂,我月入过万,合租却让我落泪_互联网大厂能月入过万吗?_软件测试君的博客-程序员宅基地

没搬过家的人不叫“北漂”,合租的房子不叫“我家”,对于刚刚毕业或是刚参加工作没多久的人而言,省钱,是合租的唯一理由。每当窝在家里,屋外门声响起,就会不自觉地调低手机的声音;屋外有人,膀胱的容量就临时性扩大,能不见面就不见面;见面不知道该不该打招呼,只能沉默以对。没有人愿意和别人分享自己的私密空间,除非缺钱。但对于高收入的互联网人而言,似乎在繁华的北京城租住一套独居房并不是什么苦难的事情,无非是有些“奢侈”。但实际是大多数互联网人十分清楚自己的使命:在黄金年龄拿到足够多的钱,要么落户,要么回家。.._互联网大厂能月入过万吗?

推荐文章

热门文章

相关标签