Skip to content
Snippets Groups Projects
Commit 4df219da authored by Daniel Povey's avatar Daniel Povey Committed by Dan Povey
Browse files

Adding language modeling code (and test code for it) that is to be used as a...

Adding language modeling code (and test code for it) that is to be used as a phone language model in CTC modeling.
parent a40fecd2
No related branches found
No related tags found
No related merge requests found
......@@ -142,7 +142,7 @@ if [ $stage -le 3 ]; then
# words (including alternatives regarding optional silences).
# --lattice-beam=$beam keeps all the alternatives that were within the beam,
# it means we do no pruning of the lattice (lattices from a training transcription
# will be samll anyway).
# will be small anyway).
echo "$0: generating lattices containing alternate pronunciations."
$cmd JOB=1:$nj $dir/log/generate_lattices.JOB.log \
gmm-latgen-faster --acoustic-scale=$acoustic_scale --beam=$final_beam \
......
all:
include ../kaldi.mk
EXTRA_CXXFLAGS += -Wno-sign-compare
TESTFILES = language-model-test
OBJFILES = language-model.o # ctc-functions.o
LIBNAME = kaldi-ctc
ADDLIBS = ../hmm/kaldi-hmm.a ../tree/kaldi-tree.a ../matrix/kaldi-matrix.a \
../cudamatrix/kaldi-cudamatrix.a ../util/kaldi-util.a \
../base/kaldi-base.a
include ../makefiles/default_rules.mk
// ctc/language-model-test.cc
// Copyright 2015 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "ctc/language-model.h"
namespace kaldi {
namespace ctc {
void GetTestingData(int32 *vocab_size,
std::vector<std::vector<int32> > *data,
std::vector<std::vector<int32> > *validation_data) {
// read the code of a C++ file as training data.
bool binary;
Input input("language-model.cc", &binary);
KALDI_ASSERT(!binary);
std::istream &is = input.Stream();
std::string line;
*vocab_size = 255;
int32 line_count = 0;
for (; getline(is, line); line_count++) {
std::vector<int32> int_line(line.size());
for (size_t i = 0; i < line.size(); i++) {
int32 this_char = line[i];
if (this_char == 0) {
this_char = 1; // should never happen, but just make sure, as 0 is
// treated as BOS/EOS in the language modeling code.
}
int_line[i] = this_char;
}
if (line_count % 10 != 0)
data->push_back(int_line);
else
validation_data->push_back(int_line);
}
KALDI_ASSERT(line_count > 0);
}
void TestLmHistoryStateMap(const LanguageModel &lm) {
LmHistoryStateMap map;
map.Init(lm);
KALDI_LOG << "Number of history states is " << map.NumLmHistoryStates();
int32 vocab_size = lm.VocabSize();
int32 num_test = 500;
for (int32 i = 0; i < num_test; i++) {
int32 history_length = RandInt(0, lm.NgramOrder() - 1);
std::vector<int32> history(history_length);
// get a random history.
for (int32 i = 0; i < history_length; i++)
history[i] = RandInt(0, vocab_size);
int32 history_state = map.GetLmHistoryState(history);
std::vector<int32> ngram(history);
int32 random_word = RandInt(0, vocab_size);
ngram.push_back(random_word);
KALDI_ASSERT(map.GetProb(lm, history_state, random_word) ==
lm.GetProb(ngram));
}
}
void TestNormalization(const LanguageModel &lm) {
int32 vocab_size = lm.VocabSize();
int32 num_test = 500;
for (int32 i = 0; i < num_test; i++) {
int32 history_length = RandInt(0, lm.NgramOrder() - 1);
std::vector<int32> history(history_length);
// get a random history.
for (int32 i = 0; i < history_length; i++)
history[i] = RandInt(0, vocab_size);
double prob_sum = 0.0;
std::vector<int32> vec(history);
vec.push_back(0);
for (int32 word = 0; word <= vocab_size; word++) {
vec[history_length] = word;
prob_sum += lm.GetProb(vec);
}
KALDI_ASSERT(ApproxEqual(prob_sum, 1.0));
}
}
void LanguageModelTest() {
int32 order = RandInt(1, 4);
int32 vocab_size;
std::vector<std::vector<int32> > data, validation_data;
GetTestingData(&vocab_size, &data, &validation_data);
LanguageModelOptions opts;
opts.ngram_order = order;
if (RandInt(0,3) == 0)
opts.state_count_cutoff1 = 100.0;
if (RandInt(0,3) == 0) {
opts.state_count_cutoff1 = 10.0;
opts.state_count_cutoff2plus = 10.0;
}
if (RandInt(0,5) == 0) {
opts.state_count_cutoff1 = 0.0;
opts.state_count_cutoff2plus = 0.0;
}
LanguageModelEstimator estimator(opts, vocab_size);
for (size_t i = 0; i < data.size(); i++) {
std::vector<int32> &sentence = data[i];
estimator.AddCounts(sentence);
}
estimator.Discount();
LanguageModel lm;
estimator.Output(&lm);
KALDI_LOG << "For order " << order << ", cutoffs "
<< opts.state_count_cutoff1 << ","
<< opts.state_count_cutoff2plus << ", perplexity is "
<< ComputePerplexity(lm, validation_data) << "[valid]"
<< " and " << ComputePerplexity(lm, data) << "[train].";
TestNormalization(lm);
TestLmHistoryStateMap(lm);
}
} // namespace ctc
} // namespace kaldi
int main() {
for (int32 i = 0; i < 30; i++)
kaldi::ctc::LanguageModelTest();
}
This diff is collapsed.
// ctc/language-model.h
// Copyright 2015 Johns Hopkins University (Author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_CTC_LANGUAGE_MODEL_H_
#define KALDI_CTC_LANGUAGE_MODEL_H_
#include <vector>
#include <map>
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "fstext/fstext-lib.h"
#include "lat/kaldi-lattice.h"
namespace kaldi {
// CTC means Connectionist Temporal Classification, see the paper by Graves et
// al.
//
// What we are implementing is an extension of CTC that we're calling
// context-dependent CTC (CCTC). It requires the estimation of an n-gram
// language model on phones.
//
// This header implements a language-model class that's suitable for this
// phone-level model. The implementation is efficient for the case where
// the number of symbols is quite small, e.g. not more than a few hundred,
// so tabulating probabilities makes sense.
// We don't put too much effort in making this the best possible language
// model and adding bells and whistles.
// We're implementing the count cutoffs in the way that we feel makes the
// most sense; it probably won't exactly match the way it's done in, say,
// SRILM. And the way the discounting is done is also not quite the same
// it's done in the original Kneser-Ney publication, as we do it using
// continuous rather discrete counts; and we fix the constants rather
// than estimating them.
// this is Kneser-Ney "with addition" instead of backoff (see A Bit of Progress
// in Language Modeling).
namespace ctc {
struct LanguageModelOptions {
int32 ngram_order;
int32 state_count_cutoff1;
int32 state_count_cutoff2plus;
BaseFloat discount1; // Discounting factor for singletons in Kneser-Ney type
// scheme
BaseFloat discount2plus; // Discounting factor for things with >1 count in
// Kneser-Ney type scheme.
LanguageModelOptions():
ngram_order(3), // recommend only 1 or 2 or 3.
state_count_cutoff1(0),
state_count_cutoff2plus(200), // count cutoff for n-grams of order >= 3 (if used)
discount1(0.8),
discount2plus(1.3) { }
void Register(OptionsItf *opts) {
opts->Register("ngram-order", &ngram_order, "n-gram order for the phone "
"language model used while training the CTC model");
opts->Register("state-count-cutoff1", &state_count_cutoff1,
"Count cutoff for language-model history states of order 1 "
"(meaning one left word is known, i.e. bigram states)");
opts->Register("state-count-cutoff2plus",
&state_count_cutoff2plus,
"Count cutoff for language-model history states of order >= 2 ");
opts->Register("discount1", &discount1, "Discount constant for 1-counts");
opts->Register("discount2plus", &discount2plus,
"Discount constant for 2-counts or greater");
}
};
class LanguageModel {
public:
LanguageModel(): vocab_size_(0), ngram_order_(0) { }
int32 NgramOrder() const { return ngram_order_; }
// Note: phone indexes are 1-based, so they range from 1 to vocab_size_.
// with 0 for BOS and EOS.
int32 VocabSize() const { return vocab_size_; }
// Get the language-model probability [not log-prob] for this history-plus-phone.
// zeros in the non-final position are interpreted as <s>.
// zeros in the final position are interpreted as </s>.
BaseFloat GetProb(const std::vector<int32> &ngram) const;
void Write(std::ostream &os, bool binary) const;
void Read(std::istream &is, bool binary);
protected:
friend class LanguageModelEstimator;
friend class LmHistoryStateMap;
typedef unordered_map<std::vector<int32>, BaseFloat,
VectorHasher<int32> > MapType;
typedef unordered_map<std::vector<int32>, std::pair<BaseFloat,BaseFloat>, VectorHasher<int32> > PairMapType;
int32 vocab_size_;
int32 ngram_order_;
// map from n-grams of the highest order to probabilities.
MapType highest_order_probs_;
// map from all other ngrams to (n-gram probability, history-state backoff
// weight). Note: history-state backoff weights will be 1.0 for history-states
// that don't exist. If a history-state exists, a direct n-gram probability
// must exist. Note: for normal n-grams we don't do any kind of pruning that could remove it;
// and for the history [0] (for BOS) the predicted n-gram [0] (for EOS) will always
// exist.
PairMapType other_probs_;
};
// Computes the perplexity of the language model on the sentences provided;
// they should not contain zeros (we'll add the BOS/EOS things internally).
BaseFloat ComputePerplexity(const LanguageModel &lm,
std::vector<std::vector<int32> > &sentences);
// This class allows you to map a language model history to an integer id which,
// with the predicted word, is sufficient to work out the probability. It's
// useful in the CCTC code. Because this isn't something that would normally
// appear in the interface of a language model, we make it a separate class. Lm
// stands for "language model". Because the term lm_history_state is used a lot
// in the CCTC code and also history_state exists there and means something
// different, we felt it was necessary to include "lm" in these names.
class LmHistoryStateMap {
public:
// Returns the number of history states. A history state is a zero-based
// index, so they go from 0 to NumHistoryStates() - 1.
// these will
int32 NumLmHistoryStates() const { return lm_history_states_.size(); }
const std::vector<int32>& GetHistoryForState(int32 lm_history_state) const;
BaseFloat GetProb(const LanguageModel &lm, int32 lm_history_state,
int32 predicted_word) const;
// Maps a history to an integer lm-history-state.
int32 GetLmHistoryState(std::vector<int32> &hist) const;
// Initialize the history states.
void Init(const LanguageModel &lm);
private:
typedef unordered_map<std::vector<int32>, int32, VectorHasher<int32> > IntMapType;
std::vector<std::vector<int32> > lm_history_states_;
IntMapType history_to_state_;
};
class LanguageModelEstimator {
public:
LanguageModelEstimator(const LanguageModelOptions &opts,
int32 vocab_size);
// Adds counts for this sentence. Basically does: for each n-gram,
// count[n-gram] += 1.
void AddCounts(std::vector<int32> &sentence);
// Does the discounting.
void Discount();
// outputs to the LM.
void Output(LanguageModel *lm) const;
private:
// Returns the probability for this word given this history; used inside
// Output(). This includes not just the direct prob, but the additional prob
// that comes via backoff, since this is Kneser-Ney "with addition". (this is
// what we need to store in the otuput language model).
BaseFloat GetProb(const std::vector<int32> &ngram) const;
// Gets the backoff probability for this state, i.e. the
// probability mass assigned to backoff.
BaseFloat GetBackoffProb(std::vector<int32> &hist) const;
typedef unordered_map<std::vector<int32>, BaseFloat, VectorHasher<int32> > MapType;
typedef unordered_map<std::vector<int32>, std::pair<BaseFloat, BaseFloat>, VectorHasher<int32> > PairMapType;
typedef unordered_set<std::vector<int32>, VectorHasher<int32> > SetType;
// applies discounting or to the counts for all stored n-grams of this order.
// If order >= 2 we apply this continuous Kneser-Ney-like discounting; if
// order == 1 we apply add-one smoothing.
void DiscountForOrder(int32 order);
// order must be >= 1 and < ngram_order. This function finds all
// history-states of order 'order' (i.e. containing 'order' words in the
// history-state), such that (the total count for that history-state is less
// than min_count, and the history-state is not listed in 'protected_states'),
// and it completely discounts all n-grams in those history-states, adding
// their count to the backoff state. We apply pruning at the level of
// history-states because this is the level at which added cost is incurred.
// For history-states which were not removed by this procedure, this function
// computes their backoff state by removing the first phone, and adds it to
// "protected_backoff_states". This will protect
// that backoff state from pruning when and if we prune the one-smaller order.
void ApplyHistoryStateCountCutoffForOrder(
int32 order, BaseFloat min_count,
const SetType &protected_states,
SetType *protected_backoff_states);
// This function does, conceptually, counts_[vec] += count.
// It's called during training and during discounting.
inline void AddCountForNgram(const std::vector<int32> &vec, BaseFloat count);
// This function does, conceptually,
// history_state_counts_[hist] += (tot_count, discounted_count).
// It's called during discounting.
inline void AddCountsForHistoryState(const std::vector<int32> &hist,
BaseFloat tot_count,
BaseFloat discounted_count);
// This function, conceptually, returns counts_[vec] (or 0 if not there).
inline BaseFloat GetCountForNgram(const std::vector<int32> &vec) const;
// This function, conceptually, returns history_state_counts_[vec] (or (0,0) if
// not there). It represents (total-count, discounted-count).
inline std::pair<BaseFloat,BaseFloat> GetCountsForHistoryState(
const std::vector<int32> &vec) const;
// Returns a discounting-amount for this count. The amount returned will be
// between zero and discount2plus. It's the amount we are to subtract
// from the count while discounting. This simple implementation doesn't
// allow us to have different discounting amounts for different n-gram
// orders. Because the count is continous, the discount1 and discount2plus
// value are interpreted as values to interpolate to when non-integer
// counts are provided.
BaseFloat GetDiscountAmount(BaseFloat count) const;
// Outputs into "backoff_vec", which must be initially empty, all but the
// first element of "vec".
inline static void RemoveFront(const std::vector<int32> &vec,
std::vector<int32> *backoff_vec);
// This function does (*map)[vec] += count, while
// ensuring that it does the right thing if the vector wasn't
// a key of the map to start with.
inline static void AddToMap(const std::vector<int32> &vec,
BaseFloat count,
MapType *map);
inline static void AddPairToMap(const std::vector<int32> &vec,
BaseFloat count1, BaseFloat count2,
PairMapType *map);
// data members:
const LanguageModelOptions &opts_;
// the allowed words go from 1 to vocab_size_. 0 is reserved for
// epsilon.
int32 vocab_size_;
// counts_ stores the raw counts. We don't make much attempt at
// memory efficiency here since a phone-level language model is quite a small
// thing. It's indexed first by the n-gram order, then by the ngram itself;
// this makes things easier when iterating in the discounting code.
std::vector<MapType> counts_;
// stores a map from a history-state to a pair (the total count for this
// history-state; the count that has been removed from this history-state via
// backoff). Indexed first by the order of the history state (which equals
// the vector length).
std::vector<PairMapType> history_state_counts_;
};
} // namespace ctc
} // namespace kaldi
#endif
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment