Skip to content
Snippets Groups Projects
Commit e6e5d44b authored by Daniel Povey's avatar Daniel Povey Committed by Dan Povey
Browse files

CCTC transition-mode code: getting it to compile.

parent ec421f05
No related branches found
No related tags found
No related merge requests found
......@@ -7,7 +7,7 @@ EXTRA_CXXFLAGS += -Wno-sign-compare
TESTFILES = language-model-test
OBJFILES = language-model.o # ctc-functions.o
OBJFILES = language-model.o cctc-transition-model.o # ctc-functions.o
LIBNAME = kaldi-ctc
......
......@@ -35,7 +35,8 @@ BaseFloat CctcTransitionModel::GraphLabelToLmProb(int32 graph_label) const {
int32 CctcTransitionModel::GraphLabelToHistoryState(int32 graph_label) const {
int32 history = graph_label / (num_phones_ + 1);
KALDI_ASSERT(static_cast<size_t>(history) < history_states_.size());
KALDI_ASSERT(static_cast<size_t>(history) < history_state_info_.size());
return history;
}
int32 CctcTransitionModel::GraphLabelToNextHistoryState(
......@@ -52,7 +53,7 @@ int32 CctcTransitionModel::InitialHistoryState() const {
int32 CctcTransitionModel::PairToGraphLabel(int32 history_state,
int32 phone) const {
KALDI_ASSERT(static_cast<size_t>(phone) < static_cast<size_t>(num_phones) &&
KALDI_ASSERT(static_cast<size_t>(phone) < static_cast<size_t>(num_phones_) &&
static_cast<size_t>(history_state) < history_state_info_.size());
return history_state * (num_phones_ + 1) + phone;
}
......@@ -79,7 +80,7 @@ void CctcTransitionModel::Check() const {
// see blank should not change the history state.
KALDI_ASSERT(info.next_history_state[0] == h);
for (int32 p = 1; p < num_phones; p++) {
int32 next_h = info.next_history_state[i];
int32 next_h = info.next_history_state[p];
KALDI_ASSERT(next_h >= 0 && next_h < num_histories);
}
// output-index if we predict blank should be after the
......@@ -114,7 +115,8 @@ void CctcTransitionModel::Check() const {
// we always get to the same history state regardless of where we started.
int32 num_test = 50;
for (int32 i = 0; i < num_test; i++) {
int32 h1 = RandInt(0, histories - 1), h2 = RandInt(0, histories - 1);
int32 h1 = RandInt(0, num_histories - 1),
h2 = RandInt(0, num_histories - 1);
for (int32 n = 0; n < phone_left_context_; n++) {
int32 p = RandInt(1, num_phones); // a real phone.
h1 = history_state_info_[h1].next_history_state[p];
......@@ -122,13 +124,88 @@ void CctcTransitionModel::Check() const {
}
KALDI_ASSERT(h1 == h2 && "Test of phone_left_context_ failed.");
}
}
void CctcTransitionModel::ComputeWeights() {
void CctcTransitionModel::Write(std::ostream &os, bool binary) const {
Check();
WriteToken(os, binary, "<CctcTransitionModel>");
if (!binary) os << "\n";
WriteToken(os, binary, "<NumPhones>");
WriteBasicType(os, binary, num_phones_);
if (!binary) os << "\n";
WriteToken(os, binary, "<PhoneLeftContext>");
WriteBasicType(os, binary, phone_left_context_);
if (!binary) os << "\n";
WriteToken(os, binary, "<NumOutputIndexes>");
WriteBasicType(os, binary, num_output_indexes_);
if (!binary) os << "\n";
WriteToken(os, binary, "<NumNonBlankIndexes>");
WriteBasicType(os, binary, num_non_blank_indexes_);
if (!binary) os << "\n";
WriteToken(os, binary, "<InitialHistoryState>");
WriteBasicType(os, binary, initial_history_state_);
if (!binary) os << "\n";
WriteToken(os, binary, "<NumHistoryStates>");
int32 num_history_states = history_state_info_.size();
WriteBasicType(os, binary, num_history_states);
if (!binary) os << "\n";
WriteToken(os, binary, "<HistoryStates>");
for (int32 h = 0; h < num_history_states; h++) {
const HistoryStateInfo &info = history_state_info_[h];
WriteIntegerVector(os, binary, info.next_history_state);
if (!binary) os << "\n";
WriteIntegerVector(os, binary, info.output_index);
if (!binary) os << "\n";
info.phone_lm_prob.Write(os, binary);
}
WriteToken(os, binary, "</CctcTransitionModel>");
}
void CctcTransitionModel::Read(std::istream &is, bool binary) {
ExpectToken(is, binary, "<CctcTransitionModel>");
ExpectToken(is, binary, "<NumPhones>");
ReadBasicType(is, binary, &num_phones_);
ExpectToken(is, binary, "<PhoneLeftContext>");
ReadBasicType(is, binary, &phone_left_context_);
ExpectToken(is, binary, "<NumOutputIndexes>");
ReadBasicType(is, binary, &num_output_indexes_);
ExpectToken(is, binary, "<NumNonBlankIndexes>");
ReadBasicType(is, binary, &num_non_blank_indexes_);
ExpectToken(is, binary, "<InitialHistoryState>");
ReadBasicType(is, binary, &initial_history_state_);
ExpectToken(is, binary, "<NumHistoryStates>");
int32 num_history_states = history_state_info_.size();
ReadBasicType(is, binary, &num_history_states);
ExpectToken(is, binary, "<HistoryStates>");
for (int32 h = 0; h < num_history_states; h++) {
HistoryStateInfo &info = history_state_info_[h];
ReadIntegerVector(is, binary, &info.next_history_state);
ReadIntegerVector(is, binary, &info.output_index);
info.phone_lm_prob.Read(is, binary);
}
ExpectToken(is, binary, "</CctcTransitionModel>");
Check();
ComputeWeights();
}
void CctcTransitionModel::ComputeWeights() {
int32 num_history_states = history_state_info_.size(),
num_output_indexes = num_output_indexes_,
num_phones = num_phones_;
Matrix<BaseFloat> weights(num_history_states,
num_output_indexes);
for (int32 h = 0; h < num_history_states; h++) {
const HistoryStateInfo &info = history_state_info_[h];
SubVector<BaseFloat> row(weights, h);
for (int32 p = 0; p < num_phones; p++) {
int32 output_index = info.output_index[p];
BaseFloat lm_prob = info.phone_lm_prob(p);
row(output_index) += lm_prob;
}
}
weights_.Swap(&weights);
}
void CctcTransitionModelCreator::InitCctcTransitionModel(
......@@ -139,11 +216,12 @@ void CctcTransitionModelCreator::InitCctcTransitionModel(
num_output_indexes_ = num_tree_leaves_ + lm_hist_state_map_.NumLmHistoryStates();
KALDI_LOG << "There are " << num_output_indexes_ << " output indexes, = "
<< num_tree_leaves_ << " for non-blank, and "
<< lm_hist_state_map_.NumLmHistoryStates() << " for blank."
<< lm_hist_state_map_.NumLmHistoryStates() << " for blank.";
GetInitialHistoryStates();
while (MergeHistoryStatesOnePass());
OutputToTransitionModel(model);
model->Check();
// Compute the model weights_, just in case they might be needed.
model->ComputeWeights();
}
......@@ -157,7 +235,7 @@ void CctcTransitionModelCreator::GetInitialHistories(SetType *hist_set) const {
int32 tree_left_context = ctx_dep_.ContextWidth() - 1;
std::queue<std::vector<int32> > hist_state_queue;
std::vector<std::vector<int32> > hist_state_queue;
for (int32 i = 0; i < lm_hist_state_map_.NumLmHistoryStates(); i++) {
const std::vector<int32> &hist = lm_hist_state_map_.GetHistoryForState(i);
......@@ -238,7 +316,7 @@ int32 CctcTransitionModelCreator::GetOutputIndex(
}
void CctcTransitionModelCreator:GetInitialHistoryStates() {
void CctcTransitionModelCreator::GetInitialHistoryStates() {
SetType hist_set;
GetInitialHistories(&hist_set);
KALDI_ASSERT(!hist_set.empty());
......@@ -276,7 +354,7 @@ void CctcTransitionModelCreator:GetInitialHistoryStates() {
// the decision tree).
int32 tree_left_context = ctx_dep_.ContextWidth() - 1;
std::vector<int32> sentence_start_hist(tree_left_context, 0);
MapType::iter iter;
MapType::iterator iter;
if ((iter = hist_to_state.find(sentence_start_hist)) == hist_to_state.end())
KALDI_ERR << "Cannot find history state for beginning of sentence.";
initial_history_state_ = iter->second;
......@@ -284,8 +362,9 @@ void CctcTransitionModelCreator:GetInitialHistoryStates() {
}
void CctcTransitionModelCreator::CreateHistoryInfo(const std::vector<int32> &hist_vec,
const MapType &hist_to_state) {
void CctcTransitionModelCreator::CreateHistoryInfo(
const std::vector<std::vector<int32> > &hist_vec,
const MapType &hist_to_state) {
int32 num_histories = hist_vec.size(), // before merging.
num_phones = phone_lang_model_.VocabSize(),
tree_left_context = ctx_dep_.ContextWidth() - 1;
......@@ -295,7 +374,7 @@ void CctcTransitionModelCreator::CreateHistoryInfo(const std::vector<int32> &his
for (int32 h = 0; h < num_histories; h++) {
const std::vector<int32> &hist = hist_vec[h];
HistoryState &state = history_states_[h];
state.lm_history_state = hist_state_map_.GetLmHistoryState(hist);
state.lm_history_state = lm_hist_state_map_.GetLmHistoryState(hist);
state.output_index.resize(num_phones + 1);
state.next_history_state.resize(num_phones + 1);
KALDI_ASSERT(hist.size() >= static_cast<size_t>(tree_left_context));
......@@ -319,7 +398,7 @@ void CctcTransitionModelCreator::CreateHistoryInfo(const std::vector<int32> &his
}
}
bool MergeHistoryStatesOnePass() {
bool CctcTransitionModelCreator::MergeHistoryStatesOnePass() {
int32 num_history_states = history_states_.size();
std::vector<int32> old2new_history_state(num_history_states);
......@@ -329,7 +408,7 @@ bool MergeHistoryStatesOnePass() {
HistoryMapType hist_to_new;
for (int32 h = 0; h < num_history_states; h++) {
const HistoryState *hist_state = history_states_[h];
const HistoryState *hist_state = &(history_states_[h]);
std::pair<const HistoryState*, int32> pair_to_insert(hist_state,
new_num_history_states);
std::pair<HistoryMapType::iterator, bool>
......@@ -385,9 +464,8 @@ void CctcTransitionModelCreator::OutputToTransitionModel(
// first clear some stuff, just in case.
trans_model->weights_.Resize(0, 0);
trans_model->history_state_info_.clear();
trans_model->graph_label_info_.clear();
int32 num_histories = hist_vec.size(), // before merging.
int32 num_histories = history_states_.size(),
num_phones = phone_lang_model_.VocabSize(),
ngram_order = phone_lang_model_.NgramOrder(),
tree_context_width = ctx_dep_.ContextWidth();
......@@ -399,7 +477,7 @@ void CctcTransitionModelCreator::OutputToTransitionModel(
trans_model->num_non_blank_indexes_ = num_tree_leaves_;
trans_model->initial_history_state_ = initial_history_state_;
KALDI_ASSERT(initial_history_state_ < num_histories);
trans_model_->history_state_info_.resize(num_histories);
trans_model->history_state_info_.resize(num_histories);
for (int32 h = 0; h < num_histories; h++)
OutputHistoryState(h, trans_model);
};
......@@ -427,9 +505,8 @@ void CctcTransitionModelCreator::OutputHistoryState(
int32 lm_history_state = src.lm_history_state;
for (int32 p = 0; p <= num_phones; p++)
info.phone_lm_prob(p) = hist_state_map_.GetProb(phone_lang_model_,
lm_history_state,
p);
info.phone_lm_prob(p) = lm_hist_state_map_.GetProb(phone_lang_model_,
lm_history_state, p);
// language model should sum to one over its output space.
KALDI_ASSERT(fabs(1.0 - info.phone_lm_prob.Sum()) < 1.001);
// eos_prob is the probability of the end-of-sequence/end-of-sentence symbol.
......
......@@ -28,7 +28,10 @@
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "fstext/fstext-lib.h"
#include "tree/context-dep.h"
#include "lat/kaldi-lattice.h"
#include "cudamatrix/cu-matrix.h"
#include "ctc/language-model.h"
namespace kaldi {
namespace ctc {
......@@ -189,7 +192,7 @@ class CctcTransitionModel {
int32 NumNonBlankIndexes() const { return num_non_blank_indexes_; }
// return the number of history-states the model contains.
int32 NumHistoryStates() { return history_state_info_.size(); }
int32 NumHistoryStates() const { return history_state_info_.size(); }
// return the number of phones. Phones are one-based, so NumPhones() is the
// index of the largest phone, but phone 0 is used to mean the blank symbol.
......@@ -242,6 +245,10 @@ class CctcTransitionModel {
// expression for the likelihood of this phone (or blank).
int32 GraphLabelToOutputIndex(int32 graph_label) const;
void Write(std::ostream &os, bool binary) const;
void Read(std::istream &is, bool binary);
protected:
// Check that the contents of this object make sense (does not check
// weights_, which is a derived variable).
......@@ -325,7 +332,9 @@ class CctcTransitionModelCreator {
void GetInitialHistories(SetType *hist_set) const;
// called from GetInitialHistoryStates(). Writes to history_state_info_.
void CreateHistoryInfo(const std::vector<int32> &hist_vec,
// Input is a vector of histories represented as phone left-context
// histories, each of length at least equal to the decision-tree left context.
void CreateHistoryInfo(const std::vector<std::vector<int32> > &hist_vec,
const MapType &hist_to_state);
// writes to history_state_info_ the initial history states, pre-merging.
......@@ -373,27 +382,26 @@ class CctcTransitionModelCreator {
// merging.
std::vector<int32> next_history_state;
bool operator == (const HistoryState &other) const {
return lm_history_state == other.lm_history_state &&
output_index == other.output_index &&
next_history_state == other.next_history_state;
}
};
// hashing object that hashes struct HistoryState (from a pointer).
struct HistoryStateHasher {
size_t operator () (const HistoryState *const hist_info) {
VectorHasher vec_hasher;
int32 p1 = 31;
size_t operator () (const HistoryState *const hist_info) const {
VectorHasher<int32> vec_hasher;
size_t p1 = 31;
return p1 * hist_info->lm_history_state +
vec_hasher(hist_info->output_index) +
vec_hasher(hist_info->next_history_state);
}
};
struct HistoryStateEqual {
size_t operator () (const HistoryState *const hist_info1,
const HistoryState *const hist_info2) {
return *hist_info1 == *hist_info2;
}
};
typedef unordered_map<const HistoryState*, int32,
HistoryStateHasher, HistoryStateEqual> HistoryMapType;
HistoryStateHasher> HistoryMapType;
const ContextDependency &ctx_dep_;
......@@ -411,8 +419,6 @@ class CctcTransitionModelCreator {
int32 initial_history_state_;
std::vector<HistoryState> history_states_;
};
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment