Python实现EM算法实例代码


Posted in Python onOctober 04, 2020

EM算法实例

通过实例可以快速了解EM算法的基本思想,具体推导请点文末链接。图a是让我们预热的,图b是EM算法的实例。

这是一个抛硬币的例子,H表示正面向上,T表示反面向上,参数θ表示正面朝上的概率。硬币有两个,A和B,硬币是有偏的。本次实验总共做了5组,每组随机选一个硬币,连续抛10次。如果知道每次抛的是哪个硬币,那么计算参数θ就非常简单了,如

下图所示:

Python实现EM算法实例代码

如果不知道每次抛的是哪个硬币呢?那么,我们就需要用EM算法,基本步骤为:

  1、给θ_AθA​和θ_BθB​一个初始值;

  2、(E-step)估计每组实验是硬币A的概率(本组实验是硬币B的概率=1-本组实验是硬币A的概率)。分别计算每组实验中,选择A硬币且正面朝上次数的期望值,选择B硬币且正面朝上次数的期望值;

  3、(M-step)利用第三步求得的期望值重新计算θ_AθA​和θ_BθB​;

  4、当迭代到一定次数,或者算法收敛到一定精度,结束算法,否则,回到第2步。

Python实现EM算法实例代码

计算过程详解:初始值θ_A^{(0)}θA(0)​=0.6,θ_B^{(0)}θB(0)​=0.5。

由两个硬币的初始值0.6和0.5,容易得出投掷出5正5反的概率是p_A=C^5_{10}*(0.6^5)*(0.4^5)pA​=C105​∗(0.65)∗(0.45),p_B=C_{10}^5*(0.5^5)*(0.5^5)pB​=C105​∗(0.55)∗(0.55), p_ApA​/(p_ApA​+p_BpB​)=0.449, 0.45就是0.449近似而来的,表示第一组实验选择的硬币是A的概率为0.45。然后,0.449 * 5H = 2.2H ,0.449 * 5T = 2.2T ,表示第一组实验选择A硬币且正面朝上次数和反面朝上次数的期望值都是2.2,其他的值依次类推。最后,求出θ_A^{(1)}θA(1)​=0.71,θ_B^{(1)}θB(1)​=0.58。重复上述过程,不断迭代,直到算法收敛到一定精度为止。

这篇博客对EM算法的推导非常详细,链接如下:

https://blog.csdn.net/zhihua_oba/article/details/73776553

Python实现

#coding=utf-8
from numpy import *
from scipy import stats
import time
start = time.perf_counter()

def em_single(priors,observations):
 """
 EM算法的单次迭代
 Arguments
 ------------
 priors:[theta_A,theta_B]
 observation:[m X n matrix]

 Returns
 ---------------
 new_priors:[new_theta_A,new_theta_B]
 :param priors:
 :param observations:
 :return:
 """
 counts = {'A': {'H': 0, 'T': 0}, 'B': {'H': 0, 'T': 0}}
 theta_A = priors[0]
 theta_B = priors[1]
 #E step
 for observation in observations:
  len_observation = len(observation)
  num_heads = observation.sum()
  num_tails = len_observation-num_heads
  #二项分布求解公式
  contribution_A = stats.binom.pmf(num_heads,len_observation,theta_A)
  contribution_B = stats.binom.pmf(num_heads,len_observation,theta_B)

  weight_A = contribution_A / (contribution_A + contribution_B)
  weight_B = contribution_B / (contribution_A + contribution_B)
  #更新在当前参数下A,B硬币产生的正反面次数
  counts['A']['H'] += weight_A * num_heads
  counts['A']['T'] += weight_A * num_tails
  counts['B']['H'] += weight_B * num_heads
  counts['B']['T'] += weight_B * num_tails

 # M step
 new_theta_A = counts['A']['H'] / (counts['A']['H'] + counts['A']['T'])
 new_theta_B = counts['B']['H'] / (counts['B']['H'] + counts['B']['T'])
 return [new_theta_A,new_theta_B]


def em(observations,prior,tol = 1e-6,iterations=10000):
 """
 EM算法
 :param observations :观测数据
 :param prior:模型初值
 :param tol:迭代结束阈值
 :param iterations:最大迭代次数
 :return:局部最优的模型参数
 """
 iteration = 0;
 while iteration < iterations:
  new_prior = em_single(prior,observations)
  delta_change = abs(prior[0]-new_prior[0])
  if delta_change < tol:
   break
  else:
   prior = new_prior
   iteration +=1
 return [new_prior,iteration]

#硬币投掷结果
observations = array([[1,0,0,0,1,1,0,1,0,1],
      [1,1,1,1,0,1,1,1,0,1],
      [1,0,1,1,1,1,1,0,1,1],
      [1,0,1,0,0,0,1,1,0,0],
      [0,1,1,1,0,1,1,1,0,1]])
print (em(observations,[0.6,0.5]))
end = time.perf_counter()
print('Running time: %f seconds'%(end-start))

总结

到此这篇关于Python实现EM算法实例的文章就介绍到这了,更多相关Python实现EM算法实例内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
常用python数据类型转换函数总结
Mar 11 Python
Python2.6版本中实现字典推导 PEP 274(Dict Comprehensions)
Apr 28 Python
深入讲解Python中面向对象编程的相关知识
May 25 Python
python基于phantomjs实现导入图片
May 13 Python
python+POP3实现批量下载邮件附件
Jun 19 Python
python opencv实现切变换 不裁减图片
Jul 26 Python
Tensorflow 实现修改张量特定元素的值方法
Jul 30 Python
浅谈pyqt5中信号与槽的认识
Feb 17 Python
Python玩转加密的技巧【推荐】
May 13 Python
Python手绘可视化工具cutecharts使用实例
Dec 05 Python
Python读取表格类型文件代码实例
Feb 17 Python
深入浅析Python代码规范性检测
Jul 31 Python
python em算法的实现
Oct 03 #Python
浅析Python中字符串的intern机制
Oct 03 #Python
Python实现AES加密,解密的两种方法
Oct 03 #Python
python实现AdaBoost算法的示例
Oct 03 #Python
Django创建一个后台的基本步骤记录
Oct 02 #Python
Python中qutip用法示例详解
Oct 02 #Python
如何利用Python给自己的头像加一个小国旗(小月饼)
Oct 02 #Python
You might like
在wamp集成环境下升级php版本(实现方法)
2013/07/01 PHP
php删除数组元素示例分享
2014/02/17 PHP
Symfony2实现在doctrine中内置数据的方法
2016/02/05 PHP
取得一定长度的内容,处理中文
2006/12/20 Javascript
jQuery formValidator表单验证插件开源了 含API帮助、源码、示例
2008/08/14 Javascript
利用jquery的获取JS文件中的字符串内容
2012/02/14 Javascript
javascript随机抽取0-100之间不重复的10个数
2016/02/25 Javascript
如何判断Javascript对象是否存在的简单实例
2016/05/18 Javascript
JS获取屏幕高度的简单实现代码
2016/05/24 Javascript
AngularJS 过滤与排序详解及实例代码
2016/09/14 Javascript
angular实现表单验证及提交功能
2017/02/01 Javascript
提高Web性能的前端优化技巧总结
2017/02/27 Javascript
浅谈webpack打包之后的文件过大的解决方法
2018/03/07 Javascript
详解Vue3 Composition API中的提取和重用逻辑
2020/04/29 Javascript
react结合bootstrap实现评论功能
2020/05/30 Javascript
详解JavaScript之ES5的继承
2020/07/08 Javascript
windows下安装python paramiko模块的代码
2013/02/10 Python
Python 序列化 pickle/cPickle模块使用介绍
2014/11/30 Python
python列表操作实例
2015/01/14 Python
深入浅析Python获取对象信息的函数type()、isinstance()、dir()
2018/09/17 Python
Python下opencv图像阈值处理的使用笔记
2019/08/04 Python
python主线程与子线程的结束顺序实例解析
2019/12/17 Python
Python3.x+迅雷x 自动下载高分电影的实现方法
2020/01/12 Python
Python使用configparser库读取配置文件
2020/02/22 Python
css3新单位vw、vh的使用教程
2018/03/23 HTML / CSS
HTML5实现页面切换激活的PageVisibility API使用初探
2016/05/13 HTML / CSS
详解HTML5 data-* 自定义属性
2018/01/24 HTML / CSS
Willer台湾:日本高速巴士/夜行巴士预约
2017/07/09 全球购物
金融专业个人求职信
2013/09/22 职场文书
领导干部群众路线个人对照检查材料思想汇报
2014/09/30 职场文书
2014年幼儿园园长工作总结
2014/12/17 职场文书
2015年公共机构节能宣传周活动总结
2015/03/26 职场文书
法人代表资格证明书
2015/06/18 职场文书
村官2015年度工作总结
2015/10/14 职场文书
2016年优秀班主任先进事迹材料
2016/02/26 职场文书
python playwright 自动等待和断言详解
2021/11/27 Python