Skip to content
Snippets Groups Projects
Commit 153983ac authored by Chao Weng's avatar Chao Weng
Browse files

helper functions to do frame level MPE

git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@893 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
parent 4c421219
No related branches found
No related tags found
No related merge requests found
// lat/lattice-functions.cc
// Copyright 2009-2011 Saarland University 2012 Daniel Povey
// Authors: Arnab Ghoshal Daniel Povey
// Authors: Arnab Ghoshal Daniel Povey Chao Weng
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -306,13 +306,20 @@ int32 LatticePhoneFrameAccuracy(const Lattice &hyp, const TransitionModel &trans
static void ForwardNodeMpe(const Lattice &lat, const TransitionModel &trans,
int32 state, int32 cur_time,
const vector< map<int32, char> > &arc_accs,
vector< pair<double, double> > *state_alphas);
const vector<double> state_alphas,
vector<double> *state_alphas_mpe);
static void BackwardNodeMpe(const Lattice &lat, const TransitionModel &trans,
int32 state, int32 cur_time,
pair<double, double> tot_forward_score,
double tot_forward_prob, double tot_forward_score,
const vector< vector<int32> > &active_states,
const vector< pair<double, double> > &state_alphas,
vector< pair<double, double> > *state_betas);
const vector< map<int32, char> > &arc_accs,
const vector<double> &state_alphas,
const vector<double> &state_alphas_mpe,
const vector<double> &state_betas,
vector<double> *state_betas_mpe,
map<int32, double> *post);
BaseFloat LatticeForwardBackwardMpe(const Lattice &lat,
const TransitionModel &trans,
......@@ -329,54 +336,84 @@ BaseFloat LatticeForwardBackwardMpe(const Lattice &lat,
int32 max_time = LatticeStateTimes(lat, &state_times);
vector< vector<int32> > active_states(max_time + 1);
// the +1 is needed since time is indexed from 0
vector<double> state_alphas(num_states, kLogZeroDouble),
state_alphas_mpe(num_states, 0), //forward variable for mpe
state_betas(num_states, kLogZeroDouble),
state_betas_mpe(num_states, 0); //backward variable for mpe
state_alphas[0] = 0.0;
state_alphas_mpe[0] = 0.0;
double tot_forward_prob = kLogZeroDouble;
double tot_forward_score = 0;
vector< pair<double, double> > state_alphas(num_states,
std::make_pair(kLogZeroDouble, 0)),
state_betas(num_states, std::make_pair(kLogZeroDouble, 0));
state_alphas[0].first = 0.0;
pair<double, double> tot_forward_score = std::make_pair(kLogZeroDouble, 0.0);
// Forward pass
//First Pass Forward,
for (int32 state = 0; state < num_states; ++state) {
int32 cur_time = state_times[state];
active_states[cur_time].push_back(state);
if (lat.Final(state) != LatticeWeight::Zero()) { // Check if final state.
BaseFloat final_loglike = -(lat.Final(state).Value1() + lat.Final(state).Value2());
state_betas[state] = final_loglike;
tot_forward_prob = LogAdd(tot_forward_prob, state_alphas[state] + final_loglike);
} else {
ForwardNode(lat, state, &state_alphas);
}
}
//Second Pass Forward, calculate forward for MPE,
for (int32 state = 0; state < num_states; ++state) {
int32 cur_time = state_times[state];
if (lat.Final(state) != LatticeWeight::Zero()) { // Check if final state.
state_betas[state] = std::make_pair(state_alphas[state].first, 0.0);
tot_forward_score.first = LogAdd(tot_forward_score.first,
state_alphas[state].first);
tot_forward_score.second += state_alphas[state].second;
tot_forward_score += state_alphas_mpe[state];
} else {
ForwardNodeMpe(lat, trans, state, cur_time, arc_accs, &state_alphas);
ForwardNodeMpe(lat, trans, state, cur_time, arc_accs,
state_alphas, &state_alphas_mpe);
}
}
// Backward pass and collect posteriors
vector< map<int32, double> > tmp_arc_post_pos(max_time),
tmp_arc_post_neg(max_time);
for (int32 state = num_states -1; state > 0; --state) {
//First Pass Backward,
vector< map<int32, double> > tmp_arc_post(max_time);
for (int32 state = num_states - 1; state > 0; --state) {
int32 cur_time = state_times[state];
BackwardNodeMpe(lat, trans, state, cur_time, tot_forward_score,
active_states, state_alphas, &state_betas);
BackwardNode(lat, state, cur_time, tot_forward_prob, active_states,
state_alphas, &state_betas, &tmp_arc_post[cur_time - 1]);
}
//First Pass Forward Backward check
double tot_backward_prob = state_betas[0]; // Initial state id == 0
if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-9)) {
KALDI_ERR << "Total forward probability over lattice = " << tot_forward_prob
<< ", while total backward probability = " << tot_backward_prob;
}
//Second Pass Backward, collect Mpe style posteriors
vector< map<int32, double> > tmp_arc_post_mpe(max_time);
for (int32 state = num_states - 1; state > 0; --state) {
int32 cur_time = state_times[state];
BackwardNodeMpe(lat, trans, state, cur_time, tot_forward_prob,
tot_forward_score, active_states, arc_accs, state_alphas,
state_alphas_mpe, state_betas, &state_betas_mpe,
&tmp_arc_post_mpe[cur_time - 1]);
}
//Second Pass Forward Backward check
double tot_backward_score = state_betas_mpe[0]; // Initial state id == 0
if (!ApproxEqual(tot_forward_score, tot_backward_score, 1e-9)) {
KALDI_ERR << "Total forward score over lattice = " << tot_forward_score
<< ", while total backward probability = " << tot_backward_score;
}
// Output the computed posteriors
arc_post->resize(max_time);
for (int32 cur_time = 0; cur_time < max_time; ++cur_time) {
map<int32, double>::const_iterator post_itr =
tmp_arc_post_mpe[cur_time].begin();
for (; post_itr != tmp_arc_post_mpe[cur_time].end(); ++post_itr) {
(*arc_post)[cur_time].push_back(std::make_pair(post_itr->first,
post_itr->second));
}
}
// double tot_backward_prob = state_betas[0]; // Initial state id == 0
// if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-9)) {
// KALDI_ERR << "Total forward probability over lattice = " << tot_forward_prob
// << ", while total backward probability = " << tot_backward_prob;
// }
// // Output the computed posteriors
// arc_post->resize(max_time);
// for (int32 cur_time = 0; cur_time < max_time; ++cur_time) {
// map<int32, double>::const_iterator post_itr =
// tmp_arc_post[cur_time].begin();
// for (; post_itr != tmp_arc_post[cur_time].end(); ++post_itr) {
// (*arc_post)[cur_time].push_back(std::make_pair(post_itr->first,
// post_itr->second));
// }
// }
return tot_forward_score.second;
return tot_forward_score;
}
......@@ -449,41 +486,103 @@ void BackwardNode(const Lattice &lat, int32 state, int32 cur_time,
}
}
// static
void ForwardNodeMpe(const Lattice &lat, const TransitionModel &tr,
void ForwardNodeMpe(const Lattice &lat, const TransitionModel &trans,
int32 state, int32 cur_time,
const vector< map<int32, char> > &arc_accs,
vector< pair<double, double> > *state_alphas) {
const vector<double> state_alphas,
vector<double> *state_alphas_mpe) {
for (fst::ArcIterator<Lattice> aiter(lat, state); !aiter.Done();
aiter.Next()) {
const LatticeArc& arc = aiter.Value();
double graph_score = arc.weight.Value1(),
am_score = arc.weight.Value2(),
arc_loglike = (*state_alphas)[state].first - am_score - graph_score;
(*state_alphas)[arc.nextstate].first =
LogAdd((*state_alphas)[arc.nextstate].first, arc_loglike);
arc_loglike = am_score + graph_score;
double frame_acc = 0.0;
if (arc.ilabel != 0) {
int32 phone = tr.TransitionIdToPhone(arc.ilabel);
int32 phone = trans.TransitionIdToPhone(arc.ilabel);
frame_acc = (arc_accs[cur_time].find(phone) == arc_accs[cur_time].end())?
0.0 : 1.0;
}
(*state_alphas)[arc.nextstate].second += ((*state_alphas)[state].second
+ frame_acc);
double arc_scale = std::exp(state_alphas[state] - arc_loglike
- state_alphas[arc.nextstate]);
(*state_alphas_mpe)[arc.nextstate] += arc_scale * ((*state_alphas_mpe)[state]
+ frame_acc);
}
}
//The "posteriors" this function collect is the regular one scaled
//by the Mpe phone arc accuracy differentiation, which could be
//postive or negative
//static
void BackwardNodeMpe(const Lattice &lat, const TransitionModel &trans,
int32 state, int32 cur_time,
pair<double, double> tot_forward_score,
double tot_forward_prob, double tot_forward_score,
const vector< vector<int32> > &active_states,
const vector< pair<double, double> > &state_alphas,
vector< pair<double, double> > *state_betas) {
const vector< map<int32, char> > &arc_accs,
const vector<double> &state_alphas,
const vector<double> &state_alphas_mpe,
const vector<double> &state_betas,
vector<double> *state_betas_mpe,
map<int32, double> *post) {
// Epsilon arcs leading into the state
for (vector<int32>::const_iterator st_it = active_states[cur_time].begin();
st_it != active_states[cur_time].end(); ++st_it) {
if ((*st_it) < state) {
for (fst::ArcIterator<Lattice> aiter(lat, (*st_it)); !aiter.Done();
aiter.Next()) {
const LatticeArc& arc = aiter.Value();
if (arc.nextstate == state) {
KALDI_ASSERT(arc.ilabel == 0);
double arc_loglike = arc.weight.Value1() + arc.weight.Value2();
double arc_scale = std::exp(state_betas[state] - arc_loglike
- state_betas[*(st_it)]);
(*state_betas_mpe)[(*st_it)] += arc_scale * (*state_betas_mpe)[state];
}
}
}
}
if (cur_time == 0) return;
// Non-epsilon arcs leading into the state
int32 prev_time = cur_time - 1;
for (vector<int32>::const_iterator st_it = active_states[prev_time].begin();
st_it != active_states[prev_time].end(); ++st_it) {
for (fst::ArcIterator<Lattice> aiter(lat, (*st_it)); !aiter.Done();
aiter.Next()) {
const LatticeArc& arc = aiter.Value();
if (arc.nextstate == state) {
int32 key = arc.ilabel;
KALDI_ASSERT(key != 0);
double graph_score = arc.weight.Value1(),
am_score = arc.weight.Value2(),
arc_loglike = graph_score + am_score;
double gamma = std::exp(state_alphas[(*st_it)] - graph_score - am_score
+ state_betas[state] - tot_forward_prob);
//calculate Mpe phone acc differentiation
int32 phone = trans.TransitionIdToPhone(arc.ilabel);
//note: should be prev_time here ?
double frame_acc = (arc_accs[prev_time].find(phone) == arc_accs[prev_time].end())?
0.0 : 1.0;
double arc_scale = std::exp(state_betas[state] - arc_loglike
- state_betas[*(st_it)]);
(*state_betas_mpe)[(*st_it)] += arc_scale * ((*state_betas_mpe)[state]
+ frame_acc);
double acc_diff = state_alphas_mpe[(*st_it)] + frame_acc
+ (*state_betas_mpe)[state] - tot_forward_score;
if (post->find(key) == post->end()) // New label found at prev_time
(*post)[key] = gamma * acc_diff;
else // Arc label already seen at this time
(*post)[key] += gamma * acc_diff;
}
}
}
}
} // namespace kaldi
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment