Python实现EM算法实例代码


Posted in Python onOctober 04, 2020

EM算法实例

通过实例可以快速了解EM算法的基本思想,具体推导请点文末链接。图a是让我们预热的,图b是EM算法的实例。

这是一个抛硬币的例子,H表示正面向上,T表示反面向上,参数θ表示正面朝上的概率。硬币有两个,A和B,硬币是有偏的。本次实验总共做了5组,每组随机选一个硬币,连续抛10次。如果知道每次抛的是哪个硬币,那么计算参数θ就非常简单了,如

下图所示:

Python实现EM算法实例代码

如果不知道每次抛的是哪个硬币呢?那么,我们就需要用EM算法,基本步骤为:

  1、给θ_AθA​和θ_BθB​一个初始值;

  2、(E-step)估计每组实验是硬币A的概率(本组实验是硬币B的概率=1-本组实验是硬币A的概率)。分别计算每组实验中,选择A硬币且正面朝上次数的期望值,选择B硬币且正面朝上次数的期望值;

  3、(M-step)利用第三步求得的期望值重新计算θ_AθA​和θ_BθB​;

  4、当迭代到一定次数,或者算法收敛到一定精度,结束算法,否则,回到第2步。

Python实现EM算法实例代码

计算过程详解:初始值θ_A^{(0)}θA(0)​=0.6,θ_B^{(0)}θB(0)​=0.5。

由两个硬币的初始值0.6和0.5,容易得出投掷出5正5反的概率是p_A=C^5_{10}*(0.6^5)*(0.4^5)pA​=C105​∗(0.65)∗(0.45),p_B=C_{10}^5*(0.5^5)*(0.5^5)pB​=C105​∗(0.55)∗(0.55), p_ApA​/(p_ApA​+p_BpB​)=0.449, 0.45就是0.449近似而来的,表示第一组实验选择的硬币是A的概率为0.45。然后,0.449 * 5H = 2.2H ,0.449 * 5T = 2.2T ,表示第一组实验选择A硬币且正面朝上次数和反面朝上次数的期望值都是2.2,其他的值依次类推。最后,求出θ_A^{(1)}θA(1)​=0.71,θ_B^{(1)}θB(1)​=0.58。重复上述过程,不断迭代,直到算法收敛到一定精度为止。

这篇博客对EM算法的推导非常详细,链接如下:

https://blog.csdn.net/zhihua_oba/article/details/73776553

Python实现

#coding=utf-8
from numpy import *
from scipy import stats
import time
start = time.perf_counter()

def em_single(priors,observations):
 """
 EM算法的单次迭代
 Arguments
 ------------
 priors:[theta_A,theta_B]
 observation:[m X n matrix]

 Returns
 ---------------
 new_priors:[new_theta_A,new_theta_B]
 :param priors:
 :param observations:
 :return:
 """
 counts = {'A': {'H': 0, 'T': 0}, 'B': {'H': 0, 'T': 0}}
 theta_A = priors[0]
 theta_B = priors[1]
 #E step
 for observation in observations:
  len_observation = len(observation)
  num_heads = observation.sum()
  num_tails = len_observation-num_heads
  #二项分布求解公式
  contribution_A = stats.binom.pmf(num_heads,len_observation,theta_A)
  contribution_B = stats.binom.pmf(num_heads,len_observation,theta_B)

  weight_A = contribution_A / (contribution_A + contribution_B)
  weight_B = contribution_B / (contribution_A + contribution_B)
  #更新在当前参数下A,B硬币产生的正反面次数
  counts['A']['H'] += weight_A * num_heads
  counts['A']['T'] += weight_A * num_tails
  counts['B']['H'] += weight_B * num_heads
  counts['B']['T'] += weight_B * num_tails

 # M step
 new_theta_A = counts['A']['H'] / (counts['A']['H'] + counts['A']['T'])
 new_theta_B = counts['B']['H'] / (counts['B']['H'] + counts['B']['T'])
 return [new_theta_A,new_theta_B]


def em(observations,prior,tol = 1e-6,iterations=10000):
 """
 EM算法
 :param observations :观测数据
 :param prior:模型初值
 :param tol:迭代结束阈值
 :param iterations:最大迭代次数
 :return:局部最优的模型参数
 """
 iteration = 0;
 while iteration < iterations:
  new_prior = em_single(prior,observations)
  delta_change = abs(prior[0]-new_prior[0])
  if delta_change < tol:
   break
  else:
   prior = new_prior
   iteration +=1
 return [new_prior,iteration]

#硬币投掷结果
observations = array([[1,0,0,0,1,1,0,1,0,1],
      [1,1,1,1,0,1,1,1,0,1],
      [1,0,1,1,1,1,1,0,1,1],
      [1,0,1,0,0,0,1,1,0,0],
      [0,1,1,1,0,1,1,1,0,1]])
print (em(observations,[0.6,0.5]))
end = time.perf_counter()
print('Running time: %f seconds'%(end-start))

总结

到此这篇关于Python实现EM算法实例的文章就介绍到这了,更多相关Python实现EM算法实例内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
详解Python list 与 NumPy.ndarry 切片之间的对比
Jul 24 Python
flask框架使用orm连接数据库的方法示例
Jul 16 Python
python requests 测试代理ip是否生效
Jul 25 Python
利用Python如何批量修改数据库执行Sql文件
Jul 29 Python
基于pycharm导入模块显示不存在的解决方法
Oct 13 Python
Python爬虫运用正则表达式的方法和优缺点
Aug 25 Python
python excel转换csv代码实例
Aug 26 Python
python opencv 检测移动物体并截图保存实例
Mar 10 Python
python实现一个猜拳游戏
Apr 05 Python
python中关于数据类型的学习笔记
Jul 19 Python
python如何做代码性能分析
Apr 26 Python
Django分页器的用法你都了解吗
May 26 Python
python em算法的实现
Oct 03 #Python
浅析Python中字符串的intern机制
Oct 03 #Python
Python实现AES加密,解密的两种方法
Oct 03 #Python
python实现AdaBoost算法的示例
Oct 03 #Python
Django创建一个后台的基本步骤记录
Oct 02 #Python
Python中qutip用法示例详解
Oct 02 #Python
如何利用Python给自己的头像加一个小国旗(小月饼)
Oct 02 #Python
You might like
PHP中的正则表达式函数介绍
2012/02/27 PHP
php反射类ReflectionClass用法分析
2016/05/12 PHP
php 计算两个时间相差的天数、小时数、分钟数、秒数详解及实例代码
2016/11/09 PHP
js 动态选中下拉框
2009/11/26 Javascript
父节点获取子节点的字符串示例代码
2014/02/26 Javascript
javascript中的return和闭包函数浅析
2014/06/06 Javascript
JavaScript的Date()方法使用详解
2015/06/09 Javascript
jQuery基于ajax实现星星评论代码
2015/08/07 Javascript
javascript DOM的详解及实例代码
2017/03/06 Javascript
[02:16]卖萌的僵尸 DOTA2神话信使飞僵小宝来袭
2014/03/24 DOTA
解析Python中while true的使用
2015/10/13 Python
完美解决Python2操作中文名文件乱码的问题
2017/01/04 Python
Python处理PDF及生成多层PDF实例代码
2017/04/24 Python
在unittest中使用 logging 模块记录测试数据的方法
2018/11/30 Python
Python 存储字符串时节省空间的方法
2019/04/23 Python
python实现简单成绩录入系统
2019/09/19 Python
Mac中PyCharm配置Anaconda环境的方法
2020/03/04 Python
详解python中的闭包
2020/09/07 Python
详解win10下pytorch-gpu安装以及CUDA详细安装过程
2021/01/28 Python
世界上最大的巴士旅游观光公司:Big Bus Tours
2016/10/20 全球购物
卡塔尔航空官方网站:Qatar Airways
2017/02/08 全球购物
英国现代市场:ARKET
2019/04/10 全球购物
红旗方阵解说词
2014/02/12 职场文书
陈欧广告词
2014/03/14 职场文书
2014年教师政治学习材料
2014/06/02 职场文书
任命书格式
2014/06/05 职场文书
敬老月活动总结
2014/08/28 职场文书
工作散漫检讨书
2014/09/16 职场文书
2014年校长工作总结
2014/12/11 职场文书
于丹讲座视频观后感
2015/06/15 职场文书
运动会加油稿
2015/07/22 职场文书
物业管理交接协议书
2016/03/24 职场文书
Node.js实现断点续传
2021/06/23 Javascript
使用GO语言实现Mysql数据库CURD的简单示例
2021/08/07 Golang
关于Python使用turtle库画任意图的问题
2022/04/01 Python
CSS实现鼠标悬浮动画特效
2023/05/07 HTML / CSS