Python实现EM算法实例代码


Posted in Python onOctober 04, 2020

EM算法实例

通过实例可以快速了解EM算法的基本思想,具体推导请点文末链接。图a是让我们预热的,图b是EM算法的实例。

这是一个抛硬币的例子,H表示正面向上,T表示反面向上,参数θ表示正面朝上的概率。硬币有两个,A和B,硬币是有偏的。本次实验总共做了5组,每组随机选一个硬币,连续抛10次。如果知道每次抛的是哪个硬币,那么计算参数θ就非常简单了,如

下图所示:

Python实现EM算法实例代码

如果不知道每次抛的是哪个硬币呢?那么,我们就需要用EM算法,基本步骤为:

  1、给θ_AθA​和θ_BθB​一个初始值;

  2、(E-step)估计每组实验是硬币A的概率(本组实验是硬币B的概率=1-本组实验是硬币A的概率)。分别计算每组实验中,选择A硬币且正面朝上次数的期望值,选择B硬币且正面朝上次数的期望值;

  3、(M-step)利用第三步求得的期望值重新计算θ_AθA​和θ_BθB​;

  4、当迭代到一定次数,或者算法收敛到一定精度,结束算法,否则,回到第2步。

Python实现EM算法实例代码

计算过程详解:初始值θ_A^{(0)}θA(0)​=0.6,θ_B^{(0)}θB(0)​=0.5。

由两个硬币的初始值0.6和0.5,容易得出投掷出5正5反的概率是p_A=C^5_{10}*(0.6^5)*(0.4^5)pA​=C105​∗(0.65)∗(0.45),p_B=C_{10}^5*(0.5^5)*(0.5^5)pB​=C105​∗(0.55)∗(0.55), p_ApA​/(p_ApA​+p_BpB​)=0.449, 0.45就是0.449近似而来的,表示第一组实验选择的硬币是A的概率为0.45。然后,0.449 * 5H = 2.2H ,0.449 * 5T = 2.2T ,表示第一组实验选择A硬币且正面朝上次数和反面朝上次数的期望值都是2.2,其他的值依次类推。最后,求出θ_A^{(1)}θA(1)​=0.71,θ_B^{(1)}θB(1)​=0.58。重复上述过程,不断迭代,直到算法收敛到一定精度为止。

这篇博客对EM算法的推导非常详细,链接如下:

https://blog.csdn.net/zhihua_oba/article/details/73776553

Python实现

#coding=utf-8
from numpy import *
from scipy import stats
import time
start = time.perf_counter()

def em_single(priors,observations):
 """
 EM算法的单次迭代
 Arguments
 ------------
 priors:[theta_A,theta_B]
 observation:[m X n matrix]

 Returns
 ---------------
 new_priors:[new_theta_A,new_theta_B]
 :param priors:
 :param observations:
 :return:
 """
 counts = {'A': {'H': 0, 'T': 0}, 'B': {'H': 0, 'T': 0}}
 theta_A = priors[0]
 theta_B = priors[1]
 #E step
 for observation in observations:
  len_observation = len(observation)
  num_heads = observation.sum()
  num_tails = len_observation-num_heads
  #二项分布求解公式
  contribution_A = stats.binom.pmf(num_heads,len_observation,theta_A)
  contribution_B = stats.binom.pmf(num_heads,len_observation,theta_B)

  weight_A = contribution_A / (contribution_A + contribution_B)
  weight_B = contribution_B / (contribution_A + contribution_B)
  #更新在当前参数下A,B硬币产生的正反面次数
  counts['A']['H'] += weight_A * num_heads
  counts['A']['T'] += weight_A * num_tails
  counts['B']['H'] += weight_B * num_heads
  counts['B']['T'] += weight_B * num_tails

 # M step
 new_theta_A = counts['A']['H'] / (counts['A']['H'] + counts['A']['T'])
 new_theta_B = counts['B']['H'] / (counts['B']['H'] + counts['B']['T'])
 return [new_theta_A,new_theta_B]


def em(observations,prior,tol = 1e-6,iterations=10000):
 """
 EM算法
 :param observations :观测数据
 :param prior:模型初值
 :param tol:迭代结束阈值
 :param iterations:最大迭代次数
 :return:局部最优的模型参数
 """
 iteration = 0;
 while iteration < iterations:
  new_prior = em_single(prior,observations)
  delta_change = abs(prior[0]-new_prior[0])
  if delta_change < tol:
   break
  else:
   prior = new_prior
   iteration +=1
 return [new_prior,iteration]

#硬币投掷结果
observations = array([[1,0,0,0,1,1,0,1,0,1],
      [1,1,1,1,0,1,1,1,0,1],
      [1,0,1,1,1,1,1,0,1,1],
      [1,0,1,0,0,0,1,1,0,0],
      [0,1,1,1,0,1,1,1,0,1]])
print (em(observations,[0.6,0.5]))
end = time.perf_counter()
print('Running time: %f seconds'%(end-start))

总结

到此这篇关于Python实现EM算法实例的文章就介绍到这了,更多相关Python实现EM算法实例内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python中的jquery PyQuery库使用小结
May 13 Python
python创建和删除目录的方法
Apr 29 Python
使用python 和 lint 删除项目无用资源的方法
Dec 20 Python
Python编程实现线性回归和批量梯度下降法代码实例
Jan 04 Python
Python实现PS图像调整黑白效果示例
Jan 25 Python
用于业余项目的8个优秀Python库
Sep 21 Python
几行Python代码爬取3000+上市公司的信息
Jan 24 Python
python使用opencv实现马赛克效果示例
Sep 28 Python
Pytorch中.new()的作用详解
Feb 18 Python
python实现遍历文件夹图片并重命名
Mar 23 Python
pandas 强制类型转换 df.astype实例
Apr 09 Python
Python爬虫入门有哪些基础知识点
Jun 02 Python
python em算法的实现
Oct 03 #Python
浅析Python中字符串的intern机制
Oct 03 #Python
Python实现AES加密,解密的两种方法
Oct 03 #Python
python实现AdaBoost算法的示例
Oct 03 #Python
Django创建一个后台的基本步骤记录
Oct 02 #Python
Python中qutip用法示例详解
Oct 02 #Python
如何利用Python给自己的头像加一个小国旗(小月饼)
Oct 02 #Python
You might like
利用PHP动态生成VRML网页
2006/10/09 PHP
php win下Socket方式发邮件类
2009/08/21 PHP
ThinkPHP访问不存在的模块跳转到404页面的方法
2014/06/19 PHP
PHP封装的HttpClient类用法实例
2015/06/17 PHP
php根据一个给定范围和步进生成数组的方法
2015/06/19 PHP
在CentOS系统上从零开始搭建WordPress博客的全流程记录
2016/04/21 PHP
PHP实现的DES加密解密类定义与用法示例
2020/11/02 PHP
PHP实现数据四舍五入的方法小结【4种方法】
2019/03/27 PHP
不错的asp中显示新闻的功能
2006/10/13 Javascript
javascript 日期时间函数(经典+完善+实用)
2009/05/27 Javascript
JavaScript 变量命名规则
2009/09/23 Javascript
分别用marquee和div+js实现首尾相连循环滚动效果,仅3行代码
2011/09/21 Javascript
jQuery实现的在线答题功能
2015/04/12 Javascript
vue事件修饰符和按键修饰符用法总结
2017/07/25 Javascript
详解angular笔记路由之angular-router
2017/09/12 Javascript
Vue中父子组件通讯之todolist组件功能开发
2018/05/21 Javascript
详解React+Koa实现服务端渲染(SSR)
2018/05/23 Javascript
Layui数据表格之获取表格中所有的数据方法
2018/08/20 Javascript
10分钟彻底搞懂Http的强制缓存和协商缓存(小结)
2018/08/30 Javascript
解决angular 使用原生拖拽页面卡顿及表单控件输入延迟问题
2020/04/21 Javascript
python和shell实现的校验IP地址合法性脚本分享
2014/10/23 Python
使用Python的Zato发送AMQP消息的教程
2015/04/16 Python
python 字符串转列表 list 出现\ufeff的解决方法
2017/06/22 Python
Django RBAC权限管理设计过程详解
2019/08/06 Python
Python解析微信dat文件的方法
2020/11/30 Python
工作违纪检讨书
2014/02/17 职场文书
租房安全协议书
2014/08/20 职场文书
小学生五年级大队长竞选发言稿
2014/09/12 职场文书
白鹤梁导游词
2015/02/06 职场文书
结婚保证书(卖身契)
2015/02/26 职场文书
民主评议教师党员自我评价
2015/03/04 职场文书
预备党员入党感言
2015/08/01 职场文书
少先队中队工作总结
2015/08/14 职场文书
晶体管单管来复再生式收音机
2021/04/22 无线电
CSS完成视差滚动效果
2021/04/27 HTML / CSS
Python使用scapy模块发包收包
2021/05/07 Python