Newer
Older
Dan Povey
committed
// lat/lattice-functions.cc
// Copyright 2009-2011 Saarland University 2012 Daniel Povey
// Authors: Arnab Ghoshal Daniel Povey Chao Weng
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
Dan Povey
committed
#include "lat/lattice-functions.h"
#include "hmm/transition-model.h"
#include "util/stl-utils.h"
namespace kaldi {
int32 LatticeStateTimes(const Lattice &lat, vector<int32> *times) {
kaldi::uint64 props = lat.Properties(fst::kFstProperties, false);
if (!(props & fst::kTopSorted))
KALDI_ERR << "Input lattice must be topologically sorted.";
KALDI_ASSERT(lat.Start() == 0);
int32 num_states = lat.NumStates();
times->resize(num_states, -1);
(*times)[0] = 0;
for (int32 state = 0; state < num_states; ++state) {
int32 cur_time = (*times)[state];
for (fst::ArcIterator<Lattice> aiter(lat, state); !aiter.Done();
aiter.Next()) {
const LatticeArc& arc = aiter.Value();
if (arc.ilabel != 0) { // Non-epsilon input label on arc
// next time instance
if ((*times)[arc.nextstate] == -1) {
(*times)[arc.nextstate] = cur_time + 1;
} else {
KALDI_ASSERT((*times)[arc.nextstate] == cur_time + 1);
}
} else { // epsilon input label on arc
// Same time instance
if ((*times)[arc.nextstate] == -1)
(*times)[arc.nextstate] = cur_time;
else
KALDI_ASSERT((*times)[arc.nextstate] == cur_time);
}
}
}
return (*std::max_element(times->begin(), times->end()));
}
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
int32 CompactLatticeStateTimes(const CompactLattice &lat, vector<int32> *times) {
kaldi::uint64 props = lat.Properties(fst::kFstProperties, false);
if (!(props & fst::kTopSorted))
KALDI_ERR << "Input lattice must be topologically sorted.";
KALDI_ASSERT(lat.Start() == 0);
int32 num_states = lat.NumStates();
times->clear();
times->resize(num_states, -1);
(*times)[0] = 0;
int32 utt_len = -1;
for (int32 state = 0; state < num_states; ++state) {
int32 cur_time = (*times)[state];
for (fst::ArcIterator<CompactLattice> aiter(lat, state); !aiter.Done();
aiter.Next()) {
const CompactLatticeArc& arc = aiter.Value();
int32 arc_len = static_cast<int32>(arc.weight.String().size());
if ((*times)[arc.nextstate] == -1)
(*times)[arc.nextstate] = cur_time + arc_len;
else
KALDI_ASSERT((*times)[arc.nextstate] == cur_time + arc_len);
}
if (lat.Final(state) != CompactLatticeWeight::Zero()) {
int32 this_utt_len = (*times)[state] + lat.Final(state).String().size();
if (utt_len == -1) utt_len = this_utt_len;
else {
if (this_utt_len != utt_len) {
KALDI_WARN << "Utterance does not "
"seem to have a consistent length.";
utt_len = std::max(utt_len, this_utt_len);
}
}
}
}
if (utt_len == -1) {
KALDI_WARN << "Utterance does not have a final-state.";
return 0;
}
return utt_len;
}
// Helper functions for lattice forward-backward
static void ForwardNode(const Lattice &lat, int32 state,
vector<double> *state_alphas);
static void BackwardNode(const Lattice &lat, int32 state, int32 cur_time,
double tot_forward_prob,
const vector< vector<int32> > &active_states,
const vector<double> &state_alphas,
vector<double> *state_betas,
BaseFloat LatticeForwardBackward(const Lattice &lat, Posterior *arc_post) {
// Make sure the lattice is topologically sorted.
kaldi::uint64 props = lat.Properties(fst::kFstProperties, false);
if (!(props & fst::kTopSorted))
KALDI_ERR << "Input lattice must be topologically sorted.";
KALDI_ASSERT(lat.Start() == 0);
int32 num_states = lat.NumStates();
vector<int32> state_times;
int32 max_time = LatticeStateTimes(lat, &state_times);
vector< vector<int32> > active_states(max_time + 1);
// the +1 is needed since time is indexed from 0.
vector<double> state_alphas(num_states, kLogZeroDouble),
state_betas(num_states, kLogZeroDouble);
state_alphas[0] = 0.0;
double tot_forward_prob = kLogZeroDouble;
// Forward pass
for (int32 state = 0; state < num_states; ++state) {
int32 cur_time = state_times[state];
active_states[cur_time].push_back(state);
if (lat.Final(state) != LatticeWeight::Zero()) { // Check if final state.
BaseFloat final_loglike = -(lat.Final(state).Value1() + lat.Final(state).Value2());
state_betas[state] = final_loglike;
tot_forward_prob = LogAdd(tot_forward_prob, state_alphas[state] + final_loglike);
} else {
ForwardNode(lat, state, &state_alphas);
}
}
// Backward pass and collect posteriors
vector< map<int32, double> > tmp_arc_post(max_time);
for (int32 state = num_states - 1; state > 0; --state) {
int32 cur_time = state_times[state];
BackwardNode(lat, state, cur_time, tot_forward_prob, active_states,
state_alphas, &state_betas, &tmp_arc_post[cur_time - 1]);
}
double tot_backward_prob = state_betas[0]; // Initial state id == 0
if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-9)) {
KALDI_ERR << "Total forward probability over lattice = " << tot_forward_prob
<< ", while total backward probability = " << tot_backward_prob;
}
// Output the computed posteriors
arc_post->resize(max_time);
for (int32 cur_time = 0; cur_time < max_time; ++cur_time) {
map<int32, double>::const_iterator post_itr =
tmp_arc_post[cur_time].begin();
for (; post_itr != tmp_arc_post[cur_time].end(); ++post_itr) {
(*arc_post)[cur_time].push_back(std::make_pair(post_itr->first,
post_itr->second));
}
}
return tot_forward_prob;
}
void LatticeActivePhones(const Lattice &lat, const TransitionModel &trans,
const vector<int32> &silence_phones,
vector< std::set<int32> > *active_phones) {
KALDI_ASSERT(IsSortedAndUniq(silence_phones));
vector<int32> state_times;
int32 num_states = lat.NumStates();
int32 max_time = LatticeStateTimes(lat, &state_times);
active_phones->resize(max_time);
for (int32 state = 0; state < num_states; ++state) {
int32 cur_time = state_times[state];
for (fst::ArcIterator<Lattice> aiter(lat, state); !aiter.Done();
aiter.Next()) {
const LatticeArc& arc = aiter.Value();
if (arc.ilabel != 0) { // Non-epsilon arc
int32 phone = trans.TransitionIdToPhone(arc.ilabel);
if (!std::binary_search(silence_phones.begin(),
silence_phones.end(), phone))
(*active_phones)[cur_time].insert(phone);
}
} // end looping over arcs
} // end looping over states
}
void ConvertLatticeToPhones(const TransitionModel &trans,
Lattice *lat) {
typedef LatticeArc Arc;
int32 num_states = lat->NumStates();
for (int32 state = 0; state < num_states; ++state) {
for (fst::MutableArcIterator<Lattice> aiter(lat, state); !aiter.Done();
aiter.Next()) {
Arc arc(aiter.Value());
arc.olabel = 0; // remove any word.
if (arc.ilabel != 0 // has a transition-id on input..
&& trans.IsFinal(arc.ilabel)) // there is one of these per phone...
arc.olabel = trans.TransitionIdToPhone(arc.ilabel);
aiter.SetValue(arc);
} // end looping over arcs
} // end looping over states
}
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
bool LatticeBoost(const TransitionModel &trans,
const std::vector<std::set<int32> > &active_phones,
const std::vector<int32> &silence_phones,
BaseFloat b,
BaseFloat max_silence_error,
Lattice *lat) {
kaldi::uint64 props = lat->Properties(fst::kFstProperties, false);
if (!(props & fst::kTopSorted)) {
if (fst::TopSort(lat) == false) {
KALDI_WARN << "Cycles detected in lattice";
return false;
}
}
KALDI_ASSERT(IsSortedAndUniq(silence_phones));
KALDI_ASSERT(max_silence_error >= 0.0 && max_silence_error <= 1.0);
vector<int32> state_times;
int32 num_states = lat->NumStates();
LatticeStateTimes(*lat, &state_times);
for (int32 state = 0; state < num_states; ++state) {
int32 cur_time = state_times[state];
if (cur_time < 0 || cur_time > active_phones.size()) {
KALDI_WARN << "Lattice is too long for active_phones: mismatched den and num lattices/alignments?";
return false;
}
for (fst::MutableArcIterator<Lattice> aiter(lat, state); !aiter.Done();
aiter.Next()) {
LatticeArc arc = aiter.Value();
if (arc.ilabel != 0) { // Non-epsilon arc
if (arc.ilabel < 0 || arc.ilabel > trans.NumTransitionIds()) {
KALDI_WARN << "Lattice has out-of-range transition-ids: lattice/model mismatch?";
return false;
}
int32 phone = trans.TransitionIdToPhone(arc.ilabel);
BaseFloat frame_error;
if (active_phones[cur_time].count(phone) == 1) {
frame_error = 0.0;
} else { // an error...
if (std::binary_search(silence_phones.begin(), silence_phones.end(), phone))
frame_error = max_silence_error;
else
frame_error = 1.0;
}
BaseFloat delta_cost = -b * frame_error; // negative cost if
// frame is wrong, to boost likelihood of arcs with errors on them.
// Add this cost to the graph part.
arc.weight.SetValue1(arc.weight.Value1() + delta_cost);
aiter.SetValue(arc);
}
}
}
return true;
}
int32 LatticePhoneFrameAccuracy(const Lattice &hyp, const TransitionModel &trans,
const vector< map<int32, int32> > &ref_phones,
vector< map<int32, char> > *arc_accs,
vector<int32> *state_times) {
vector<int32> state_times_hyp;
int32 max_time_hyp = LatticeStateTimes(hyp, &state_times_hyp),
max_time_ref = ref_phones.size();
if (max_time_ref != max_time_hyp) {
KALDI_ERR << "Reference and hypothesis lattices must have same numbers of "
<< "frames. Found " << max_time_ref << " in ref and "
<< max_time_hyp << " in hyp.";
}
int32 num_states_hyp = hyp.NumStates();
for (int32 state = 0; state < num_states_hyp; ++state) {
int32 cur_time = state_times_hyp[state];
for (fst::ArcIterator<Lattice> aiter(hyp, state); !aiter.Done();
aiter.Next()) {
const LatticeArc& arc = aiter.Value();
if (arc.ilabel != 0) { // Non-epsilon arc
int32 phone = trans.TransitionIdToPhone(arc.ilabel);
(*arc_accs)[cur_time][phone] =
(ref_phones[cur_time].find(phone) == ref_phones[cur_time].end())?
0 : 1;
}
} // end looping over arcs
} // end looping over states
if (state_times != NULL)
(*state_times) = state_times_hyp;
return max_time_hyp;
// Helper functions for MPE lattice forward-backward
static void ForwardNodeMpe(const Lattice &lat, const TransitionModel &trans,
int32 state, int32 cur_time,
const vector< map<int32, char> > &arc_accs,
const vector<double> state_alphas,
vector<double> *state_alphas_mpe);
static void BackwardNodeMpe(const Lattice &lat, const TransitionModel &trans,
int32 state, int32 cur_time,
double tot_forward_prob, double tot_forward_score,
const vector< vector<int32> > &active_states,
const vector< map<int32, char> > &arc_accs,
const vector<double> &state_alphas,
const vector<double> &state_alphas_mpe,
const vector<double> &state_betas,
vector<double> *state_betas_mpe,
map<int32, double> *post);
BaseFloat LatticeForwardBackwardMpe(const Lattice &lat,
const TransitionModel &trans,
const vector< map<int32, char> > &arc_accs,
Posterior *arc_post) {
// Make sure the lattice is topologically sorted.
kaldi::uint64 props = lat.Properties(fst::kFstProperties, false);
if (!(props & fst::kTopSorted))
KALDI_ERR << "Input lattice must be topologically sorted.";
KALDI_ASSERT(lat.Start() == 0);
int32 num_states = lat.NumStates();
vector<int32> state_times;
int32 max_time = LatticeStateTimes(lat, &state_times);
vector< vector<int32> > active_states(max_time + 1);
// the +1 is needed since time is indexed from 0
vector<double> state_alphas(num_states, kLogZeroDouble),
state_alphas_mpe(num_states, 0), //forward variable for mpe
state_betas(num_states, kLogZeroDouble),
state_betas_mpe(num_states, 0); //backward variable for mpe
state_alphas[0] = 0.0;
state_alphas_mpe[0] = 0.0;
double tot_forward_prob = kLogZeroDouble;
double tot_forward_score = 0;
for (int32 state = 0; state < num_states; ++state) {
int32 cur_time = state_times[state];
active_states[cur_time].push_back(state);
if (lat.Final(state) != LatticeWeight::Zero()) { // Check if final state.
BaseFloat final_loglike = -(lat.Final(state).Value1() + lat.Final(state).Value2());
state_betas[state] = final_loglike;
tot_forward_prob = LogAdd(tot_forward_prob, state_alphas[state] + final_loglike);
} else {
ForwardNode(lat, state, &state_alphas);
}
}
//Second Pass Forward, calculate forward for MPE,
for (int32 state = 0; state < num_states; ++state) {
int32 cur_time = state_times[state];
if (lat.Final(state) != LatticeWeight::Zero()) { // Check if final state.
tot_forward_score += state_alphas_mpe[state];
ForwardNodeMpe(lat, trans, state, cur_time, arc_accs,
state_alphas, &state_alphas_mpe);
//First Pass Backward,
vector< map<int32, double> > tmp_arc_post(max_time);
for (int32 state = num_states - 1; state > 0; --state) {
int32 cur_time = state_times[state];
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
BackwardNode(lat, state, cur_time, tot_forward_prob, active_states,
state_alphas, &state_betas, &tmp_arc_post[cur_time - 1]);
}
//First Pass Forward Backward check
double tot_backward_prob = state_betas[0]; // Initial state id == 0
if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-9)) {
KALDI_ERR << "Total forward probability over lattice = " << tot_forward_prob
<< ", while total backward probability = " << tot_backward_prob;
}
//Second Pass Backward, collect Mpe style posteriors
vector< map<int32, double> > tmp_arc_post_mpe(max_time);
for (int32 state = num_states - 1; state > 0; --state) {
int32 cur_time = state_times[state];
BackwardNodeMpe(lat, trans, state, cur_time, tot_forward_prob,
tot_forward_score, active_states, arc_accs, state_alphas,
state_alphas_mpe, state_betas, &state_betas_mpe,
&tmp_arc_post_mpe[cur_time - 1]);
}
//Second Pass Forward Backward check
double tot_backward_score = state_betas_mpe[0]; // Initial state id == 0
if (!ApproxEqual(tot_forward_score, tot_backward_score, 1e-9)) {
KALDI_ERR << "Total forward score over lattice = " << tot_forward_score
<< ", while total backward probability = " << tot_backward_score;
}
// Output the computed posteriors
arc_post->resize(max_time);
for (int32 cur_time = 0; cur_time < max_time; ++cur_time) {
map<int32, double>::const_iterator post_itr =
tmp_arc_post_mpe[cur_time].begin();
for (; post_itr != tmp_arc_post_mpe[cur_time].end(); ++post_itr) {
(*arc_post)[cur_time].push_back(std::make_pair(post_itr->first,
post_itr->second));
}
// ----------------------- Helper function definitions -----------------------
// static
void ForwardNode(const Lattice &lat, int32 state,
vector<double> *state_alphas) {
for (fst::ArcIterator<Lattice> aiter(lat, state); !aiter.Done();
aiter.Next()) {
const LatticeArc& arc = aiter.Value();
double graph_score = arc.weight.Value1(),
am_score = arc.weight.Value2(),
arc_loglike = (*state_alphas)[state] - am_score - graph_score;
(*state_alphas)[arc.nextstate] = LogAdd((*state_alphas)[arc.nextstate],
}
}
// static
void BackwardNode(const Lattice &lat, int32 state, int32 cur_time,
double tot_forward_prob,
const vector< vector<int32> > &active_states,
const vector<double> &state_alphas,
vector<double> *state_betas,
// Epsilon arcs leading into the state
for (vector<int32>::const_iterator st_it = active_states[cur_time].begin();
st_it != active_states[cur_time].end(); ++st_it) {
if ((*st_it) < state) {
for (fst::ArcIterator<Lattice> aiter(lat, (*st_it)); !aiter.Done();
aiter.Next()) {
const LatticeArc& arc = aiter.Value();
if (arc.nextstate == state) {
KALDI_ASSERT(arc.ilabel == 0);
double arc_loglike = (*state_betas)[state] - arc.weight.Value1()
- arc.weight.Value2();
(*state_betas)[(*st_it)] = LogAdd((*state_betas)[(*st_it)],
}
}
}
}
if (cur_time == 0) return;
// Non-epsilon arcs leading into the state
int32 prev_time = cur_time - 1;
for (vector<int32>::const_iterator st_it = active_states[prev_time].begin();
st_it != active_states[prev_time].end(); ++st_it) {
for (fst::ArcIterator<Lattice> aiter(lat, (*st_it)); !aiter.Done();
aiter.Next()) {
const LatticeArc& arc = aiter.Value();
if (arc.nextstate == state) {
int32 key = arc.ilabel;
KALDI_ASSERT(key != 0);
double graph_score = arc.weight.Value1(),
am_score = arc.weight.Value2(),
arc_loglike = (*state_betas)[state] - graph_score - am_score;
(*state_betas)[(*st_it)] = LogAdd((*state_betas)[(*st_it)],
double gamma = std::exp(state_alphas[(*st_it)] - graph_score - am_score
+ (*state_betas)[state] - tot_forward_prob);
if (post->find(key) == post->end()) // New label found at prev_time
(*post)[key] = gamma;
else // Arc label already seen at this time
(*post)[key] += gamma;
}
}
}
}
void ForwardNodeMpe(const Lattice &lat, const TransitionModel &trans,
int32 state, int32 cur_time,
const vector< map<int32, char> > &arc_accs,
const vector<double> state_alphas,
vector<double> *state_alphas_mpe) {
for (fst::ArcIterator<Lattice> aiter(lat, state); !aiter.Done();
aiter.Next()) {
const LatticeArc& arc = aiter.Value();
double graph_score = arc.weight.Value1(),
am_score = arc.weight.Value2(),
arc_loglike = am_score + graph_score;
double frame_acc = 0.0;
if (arc.ilabel != 0) {
int32 phone = trans.TransitionIdToPhone(arc.ilabel);
frame_acc = (arc_accs[cur_time].find(phone) == arc_accs[cur_time].end())?
0.0 : 1.0;
}
double arc_scale = std::exp(state_alphas[state] - arc_loglike
- state_alphas[arc.nextstate]);
(*state_alphas_mpe)[arc.nextstate] += arc_scale * ((*state_alphas_mpe)[state]
+ frame_acc);
//The "posteriors" this function collect is the regular one scaled
//by the Mpe phone arc accuracy differentiation, which could be
//postive or negative
//static
void BackwardNodeMpe(const Lattice &lat, const TransitionModel &trans,
int32 state, int32 cur_time,
double tot_forward_prob, double tot_forward_score,
const vector< vector<int32> > &active_states,
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
const vector< map<int32, char> > &arc_accs,
const vector<double> &state_alphas,
const vector<double> &state_alphas_mpe,
const vector<double> &state_betas,
vector<double> *state_betas_mpe,
map<int32, double> *post) {
// Epsilon arcs leading into the state
for (vector<int32>::const_iterator st_it = active_states[cur_time].begin();
st_it != active_states[cur_time].end(); ++st_it) {
if ((*st_it) < state) {
for (fst::ArcIterator<Lattice> aiter(lat, (*st_it)); !aiter.Done();
aiter.Next()) {
const LatticeArc& arc = aiter.Value();
if (arc.nextstate == state) {
KALDI_ASSERT(arc.ilabel == 0);
double arc_loglike = arc.weight.Value1() + arc.weight.Value2();
double arc_scale = std::exp(state_betas[state] - arc_loglike
- state_betas[*(st_it)]);
(*state_betas_mpe)[(*st_it)] += arc_scale * (*state_betas_mpe)[state];
}
}
}
}
if (cur_time == 0) return;
// Non-epsilon arcs leading into the state
int32 prev_time = cur_time - 1;
for (vector<int32>::const_iterator st_it = active_states[prev_time].begin();
st_it != active_states[prev_time].end(); ++st_it) {
for (fst::ArcIterator<Lattice> aiter(lat, (*st_it)); !aiter.Done();
aiter.Next()) {
const LatticeArc& arc = aiter.Value();
if (arc.nextstate == state) {
int32 key = arc.ilabel;
KALDI_ASSERT(key != 0);
double graph_score = arc.weight.Value1(),
am_score = arc.weight.Value2(),
arc_loglike = graph_score + am_score;
double gamma = std::exp(state_alphas[(*st_it)] - graph_score - am_score
+ state_betas[state] - tot_forward_prob);
//calculate Mpe phone acc differentiation
int32 phone = trans.TransitionIdToPhone(arc.ilabel);
//note: should be prev_time here ?
double frame_acc = (arc_accs[prev_time].find(phone) == arc_accs[prev_time].end())?
0.0 : 1.0;
double arc_scale = std::exp(state_betas[state] - arc_loglike
- state_betas[*(st_it)]);
(*state_betas_mpe)[(*st_it)] += arc_scale * ((*state_betas_mpe)[state]
+ frame_acc);
double acc_diff = state_alphas_mpe[(*st_it)] + frame_acc
+ (*state_betas_mpe)[state] - tot_forward_score;
if (post->find(key) == post->end()) // New label found at prev_time
(*post)[key] = gamma * acc_diff;
else // Arc label already seen at this time
(*post)[key] += gamma * acc_diff;
}
}
}
} // namespace kaldi