Skip to content
Snippets Groups Projects
Commit fe3a698d authored by Dan Povey's avatar Dan Povey
Browse files

Various changes to CCTC training code

parent d765fb3b
No related branches found
No related tags found
No related merge requests found
......@@ -10,7 +10,6 @@ TESTFILES = language-model-test cctc-transition-model-test
OBJFILES = language-model.o cctc-transition-model.o cctc-graph.o \
ctc-supervision.o cctc-training.o
# ctc-functions.o
LIBNAME = kaldi-ctc
......
......@@ -23,35 +23,155 @@
namespace kaldi {
namespace ctc {
void CctcTrainer::CheckDims() const {
KALDI_ASSERT(weights_.NumRows() == trans_model_.NumHistoryStates() &&
weights_.NumCols() == trans_model_.NumIndexes());
CctcComputation::CctcComputation(
const CctcTrainingOptions &opts,
const CctcTransitionModel &trans_model,
const CuMatrix<BaseFloat> &cu_weights,
const CtcSupervision &supervision,
const CuMatrixBase<BaseFloat> &nnet_output):
opts_(opts), trans_model_(trans_model), cu_weights_(cu_weights),
supervision_(supervision), nnet_output_(nnet_output) {
CheckDims();
}
void CctcComputation::CheckDims() const {
KALDI_ASSERT(cu_weights_.NumRows() == trans_model_.NumHistoryStates() &&
cu_weights_.NumCols() == trans_model_.NumOutputIndexes());
KALDI_ASSERT(nnet_output_.NumRows() == supervision_.num_frames &&
nnet_output_.NumCols() == trans_model_.NumIndexes());
KALDI_ASSERT(supervision_.label_dim == trans_model_.NumIndexes());
nnet_output_.NumCols() == trans_model_.NumOutputIndexes());
KALDI_ASSERT(supervision_.label_dim == trans_model_.NumOutputIndexes());
}
void CctcTraining::Forward() {
CheckDims();
void CctcComputation::ComputeLookupIndexes() {
std::vector<int32> fst_state_times;
ComputeFstStateTimes(supervision_.fst, &fst_state_times);
int32 num_states = supervision_.fst.NumStates();
int32 num_arcs_guess = num_states * 2;
// arc_probs_temp will store the language-model probabilities for each arc.
std::vector<BaseFloat> arc_probs_temp;
fst_indexes_.reserve(num_arcs_guess);
arc_probs_temp.reserve(num_arcs_guess);
int32 cur_time = 0;
// the following are CPU versions of numerator_indexes_ and
// denominator_indexes_. numerator_indexes_cpu is a list of pairs (t,
// output-index) and denominator_indexes_cpu is a list of pairs (c,
// history-state-index).
std::vector<Int32Pair> numerator_indexes_cpu, denominator_indexes_cpu;
// numerator_index_map_this_frame is a map, only valid for t == cur_time,
// from the output-index to the index into numerator_indexes_cpu for the
// p air (cur_time, output-index).
unordered_map<int32,int32> numerator_index_map_this_frame;
// denoninator_index_map_this_frame is a map, only valid for t == cur_time,
// from the output-index to the index into numerator_indexes_cpu for the
// p air (cur_time, output-index).
unordered_map<int32,int32> denominator_index_map_this_frame;
typedef unordered_map<int32,int32>::iterator IterType;
for (int32 state = 0; state < num_states; state++) {
int32 t = fst_state_times[state];
if (t != cur_time) {
KALDI_ASSERT(t == cur_time + 1);
numerator_index_map_this_frame.clear();
denominator_index_map_this_frame.clear();
cur_time = t;
}
for (fst::ArcIterator<fst::StdVectorFst> aiter(supervision_.fst, state);
!aiter.Done(); aiter.Next()) {
int32 graph_label = aiter.Value().ilabel,
output_index = trans_model_.GraphLabelToOutputIndex(graph_label),
history_state = trans_model_.GraphLabelToHistoryState(graph_label);
int32 numerator_index = numerator_indexes_cpu.size(),
denominator_index = denominator_indexes_cpu.size();
Int32Pair num_pair, den_pair; // can't use constructors as declared in C.
num_pair.first = t;
num_pair.second = output_index;
den_pair.first = t;
den_pair.second = history_state;
// the next few lines are a more efficient way of doing the following:
// if (numerator_index_map_this_frame.count(output_index) == 0) {
// numerator_index_map_this_frame[output_index] = numerator_index;
// else
// numerator_index = numerator_index_map_this_frame[output_index];
std::pair<IterType,bool> p = numerator_index_map_this_frame.insert(
std::pair<const int32, int32>(output_index, numerator_index));
if (p.second) { // Was inserted -> map had no key 'output_index'
numerator_indexes_cpu.push_back(num_pair);
} else { // was not inserted -> set numerator_index to the existing index.
numerator_index = p.first->second;
KALDI_PARANOID_ASSERT(numerator_indexes_cpu[numerator_index] ==
num_pair);
}
// the next few lines are a more efficient way of doing the following:
// if (denominator_index_map_this_frame.count(history_state) == 0) {
// denominator_index_map_this_frame[history_state] = denominator_index;
// else
// denominator_index = denominator_index_map_this_frame[history_state];
p = denominator_index_map_this_frame.insert(
std::pair<const int32, int32>(history_state, denominator_index));
if (p.second) { // Was inserted -> map had no key 'history_state'
denominator_indexes_cpu.push_back(den_pair);
} else { // was not inserted -> set denominator_index to the existing index.
denominator_index = p.first->second;
KALDI_PARANOID_ASSERT(denominator_indexes_cpu[denominator_index] ==
den_pair);
}
fst_indexes_.push_back(std::pair<int32,int32>(numerator_index,
denominator_index));
arc_probs_temp.push_back(trans_model_.GraphLabelToLmProb(graph_label));
}
}
numerator_indexes_ = numerator_indexes_cpu;
denominator_indexes_ = denominator_indexes_cpu;
int32 num_arcs = fst_indexes_.size();
KALDI_ASSERT(num_arcs > 0);
arc_probs_.Resize(num_arcs);
memcpy(static_cast<void*>(arc_probs_.Data()),
static_cast<void*>(&(arc_probs_temp[0])),
num_arcs * sizeof(BaseFloat));
}
BaseFloat CctcComputation::Forward() {
ComputeLookupIndexes();
exp_nnet_output_ = nnet_output_;
exp_nnet_output_.ApplyExp();
normalizers_.Resize(exp_nnet_output_.NumRows(),
trans_model_.NumHistoryStates());
normalizers_.AddMatMat(1.0, exp_nnet_output_, kNoTrans, weights_, kTrans);
normalizers_.AddMatMat(1.0, exp_nnet_output_, kNoTrans, cu_weights_, kTrans,
0.0);
LookUpLikelihoods();
ComputeAlphas();
ComputeAlpha();
return tot_log_prob_;
}
void CctcComputation::LookUpLikelihoods() {
numerator_probs_.Resize(numerator_indexes_.Dim(), kUndefined);
exp_nnet_output_.Lookup(numerator_indexes_, numerator_probs_.Data());
denominator_probs_.Resize(denominator_indexes_.Dim(), kUndefined);
normalizers_.Lookup(denominator_indexes_, denominator_probs_.Data());
// Note: at this point, arc_probs_ contains the phone language model
// probabilities.
BaseFloat *arc_prob_data = arc_probs_.Data();
const BaseFloat *numerator_prob_data = numerator_probs_.Data(),
*denominator_prob_data = denominator_probs_.Data();
std::vector<std::pair<int32,int32> >::const_iterator
iter = fst_indexes_.begin(), end = fst_indexes_.end();
for (; iter != end; ++iter, ++arc_prob_data)
*arc_prob_data *= numerator_prob_data[iter->first] /
denominator_prob_data[iter->second];
}
bool CctcTraining::Backward(CuMatrixBase<BaseFloat> *nnet_output_deriv) {
bool CctcComputation::Backward(CuMatrixBase<BaseFloat> *nnet_output_deriv) {
ComputeBeta();
return ComputeDerivatives(nnet_output_deriv);
}
bool CctcTraining::ComputeDerivatives(
bool CctcComputation::ComputeDerivatives(
CuMatrixBase<BaseFloat> *nnet_output_deriv) {
// we assume nnet_output_deriv is already zeroed; we add to it.
int32 num_states = supervision_.fst.NumStates();
......@@ -65,13 +185,13 @@ bool CctcTraining::ComputeDerivatives(
numerator_probs_.SetZero(); // we'll use this to store derivatives w.r.t. the
// numerator log-prob; these derivatives are just
// sums of occupation counts.
BaseFloat numerator_deriv_data = numerator_probs_.Data();
BaseFloat *numerator_deriv_data = numerator_probs_.Data();
// size and zero denominator_deriv_. It will contain the sum of negated
// occupancies that map to each element of the denominator_indexes_ and
// denominator_prob_ vectors.
denominator_deriv_.Resize(denominator_probs_.Dim());
BaseFloat denominator_deriv_data = denominator_deriv_.Data();
BaseFloat *denominator_deriv_data = denominator_deriv_.Data();
const BaseFloat *arc_prob_data = arc_probs_.Data();
for (int32 state = 0; state < num_states; state++) {
......@@ -81,8 +201,8 @@ bool CctcTraining::ComputeDerivatives(
int32 nextstate = arc.nextstate;
double arc_posterior =
exp(alpha_data[state] + beta_data[nextstate] - tot_log_prob_) *
arc_probs_[arc_index];
KALDI_ASSERT(arc_prob >= 0.0 && arc_prob < 1.1);
arc_prob_data[arc_index];
KALDI_ASSERT(arc_posterior >= 0.0 && arc_posterior < 1.1);
int32 numerator_index = fst_indexes_iter->first,
denominator_index = fst_indexes_iter->second;
// interpret this as d(objf)/d(log of numerator)
......@@ -99,9 +219,32 @@ bool CctcTraining::ComputeDerivatives(
// We will reuse the normalizers_ array to be the derivatives
// w.r.t. the normalizers.
normalizers_.SetZero();
normalizers_.AddElements(1.0, denominator_indexes_,
denominator_deriv_data);
// Even though the next statement adds it with zero coefficient, we need
// to set it to zero to guard against inf's or NaN's.
nnet_output_deriv->SetZero();
// After the following statement, 'nnet_output_deriv' contains the derivative
// with respect to 'exp_nnet_output_', considering only the denominator term.
nnet_output_deriv->AddMatMat(1.0, normalizers_, kNoTrans,
cu_weights_, kNoTrans, 0.0);
// After the following statement, 'nnet_output_deriv' contains the derivative with
// respect to 'nnet_output_', considering only the denominator term.
// we use that y/d(exp x) = exp(x) dy/dx.
nnet_output_deriv->MulElements(exp_nnet_output_);
// After the following statement, 'nnet_output_deriv' should contain the
// entire derivative, also including the numerator term. Note: at this point,
// numerator_probs_ contains summed posteriors, which equal the derivative of
// the likelihood w.r.t. the nnet log output (considering only the numerator
// term).
nnet_output_deriv->AddElements(1.0, numerator_indexes_, numerator_probs_.Data());
BaseFloat sum = nnet_output_deriv->Sum();
return (sum == sum && sum - sum == 0); // check for NaN/inf.
}
......@@ -137,8 +280,6 @@ bool CctcTraining::ComputeDerivatives(
// lm_prob * num / den.
}
} // namespace ctc
} // namespace kaldi
......@@ -32,6 +32,10 @@
#include "lat/kaldi-lattice.h"
#include "matrix/kaldi-matrix.h"
#include "ctc/language-model.h"
#include "ctc/cctc-transition-model.h"
#include "ctc/ctc-supervision.h"
#include "cudamatrix/cu-matrix.h"
#include "cudamatrix/cu-array.h"
namespace kaldi {
namespace ctc {
......@@ -75,9 +79,9 @@ class CctcComputation {
// Does the forward computation. Returns the total log-prob.
BaseFloat Forward();
// Does the backward computation and adds the derivative w.r.t. the neural
// network output to 'nnet_output_deriv' (so you should probably set it to
// zero beforehand).
// Does the backward computation and writes the derivative w.r.t. the neural
// network output to 'nnet_output_deriv' (which does not have to be initialized
// beforehand).
// Returns true if everything was OK (which it should be, normally), and
// false if some kind of NaN or inf was discovered, in which case you
// shouldn't use the derivatives. We're concerned about this because
......@@ -91,10 +95,9 @@ class CctcComputation {
const CctcTransitionModel &trans_model_;
// CUDA copy of trans_model_.Weights(). Dimension is
// trans_model_.NumHistoryStates() by trans_model_.NumOutputIndexes().
const CuMatrix<BaseFloat> &weights_;
const CuMatrix<BaseFloat> &cu_weights_;
// The supervision object
const CtcSupervision &supervision_;
// The neural net output
const CuMatrixBase<BaseFloat> &nnet_output_;
// the exp of the neural net output.
......@@ -122,7 +125,7 @@ class CctcComputation {
// exp_nnet_output_ for the forward-backward computation. The order is not
// important, but indexes into this vector appear in .first members in
// fst_indexes.
std::vector<Int32Pair> numerator_indexes_;
CuArray<Int32Pair> numerator_indexes_;
// the numerator of the probability. in the forward computation,
// numerator_probs_[i] equals exp_nnet_output_(row,column), where (row,column)
// is the i'th member of numerator_indexes. In the backward computation,
......@@ -134,7 +137,7 @@ class CctcComputation {
// normalizers_ for the forward-backward computation. The order is not
// important, but indexes into this vector appear in .second members in
// fst_indexes.
std::vector<Int32Pair> denominator_indexes_;
CuArray<Int32Pair> denominator_indexes_;
// the denominator of the probability. denominator_probs_[i] equals
// exp_nnet_output_(row,column), where (row,column) is the i'th member of
// denominator_indexes.
......@@ -163,7 +166,7 @@ class CctcComputation {
// numerator_indexes_ and denominator_indexes_.
void ComputeLookupIndexes();
// This function, called from Forward(), computes denomator_probs_ and
// This function, called from Forward(), computes denominator_probs_ and
// numerator_probs_ via batch lookup operations in exp_nnet_output_ and
// normalizers_, and then computes arc_probs_.
void LookUpLikelihoods();
......
......@@ -423,6 +423,8 @@ void TestCctcSupervision(const CctcTransitionModel &trans_model) {
// ShortestPath effectively chooses an arbitrary path, because all paths have
// unit weight / zero cost.
ShortestPath(supervision.fst, &one_path);
std::vector<int32> graph_label_seq_in, graph_label_seq_out;
fst::TropicalWeight tot_weight;
GetLinearSymbolSequence(one_path, &graph_label_seq_in,
......@@ -430,6 +432,15 @@ void TestCctcSupervision(const CctcTransitionModel &trans_model) {
KALDI_ASSERT(tot_weight == fst::TropicalWeight::One() &&
graph_label_seq_in == graph_label_seq_out);
{ // basic testing of ComputeFstStateTimes (it has a lot of asserts).
std::vector<int32> state_times;
int32 length = ComputeFstStateTimes(supervision.fst, &state_times);
KALDI_ASSERT(static_cast<size_t>(length) == graph_label_seq_out.size());
for (size_t i = 0; i + 1 < state_times.size(); i++)
KALDI_ASSERT(state_times[i] <= state_times[i+1]);
}
std::vector<int32> phones_from_graph;
for (size_t i = 0; i < graph_label_seq_in.size(); i++) {
int32 this_phone = trans_model.GraphLabelToPhone(graph_label_seq_in[i]);
......
......@@ -528,6 +528,41 @@ void CtcSupervision::Read(std::istream &is, bool binary) {
ExpectToken(is, binary, "</CtcSupervision>");
}
int32 ComputeFstStateTimes(const fst::StdVectorFst &fst,
std::vector<int32> *state_times) {
if (fst.Start() != 0) // this is implied by our properties.
KALDI_ERR << "Expecting input FST start state to be zero";
int32 num_states = fst.NumStates();
int32 total_length = -1;
state_times->clear();
state_times->resize(num_states, -1);
(*state_times)[0] = 0;
for (int32 state = 0; state < num_states; state++) {
int32 next_state_time = (*state_times)[state] + 1;
if (next_state_time <= 0) // i.e. (*state_times)[state] < 0
KALDI_ERR << "Input FST does not have required properties.";
for (fst::ArcIterator<fst::StdVectorFst> aiter(fst, state);
!aiter.Done(); aiter.Next()) {
const fst::StdArc &arc = aiter.Value();
KALDI_ASSERT(arc.ilabel != 0);
int32 &next_state_ref = (*state_times)[arc.nextstate];
if (next_state_ref == -1)
next_state_ref = next_state_time;
else if (next_state_ref != next_state_time)
KALDI_ERR << "Input FST does not have required properties.";
}
if (fst.Final(state) != fst::TropicalWeight::Zero()) {
if (total_length == -1)
total_length = next_state_time - 1;
else if (total_length != next_state_time - 1)
KALDI_ERR << "Input FST does not have required properties.";
}
}
if (total_length < 0)
KALDI_ERR << "Input FST does not have required properties.";
return total_length;
}
} // namespace ctc
} // namespace kaldi
......@@ -330,6 +330,31 @@ class CtcSupervisionSplitter {
std::vector<int32> frame_;
};
/// Assuming the 'fst' is epsilon-free, connected, and has the property that all
/// paths from the start-state are of the same length, output a vector
/// containing that length (from the start-state to the current state) to
/// 'state_times'. The member 'fst' of struct CtcSupervision has this property.
/// Returns the total number of frames. This function is similar to
/// LatticeStateTimes() and CompactLatticeStateTimes() declared in
/// lat/lattice-functions.h (except that unlike LatticeStateTimes(), we don't
/// allow epsilons, not because they are hard to handle but because in this
/// context we don't expect them. This function also expects that the input fst
/// will have the property that the state times are in nondecreasing order (as
/// SortBreadthFirstSearch() does for FSTs satsifying the other properties we
/// mentioned). This just happens to be something we enforce while creating
/// these FSTs.
///
/// @param fst[in] The input fst: should be epsilon-free; connected; nonempty;
/// should have the property that all paths to a given state (or
/// to a nonzero final-prob) should have the same number of arcs;
/// and its states should be sorted on this path length (e.g.
/// SortBreadthFirst will do this).
/// @param state_times[out] The state times that we output; will be set to
/// a vector with the dimension fst.NumStates()
/// @return Returns the path length
int32 ComputeFstStateTimes(const fst::StdVectorFst &fst,
std::vector<int32> *state_times);
} // namespace ctc
} // namespace kaldi
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment