Skip to content
Snippets Groups Projects
Commit 07ccd077 authored by Dan Povey's avatar Dan Povey
Browse files

trunk: merging some minor fixes/improvements from sandbox/dan, including speed...

trunk: merging some minor fixes/improvements from sandbox/dan, including speed improvement to lattice rescoring code.

git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@4426 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
parent 944ff9ce
Branches
No related tags found
No related merge requests found
......@@ -4,7 +4,7 @@
// Saarland University (Author: Arnab Ghoshal);
// Georg Stemmer; Jan Silovsky
// 2012 Arnab Ghoshal
// 2013 Johns Hopkins University (author: Daniel Povey)
// 2013-2014 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
......@@ -941,4 +941,28 @@ void DiagGmm::Generate(VectorBase<BaseFloat> *output) {
}
}
DiagGmm::DiagGmm(const GaussClusterable &gc,
BaseFloat var_floor): valid_gconsts_(false) {
Vector<BaseFloat> x (gc.x_stats());
Vector<BaseFloat> x2 (gc.x2_stats());
BaseFloat count = gc.count();
KALDI_ASSERT(count > 0.0);
this->Resize(1, x.Dim());
x.Scale(1.0/count);
x2.Scale(1.0/count);
x2.AddVec2(-1.0, x); // subtract mean^2.
x2.ApplyFloor(var_floor);
x2.InvertElements(); // get inv-var.
KALDI_ASSERT(x2.Min() > 0);
Matrix<BaseFloat> mean(1, x.Dim());
mean.Row(0).CopyFromVec(x);
Matrix<BaseFloat> inv_var(1, x.Dim());
inv_var.Row(0).CopyFromVec(x2);
this->SetInvVarsAndMeans(inv_var, mean);
Vector<BaseFloat> weights(1);
weights(0) = 1.0;
this->SetWeights(weights);
this->ComputeGconsts();
}
} // End namespace kaldi
......@@ -4,7 +4,7 @@
// Saarland University (Author: Arnab Ghoshal);
// Georg Stemmer; Jan Silovsky
// 2012 Arnab Ghoshal
// 2013 Johns Hopkins University (author: Daniel Povey)
// 2013-2014 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
......@@ -31,6 +31,7 @@
#include "gmm/model-common.h"
#include "matrix/matrix-lib.h"
#include "tree/cluster-utils.h"
#include "tree/clusterable-classes.h"
namespace kaldi {
......@@ -51,6 +52,10 @@ class DiagGmm {
CopyFromDiagGmm(gmm);
}
/// Initializer from GaussClusterable initializes the DiagGmm as
/// a single Gaussian from tree stats.
DiagGmm(const GaussClusterable &gc, BaseFloat var_floor);
/// Copies from DiagGmmNormal; does not resize.
void CopyFromNormal(const DiagGmmNormal &diag_gmm_normal);
......
......@@ -71,11 +71,9 @@ void InitAmGmm(const BuildTreeStatsType &stats,
for (size_t i = 0; i < summed_stats.size(); i++) {
GaussClusterable *c =
static_cast<GaussClusterable*>(summed_stats[i] != NULL ? summed_stats[i] : avg_stats);
DiagGmm gmm;
Vector<BaseFloat> x (c->x_stats());
Vector<BaseFloat> x2 (c->x2_stats());
DiagGmm gmm(*c, var_floor);
am_gmm->AddPdf(gmm);
BaseFloat count = c->count();
gmm.Resize(1, x.Dim());
if (count < 100) {
std::vector<int32> bad_pdfs(1, i), bad_phones;
GetPhonesForPdfs(trans_model, bad_pdfs, &bad_phones);
......@@ -85,24 +83,6 @@ void InitAmGmm(const BuildTreeStatsType &stats,
KALDI_WARN << "Very small count for state " << i << ": "
<< count << "; corresponding phone list: " << ss.str();
}
x.Scale(1.0/count);
x2.Scale(1.0/count);
x2.AddVec2(-1.0, x); // subtract mean^2.
x2.ApplyFloor(var_floor);
x2.InvertElements(); // get inv-var.
KALDI_ASSERT(x2.Min() > 0);
Matrix<BaseFloat> mean(1, x.Dim());
mean.Row(0).CopyFromVec(x);
Matrix<BaseFloat> inv_var(1, x.Dim());
inv_var.Row(0).CopyFromVec(x2);
gmm.SetInvVarsAndMeans(inv_var, mean);
Vector<BaseFloat> weights(1);
weights(0) = 1.0;
gmm.SetWeights(weights);
gmm.ComputeGconsts();
am_gmm->AddPdf(gmm);
}
DeletePointers(&summed_stats);
delete avg_stats;
......@@ -139,6 +119,7 @@ void InitAmGmmFromOld(const BuildTreeStatsType &stats,
int32 P, // central-position
const std::string &old_tree_rxfilename,
const std::string &old_model_rxfilename,
BaseFloat var_floor,
AmDiagGmm *am_gmm) {
AmDiagGmm old_am_gmm;
......@@ -161,15 +142,30 @@ void InitAmGmmFromOld(const BuildTreeStatsType &stats,
std::vector<BuildTreeStatsType> split_stats;
SplitStatsByMap(stats, to_pdf_map, &split_stats);
// Make sure each leaf has stats.
for (size_t i = 0; i < split_stats.size(); i++)
KALDI_ASSERT(! split_stats[i].empty() && "Tree has leaves with no stats."
" Modify your roots file as necessary to fix this.");
KALDI_ASSERT(static_cast<int32>(split_stats.size()-1) == to_pdf_map.MaxResult()
&& "Tree may have final leaf with no stats. "
"Modify your roots file as necessary to fix this.");
for (size_t i = 0; i < split_stats.size(); i++) {
if (split_stats[i].empty()) {
KALDI_WARN << "Leaf " << i << " of new tree has no stats.";
}
}
if (static_cast<int32>(split_stats.size()) != to_pdf_map.MaxResult() + 1) {
KALDI_ASSERT(static_cast<int32>(split_stats.size()) <
to_pdf_map.MaxResult() + 1);
KALDI_WARN << "Tree may have final leaf with no stats.";
split_stats.resize(to_pdf_map.MaxResult() + 1);
// avoid indexing errors later.
}
int32 oldN = old_tree.ContextWidth(), oldP = old_tree.CentralPosition();
// avg_stats will be used for leaves that have no stats.
Clusterable *avg_stats = SumStats(stats);
GaussClusterable *avg_stats_gc = dynamic_cast<GaussClusterable*>(avg_stats);
KALDI_ASSERT(avg_stats_gc != NULL && "Empty stats input.");
DiagGmm avg_gmm(*avg_stats_gc, var_floor);
delete avg_stats;
avg_stats = NULL;
avg_stats_gc = NULL;
const EventMap &old_map = old_tree.ToPdfMap();
KALDI_ASSERT(am_gmm->NumPdfs() == 0);
......@@ -188,7 +184,6 @@ void InitAmGmmFromOld(const BuildTreeStatsType &stats,
// that align to this "new" pdf... we'll use it to work out the old pdf-id
// that's "closest" in stats overlap to this new pdf ("pdf").
std::map<int32, BaseFloat> oldpdf_to_count;
KALDI_ASSERT(!my_stats.empty()); // would be code error; checked already.
for (size_t i = 0; i < my_stats.size(); i++) {
EventType evec = my_stats[i].first;
EventAnswerType ans;
......@@ -199,7 +194,6 @@ void InitAmGmmFromOld(const BuildTreeStatsType &stats,
if (oldpdf_to_count.count(ans) == 0) oldpdf_to_count[ans] = stats_count;
else oldpdf_to_count[ans] += stats_count;
}
KALDI_ASSERT(!oldpdf_to_count.empty());
BaseFloat max_count = 0; int32 max_old_pdf = -1;
for (std::map<int32, BaseFloat>::const_iterator iter = oldpdf_to_count.begin();
iter != oldpdf_to_count.end();
......@@ -209,11 +203,15 @@ void InitAmGmmFromOld(const BuildTreeStatsType &stats,
max_old_pdf = iter->first;
}
}
KALDI_ASSERT(max_count != 0 && max_old_pdf != -1);
if (max_count == 0) { // no overlap - probably a leaf with no stats at all.
KALDI_WARN << "Leaf " << pdf << " of new tree being initialized with "
<< "globally averaged stats.";
am_gmm->AddPdf(avg_gmm);
} else {
am_gmm->AddPdf(old_am_gmm.GetPdf(max_old_pdf)); // Here is where we copy the relevant old PDF.
}
}
}
......@@ -289,6 +287,7 @@ int main(int argc, char *argv[]) {
ctx_dep.CentralPosition(),
old_tree_filename,
old_model_filename,
var_floor,
&am_gmm);
}
......
......@@ -45,9 +45,13 @@ int main(int argc, char *argv[]) {
" or: lattice-compose ark:1.lats G.fst ark:composed.lats\n";
ParseOptions po(usage);
int32 num_states_cache = 50000;
int32 phi_label = fst::kNoLabel; // == -1
po.Register("phi-label", &phi_label, "If >0, the label on backoff arcs of the LM");
po.Register("num-states-cache", &num_states_cache,
"Number of states we cache when mapping LM FST to lattice type. "
"More -> more memory but faster.");
po.Read(argc, argv);
if (po.NumArgs() != 3) {
......@@ -81,9 +85,10 @@ int main(int argc, char *argv[]) {
if (phi_label > 0)
PropagateFinal(phi_label, fst2);
fst::CacheOptions cache_opts(true, num_states_cache);
fst::StdToLatticeMapper<BaseFloat> mapper;
fst::MapFst<StdArc, LatticeArc, fst::StdToLatticeMapper<BaseFloat> >
mapped_fst2(*fst2, mapper);
mapped_fst2(*fst2, mapper, cache_opts);
for (; !lattice_reader1.Done(); lattice_reader1.Next()) {
std::string key = lattice_reader1.Key();
KALDI_VLOG(1) << "Processing lattice for key " << key;
......
......@@ -43,8 +43,12 @@ int main(int argc, char *argv[]) {
ParseOptions po(usage);
BaseFloat lm_scale = 1.0;
int32 num_states_cache = 50000;
po.Register("lm-scale", &lm_scale, "Scaling factor for language model costs; frequently 1.0 or -1.0");
po.Register("num-states-cache", &num_states_cache,
"Number of states we cache when mapping LM FST to lattice type. "
"More -> more memory but faster.");
po.Read(argc, argv);
......@@ -69,9 +73,10 @@ int main(int argc, char *argv[]) {
// mapped_fst is the LM fst interpreted using the LatticeWeight semiring,
// with all the cost on the first member of the pair (since it's a graph
// weight).
fst::CacheOptions cache_opts(true, num_states_cache);
fst::StdToLatticeMapper<BaseFloat> mapper;
fst::MapFst<StdArc, LatticeArc, fst::StdToLatticeMapper<BaseFloat> >
lm_fst(*std_lm_fst, mapper);
lm_fst(*std_lm_fst, mapper, cache_opts);
delete std_lm_fst;
// The next fifteen or so lines are a kind of optimization and
......
// bin/copy-matrix.cc
// online2bin/ivector-randomize.cc
// Copyright 2014 Johns Hopkins University (author: Daniel Povey)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment