Python——EM(期望极大算法)实战(附详细代码与注解)(二)_程旭员的博客-程序员秘密

技术标签: Python实战-机器学习  EM-GMM  机器学习  GMM  EM  Python  

开始之前

各位朋友,大家好!针对上回讲的EM算法,有朋友反馈还是没弄清楚,今天,我再来详细的讲一下EM算法。请耐心食用本教程,滴滴滴~,上车!

前提准备

Jupyter notebook 或 Pycharm
火狐浏览器或谷歌浏览器
win7或win10电脑一台
网盘提取csv数据

需求分析

实现高斯混合模型的 EM 算法(GMM_EM)
高斯混合模型是多个高斯模型的线性叠加而成的,高斯混合模型的概率分布表示如下:
在这里插入图片描述
其中,k表示模型的个数, α k α_k αk 是第 k 个模型的系数,表示出现该模型的概率,ϕ(x;μk,Σk) 是第 k 个高斯模型的概率分布。

E步:样本 x i x_i xi来自于第 k 个模型的概率,我们把这个概率称为模型 k 对样本 x i x_i xi 的“责任”,也叫“响应度”,记作 γ ( i k ) γ_(ik) γ(ik),计算公式如下:
在这里插入图片描述
M步:根据样本和当前 γ 矩阵重新估计参数,注意这里 x 为列向量,计算公式如下:
在这里插入图片描述

【目标】给定一堆没有标签的样本和模型个数 K,以此求得混合模型的参数,然后就可以用这个模型来对样本进行聚类。

python代码如下:

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal   #本问题考虑的是高斯混合模型,所以导入多元高斯分布multivariate_normal

def prob_Y_k(Y,mu_k,cov_k):                 #Y为样本矩阵
    norm = multivariate_normal(mean = mu_k , cov = cov_k)    #生成多元正太分布,mu为第k个模型的均值,cov为第k个模型的协方差矩阵(协方差矩阵必须是实对称矩阵)
    return norm.pdf(Y)        #返回样本Y来自于第k个模型的概率

def Estep(Y,mu,cov,alpha):       #Y为样本矩阵,alpha为权重
    
    N = Y.shape[0]         #样本数
    K = alpha.shape[0]      #模型数
    
    assert N>1 , "There must be more than one sample!"     #为避免单个样本导致返回的结果的类型不一致,因此要求样本数必须大于一
    assert K>1 , "There must be more than one gaussian model!"    #为避免单个模型结果的类型不一致,因此要求模型须大于一
    
    gamma = np.mat(np.zeros((N,K)))    #初始化响应度矩阵,行对应样本数,列对应模型数
    prob = np.zeros((N,K))            #初始化所有样本出现的概率矩阵,行对应样本数,列对应响应度
    for k in range(K):
        prob[:,k] = prob_Y_k(Y,mu[k],cov[k])         #第k个模型的概率prob_Y_k
    prob = np.mat(prob)                   #K个prob放入数组中
    
    for k in range(K):
        gamma[:,k] = alpha[k] * prob[:,k]   #计算模型k对样本i的响应度
    for i in range(N):
        gamma[i,:] /= np.sum(gamma[i,:])   #第i个样本的占总样本的响应程度
    return gamma        #gamma为响应度矩阵
   
 def Mstep(Y,gamma):        #传入样本矩阵Y和Estep得到的gamma响应度矩阵
    N, D = Y.shape         #N为样本数,D为特征数
    K = gamma.shape[1]     #模型数
    mu = np.zeros((K,D))    #初始化参数均值mu,每个模型的D维各有均值故mu的矩阵为K行D列
    cov = []               #初始化参数协方差矩阵
    alpha = np.zeros(K)     # 初始化权重数组,每个模型都有权值
    
    #接下来是更新每个模型的参数
    for k in range(K):
        Nk = np.sum(gamma[:,k])       #第k个模型所有样本的响应度之和
        mu[k,:] = np.sum(np.multiply(Y, gamma[:,k]),axis=0)/Nk    #更新参数均值mu,对每个特征求均值
        cov_k = (Y - mu[k]).T * np.multiply((Y - mu[k]), gamma[:,k]) / Nk   #更新cov
        cov = np.append(cov_k)
        alpha[k] = Nk / N
    cov = np.array(cov)
    return mu, cov, alpha

def normalize_data(Y):              #将所有数据进行归一化处理,
    for i in range(Y.shape[1]):
        max_data = Y[:,i].max()
        min_data = Y[:,i].min()
        Y[:,i] = (Y[:,i] - min_data)/(max_data - min_data)      #此处用到min-max归一化
        debug("Data Normalized")
    return Y

def init_params(shape,K):     #在执行该算法之前,需要先给出一个初始化的模型参数。我们让每个模型的μ为随机值,Σ 为单位矩阵,α 为 1/K,即每个模型初始时都是等概率出现的。
    N, D = shape
    mu = np.random.rand(K, D)         #生成一个K行D列的[0,1)之间的数组
    cov = np.array([np.eye(D)] * K)    #生成K个D维的对角矩阵
    alpha = np.array([1.0 / K] * K)    #生成K个权重
    debug("Parameters initialized.")     
    debug("mu:",mu, "cov:",cov ,"alpha:",alpha,sep = "\n" )
    return mu, cov, alpha

def GMM_EM(Y, K, times):       #高斯混合EM算法,Y为给定样本矩阵,K为模型个数,times为迭代次数,目的是求该模型的参数
    Y = normalize_data(Y)      #调用前面定义的normalize_data函数,归一化样本矩阵Y
    mu, cov, alpha = init_params(Y.shape, K)      #调用init_params函数得到初始化的参数mu,cov,alpha
    for i in range(times):
        gamma = Estep(Y, mu, cov, alpha)         #调用Estep得到响应度矩阵
        mu, cov, alpha = Mstep(Y, gamma)         #调用Mstep得到更新后的参数mu,cov,alpha
    debug("{sep} Result {sep}".format(sep="-"*20))
    debug("mu:", mu , "cov:",cov , "alpha:",alpha , sep="\n")
    return mu,cov,alpha

import matplotlib.pyplot as plt
from gmm import *

DEBUG = True 
Y = np.loadtxt("gmm.data")        #载入数据
matY = np.matrix(Y ,copy = True)

K = 2        #模型个数(相当于聚类的类别个数)

mu, cov, alpha = GMM_EM(matY , K , 100)    #调用GMM_EM函数,计算GMM模型参数

N = Y.shape[0]
gamma = Estep(matY, mu, cov, alpha)      #求当前模型参数下,各模型对样本的响应矩阵

category = gamma.argmax(axis = 1).flatten().tolist()[0]      #对每个样本,求响应度最大的模型下标,作为其类别标识

class1 = np.array([Y[i] for i in range(N) if category[i] == 0])   #将每个样本放入对应样本的列表中
class2 = np.array([Y[i] for i in range(N) if category[i] == 1])   

plt.plot(class1[:,0],class1[:,1], 'rs' ,label = "class1")
plt.plot(class2[:,0],class2[:,1], 'bo' ,label = "class2")
plt.legend(loc = "best")
plt.title("GMM Clustering By EM Algorithm")
plt.show()

import numpy as np
import matplotlib.pyplot as plt

cov1 = np.mat("0.3 0 ; 0 0.1")          #2维协方差矩阵(必须是对角矩阵)  
cov2 = np.mat("0.2 0 ; 0 0.3")
mu1 = np.array([0,1])
mu2 = np.array([2,1])

sample = np.zeros((100,2))        #初始化100个样本,样本特征为2
sample[:30, :] = np.random.multivariate_normal(mean=mu1, cov=cov1, size=30)      #生成多元正态分布矩阵
sample[30:, :] = np.random.multivariate_normal(mean=mu2, cov=cov2, size=70)      
np.savetxt("sample.data",sample)  # 将array保存到txt文件中

plt.plot(sample[:30, 0], sample[:30, 1], "bo")   #30个样本用蓝色圆圈标记
plt.plot(sample[30:, 0], sample[30:, 1], "rs")   #70个样本用红色方块标记
plt.title("sample_data")
plt.show()

效果:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
【附】gmm.data链接
提取码:765t

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_37763870/article/details/103332671

智能推荐

unity+高通vuforia开发增强现实(AR)基础_vuforia增强_pony-Stark的博客-程序员秘密

                                                 unity+高通vuforia开发增强现实(AR)教程(一)增强现实(Augmented Reality,简称AR),是在虚拟现实的基础上发展起来的新技术,也被称之为混合现实。是通过计算机系统提供的信息增加用户对现实世界感知的技术,将虚拟的信息应用到真实世界,并将计算机生成的虚拟物体、

totolink服务器未响应,WiFi效果差的罪魁祸首竟然是这个 TOTOLINK为你深度讲解_bathroom火冒的博客-程序员秘密

今天小编想问大家几个问题,这都是有关你和路由器之间的:问题1:你家的无线路由器放在那里?问题2:你知道WiFi信号最怕受到哪些干扰吗?问题3:路由器的天线,应该怎样摆放才对?问题4:应该怎样测试家里各个位置的信号强度呢?想想看,这四个问题你都能回答吗?你面对上述问题是否有合适的解决方案?现在的上网体验怎么样呢?是不是很多人都饱受上网体验不畅的痛快,还有很多人对上述问题根本不知道解决方案?别着急,今...

Whctf 2017 -UNTITLED- Writeup_baikeng3674的博客-程序员秘密

Whctf 2017 -UNTITLED- Writeup转载请表明出处http://www.cnblogs.com/WangAoBo/p/7541481.html分析:下载下来的附件是一个py脚本,如下 1 from Crypto.Util.number import getPrime,long_to_bytes,bytes_to_long 2 ...

mysql笔记(14)_HardyDragon_CC的博客-程序员秘密

union(并查询/组合查询) 类似 与 where和or的搭配,例如:UNION指示MySQL执行两条SELECT语句,并把输出组合成单个查询结果集。SELECT vend_id ,prod_id ,prod_price FROM products p WHERE prod_price <=5UNION SELECT vend_id ,prod_id ,prod_price FROM products p WHERE vend_id IN (1001,1002);---SEL

解决“HTTP/1.1 405 Method not allowed”问题_"get /relation http/1.1\" 405"_bing.shao的博客-程序员秘密

解决“HTTP/1.1 405 Method not allowed”问题      Apache、IIS、Nginx等绝大多数web服务器,都不允许静态文件响应POST请求,否则会返回“HTTP/1.1 405 Method not allowed”错误。 即,将出错页面表单的method=“post”改为“get”即可...appweb也是如此

2021掌握JVM 运行时数据区,其实不是很难,加薪也是要技巧可言的_2021运行时数据区_专注方法攻略分享的博客-程序员秘密

<h1><a id="_2"></a>一、概念</h1> Java 内存区域和内存模型是不一样的东西,内存区域是指 Jvm 运行时将数据分区域存储,强调对内存空间的划分。 而内存模型(Java Memory Model,简称 JMM )是定义了线程和主内存之间的抽象关系,即 JMM 定义了 JVM 在计算机内存(RAM)中的工作方式, 如果我们要想深入了解Java并发编程,就要先理解好Java内存模型。 二、JVM 运...

随便推点

Java 生成 pdf_爱游戏爱动漫的肥宅的博客-程序员秘密

Java 可以通过 IText 直接把内容生成 pdf 文件,下面的 Demo 主要演示了文件属性、页眉页脚、表格、段落文字、图片的生成。先上代码:package html2PDF;import java.io.File;import java.io.FileOutputStream;import java.io.IOException;import java.net.Malform...

tf.keras.models.load_model模型加载时报错_qq_34373543的博客-程序员秘密

tf.keras.models.load_model模型加载时报错File “h5py_objects.pyx”, line 54, in h5py._objects.with_phil.wrapper File “h5py_objects.pyx”, line 55, in h5py._objects.with_phil.wrapper File “h5py\h5f.pyx”, line 156, in h5py.h5f.is_hdf5 OSError保存模型model.save('./data/m

golang实现全局唯一id snowflake算法_golang 生成唯一id_migu666的博客-程序员秘密

在应用程序中,经常需要全局唯一的ID作为数据库主键。在一台节点容易全局唯一,那在多台节点呢?有两个思路:1使用散列函数,如sha256,加上时间戳、mac地址、cpu负荷、随机数等组成,id足够长,引入多个不确定因素,以至于碰撞几率非常小,可以认为是全局唯一。例如uuid就是这种。但是uuid是字符串的形式,对于DB来说,占用的空间至少大一倍,DB的索引是需要存储和对比的,因此在存储空间...

UITableView_Zhang_dy_blog的博客-程序员秘密

1.为tableView上添加图片,实现图片的拉大的效果self.image=[[UIImageView alloc] initWithImage:[UIImage imageNamed:@"6.jpg"]];self.image.frame=CGRectMake(0, 0, self.view.frame.size.width, 200);//给tableView添加头视

JS 中的广度与深度优先遍历_前端res.shift();_zhongjunyao的博客-程序员秘密

现在有一种类似树的数据结构,但是不存在共同的根节点 root,每一个节点的结构为 {key: 'one', value: '1', children: [...]},都包含 key 和 value,如果存在 children 则内部会存在 n 个和此结构相同的节点,现模拟数据如下图:已知一个 value 如 3-2-1,需要取出该路径上的所有 key,即期望得到 ['three', 'three-...

C++ 运算符函数重载,语法练习_sergery的博客-程序员秘密

参考: http://wenku.baidu.com/view/6075e19951e79b8968022606.html1. 用独立函数重载运算符 :#include #include using namespace std;class complex{public: //存 void setreal(double r){real = r;} void

推荐文章

热门文章

相关标签