Skip to content
Snippets Groups Projects
Commit 0628d2db authored by Daniel Povey's avatar Daniel Povey
Browse files

Merge pull request #149 from vijayaditya/nnet3_unittests

Added --scale option to nnet3-am-copy, useful for shrinkage of parame…
parents 279cd192 24833034
No related branches found
No related tags found
No related merge requests found
......@@ -33,7 +33,9 @@ num_epochs=5
# training options
initial_effective_lrate=0.0003
final_effective_lrate=0.00003
shrink=0.0
num_chunk_per_minibatch=100
num_bptt_steps=20
samples_per_iter=20000
remove_egs=true
# End configuration section.
......@@ -95,6 +97,7 @@ if [ $stage -le 8 ]; then
--online-ivector-dir exp/$mic/nnet3/ivectors_${train_set}_hires \
--cmvn-opts "--norm-means=false --norm-vars=false" \
--initial-effective-lrate $initial_effective_lrate --final-effective-lrate $final_effective_lrate \
--shrink $shrink \
--cmd "$decode_cmd" \
--num-lstm-layers $num_lstm_layers \
--cell-dim $cell_dim \
......@@ -104,6 +107,7 @@ if [ $stage -le 8 ]; then
--non-recurrent-projection-dim $non_recurrent_projection_dim \
--chunk-width $chunk_width \
--chunk-left-context $chunk_left_context \
--num-bptt-steps $num_bptt_steps \
--norm-based-clipping $norm_based_clipping \
--ng-per-element-scale-options "$ng_per_element_scale_options" \
--ng-affine-options "$ng_affine_options" \
......@@ -112,7 +116,7 @@ if [ $stage -le 8 ]; then
data/$mic/${train_set}_hires data/lang $ali_dir $dir || exit 1;
fi
if [ $stage -le 8 ]; then
if [ $stage -le 9 ]; then
# this version of the decoding treats each utterance separately
# without carrying forward speaker information.
for decode_set in dev eval; do
......
......@@ -18,6 +18,7 @@ num_epochs=5 # Number of epochs of training;
# the number of iterations is worked out from this.
initial_effective_lrate=0.0003
final_effective_lrate=0.00003
shrink=0.0 # if non-zero this parameter would be used to scale the parameter matrices
rand_prune=4.0 # Relates to a speedup we do for LDA.
num_chunk_per_minibatch=100 # number of sequences to be processed in parallel every mini-batch
......@@ -127,6 +128,7 @@ if [ $# != 4 ]; then
echo " # the pre-softmax outputs (set to 0.0 to disable the presoftmax element scale)"
echo " --num-jobs-initial <num-jobs|1> # Number of parallel jobs to use for neural net training, at the start."
echo " --num-jobs-final <num-jobs|8> # Number of parallel jobs to use for neural net training, at the end"
echo " --shrink <shrink|0.0> # if non-zero this parameter will be used to scale the parameter matrices"
echo " --num-threads <num-threads|16> # Number of parallel threads per job, for CPU-based training (will affect"
echo " # results as well as speed; may interact with batch size; if you increase"
echo " # this, you may want to decrease the batch size."
......@@ -581,7 +583,7 @@ while [ $x -lt $num_iters ]; do
# average the output of the different jobs.
$cmd $dir/log/average.$x.log \
nnet3-average $nnets_list - \| \
nnet3-am-copy --set-raw-nnet=- $dir/$x.mdl $dir/$[$x+1].mdl || exit 1;
nnet3-am-copy --scale=$shrink --set-raw-nnet=- $dir/$x.mdl $dir/$[$x+1].mdl || exit 1;
else
# choose the best from the different jobs.
n=$(perl -e '($nj,$pat)=@ARGV; $best_n=1; $best_logprob=-1.0e+10; for ($n=1;$n<=$nj;$n++) {
......@@ -591,7 +593,7 @@ while [ $x -lt $num_iters ]; do
$best_n=$n; } } print "$best_n\n"; ' $this_num_jobs $dir/log/train.$x.%d.log) || exit 1;
[ -z "$n" ] && echo "Error getting best model" && exit 1;
$cmd $dir/log/select.$x.log \
nnet3-am-copy --set-raw-nnet=$dir/$[$x+1].$n.raw $dir/$x.mdl $dir/$[$x+1].mdl || exit 1;
nnet3-am-copy --scale=$shrink --set-raw-nnet=$dir/$[$x+1].$n.raw $dir/$x.mdl $dir/$[$x+1].mdl || exit 1;
fi
rm $nnets_list
......
......@@ -48,7 +48,8 @@ int main(int argc, char *argv[]) {
raw = false;
BaseFloat learning_rate = -1;
std::string set_raw_nnet = "";
BaseFloat scale = 0.0;
ParseOptions po(usage);
po.Register("binary", &binary_write, "Write output in binary mode");
po.Register("raw", &raw, "If true, write only 'raw' neural net "
......@@ -59,7 +60,9 @@ int main(int argc, char *argv[]) {
"before the learning-rate is changed.");
po.Register("learning-rate", &learning_rate,
"If supplied, all the learning rates of updatable components"
"are set to this value.");
" are set to this value.");
po.Register("scale", &scale, "If non-zero the parameter matrices are scaled"
" by the specified value.");
po.Read(argc, argv);
......@@ -90,6 +93,8 @@ int main(int argc, char *argv[]) {
if (learning_rate >= 0)
SetLearningRate(learning_rate, &(am_nnet.GetNnet()));
if (scale != 0.0)
ScaleNnet(scale, &(am_nnet.GetNnet()));
if (raw) {
WriteKaldiObject(am_nnet.GetNnet(), nnet_wxfilename, binary_write);
......
......@@ -79,7 +79,7 @@ int main(int argc, char *argv[]) {
for (; !example_reader.Done(); example_reader.Next())
trainer.Train(example_reader.Value());
bool ok = trainer.PrintTotalStats();
ok = trainer.PrintTotalStats();
// need trainer's destructor to be called before we write model.
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment