Python实现EM算法实例代码


Posted in Python onOctober 04, 2020

EM算法实例

通过实例可以快速了解EM算法的基本思想,具体推导请点文末链接。图a是让我们预热的,图b是EM算法的实例。

这是一个抛硬币的例子,H表示正面向上,T表示反面向上,参数θ表示正面朝上的概率。硬币有两个,A和B,硬币是有偏的。本次实验总共做了5组,每组随机选一个硬币,连续抛10次。如果知道每次抛的是哪个硬币,那么计算参数θ就非常简单了,如

下图所示:

Python实现EM算法实例代码

如果不知道每次抛的是哪个硬币呢?那么,我们就需要用EM算法,基本步骤为:

  1、给θ_AθA​和θ_BθB​一个初始值;

  2、(E-step)估计每组实验是硬币A的概率(本组实验是硬币B的概率=1-本组实验是硬币A的概率)。分别计算每组实验中,选择A硬币且正面朝上次数的期望值,选择B硬币且正面朝上次数的期望值;

  3、(M-step)利用第三步求得的期望值重新计算θ_AθA​和θ_BθB​;

  4、当迭代到一定次数,或者算法收敛到一定精度,结束算法,否则,回到第2步。

Python实现EM算法实例代码

计算过程详解:初始值θ_A^{(0)}θA(0)​=0.6,θ_B^{(0)}θB(0)​=0.5。

由两个硬币的初始值0.6和0.5,容易得出投掷出5正5反的概率是p_A=C^5_{10}*(0.6^5)*(0.4^5)pA​=C105​∗(0.65)∗(0.45),p_B=C_{10}^5*(0.5^5)*(0.5^5)pB​=C105​∗(0.55)∗(0.55), p_ApA​/(p_ApA​+p_BpB​)=0.449, 0.45就是0.449近似而来的,表示第一组实验选择的硬币是A的概率为0.45。然后,0.449 * 5H = 2.2H ,0.449 * 5T = 2.2T ,表示第一组实验选择A硬币且正面朝上次数和反面朝上次数的期望值都是2.2,其他的值依次类推。最后,求出θ_A^{(1)}θA(1)​=0.71,θ_B^{(1)}θB(1)​=0.58。重复上述过程,不断迭代,直到算法收敛到一定精度为止。

这篇博客对EM算法的推导非常详细,链接如下:

https://blog.csdn.net/zhihua_oba/article/details/73776553

Python实现

#coding=utf-8
from numpy import *
from scipy import stats
import time
start = time.perf_counter()

def em_single(priors,observations):
 """
 EM算法的单次迭代
 Arguments
 ------------
 priors:[theta_A,theta_B]
 observation:[m X n matrix]

 Returns
 ---------------
 new_priors:[new_theta_A,new_theta_B]
 :param priors:
 :param observations:
 :return:
 """
 counts = {'A': {'H': 0, 'T': 0}, 'B': {'H': 0, 'T': 0}}
 theta_A = priors[0]
 theta_B = priors[1]
 #E step
 for observation in observations:
  len_observation = len(observation)
  num_heads = observation.sum()
  num_tails = len_observation-num_heads
  #二项分布求解公式
  contribution_A = stats.binom.pmf(num_heads,len_observation,theta_A)
  contribution_B = stats.binom.pmf(num_heads,len_observation,theta_B)

  weight_A = contribution_A / (contribution_A + contribution_B)
  weight_B = contribution_B / (contribution_A + contribution_B)
  #更新在当前参数下A,B硬币产生的正反面次数
  counts['A']['H'] += weight_A * num_heads
  counts['A']['T'] += weight_A * num_tails
  counts['B']['H'] += weight_B * num_heads
  counts['B']['T'] += weight_B * num_tails

 # M step
 new_theta_A = counts['A']['H'] / (counts['A']['H'] + counts['A']['T'])
 new_theta_B = counts['B']['H'] / (counts['B']['H'] + counts['B']['T'])
 return [new_theta_A,new_theta_B]


def em(observations,prior,tol = 1e-6,iterations=10000):
 """
 EM算法
 :param observations :观测数据
 :param prior:模型初值
 :param tol:迭代结束阈值
 :param iterations:最大迭代次数
 :return:局部最优的模型参数
 """
 iteration = 0;
 while iteration < iterations:
  new_prior = em_single(prior,observations)
  delta_change = abs(prior[0]-new_prior[0])
  if delta_change < tol:
   break
  else:
   prior = new_prior
   iteration +=1
 return [new_prior,iteration]

#硬币投掷结果
observations = array([[1,0,0,0,1,1,0,1,0,1],
      [1,1,1,1,0,1,1,1,0,1],
      [1,0,1,1,1,1,1,0,1,1],
      [1,0,1,0,0,0,1,1,0,0],
      [0,1,1,1,0,1,1,1,0,1]])
print (em(observations,[0.6,0.5]))
end = time.perf_counter()
print('Running time: %f seconds'%(end-start))

总结

到此这篇关于Python实现EM算法实例的文章就介绍到这了,更多相关Python实现EM算法实例内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
深入浅析python继承问题
May 29 Python
python 简单的绘图工具turtle使用详解
Jun 21 Python
python3+PyQt5实现支持多线程的页面索引器应用程序
Apr 20 Python
Python3使用turtle绘制超立方体图形示例
Jun 19 Python
python 统计数组中元素出现次数并进行排序的实例
Jul 02 Python
python中pip的安装与使用教程
Aug 10 Python
flask应用部署到服务器的方法
Jul 12 Python
django mysql数据库及图片上传接口详解
Jul 18 Python
原生python实现knn分类算法
Oct 24 Python
Python开发之pip安装及使用方法详解
Feb 21 Python
Python实现自动装机功能案例分析
Oct 22 Python
在Python中如何使用yield
Jun 07 Python
python em算法的实现
Oct 03 #Python
浅析Python中字符串的intern机制
Oct 03 #Python
Python实现AES加密,解密的两种方法
Oct 03 #Python
python实现AdaBoost算法的示例
Oct 03 #Python
Django创建一个后台的基本步骤记录
Oct 02 #Python
Python中qutip用法示例详解
Oct 02 #Python
如何利用Python给自己的头像加一个小国旗(小月饼)
Oct 02 #Python
You might like
php中函数的形参与实参的问题说明
2010/09/01 PHP
php数组查找函数总结
2014/11/18 PHP
php自动加载方式集合
2016/04/04 PHP
php使用自定义函数实现汉字分割替换功能示例
2017/01/30 PHP
如何直接访问php实例对象中的private属性详解
2017/10/12 PHP
BOOM vs RR BO5 第二场 2.14
2021/03/10 DOTA
网页中表单按回车就自动提交的问题的解决方案
2014/11/03 Javascript
快速学习JavaScript的6个思维技巧
2015/10/13 Javascript
详解JavaScript跨域总结与解决办法
2016/10/31 Javascript
Bootstrap实现提示框和弹出框效果
2017/01/11 Javascript
ES6下React组件的写法示例代码
2017/05/04 Javascript
react router 4.0以上的路由应用详解
2017/09/21 Javascript
解决 viewer.js 动态更新图片导致无法预览的问题
2019/05/14 Javascript
layer.prompt输入层的例子
2019/09/24 Javascript
taro小程序添加骨架屏的实现代码
2019/11/15 Javascript
vue实现在线学生录入系统
2020/05/30 Javascript
用vue写一个日历
2020/11/02 Javascript
深入讲解Python函数中参数的使用及默认参数的陷阱
2016/03/13 Python
使用Python和xlwt向Excel文件中写入中文的实例
2018/04/21 Python
Python文本统计功能之西游记用字统计操作示例
2018/05/07 Python
梅尔倒谱系数(MFCC)实现
2019/06/19 Python
利用PyCharm操作Github(仓库新建、更新,代码回滚)
2019/12/18 Python
pytorch之添加BN的实现
2020/01/06 Python
Python面向对象程序设计之继承、多态原理与用法详解
2020/03/23 Python
python实现五子棋程序
2020/04/24 Python
Python基于tkinter canvas实现图片裁剪功能
2020/11/05 Python
css3实现背景颜色渐变让图片不再是唯一的实现方式
2012/12/18 HTML / CSS
用html5绘制折线图的实例代码
2016/03/25 HTML / CSS
3种方式实现瀑布流布局小结
2019/09/05 HTML / CSS
面向对象编程的优势是什么
2015/12/17 面试题
销售会计工作职责
2013/12/02 职场文书
医学院校毕业生自荐信范文
2014/01/01 职场文书
大学生职业生涯规划书范文
2014/01/04 职场文书
惊涛骇浪观后感
2015/06/05 职场文书
2016党员干部反腐倡廉心得体会
2016/01/13 职场文书
python如何利用cv2模块读取显示保存图片
2021/06/04 Python