机器学习实战之预测数值型数据:回归_常用的回归预测-程序员宅基地

技术标签: 因吉  机器学习  Smale  # 机器学习  Python  

引入

  分类的目标变量是标称型数据,而回归则是对连续型的数据做出预测。回归能做什么呢?Peter Harrington的观点是可以做任何事情,包括他本人提到的一个比较有新意的应用:预测名人的离婚率。

1 线性回归

线性回归

  优点:结果易于理解,计算并不复杂;
  缺点:对非线性数据拟合效果不好;
  使用数据类型:数值型和标称型数据。

  回归的目的是预测数值型的目标值。最接近的的办法是依据输入写一个目标值的计算公式。假设你想对评估一个自己今天的颓废程度,可能会这样计算:

颓废程度 = 起床时间 × 1.2 + 发呆时间 × 1.1 - 学习时间 × 1.3

  以上是个示例,便是所谓的回归方程(regression equation),其中的数字称作回归系数(regression weights),求回归系数的过程就是回归。一旦求得回归系数,再给定输入,便能轻松地获得预测值。
  说到回归,一般指线性回归(linear regression),之后所述的回归都是这个意思。需要说明的是,存在另一种称为非线性回归的回归模型,该模型不认同上面的做法,比如认为输出可能是输入的乘积。例如:

颓废程度 =1.2 × 起床时间 × 发呆时间 × 1.1 / 学习时间

1.1 基本概念

“回归”一词的来历

  今天所知道的回归是由达尔文(Charles Darwin)的表兄弟Francis Galton发明的。Galton于1877年完成了第一次回归预测,目的是根据上一代豌豆种子(双亲)的尺寸来预测下一代豌豆种子(孩子)的尺寸。Galton在大量对象上应用了回归分析,甚至包括人的身高。他注意到,如果双亲的高度比平均高度高,他们的子女也倾向于比平均高度高,但尚不及双亲。孩子的高度向着平均高度回退(回归)。Galton在多项研究上都注意到这个现象,所有尽管这个英文单词根数值预测没有任何关系,但这种研究方法仍被称为回归。

  如何求回归系数呢?假设输入数据都存放在矩阵 X X X中,而回归系数存放在向量 W W W中。那么对于给定的输入 X i X_i Xi,预测结果将会通过 Y i = X i t W Y_i=X^t_iW Yi=XitW给出。现在的问题是,已有一部分 X X X和对应的 Y Y Y,如何找到 W W W呢?一个常用的方法就是找出使误差最小的 W W W。这里的误差是指预测值 Y Y Y和真实值 Y Y Y之间的差值,使用该误差的简单累加将使得正差值和负差值相互抵消,,故可采用平方误差。平方误差可以写做:
∑ i = 1 m ( y i − x i T w ) 2 (1-1) \sum^m_{i=1}(y_i-x^T_iw)^2\tag{1-1} i=1m(yixiTw)2(1-1)  用矩阵表示还可以写做 ( Y − X W ) T ( Y − X W ) (Y-XW)^T(Y-XW) (YXW)T(YXW)。如果对 W W W求导,得到 X T ( Y − X W ) X^T(Y-XW) XT(YXW),令其等于0,解出 W W W如下:
W ^ = ( X T X ) − 1 X T Y \hat W=(X^TX)^{-1}X^TY W^=(XTX)1XTY   W W W上的帽子表示:这是当前可以估计出的最优解。值得注意的是,上述公式中包含 ( X T X ) − 1 (X^TX)^{-1} (XTX)1,也就是需要对矩阵求逆,因此这个方程只在矩阵存在的时候使用。然而,矩阵的逆不一定存在,故需在代码中进行判断。

1.2 标准线性回归

  上述通过 W ^ = ( X T X ) − 1 X T Y \hat W=(X^TX)^{-1}X^TY W^=(XTX)1XTY的方式求解最佳回归系数,该方法被称为OLS,即普通最小二乘法(ordinary least squares)。示例数据集如下,存于文件ex0.txt中:

1.000000	0.067732	3.176513
1.000000	0.427810	3.816464
1.000000	0.995731	4.550095
1.000000	0.738336	4.256571
1.000000	0.981083	4.560815
1.000000	0.526171	3.929515
1.000000	0.378887	3.526170
1.000000	0.033859	3.156393
1.000000	0.132791	3.110301
1.000000	0.138306	3.149813
1.000000	0.247809	3.476346
1.000000	0.648270	4.119688
1.000000	0.731209	4.282233
1.000000	0.236833	3.486582
1.000000	0.969788	4.655492
1.000000	0.607492	3.965162
1.000000	0.358622	3.514900
1.000000	0.147846	3.125947
1.000000	0.637820	4.094115
1.000000	0.230372	3.476039
1.000000	0.070237	3.210610
1.000000	0.067154	3.190612
1.000000	0.925577	4.631504
1.000000	0.717733	4.295890
1.000000	0.015371	3.085028
1.000000	0.335070	3.448080
1.000000	0.040486	3.167440
1.000000	0.212575	3.364266
1.000000	0.617218	3.993482
1.000000	0.541196	3.891471
1.000000	0.045353	3.143259
1.000000	0.126762	3.114204
1.000000	0.556486	3.851484
1.000000	0.901144	4.621899
1.000000	0.958476	4.580768
1.000000	0.274561	3.620992
1.000000	0.394396	3.580501
1.000000	0.872480	4.618706
1.000000	0.409932	3.676867
1.000000	0.908969	4.641845
1.000000	0.166819	3.175939
1.000000	0.665016	4.264980
1.000000	0.263727	3.558448
1.000000	0.231214	3.436632
1.000000	0.552928	3.831052
1.000000	0.047744	3.182853
1.000000	0.365746	3.498906
1.000000	0.495002	3.946833
1.000000	0.493466	3.900583
1.000000	0.792101	4.238522
1.000000	0.769660	4.233080
1.000000	0.251821	3.521557
1.000000	0.181951	3.203344
1.000000	0.808177	4.278105
1.000000	0.334116	3.555705
1.000000	0.338630	3.502661
1.000000	0.452584	3.859776
1.000000	0.694770	4.275956
1.000000	0.590902	3.916191
1.000000	0.307928	3.587961
1.000000	0.148364	3.183004
1.000000	0.702180	4.225236
1.000000	0.721544	4.231083
1.000000	0.666886	4.240544
1.000000	0.124931	3.222372
1.000000	0.618286	4.021445
1.000000	0.381086	3.567479
1.000000	0.385643	3.562580
1.000000	0.777175	4.262059
1.000000	0.116089	3.208813
1.000000	0.115487	3.169825
1.000000	0.663510	4.193949
1.000000	0.254884	3.491678
1.000000	0.993888	4.533306
1.000000	0.295434	3.550108
1.000000	0.952523	4.636427
1.000000	0.307047	3.557078
1.000000	0.277261	3.552874
1.000000	0.279101	3.494159
1.000000	0.175724	3.206828
1.000000	0.156383	3.195266
1.000000	0.733165	4.221292
1.000000	0.848142	4.413372
1.000000	0.771184	4.184347
1.000000	0.429492	3.742878
1.000000	0.162176	3.201878
1.000000	0.917064	4.648964
1.000000	0.315044	3.510117
1.000000	0.201473	3.274434
1.000000	0.297038	3.579622
1.000000	0.336647	3.489244
1.000000	0.666109	4.237386
1.000000	0.583888	3.913749
1.000000	0.085031	3.228990
1.000000	0.687006	4.286286
1.000000	0.949655	4.628614
1.000000	0.189912	3.239536
1.000000	0.844027	4.457997
1.000000	0.333288	3.513384
1.000000	0.427035	3.729674
1.000000	0.466369	3.834274
1.000000	0.550659	3.811155
1.000000	0.278213	3.598316
1.000000	0.918769	4.692514
1.000000	0.886555	4.604859
1.000000	0.569488	3.864912
1.000000	0.066379	3.184236
1.000000	0.335751	3.500796
1.000000	0.426863	3.743365
1.000000	0.395746	3.622905
1.000000	0.694221	4.310796
1.000000	0.272760	3.583357
1.000000	0.503495	3.901852
1.000000	0.067119	3.233521
1.000000	0.038326	3.105266
1.000000	0.599122	3.865544
1.000000	0.947054	4.628625
1.000000	0.671279	4.231213
1.000000	0.434811	3.791149
1.000000	0.509381	3.968271
1.000000	0.749442	4.253910
1.000000	0.058014	3.194710
1.000000	0.482978	3.996503
1.000000	0.466776	3.904358
1.000000	0.357767	3.503976
1.000000	0.949123	4.557545
1.000000	0.417320	3.699876
1.000000	0.920461	4.613614
1.000000	0.156433	3.140401
1.000000	0.656662	4.206717
1.000000	0.616418	3.969524
1.000000	0.853428	4.476096
1.000000	0.133295	3.136528
1.000000	0.693007	4.279071
1.000000	0.178449	3.200603
1.000000	0.199526	3.299012
1.000000	0.073224	3.209873
1.000000	0.286515	3.632942
1.000000	0.182026	3.248361
1.000000	0.621523	3.995783
1.000000	0.344584	3.563262
1.000000	0.398556	3.649712
1.000000	0.480369	3.951845
1.000000	0.153350	3.145031
1.000000	0.171846	3.181577
1.000000	0.867082	4.637087
1.000000	0.223855	3.404964
1.000000	0.528301	3.873188
1.000000	0.890192	4.633648
1.000000	0.106352	3.154768
1.000000	0.917886	4.623637
1.000000	0.014855	3.078132
1.000000	0.567682	3.913596
1.000000	0.068854	3.221817
1.000000	0.603535	3.938071
1.000000	0.532050	3.880822
1.000000	0.651362	4.176436
1.000000	0.901225	4.648161
1.000000	0.204337	3.332312
1.000000	0.696081	4.240614
1.000000	0.963924	4.532224
1.000000	0.981390	4.557105
1.000000	0.987911	4.610072
1.000000	0.990947	4.636569
1.000000	0.736021	4.229813
1.000000	0.253574	3.500860
1.000000	0.674722	4.245514
1.000000	0.939368	4.605182
1.000000	0.235419	3.454340
1.000000	0.110521	3.180775
1.000000	0.218023	3.380820
1.000000	0.869778	4.565020
1.000000	0.196830	3.279973
1.000000	0.958178	4.554241
1.000000	0.972673	4.633520
1.000000	0.745797	4.281037
1.000000	0.445674	3.844426
1.000000	0.470557	3.891601
1.000000	0.549236	3.849728
1.000000	0.335691	3.492215
1.000000	0.884739	4.592374
1.000000	0.918916	4.632025
1.000000	0.441815	3.756750
1.000000	0.116598	3.133555
1.000000	0.359274	3.567919
1.000000	0.814811	4.363382
1.000000	0.387125	3.560165
1.000000	0.982243	4.564305
1.000000	0.780880	4.215055
1.000000	0.652565	4.174999
1.000000	0.870030	4.586640
1.000000	0.604755	3.960008
1.000000	0.255212	3.529963
1.000000	0.730546	4.213412
1.000000	0.493829	3.908685
1.000000	0.257017	3.585821
1.000000	0.833735	4.374394
1.000000	0.070095	3.213817
1.000000	0.527070	3.952681
1.000000	0.116163	3.129283

  选取第二、三列绘制如下:
在这里插入图片描述

图1-1 训练数据集


  创建regression.py文件并添加以下代码:

程序清单1-1: 标准回归函数、数据导入函数及测试函数

import matplotlib.pyplot as plt
from numpy import *

def load_data_set(file_name):
    with open(file_name) as fd:
        fd_data = fd.readlines()
    x_set = []; y_set = []
    for data in fd_data:
        data = data.strip().split('\t')
        data = [float(value) for value in data]
        x_set.append(data[:-1])
        y_set.append(data[-1])
        plt.scatter(data[1], data[2], c='red')
    return x_set, y_set

def stand_regres(x_set, y_set):
    x_mat = mat(x_set); y_mat = mat(y_set).T
    x_tx = x_mat.T * x_mat
    if linalg.det(x_tx) == 0.0:    #linalg.det()用于计算行列式,若行列式为0,则矩阵不可进行求逆运算
        print("This matrix is singular,cannot do inverse")
        return
    w = x_tx.I * (x_mat.T * y_mat)    #对应普通二乘法公式
    return w

def test1():
    x_set, y_set = load_data_set('E:/Machine Learing/myMachineLearning/data/ex0.txt')
    w = stand_regres(x_set, y_set)
    print("W:", w.T)
    x_copy = mat(x_set).copy()    #拷贝
    x_copy .sort(0)    #排序点
    y_hat = x_copy * w    #获得预测值
    plt.plot(x_copy [:,1], y_hat)
    plt.show()

if __name__ == '__main__':
    test1()

  运行结果:

W: [[3.00774324 1.69532264]]

在这里插入图片描述

图1-2 训练数据集和它的最佳拟合直线


  需要注意的是给定数据集中的第一列总是等于1.0,即X0。这是因为我们假定偏移量是一个常数,第二、三列才是数据真实的属性。在绘图时,预测结果保存于y_hat中,但是原本数据集中实例是杂乱的,故在绘制前进行排序。
  
  这一模型简单,几乎所有数据集都可以用上述方式建立模型,那么,如何判断这些模型的优劣呢?有种方式是计算预测值y_hat与真实值y的匹配程度,即相关系数。Python中计算相关系数的命令是corrcoef(y_hat.T,y_mat)。最终结果如下:

Correlation coefficient:
 [[1.         0.98647356]
 [0.98647356 1.        ]]

  对角线上的数据是1.0,这是因为自己与自己匹配的结果;而预测结果与真实结果的相关性达到了0.98。

2 局部加权线性回归

  线性回归很可能出现欠拟合现象,因为它求的是具有最小均方误差的无偏估计。显然,如果模型欠拟合则不能取得最好的预测效果。所以有的方法允许在估计中引入一些偏差,从而降低预测的均方误差。

2.1 基本概念

此处介绍一个方法:局部加权线性回归(Locally Weighted Linear Regression, LWLR)。在该方法中,给待预测点附近的每个点赋予一定的权重,其他则与标准回归一致。
  与kNN一样,该算法每次预测均需事先选取出对应的数据子集,其所对应的 W W W如下:
W ^ = ( X T W ′ X ) − 1 X T W ′ Y (2-1) \hat W=(X^TW'X)^{-1}X^TW'Y\tag{2-1} W^=(XTWX)1XTWY(2-1)其中 W ′ W' W特指权重。
  LWLR使用“核”来对附近的点赋予更高的权重。核的类型可以自由选择,最常用的则是高斯核,高斯核对应的权重如下:
W ′ ( i , i ) = e x p ( ∣ x ( i ) − x ∣ − 2 k 2 ) (2-2) W'(i,i)=exp(\frac{|x^{(i)}-x|}{-2k^2})\tag{2-2} W(i,i)=exp(2k2x(i)x)(2-2)  由此就构建了一个只包含对角元素的权重矩阵 W ′ W' W,并且点 x x x x ( i ) x(i) x(i)越近, W ′ ( i , i ) W'(i,i) W(i,i)便越大。式 2 − 2 2-2 22中包含了一个用户指定的参数 k k k,它决定了对附近的点赋予多大的权重。当然,上述公式还可以写作以下形式:
W ′ ( i , i ) = e x p ( − ( x − x ( i ) ) 2 2 k 2 ) (2-3) W'(i,i)=exp(-\frac{(x-x(i))^2}{2k^2})\tag{2-3} W(i,i)=exp(2k2(xx(i))2)(2-3)  取k分布等于0.5、0.1、0.01且 x ( i ) = 0.5 x(i)=0.5 x(i)=0.5作为示例绘制权重变化图如下:
在这里插入图片描述

图2-1 k相关权重图


  以下为具体实现过程。于regression.py文件并添加以下代码:

程序清单2-1: 局部加权线性回归函数

def lwlr(test_point, x_set, y_set, k=1.0):    #输入参数:单个实例、x、y、k
    x_mat = mat(x_set); y_mat = mat(y_set).T
    m = shape(x_mat)[0]
    weights = mat(eye((m)))    #创建对角矩阵
    for j in range(m):    #计算权重
        diff_mat = test_point - x_mat[j,:]
        weights[j, j] = exp(diff_mat * diff_mat.T / (-2.0 * k**2))
    x_tx = x_mat.T * (weights * x_mat)
    if linalg.det(x_tx) == 0.0:    #判断行列式是否为零
        print("This matrix is singular, cannot do inverse")
        return
    w = x_tx.I * (x_mat.T * (weights * y_mat))
    return test_point * w

def lwlr_test(test_set, x_set, y_set, k=1.0):
    test_mat = mat(test_set)
    m = shape(test_mat)[0]
    y_hat = zeros(m)
    for i in range(m):    #对每一个实例进行预测
        y_hat[i] = lwlr(test_mat[i], x_set, y_set, k)
    return y_hat

def test2():
    x_set, y_set = load_data_set('E:\Machine Learing\myMachineLearning\data\ex0.txt')
    y_hat = lwlr_test(x_set, x_set, y_set)
    print("The predicted is:\n", y_hat[:6])
    print("The real is:\n", y_set[:6])
    x_mat = mat(x_set); y_mat = mat(y_set)
    sorted_index = x_mat[:,1].argsort(0)    #获取矩阵排序后的索引,但不改变矩阵
    x_sorted = x_mat[sorted_index][:,0,:]    #排序后的sorted_index
    plt.axis([0, 1, 3, 5])
    plt.subplot(311)
    plt.plot(x_sorted[:,1], y_hat[sorted_index])
    plt.scatter(x_mat[:,1].flatten().A[0], y_mat.T.flatten().A[0], s=2, c='red')
    plt.subplot(312)
    y_hat = lwlr_test(x_set, x_set, y_set, 0.1)
    plt.plot(x_sorted[:, 1], y_hat[sorted_index])
    plt.scatter(x_mat[:, 1].flatten().A[0], y_mat.T.flatten().A[0], s=2, c='red')
    plt.subplot(313)
    y_hat = lwlr_test(x_set, x_set, y_set, 0.003)
    plt.plot(x_sorted[:, 1], y_hat[sorted_index])
    plt.scatter(x_mat[:, 1].flatten().A[0], y_mat.T.flatten().A[0], s=2, c='red')
    plt.show()

if __name__ == '__main__':
    test2()
    # test1()

  运行结果:

The predicted is:
 [3.12204471 3.73284336 4.69692033 4.25997574 4.67205815 3.89979584]
The real is:
 [3.176513, 3.816464, 4.550095, 4.256571, 4.560815, 3.929515]

在这里插入图片描述

图2-2 不同k值下局部加权线性回归的结果


  如图2-2:
  1)k=1.0时,权重很大,如同将所有数据视为等权重,得出的最佳拟合直线与标准回归一致,出现欠拟合;
  2)k=0.01时,权重适中,抓住了数据的潜在模式;
  3)k=0.003时,权重较小,纳入了太多噪声点,拟合的直线与数据点过于贴切,出现过拟合。

2.2 示例:预测鲍鱼的年龄

  鲍鱼数据集‘abalone.txt’来源于UCI数据集合,记录了鲍鱼的年龄。鲍鱼的年龄可以从鲍鱼壳的层数推算得到,具体内容如下,其中每个实例的最后一个元素代表鲍鱼真实年龄(由于数据集较大,只列出前一百行):

1	0.455	0.365	0.095	0.514	0.2245	0.101	0.15	15
1	0.35	0.265	0.09	0.2255	0.0995	0.0485	0.07	7
-1	0.53	0.42	0.135	0.677	0.2565	0.1415	0.21	9
1	0.44	0.365	0.125	0.516	0.2155	0.114	0.155	10
0	0.33	0.255	0.08	0.205	0.0895	0.0395	0.055	7
0	0.425	0.3	0.095	0.3515	0.141	0.0775	0.12	8
-1	0.53	0.415	0.15	0.7775	0.237	0.1415	0.33	20
-1	0.545	0.425	0.125	0.768	0.294	0.1495	0.26	16
1	0.475	0.37	0.125	0.5095	0.2165	0.1125	0.165	9
-1	0.55	0.44	0.15	0.8945	0.3145	0.151	0.32	19
-1	0.525	0.38	0.14	0.6065	0.194	0.1475	0.21	14
1	0.43	0.35	0.11	0.406	0.1675	0.081	0.135	10
1	0.49	0.38	0.135	0.5415	0.2175	0.095	0.19	11
-1	0.535	0.405	0.145	0.6845	0.2725	0.171	0.205	10
-1	0.47	0.355	0.1	0.4755	0.1675	0.0805	0.185	10
1	0.5	0.4	0.13	0.6645	0.258	0.133	0.24	12
0	0.355	0.28	0.085	0.2905	0.095	0.0395	0.115	7
-1	0.44	0.34	0.1	0.451	0.188	0.087	0.13	10
1	0.365	0.295	0.08	0.2555	0.097	0.043	0.1	7
1	0.45	0.32	0.1	0.381	0.1705	0.075	0.115	9
1	0.355	0.28	0.095	0.2455	0.0955	0.062	0.075	11
0	0.38	0.275	0.1	0.2255	0.08	0.049	0.085	10
-1	0.565	0.44	0.155	0.9395	0.4275	0.214	0.27	12
-1	0.55	0.415	0.135	0.7635	0.318	0.21	0.2	9
-1	0.615	0.48	0.165	1.1615	0.513	0.301	0.305	10
-1	0.56	0.44	0.14	0.9285	0.3825	0.188	0.3	11
-1	0.58	0.45	0.185	0.9955	0.3945	0.272	0.285	11
1	0.59	0.445	0.14	0.931	0.356	0.234	0.28	12
1	0.605	0.475	0.18	0.9365	0.394	0.219	0.295	15
1	0.575	0.425	0.14	0.8635	0.393	0.227	0.2	11
1	0.58	0.47	0.165	0.9975	0.3935	0.242	0.33	10
-1	0.68	0.56	0.165	1.639	0.6055	0.2805	0.46	15
1	0.665	0.525	0.165	1.338	0.5515	0.3575	0.35	18
-1	0.68	0.55	0.175	1.798	0.815	0.3925	0.455	19
-1	0.705	0.55	0.2	1.7095	0.633	0.4115	0.49	13
1	0.465	0.355	0.105	0.4795	0.227	0.124	0.125	8
-1	0.54	0.475	0.155	1.217	0.5305	0.3075	0.34	16
-1	0.45	0.355	0.105	0.5225	0.237	0.1165	0.145	8
-1	0.575	0.445	0.135	0.883	0.381	0.2035	0.26	11
1	0.355	0.29	0.09	0.3275	0.134	0.086	0.09	9
-1	0.45	0.335	0.105	0.425	0.1865	0.091	0.115	9
-1	0.55	0.425	0.135	0.8515	0.362	0.196	0.27	14
0	0.24	0.175	0.045	0.07	0.0315	0.0235	0.02	5
0	0.205	0.15	0.055	0.042	0.0255	0.015	0.012	5
0	0.21	0.15	0.05	0.042	0.0175	0.0125	0.015	4
0	0.39	0.295	0.095	0.203	0.0875	0.045	0.075	7
1	0.47	0.37	0.12	0.5795	0.293	0.227	0.14	9
-1	0.46	0.375	0.12	0.4605	0.1775	0.11	0.15	7
0	0.325	0.245	0.07	0.161	0.0755	0.0255	0.045	6
-1	0.525	0.425	0.16	0.8355	0.3545	0.2135	0.245	9
0	0.52	0.41	0.12	0.595	0.2385	0.111	0.19	8
1	0.4	0.32	0.095	0.303	0.1335	0.06	0.1	7
1	0.485	0.36	0.13	0.5415	0.2595	0.096	0.16	10
-1	0.47	0.36	0.12	0.4775	0.2105	0.1055	0.15	10
1	0.405	0.31	0.1	0.385	0.173	0.0915	0.11	7
-1	0.5	0.4	0.14	0.6615	0.2565	0.1755	0.22	8
1	0.445	0.35	0.12	0.4425	0.192	0.0955	0.135	8
1	0.47	0.385	0.135	0.5895	0.2765	0.12	0.17	8
0	0.245	0.19	0.06	0.086	0.042	0.014	0.025	4
-1	0.505	0.4	0.125	0.583	0.246	0.13	0.175	7
1	0.45	0.345	0.105	0.4115	0.18	0.1125	0.135	7
1	0.505	0.405	0.11	0.625	0.305	0.16	0.175	9
-1	0.53	0.41	0.13	0.6965	0.302	0.1935	0.2	10
1	0.425	0.325	0.095	0.3785	0.1705	0.08	0.1	7
1	0.52	0.4	0.12	0.58	0.234	0.1315	0.185	8
1	0.475	0.355	0.12	0.48	0.234	0.1015	0.135	8
-1	0.565	0.44	0.16	0.915	0.354	0.1935	0.32	12
-1	0.595	0.495	0.185	1.285	0.416	0.224	0.485	13
-1	0.475	0.39	0.12	0.5305	0.2135	0.1155	0.17	10
0	0.31	0.235	0.07	0.151	0.063	0.0405	0.045	6
1	0.555	0.425	0.13	0.7665	0.264	0.168	0.275	13
-1	0.4	0.32	0.11	0.353	0.1405	0.0985	0.1	8
-1	0.595	0.475	0.17	1.247	0.48	0.225	0.425	20
1	0.57	0.48	0.175	1.185	0.474	0.261	0.38	11
-1	0.605	0.45	0.195	1.098	0.481	0.2895	0.315	13
-1	0.6	0.475	0.15	1.0075	0.4425	0.221	0.28	15
1	0.595	0.475	0.14	0.944	0.3625	0.189	0.315	9
-1	0.6	0.47	0.15	0.922	0.363	0.194	0.305	10
-1	0.555	0.425	0.14	0.788	0.282	0.1595	0.285	11
-1	0.615	0.475	0.17	1.1025	0.4695	0.2355	0.345	14
-1	0.575	0.445	0.14	0.941	0.3845	0.252	0.285	9
1	0.62	0.51	0.175	1.615	0.5105	0.192	0.675	12
-1	0.52	0.425	0.165	0.9885	0.396	0.225	0.32	16
1	0.595	0.475	0.16	1.3175	0.408	0.234	0.58	21
1	0.58	0.45	0.14	1.013	0.38	0.216	0.36	14
-1	0.57	0.465	0.18	1.295	0.339	0.2225	0.44	12
1	0.625	0.465	0.14	1.195	0.4825	0.205	0.4	13
1	0.56	0.44	0.16	0.8645	0.3305	0.2075	0.26	10
-1	0.46	0.355	0.13	0.517	0.2205	0.114	0.165	9
-1	0.575	0.45	0.16	0.9775	0.3135	0.231	0.33	12
1	0.565	0.425	0.135	0.8115	0.341	0.1675	0.255	15
1	0.555	0.44	0.15	0.755	0.307	0.1525	0.26	12
1	0.595	0.465	0.175	1.115	0.4015	0.254	0.39	13
-1	0.625	0.495	0.165	1.262	0.507	0.318	0.39	10
1	0.695	0.56	0.19	1.494	0.588	0.3425	0.485	15
1	0.665	0.535	0.195	1.606	0.5755	0.388	0.48	14
1	0.535	0.435	0.15	0.725	0.269	0.1385	0.25	9
1	0.47	0.375	0.13	0.523	0.214	0.132	0.145	8
1	0.47	0.37	0.13	0.5225	0.201	0.133	0.165	7
-1	0.475	0.375	0.125	0.5785	0.2775	0.085	0.155	10

于regression.py文件并添加以下代码:

程序清单2-2: 预测鲍鱼年龄

def test3():
    x_set, y_set = load_data_set('E:\Machine Learing\myMachineLearning\data/abalone.txt')
    x_label = list(arange(0.1, 10, 0.1))
    y_hat01 = lwlr_test(x_set[0:99], x_set[0:99], y_set[0:99], 0.1)
    rss_error01 = rss_error(y_set[0:99], y_hat01.T)
    y_hat1 = lwlr_test(x_set[0:99], x_set[0:99], y_set[0:99], 1)
    rss_error1 = rss_error(y_set[0:99], y_hat1.T)
    y_hat10 = lwlr_test(x_set[0:99], x_set[0:99], y_set[0:99], 10)
    rss_error10 = rss_error(y_set[0:99], y_hat10.T)
    print("The training error:\nThe prediction error when k=0.1:", rss_error01)
    print("The prediction error when k=1:", rss_error1)
    print("The prediction error when k=10:", rss_error10)
    y_hat01 = lwlr_test(x_set[100:199], x_set[0:99], y_set[0:99], 0.1)
    rss_error01 = rss_error(y_set[100:199], y_hat01.T)
    y_hat1 = lwlr_test(x_set[100:199], x_set[0:99], y_set[0:99], 1)
    rss_error1 = rss_error(y_set[100:199], y_hat1.T)
    y_hat10 = lwlr_test(x_set[100:199], x_set[0:99], y_set[0:99], 10)
    rss_error10 = rss_error(y_set[100:199], y_hat10.T)
    print("The test error:\nThe prediction error when k=0.1:", rss_error01)
    print("The prediction error when k=1:", rss_error1)
    print("The prediction error when k=10:", rss_error10)

if __name__ == '__main__':
    test3()

  运行结果:

The training error:
The prediction error when k=0.1: 56.78868743050092
The prediction error when k=1: 429.89056187038
The prediction error when k=10: 549.1181708827924
The test error:
The prediction error when k=0.1: 57913.51550155911
The prediction error when k=1: 573.5261441895982
The prediction error when k=10: 517.5711905381903

  如前所述,k过小时尽管训练预测误差较小,但测试预测误差较大,即欠拟合;反之k过大,则出现训练预测误差大于测试误差的情况。具体的k取值多少,则需要依多次实验而定。

3 缩减系数来“理解”数据

  当数据的特征比样本点多时,线性回归和局部加权线性回归便不再适用,这是因为在计算 ( X T X ) − 1 (X^TX)^{-1} (XTX)1时会出错。即此时输入数据的矩阵 X X X不是满秩矩阵。
  为了解决该问题,统计学家引入了岭回归(ridge regression)的概念。

3.1 岭回归

  简单来说,岭回归就是在矩阵 X T X X^TX XTX上加一个 λ I \lambda I λI从而使得矩阵满秩,进而能对 X T X + λ I X^TX+\lambda I XTX+λI求逆。其中矩阵 I I I是一个 n × n n×n n×n的单位矩阵,对角线的元素全为1,其他元素全为0.而 λ \lambda λ是一个用户输入的数值。由此回归系数的公式变为:
W ^ = ( X T X + λ I ) − 1 X T Y (3-1) \hat{W}=(X^TX+\lambda I)^{-1}X^TY\tag{3-1} W^=(XTX+λI)1XTY(3-1)

岭回归中的岭是什么?
  岭回归使用了单位矩阵乘以常量 λ \lambda λ,观察其中的单位矩阵 i i i,可以发现值 I I I贯穿整个对角线,其余元素全是0。形象的,在0构成的平面上有一条1组成的“岭”,这便是“岭”的由来。

  岭回归最先用来处理特征数多于样本数的情况,现在也用于在估计中加入偏差,从而得到更好的估计。这里通过引入 λ \lambda λ来限制所有 W W W的和,通过引入该惩罚项,能够减少不重要的参数,这个技术在统计学也叫做 缩减(shrinkage)。
缩减方法可以去掉不中的参数,因此能够更好地理解数据,自然能比简单线性回归取得更好的结果。与之前类似,通过预测误差最小化得到 λ \lambda λ,再求得 W W W。于regression.py文件并添加以下代码:

程序清单3-1: 岭回归

"""岭回归"""
def ridge_regres(x_mat, y_mat, lam=0.2):
    x_tx = x_mat.T * x_mat
    denom = x_tx + eye(shape(x_mat)[1]) * lam    #对应岭回归公式
    if linalg.det(denom) == 0.0:    #lam=0时依然会出现错误
        print("This matrix is singular,cannot do inverse")
        return
    w = denom.I * (x_mat.T * y_mat)    #.I为求逆
    return w

def ridge_test(x_set, y_set, N=30):
    x_mat = mat(x_set); y_mat = mat(y_set).T
    y_mean = mean(y_mat, 0)
    y_mat = y_mat - y_mean    #特征标准化处理,使得每维特征具有同等重要性,不考虑特征代表什么
    x_means = mean(x_mat, 0)
    x_var = var(x_mat, 0)    #每列方差
    x_mat = (x_mat - x_means) / x_var
    num_test = N
    w_mat = zeros((num_test, shape(x_mat)[1]))
    for i in range(num_test):
        w = ridge_regres(x_mat, y_mat, exp(i - 10))    #lam呈指数级变化,这样可以看出lam为较小值与较大值时对结果的影响
        w_mat[i,:] = w.T
    return w_mat

def test4():
    x_set, y_set = load_data_set('E:\Machine Learing\myMachineLearning\data/abalone.txt')
    w_mat = ridge_test(x_set, y_set)
    print("W matrix:", w_mat)
    plt.plot(w_mat)
    plt.xlabel('log(lambda)')
    plt.axis([-1, 30, -1, 2.5])
    plt.show()

  运行结果:

...
 [-3.13618246e-06  3.43488557e-04  4.29265642e-04  9.86279863e-04
   8.16188652e-05  1.39858822e-04  3.40121256e-04  3.34847052e-04]]

在这里插入图片描述

图3-1 岭回归回归系数变化图


   λ \lambda λ非常小,系数与普通回归系数一样; λ \lambda λ非常大时,所有回归系数缩减为0。可以在中间某处找到使得预测结果最好的 λ \lambda λ。如何找到?则可以交叉验证。
  还有其他一些缩减方法,如lasso、LAR、PCA回归以及子集选择等。与岭回归一样,这些方法不仅可以提高预测精度,也可以解释回归系数。

3.2 lasso

  在增加约束的时,普通的最小二乘法回归会得到与岭回归一样的公式:
∑ k = 1 n W k 2 ≤ λ (3-2) \sum^n_{k=1}W^2_k\le\lambda\tag{3-2} k=1nWk2λ(3-2)  上式限定了所有回归系数的平方和不能大于 λ \lambda λ,从而避免当两个或更多的特征相关时,出现一个很大正系数或者负系数的情况。与此类似,lasso也对回归系数做了限定:
∑ k = 1 n ∣ W k ∣ ≤ λ (3-3) \sum^n_{k=1}|W_k|\le\lambda\tag{3-3} k=1nWkλ(3-3)  这里的约束条件用绝对值代替平方和,使得在 λ \lambda λ足够小时,一些系数被迫缩减为0。这个特征可以更好地理解数据,但是也大大增加了计算复杂度,如果需要在此条件下解出回归系数,需要使用二次规划算法。

3.3 前向逐步回归

  前向逐步回归在更加简单的情况下,也能达到和lasso差不多的效果。它属于一种贪心算法,即每一步都尽可能减小误差。一开始,所有权重都设为1,然后每一步所做的决策是对某个权重增加或者减小一个很小的值。其伪代码如下:
  数据标准化,使其分布满足0均值和单位方差
  在每轮迭代中:
    设置当前最小误差lowest_error为正无穷
    对每个特征:
      增大或减小:
        改变一个系数得到一个新的 W W W
        计算新 W W W下的误差error
        如果当前误差error小于最小误差lowest_error:
          设置W_best等于当前 W W W
      将 W W W设置为新的W_best

于regression.py文件并添加以下代码:

程序清单3-2: 前向逐步线性回归

def stage_wise(x_set, y_set, eps=0.1, max_iter=100):    #eps为每次迭代需要调整的步长
    x_mat = mat(x_set); y_mat = mat(y_set).T
    y_mean = mean(y_mat, 0)    #平均值
    y_mat = y_mat - y_mean
    x_mat = regularize(x_mat)
    m ,n =shape(x_mat)
    return_mat = zeros((max_iter, n))
    w = zeros((n, 1))
    w_max = w.copy()
    for i in range(max_iter):
        # print(w.T)
        lowest_error = inf;
        for j in range(n):
            for sign in [-1, 1]:
                w_test = w.copy()
                w_test[j] += eps * sign
                y_test = x_mat * w_test
                rss_e = rss_error(y_mat.A, y_test.A)
                if rss_e < lowest_error:
                    lowest_error = rss_e
                    w_max = w_test
        w = w_max.copy()
        return_mat[i] = w.T
    return return_mat

def test5():
    x_set, y_set = load_data_set('E:\Machine Learing\myMachineLearning\data/abalone.txt')
    return_mat = stage_wise(x_set, y_set, 0.01, 200)
    print("Return mat:\n", return_mat)
    plt.plot(return_mat)
    plt.show()

if __name__ == '__main__':
    test5()
    # test4()
    # test3()
    # test2()
    # test1()

  运行结果:

Return mat:
 [[ 0.    0.    0.   ...  0.    0.    0.  ]
 [ 0.    0.    0.   ...  0.    0.    0.  ]
 [ 0.    0.    0.   ...  0.    0.    0.  ]
 ...
 [ 0.05  0.    0.09 ... -0.64  0.    0.36]
 [ 0.04  0.    0.09 ... -0.64  0.    0.36]
 [ 0.05  0.    0.09 ... -0.64  0.    0.36]]

在这里插入图片描述

图3-2 前向逐步回归系数变化图(eps=0.01,max_iter=200)


  上述结果中,W1和W6都是0,这表明它们不对目标值造成任何影响,很有可能这些特征是不需要的。另外在eps设置为0.01时,一段时间后系数就已经饱和并在特定值之间来回震荡,这是因为步长太长的原因。接下来尝试更小步长并绘制:
在这里插入图片描述

图3-3 前向逐步回归系数变化图(eps=0.005,max_iter=1000)


  前向逐步线性回归主要的优点在于:
  可以帮助理解现在模型并改进。当构建一个模型之后,可以运行该算法找出重要的特征,这样就有可能及时停止收集那些不重要特征的收集。

4 权衡偏差与方差

  一旦发现模型与测量值之间存在差异,就说明出现了误差。当考虑模型中“噪声”或者说误差时,必须考虑其来源:
  1)对复杂的过程进行简化时,将导致模型和测量值之间出现“噪声”或误差;
  2)若无法理解数据的真实生成过程,会导致差异的发生;
  3)测量过程本身也可能产生“噪声”或问题。
  例如之前用到的‘ex0.txt’数据集,是认为制造的,其生成公式如下:
y = 3.0 + 1.7 x + 0.1 s i n ( 30 x ) + 0.06 N ( 0 , 1 ) y=3.0+1.7x+0.1sin(30x)+0.06N(0,1) y=3.0+1.7x+0.1sin(30x)+0.06N(0,1)其中 N ( 0 , 1 ) N(0,1) N(0,1)是一个均值为0、方差为1的正太分布。若用一条进行进行拟合,那么最佳拟合应该是 3.0 + 1.7 x 3.0 + 1.7x 3.0+1.7x这部分,这样一来,误差部分则是 0.1 s i n ( 30 x ) + 0.06 N ( 0 , 1 ) 0.1sin(30x)+0.06N(0,1) 0.1sin(30x)+0.06N(0,1)
  下图为训练误差和测试误差的曲线图(来源):
在这里插入图片描述

图4-1 偏差方差折中与测试误差及训练误差的关系(eps=0.005,max_iter=1000)


  根据局部加权线性回归中的实验知道:
  1)降低核的大小,那么训练误差将变小,即图中红色线;
  2)降低核的大小,那么测试误差将有一个先变小后变大的过程,即黑色线。
  一般认为,上述两种误差由三部分组成:
  偏差、测量误差、随机噪声。
  
  如果从鲍鱼数据集中取一个随机样本集,例如取其中100个数据,并用线性模型拟合,将会得到一组回归系数。同理,再取另一组随机样本集并拟合,将会得到另一组回归系数。这些系数间的差异就是模型方差1大小的反映。


  1. 方差指模型之间的差异;偏差指模型预测值和数据之间的差异。

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_44575152/article/details/100640051

智能推荐

oracle 12c 集群安装后的检查_12c查看crs状态-程序员宅基地

文章浏览阅读1.6k次。安装配置gi、安装数据库软件、dbca建库见下:http://blog.csdn.net/kadwf123/article/details/784299611、检查集群节点及状态:[root@rac2 ~]# olsnodes -srac1 Activerac2 Activerac3 Activerac4 Active[root@rac2 ~]_12c查看crs状态

解决jupyter notebook无法找到虚拟环境的问题_jupyter没有pytorch环境-程序员宅基地

文章浏览阅读1.3w次,点赞45次,收藏99次。我个人用的是anaconda3的一个python集成环境,自带jupyter notebook,但在我打开jupyter notebook界面后,却找不到对应的虚拟环境,原来是jupyter notebook只是通用于下载anaconda时自带的环境,其他环境要想使用必须手动下载一些库:1.首先进入到自己创建的虚拟环境(pytorch是虚拟环境的名字)activate pytorch2.在该环境下下载这个库conda install ipykernelconda install nb__jupyter没有pytorch环境

国内安装scoop的保姆教程_scoop-cn-程序员宅基地

文章浏览阅读5.2k次,点赞19次,收藏28次。选择scoop纯属意外,也是无奈,因为电脑用户被锁了管理员权限,所有exe安装程序都无法安装,只可以用绿色软件,最后被我发现scoop,省去了到处下载XXX绿色版的烦恼,当然scoop里需要管理员权限的软件也跟我无缘了(譬如everything)。推荐添加dorado这个bucket镜像,里面很多中文软件,但是部分国外的软件下载地址在github,可能无法下载。以上两个是官方bucket的国内镜像,所有软件建议优先从这里下载。上面可以看到很多bucket以及软件数。如果官网登陆不了可以试一下以下方式。_scoop-cn

Element ui colorpicker在Vue中的使用_vue el-color-picker-程序员宅基地

文章浏览阅读4.5k次,点赞2次,收藏3次。首先要有一个color-picker组件 <el-color-picker v-model="headcolor"></el-color-picker>在data里面data() { return {headcolor: ’ #278add ’ //这里可以选择一个默认的颜色} }然后在你想要改变颜色的地方用v-bind绑定就好了,例如:这里的:sty..._vue el-color-picker

迅为iTOP-4412精英版之烧写内核移植后的镜像_exynos 4412 刷机-程序员宅基地

文章浏览阅读640次。基于芯片日益增长的问题,所以内核开发者们引入了新的方法,就是在内核中只保留函数,而数据则不包含,由用户(应用程序员)自己把数据按照规定的格式编写,并放在约定的地方,为了不占用过多的内存,还要求数据以根精简的方式编写。boot启动时,传参给内核,告诉内核设备树文件和kernel的位置,内核启动时根据地址去找到设备树文件,再利用专用的编译器去反编译dtb文件,将dtb还原成数据结构,以供驱动的函数去调用。firmware是三星的一个固件的设备信息,因为找不到固件,所以内核启动不成功。_exynos 4412 刷机

Linux系统配置jdk_linux配置jdk-程序员宅基地

文章浏览阅读2w次,点赞24次,收藏42次。Linux系统配置jdkLinux学习教程,Linux入门教程(超详细)_linux配置jdk

随便推点

matlab(4):特殊符号的输入_matlab微米怎么输入-程序员宅基地

文章浏览阅读3.3k次,点赞5次,收藏19次。xlabel('\delta');ylabel('AUC');具体符号的对照表参照下图:_matlab微米怎么输入

C语言程序设计-文件(打开与关闭、顺序、二进制读写)-程序员宅基地

文章浏览阅读119次。顺序读写指的是按照文件中数据的顺序进行读取或写入。对于文本文件,可以使用fgets、fputs、fscanf、fprintf等函数进行顺序读写。在C语言中,对文件的操作通常涉及文件的打开、读写以及关闭。文件的打开使用fopen函数,而关闭则使用fclose函数。在C语言中,可以使用fread和fwrite函数进行二进制读写。‍ Biaoge 于2024-03-09 23:51发布 阅读量:7 ️文章类型:【 C语言程序设计 】在C语言中,用于打开文件的函数是____,用于关闭文件的函数是____。

Touchdesigner自学笔记之三_touchdesigner怎么让一个模型跟着鼠标移动-程序员宅基地

文章浏览阅读3.4k次,点赞2次,收藏13次。跟随鼠标移动的粒子以grid(SOP)为partical(SOP)的资源模板,调整后连接【Geo组合+point spirit(MAT)】,在连接【feedback组合】适当调整。影响粒子动态的节点【metaball(SOP)+force(SOP)】添加mouse in(CHOP)鼠标位置到metaball的坐标,实现鼠标影响。..._touchdesigner怎么让一个模型跟着鼠标移动

【附源码】基于java的校园停车场管理系统的设计与实现61m0e9计算机毕设SSM_基于java技术的停车场管理系统实现与设计-程序员宅基地

文章浏览阅读178次。项目运行环境配置:Jdk1.8 + Tomcat7.0 + Mysql + HBuilderX(Webstorm也行)+ Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。项目技术:Springboot + mybatis + Maven +mysql5.7或8.0+html+css+js等等组成,B/S模式 + Maven管理等等。环境需要1.运行环境:最好是java jdk 1.8,我们在这个平台上运行的。其他版本理论上也可以。_基于java技术的停车场管理系统实现与设计

Android系统播放器MediaPlayer源码分析_android多媒体播放源码分析 时序图-程序员宅基地

文章浏览阅读3.5k次。前言对于MediaPlayer播放器的源码分析内容相对来说比较多,会从Java-&amp;amp;gt;Jni-&amp;amp;gt;C/C++慢慢分析,后面会慢慢更新。另外,博客只作为自己学习记录的一种方式,对于其他的不过多的评论。MediaPlayerDemopublic class MainActivity extends AppCompatActivity implements SurfaceHolder.Cal..._android多媒体播放源码分析 时序图

java 数据结构与算法 ——快速排序法-程序员宅基地

文章浏览阅读2.4k次,点赞41次,收藏13次。java 数据结构与算法 ——快速排序法_快速排序法