Skip to content
Snippets Groups Projects
Commit bba0b6b7 authored by Dan Povey's avatar Dan Povey
Browse files

trunk: Fix to bug in arpa2fst that sometimes caused wrong LMs when words...

trunk: Fix to bug in arpa2fst that sometimes caused wrong LMs when words contained underscores; speed up arpa2fst by using better string hasher.

git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@3000 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
parent 2bd27c09
No related branches found
No related tags found
No related merge requests found
......@@ -30,7 +30,7 @@
namespace kaldi {
// add the string contained in inpline to the current transducer
// starting at initial state
StateId LangModelFst::ReadTxtLine(const string &inpline) {
LangModelFst::StateId LangModelFst::ReadTxtLine(const string &inpline) {
KALDI_ASSERT(pfst_);
KALDI_ASSERT(pfst_->InputSymbols());
KALDI_ASSERT(pfst_->OutputSymbols());
......
......@@ -58,8 +58,11 @@ enum GrammarType {
/// LangModelFst is a standard vector FST that also provides
/// Read() and Write() functions for file-based language models
/// or text files defining strings and grammars.
class LangModelFst : public fst::VectorFst<fst::StdArc> {
class LangModelFst: public fst::VectorFst<fst::StdArc> {
public:
typedef fst::StdArc::Weight LmWeight;
typedef fst::StdArc::StateId StateId;
LangModelFst() {
pfst_ = new fst::VectorFst<fst::StdArc>;
......
......@@ -32,22 +32,23 @@ namespace kaldi {
// typedef fst::StdArc::StateId StateId;
// newlyAdded will be updated
StateId LmFstConverter::AddStateFromSymb(
const std::vector<string> &ngramString,
int kstart, int kend,
const char *sep,
fst::StdVectorFst *pfst,
fst::SymbolTable *psst,
bool &newlyAdded) {
LmFstConverter::StateId LmFstConverter::AddStateFromSymb(
const std::vector<string> &ngramString,
int kstart, int kend,
fst::StdVectorFst *pfst,
bool &newlyAdded) {
fst::StdArc::StateId sid;
std::string separator;
separator.resize(1);
separator[0] = '\0';
std::string hist;
if (kstart == 0) {
hist.append(sep);
hist.append(separator);
} else {
for (int k = kstart; k >= kend; k--) {
hist.append(ngramString[k]);
hist.append(sep);
hist.append(separator);
}
}
......@@ -81,7 +82,7 @@ void LmFstConverter::ConnectUnusedStates(fst::StdVectorFst *pfst) {
connected++;
}
}
cerr << "Connected "<<connected<<" states without outgoing arcs."<<endl;
cerr << "Connected " << connected << " states without outgoing arcs." << endl;
}
void LmFstConverter::AddArcsForNgramProb(
......@@ -89,8 +90,8 @@ void LmFstConverter::AddArcsForNgramProb(
float logProb,
float logBow,
std::vector<string> &ngs,
fst::StdVectorFst *pfst,
fst::SymbolTable *psst,
fst::StdVectorFst *fst,
fst::SymbolTable *symtab,
const string startSent,
const string endSent) {
fst::StdArc::StateId src, dst, dbo;
......@@ -102,72 +103,72 @@ void LmFstConverter::AddArcsForNgramProb(
if (ilev >= 2) {
// General case works from N down to 2-grams
src = AddStateFromSymb(ngs, ilev, 2, "_", pfst, psst, newSrc);
src = AddStateFromSymb(ngs, ilev, 2, fst, newSrc);
if (ilev != maxlev) {
// add all intermediate levels from 2 to current
// last ones will be current backoff source and destination
for (int iilev=2; iilev <= ilev; iilev++) {
dst = AddStateFromSymb(ngs, iilev, 1, "_", pfst, psst, newDst);
dbo = AddStateFromSymb(ngs, iilev-1, 1, "_", pfst, psst, newDbo);
dst = AddStateFromSymb(ngs, iilev, 1, fst, newDst);
dbo = AddStateFromSymb(ngs, iilev-1, 1, fst, newDbo);
bkState_[dst] = dbo;
}
} else {
// add all intermediate levels from 2 to current
// last ones will be current backoff source and destination
for (int iilev=2; iilev <= ilev; iilev++) {
dst = AddStateFromSymb(ngs, iilev-1, 1, "_", pfst, psst, newDst);
dbo = AddStateFromSymb(ngs, iilev-2, 1, "_", pfst, psst, newDbo);
dst = AddStateFromSymb(ngs, iilev-1, 1, fst, newDst);
dbo = AddStateFromSymb(ngs, iilev-2, 1, fst, newDbo);
bkState_[dst] = dbo;
}
}
} else {
// special case for 1-grams: start from 0-gram
if (curwrd.compare(startSent) != 0) {
src = AddStateFromSymb(ngs, 0, 1, "_", pfst, psst, newSrc);
src = AddStateFromSymb(ngs, 0, 1, fst, newSrc);
} else {
// extra special case if in addition we are at beginning of sentence
// starts from initial state and has no cost
src = pfst->Start();
src = fst->Start();
prob = fst::StdArc::Weight::One();
}
dst = AddStateFromSymb(ngs, 1, 1, "_", pfst, psst, newDst);
dbo = AddStateFromSymb(ngs, 0, 1, "_", pfst, psst, newDbo);
dst = AddStateFromSymb(ngs, 1, 1, fst, newDst);
dbo = AddStateFromSymb(ngs, 0, 1, fst, newDbo);
bkState_[dst] = dbo;
}
// state is final if last word is end of sentence
if (curwrd.compare(endSent) == 0) {
pfst->SetFinal(dst, fst::StdArc::Weight::One());
fst->SetFinal(dst, fst::StdArc::Weight::One());
}
// add labels to symbol tables
ilab = pfst->MutableInputSymbols()->AddSymbol(curwrd);
olab = pfst->MutableOutputSymbols()->AddSymbol(curwrd);
ilab = fst->MutableInputSymbols()->AddSymbol(curwrd);
olab = fst->MutableOutputSymbols()->AddSymbol(curwrd);
// add arc with weight "prob" between source and destination states
// cerr << "n-gram prob, fstAddArc: src "<< src << " dst " << dst;
// cerr << " lab " << ilab << endl;
pfst->AddArc(src, fst::StdArc(ilab, olab, prob, dst));
fst->AddArc(src, fst::StdArc(ilab, olab, prob, dst));
// add backoffs to any newly created destination state
// but only if non-final
if (!IsFinal(pfst, dst) && newDst && dbo != dst) {
if (!IsFinal(fst, dst) && newDst && dbo != dst) {
ilab = olab = 0;
// cerr << "backoff, fstAddArc: src "<< src << " dst " << dst;
// cerr << " lab " << ilab << endl;
pfst->AddArc(dst, fst::StdArc(ilab, olab, bow, dbo));
fst->AddArc(dst, fst::StdArc(ilab, olab, bow, dbo));
}
}
#ifndef HAVE_IRSTLM
bool LmTable::ReadFstFromLmFile(std::istream &istrm,
fst::StdVectorFst *pfst,
fst::StdVectorFst *fst,
bool useNaturalOpt,
const string startSent,
const string endSent) {
#ifdef KALDI_PARANOID
KALDI_ASSERT(pfst);
KALDI_ASSERT(pfst->InputSymbols() && pfst->OutputSymbols());
KALDI_ASSERT(fst);
KALDI_ASSERT(fst->InputSymbols() && fst->OutputSymbols());
#endif
conv_->UseNaturalLog(useNaturalOpt);
......@@ -294,12 +295,12 @@ bool LmTable::ReadFstFromLmFile(std::istream &istrm,
}
}
conv_->AddArcsForNgramProb(ilev, maxlev, prob, bow,
ngramString, pfst,
ngramString, fst,
pStateSymbs, startSent, endSent);
} // end of loop on individual n-gram lines
}
conv_->ConnectUnusedStates(pfst);
conv_->ConnectUnusedStates(fst);
// not used anymore: delete pStateSymbs;
......@@ -312,7 +313,7 @@ bool LmTable::ReadFstFromLmFile(std::istream &istrm,
// #ifdef HAVE_IRSTLM implementation
bool LmTable::ReadFstFromLmFile(std::istream &istrm,
fst::StdVectorFst *pfst,
fst::StdVectorFst *fst,
bool useNaturalOpt,
const string startSent,
const string endSent) {
......@@ -320,7 +321,7 @@ bool LmTable::ReadFstFromLmFile(std::istream &istrm,
ngram ng(this->getDict(), 0);
conv_->UseNaturalLog(useNaturalOpt);
DumpStart(ng, pfst, startSent, endSent);
DumpStart(ng, fst, startSent, endSent);
// should do some check before returning true
return true;
......@@ -328,12 +329,12 @@ bool LmTable::ReadFstFromLmFile(std::istream &istrm,
// run through all nodes in table (as in dumplm)
void LmTable::DumpStart(ngram ng,
fst::StdVectorFst *pfst,
fst::StdVectorFst *fst,
const string startSent,
const string endSent) {
#ifdef KALDI_PARANOID
KALDI_ASSERT(pfst);
KALDI_ASSERT(pfst->InputSymbols() && pfst->OutputSymbols());
KALDI_ASSERT(fst);
KALDI_ASSERT(fst->InputSymbols() && fst->OutputSymbols());
#endif
// we need a state symbol table while traversing word contexts
fst::SymbolTable *pStateSymbs = new fst::SymbolTable("kaldi-lm-state");
......@@ -343,7 +344,7 @@ void LmTable::DumpStart(ngram ng,
ng.size = 0;
cerr << "Processing " << l << "-grams" << endl;
DumpContinue(ng, 1, l, 0, cursize[1],
pfst, pStateSymbs, startSent, endSent);
fst, pStateSymbs, startSent, endSent);
}
delete pStateSymbs;
......@@ -353,7 +354,7 @@ void LmTable::DumpStart(ngram ng,
// run through given levels and positions in table
void LmTable::DumpContinue(ngram ng, int ilev, int elev,
table_entry_pos_t ipos, table_entry_pos_t epos,
fst::StdVectorFst *pfst,
fst::StdVectorFst *fst,
fst::SymbolTable *pStateSymbs,
const string startSent, const string endSent) {
LMT_TYPE ndt = tbltype[ilev];
......@@ -384,7 +385,7 @@ void LmTable::DumpContinue(ngram ng, int ilev, int elev,
(table_pos_t) i * ndsz, ndt);
if (isucc < esucc) // there are successors!
DumpContinue(ng, ilev+1, elev, isucc, esucc,
pfst, pStateSymbs, startSent, endSent);
fst, pStateSymbs, startSent, endSent);
// else
// cerr << "no successors for " << ng << "\n";
} else {
......@@ -418,7 +419,7 @@ void LmTable::DumpContinue(ngram ng, int ilev, int elev,
// else if (ibo != 0.0) cerr << "\t" << ibo;
}
conv_->AddArcsForNgramProb(ilev, maxlev, ipr, ibo,
ngramString, pfst, pStateSymbs,
ngramString, fst, pStateSymbs,
startSent, endSent);
}
}
......
......@@ -44,6 +44,7 @@ using std::tr1::unordered_map;
#include "fst/fst-decl.h"
#include "fst/arc.h"
#include "base/kaldi-common.h"
#include "util/stl-utils.h"
#ifdef _MSC_VER
# define STRTOF(cur_cstr, end_cstr) static_cast<float>(strtod(cur_cstr, end_cstr));
......@@ -64,13 +65,14 @@ namespace kaldi {
* does not require an external library.
*/
typedef fst::StdArc::Weight LmWeight;
typedef fst::StdArc::StateId StateId;
/// @brief Helper methods to convert toolkit internal representations into FST.
class LmFstConverter {
typedef fst::StdArc::Weight LmWeight;
typedef fst::StdArc::StateId StateId;
typedef unordered_map<StateId, StateId> BkStateMap;
typedef unordered_map<std::string, StateId> HistStateMap;
typedef unordered_map<std::string, StateId, StringHasher> HistStateMap;
public:
......@@ -111,9 +113,7 @@ class LmFstConverter {
StateId AddStateFromSymb(const std::vector<string> &ngramString,
int kstart,
int kend,
const char *sep,
fst::StdVectorFst *pfst,
fst::SymbolTable *psst,
bool &newlyAdded);
StateId FindState(const std::string str) {
......@@ -181,7 +181,7 @@ class LmTable : public lmtable {
void DumpContinue(ngram ng,
int ilev, int elev,
table_entry_pos_t ipos, table_entry_pos_t epos,
fst::StdVectorFst *pfst, fst::SymbolTable *pStateSymbs,
fst::StdVectorFst *pfst,
const string startSent, const string endSent);
LmFstConverter *conv_;
......
......@@ -40,10 +40,10 @@ namespace kaldi {
#define MAX_SENTENCE_LENGTH 1000
/// @brief Recursively prints all complete paths starting at s and their score.
static LmWeight PrintCompletePath(fst::SymbolTable *pst,
fst::StdVectorFst *pfst,
fst::StdArc::StateId s,
LmWeight score) {
static LangModelFst::LmWeight PrintCompletePath(fst::SymbolTable *pst,
fst::StdVectorFst *pfst,
fst::StdArc::StateId s,
LangModelFst::LmWeight score) {
fst::ArcIterator<fst::StdVectorFst> ai(*pfst, s);
for (ai.Reset(); !ai.Done(); ai.Next()) {
std::cout << pst->Find(ai.Value().ilabel) << " ";
......@@ -72,8 +72,8 @@ static LmWeight PrintCompletePath(fst::SymbolTable *pst,
}
/// @brief Recursively prints all complete paths starting from initial state.
static LmWeight PrintCompletePaths(fst::SymbolTable *pst,
fst::StdVectorFst *pfst) {
static LangModelFst::LmWeight PrintCompletePaths(fst::SymbolTable *pst,
fst::StdVectorFst *pfst) {
KALDI_ASSERT(pst);
KALDI_ASSERT(pfst);
KALDI_ASSERT(pfst->Start() >=0);
......@@ -239,7 +239,7 @@ bool TestLmTableEvalScore(const string &inpfile,
// read in reference score
std::ifstream strm(refScoreFile.c_str(), std::ifstream::in);
LmWeight refScore;
LangModelFst::LmWeight refScore;
strm >> refScore;
std::cout << "Reference score is " << refScore << '\n';
......@@ -278,9 +278,9 @@ bool TestLmTableEvalScore(const string &inpfile,
fst::ShortestPath(composedFst, bestFst, 1);
std::cout << "Best path has " << bestFst->NumStates() << " states" << '\n';
LmWeight testScore = PrintCompletePaths(
bestFst->MutableInputSymbols(),
bestFst);
LangModelFst::LmWeight testScore = PrintCompletePaths(
bestFst->MutableInputSymbols(),
bestFst);
std::cout << "Complete path score is " << testScore << '\n';
if (testScore.Value() <= refScore.Value()) {
......
......@@ -254,9 +254,9 @@ struct PairHasher { // hashing function for pair<int>
/// A hashing function object for strings.
struct StringHasher { // hashing function for std::string
size_t operator()(const std::string &str) const {
size_t ans = 0;
const char *c = str.c_str();
for (; *c != '\0'; c++) {
size_t ans = 0, len = str.length();
const char *c = str.c_str(), *end = c + len;
for (; c != end; c++) {
ans *= kPrime;
ans += *c;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment