打开APP
userphoto
未登录

开通VIP,畅享免费电子书等14项超值服

开通VIP
向前向后算法--EM算法及其java实现

向前-向后算法(forward-backward algorithm)

本文承接上篇博客《隐马尔可夫模型及的评估和解码问题》,用到的概念和例子都是那里面的。

学习问题

在HMM模型中,已知隐藏状态的集合S,观察值的集合O,以及一个观察序列(o1,o2,...,on),求使得该观察序列出现的可能性最大的模型参数(包括初始状态概率矩阵π,状态转移矩阵A,发射矩阵B)。这正好就是EM算法要求解的问题:已知一系列的观察值X,在隐含变量Y未知的情况下求最佳参数θ*,使得:

在中文词性标注里,根据为训练语料,我们观察到了一系列的词(对应EM中的X),如果每个词的词性(即隐藏状态)也是知道的,那它就不需要用EM来求模型参数θ了,因为Y是已知的,不存在隐含变量了。当没有隐含变量时,直接用maximum likelihood就可以把模型参数求出来。

预备知识

首先你得对下面的公式表示认同。

以下都是针对相互独立的事件,

P(A,B)=P(B|A)*P(A)

P(A,B,C)=P(C)*P(A,B|C)=P(A,C|B)*P(B)=P(B,C|A)*P(A)

P(A,B,C,D)=P(D)*P(A,B|D)*P(C|A)=P(D)*P(A,B|D)*P(C|B)

P(A,B|C)=P(D1,A,B|C)+P(D2,A,B|C)     D1,D2是事件D的一个全划分

理解了上面几个式子,你也就能理解本文中出现的公式是怎么推导出来的了。

EM算法求解

我们已经知道如果隐含变量Y是已知的,那么求解模型参数直接利用Maximum Likelihood就可以了。EM算法的基本思路是:随机初始化一组参数θ(0),根据后验概率Pr(Y|X;θ)来更新Y的期望E(Y),然后用E(Y)代替Y求出新的模型参数θ(1)。如此迭代直到θ趋于稳定。

在HMM问题中,隐含变量自然就是状态变量,要求状态变量的期望值,其实就是求时刻ti观察到xi时处于状态si的概率,为了求此概率,需要用到向前变量和向后变量。

向前变量

向前变量 是假定的参数

它表示t时刻满足状态

,且t时刻之前(包括t时刻)满足给定的观测序列
的概率。

  1. 令初始值
  2. 归纳法计算
  3. 最后计算
复杂度
向后变量
向后变量
               
它表示在时刻t出现状态
,且t时刻以后的观察序列满足
的概率。
  1. 初始值
  2. 归纳计算

E-Step

定义变量

为t时刻处于状态i,t+1时刻处于状态j的概率。

        

定义变量

表示t时刻呈现状态i的概率。

实际上

   
    

    

 

是从其他所有状态转移到状态i的次数的期望值。

是从状态i转移出去的次数的期望值。

是从状态i转移到状态j的次数的期望值。

M-Step

是在初始时刻出现状态i的频率的期望值,
是从状态i转移到状态j的次数的期望值  除以  从状态i转移出去的次数的期望值,
是在状态j下观察到活动为k的次数的期望值  除以  从其他所有状态转移到状态j的次数的期望值,
 
然后用新的参数
再来计算向前变量、向后变量、
。如此循环迭代,直到前后两次参数的变化量小于某个值为止。
下面给出我的java代码:
package nlp;
/**
*
@author Orisun
* date 2011-10-22
*/
import java.util.ArrayList;

public class BaumWelch {

int M; // 隐藏状态的种数
int N; // 输出活动的种数
double[] PI; // 初始状态概率矩阵
double[][] A; // 状态转移矩阵
double[][] B; // 混淆矩阵

ArrayList<Integer> observation = new ArrayList<Integer>(); // 观察到的集合
ArrayList<Integer> state = new ArrayList<Integer>(); // 中间状态集合
int[] out_seq = { 2, 1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1,
1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 1 }; // 测试用的观察序列
int[] hidden_seq = { 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1,
1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1 }; // 测试用的隐藏状态序列
int T = 32; // 序列长度为32

double[][] alpha = new double[T][]; // 向前变量
double PO;
double[][] beta = new double[T][]; // 向后变量
double[][] gamma = new double[T][];
double[][][] xi = new double[T - 1][][];

// 初始化参数。Baum-Welch得到的是局部最优解,所以初始参数直接影响解的好坏
public void initParameters() {
M = 2;
N = 2;
PI = new double[M];
PI[0] = 0.5;
PI[1] = 0.5;
A = new double[M][];
B = new double[M][];
for (int i = 0; i < M; i++) {
A[i] = new double[M];
B[i] = new double[N];
}
A[0][0] = 0.8125;
A[0][1] = 0.1875;
A[1][0] = 0.2;
A[1][1] = 0.8;
B[0][0] = 0.875;
B[0][1] = 0.125;
B[1][0] = 0.25;
B[1][1] = 0.75;

observation.add(1);
observation.add(2);
state.add(1);
state.add(2);

for (int t = 0; t < T; t++) {
alpha[t] = new double[M];
beta[t] = new double[M];
gamma[t] = new double[M];
}
for (int t = 0; t < T - 1; t++) {
xi[t] = new double[M][];
for (int i = 0; i < M; i++)
xi[t][i] = new double[M];
}
}

// 更新向前变量
public void updateAlpha() {
for (int i = 0; i < M; i++) {
alpha[0][i] = PI[i] * B[i][observation.indexOf(out_seq[0])];
}
for (int t = 1; t < T; t++) {
for (int i = 0; i < M; i++) {
alpha[t][i] = 0;
for (int j = 0; j < M; j++) {
alpha[t][i] += alpha[t - 1][j] * A[j][i];
}
alpha[t][i] *= B[i][observation.indexOf(out_seq[t])];
}
}
}

// 更新观察序列出现的概率,它在一些公式中当分母
public void updatePO() {
for (int i = 0; i < M; i++)
PO += alpha[T - 1][i];
}

// 更新向后变量
public void updateBeta() {
for (int i = 0; i < M; i++) {
beta[T - 1][i] = 1;
}
for (int t = T - 2; t >= 0; t--) {
for (int i = 0; i < M; i++) {
for (int j = 0; j < M; j++) {
beta[t][i] += A[i][j]
* B[j][observation.indexOf(out_seq[t + 1])]
* beta[t + 1][j];
}
}
}
}

// 更新xi
public void updateXi() {
for (int t = 0; t < T - 1; t++) {
double frac = 0.0;
for (int i = 0; i < M; i++) {
for (int j = 0; j < M; j++) {
frac += alpha[t][i] * A[i][j]
* B[j][observation.indexOf(out_seq[t + 1])]
* beta[t + 1][j];
}
}
for (int i = 0; i < M; i++) {
for (int j = 0; j < M; j++) {
xi[t][i][j] = alpha[t][i] * A[i][j]
* B[j][observation.indexOf(out_seq[t + 1])]
* beta[t + 1][j] / frac;
}
}
}
}

// 更新gamma
public void updateGamma() {
for (int t = 0; t < T - 1; t++) {
double frac = 0.0;
for (int i = 0; i < M; i++) {
frac += alpha[t][i] * beta[t][i];
}
// double frac = PO;
for (int i = 0; i < M; i++) {
gamma[t][i] = alpha[t][i] * beta[t][i] / frac;
}
// for(int i=0;i<M;i++){
// gamma[t][i]=0;
// for(int j=0;j<M;j++)
// gamma[t][i]+=xi[t][i][j];
// }
}
}

// 更新状态概率矩阵
public void updatePI() {
for (int i = 0; i < M; i++)
PI[i] = gamma[0][i];
}

// 更新状态转移矩阵
public void updateA() {
for (int i = 0; i < M; i++) {
double frac = 0.0;
for (int t = 0; t < T - 1; t++) {
frac += gamma[t][i];
}
for (int j = 0; j < M; j++) {
double dem = 0.0;
// for (int t = 0; t < T - 1; t++) {
// dem += xi[t][i][j];
// for (int k = 0; k < M; k++)
// frac += xi[t][i][k];
// }
for (int t = 0; t < T - 1; t++) {
dem += xi[t][i][j];
}
A[i][j] = dem / frac;
}
}
}

// 更新混淆矩阵
public void updateB() {
for (int i = 0; i < M; i++) {
double frac = 0.0;
for (int t = 0; t < T; t++)
frac += gamma[t][i];
for (int j = 0; j < N; j++) {
double dem = 0.0;
for (int t = 0; t < T; t++) {
if (out_seq[t] == observation.get(j))
dem += gamma[t][i];
}
B[i][j] = dem / frac;
}
}
}

// 运行Baum-Welch算法
public void run() {
initParameters();
int iter = 22; // 迭代次数
while (iter-- > 0) {
// E-Step
updateAlpha();
// updatePO();
updateBeta();
updateGamma();
updatePI();
updateXi();
// M-Step
updateA();
updateB();
}
}

public static void main(String[] args) {
BaumWelch bw = new BaumWelch();
bw.run();
System.out.println("训练后的初始状态概率矩阵:");
for (int i = 0; i < bw.M; i++)
System.out.print(bw.PI[i] + "\t");
System.out.println();
System.out.println("训练后的状态转移矩阵:");
for (int i = 0; i < bw.M; i++) {
for (int j = 0; j < bw.M; j++) {
System.out.print(bw.A[i][j] + "\t");
}
System.out.println();
}
System.out.println("训练后的混淆矩阵:");
for (int i = 0; i < bw.M; i++) {
for (int j = 0; j < bw.N; j++) {
System.out.print(bw.B[i][j] + "\t");
}
System.out.println();
}
}
}
迭代22次后得到的参数:
训练后的初始状态概率矩阵:
6.72801479161809E-301.0
训练后的状态转移矩阵:
0.76720211710795320.23282165928765827
0.357061195165864760.6429096688758965
训练后的混淆矩阵:
0.99589658628791480.004103413712085399
2.135019831171061E-60.9999978649801687

原文来自:博客园(华夏35度)http://www.cnblogs.com/zhangchaoyang 作者:Orisun
本站仅提供存储服务,所有内容均由用户发布,如发现有害或侵权内容,请点击举报
打开APP,阅读全文并永久保存 查看更多类似文章
猜你喜欢
类似文章
【热】打开小程序,算一算2024你的财运
矩阵的LU分解 c++
概率模型
向前
OpenCV矩阵运算
HMM学习最佳范例
数字信号处理
更多类似文章 >>
生活服务
热点新闻
分享 收藏 导长图 关注 下载文章
绑定账号成功
后续可登录账号畅享VIP特权!
如果VIP功能使用有故障,
可点击这里联系客服!

联系客服