Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
K
kaldi-commonvoice
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Simon Will
kaldi-commonvoice
Commits
e1dd41de
Commit
e1dd41de
authored
8 years ago
by
Daniel Povey
Browse files
Options
Downloads
Patches
Plain Diff
Add file accidentally omitted from commit 86417db, nnet3-latgen-faster-parallel.cc
parent
d85a1100
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
src/nnet3bin/nnet3-latgen-faster-parallel.cc
+269
-0
269 additions, 0 deletions
src/nnet3bin/nnet3-latgen-faster-parallel.cc
with
269 additions
and
0 deletions
src/nnet3bin/nnet3-latgen-faster-parallel.cc
0 → 100644
+
269
−
0
View file @
e1dd41de
// nnet3bin/nnet3-latgen-faster-parallel.cc
// Copyright 2012-2016 Johns Hopkins University (author: Daniel Povey)
// 2014 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include
"base/timer.h"
#include
"base/kaldi-common.h"
#include
"decoder/decoder-wrappers.h"
#include
"fstext/fstext-lib.h"
#include
"hmm/transition-model.h"
#include
"nnet3/nnet-am-decodable-simple.h"
#include
"thread/kaldi-task-sequence.h"
#include
"tree/context-dep.h"
#include
"util/common-utils.h"
int
main
(
int
argc
,
char
*
argv
[])
{
// note: making this program work with GPUs is as simple as initializing the
// device, but it probably won't make a huge difference in speed for typical
// setups.
try
{
using
namespace
kaldi
;
using
namespace
kaldi
::
nnet3
;
typedef
kaldi
::
int32
int32
;
using
fst
::
SymbolTable
;
using
fst
::
VectorFst
;
using
fst
::
StdArc
;
const
char
*
usage
=
"Generate lattices using nnet3 neural net model.
\n
"
"Usage: nnet3-latgen-faster-parallel [options] <nnet-in> <fst-in|fsts-rspecifier> <features-rspecifier>"
" <lattice-wspecifier> [ <words-wspecifier> [<alignments-wspecifier>] ]
\n
"
;
ParseOptions
po
(
usage
);
Timer
timer
;
bool
allow_partial
=
false
;
TaskSequencerConfig
sequencer_config
;
// has --num-threads option
LatticeFasterDecoderConfig
config
;
NnetSimpleComputationOptions
decodable_opts
;
std
::
string
word_syms_filename
;
std
::
string
ivector_rspecifier
,
online_ivector_rspecifier
,
utt2spk_rspecifier
;
int32
online_ivector_period
=
0
;
sequencer_config
.
Register
(
&
po
);
config
.
Register
(
&
po
);
decodable_opts
.
Register
(
&
po
);
po
.
Register
(
"word-symbol-table"
,
&
word_syms_filename
,
"Symbol table for words [for debug output]"
);
po
.
Register
(
"allow-partial"
,
&
allow_partial
,
"If true, produce output even if end state was not reached."
);
po
.
Register
(
"ivectors"
,
&
ivector_rspecifier
,
"Rspecifier for "
"iVectors as vectors (i.e. not estimated online); per utterance "
"by default, or per speaker if you provide the --utt2spk option."
);
po
.
Register
(
"online-ivectors"
,
&
online_ivector_rspecifier
,
"Rspecifier for "
"iVectors estimated online, as matrices. If you supply this,"
" you must set the --online-ivector-period option."
);
po
.
Register
(
"online-ivector-period"
,
&
online_ivector_period
,
"Number of frames "
"between iVectors in matrices supplied to the --online-ivectors "
"option"
);
po
.
Read
(
argc
,
argv
);
if
(
po
.
NumArgs
()
<
4
||
po
.
NumArgs
()
>
6
)
{
po
.
PrintUsage
();
exit
(
1
);
}
std
::
string
model_in_filename
=
po
.
GetArg
(
1
),
fst_in_str
=
po
.
GetArg
(
2
),
feature_rspecifier
=
po
.
GetArg
(
3
),
lattice_wspecifier
=
po
.
GetArg
(
4
),
words_wspecifier
=
po
.
GetOptArg
(
5
),
alignment_wspecifier
=
po
.
GetOptArg
(
6
);
TaskSequencer
<
DecodeUtteranceLatticeFasterClass
>
sequencer
(
sequencer_config
);
TransitionModel
trans_model
;
AmNnetSimple
am_nnet
;
{
bool
binary
;
Input
ki
(
model_in_filename
,
&
binary
);
trans_model
.
Read
(
ki
.
Stream
(),
binary
);
am_nnet
.
Read
(
ki
.
Stream
(),
binary
);
}
bool
determinize
=
config
.
determinize_lattice
;
CompactLatticeWriter
compact_lattice_writer
;
LatticeWriter
lattice_writer
;
if
(
!
(
determinize
?
compact_lattice_writer
.
Open
(
lattice_wspecifier
)
:
lattice_writer
.
Open
(
lattice_wspecifier
)))
KALDI_ERR
<<
"Could not open table for writing lattices: "
<<
lattice_wspecifier
;
RandomAccessBaseFloatMatrixReader
online_ivector_reader
(
online_ivector_rspecifier
);
RandomAccessBaseFloatVectorReaderMapped
ivector_reader
(
ivector_rspecifier
,
utt2spk_rspecifier
);
Int32VectorWriter
words_writer
(
words_wspecifier
);
Int32VectorWriter
alignment_writer
(
alignment_wspecifier
);
fst
::
SymbolTable
*
word_syms
=
NULL
;
if
(
word_syms_filename
!=
""
)
if
(
!
(
word_syms
=
fst
::
SymbolTable
::
ReadText
(
word_syms_filename
)))
KALDI_ERR
<<
"Could not read symbol table from file "
<<
word_syms_filename
;
double
tot_like
=
0.0
;
kaldi
::
int64
frame_count
=
0
;
int
num_success
=
0
,
num_fail
=
0
;
if
(
ClassifyRspecifier
(
fst_in_str
,
NULL
,
NULL
)
==
kNoRspecifier
)
{
SequentialBaseFloatMatrixReader
feature_reader
(
feature_rspecifier
);
// Input FST is just one FST, not a table of FSTs.
VectorFst
<
StdArc
>
*
decode_fst
=
fst
::
ReadFstKaldi
(
fst_in_str
);
{
LatticeFasterDecoder
decoder
(
*
decode_fst
,
config
);
for
(;
!
feature_reader
.
Done
();
feature_reader
.
Next
())
{
std
::
string
utt
=
feature_reader
.
Key
();
const
Matrix
<
BaseFloat
>
&
features
(
feature_reader
.
Value
());
if
(
features
.
NumRows
()
==
0
)
{
KALDI_WARN
<<
"Zero-length utterance: "
<<
utt
;
num_fail
++
;
continue
;
}
const
Matrix
<
BaseFloat
>
*
online_ivectors
=
NULL
;
const
Vector
<
BaseFloat
>
*
ivector
=
NULL
;
if
(
!
ivector_rspecifier
.
empty
())
{
if
(
!
ivector_reader
.
HasKey
(
utt
))
{
KALDI_WARN
<<
"No iVector available for utterance "
<<
utt
;
num_fail
++
;
continue
;
}
else
{
ivector
=
&
ivector_reader
.
Value
(
utt
);
}
}
if
(
!
online_ivector_rspecifier
.
empty
())
{
if
(
!
online_ivector_reader
.
HasKey
(
utt
))
{
KALDI_WARN
<<
"No online iVector available for utterance "
<<
utt
;
num_fail
++
;
continue
;
}
else
{
online_ivectors
=
&
online_ivector_reader
.
Value
(
utt
);
}
}
LatticeFasterDecoder
*
decoder
=
new
LatticeFasterDecoder
(
*
decode_fst
,
config
);
DecodableInterface
*
nnet_decodable
=
new
DecodableAmNnetSimpleParallel
(
decodable_opts
,
trans_model
,
am_nnet
,
features
,
ivector
,
online_ivectors
,
online_ivector_period
);
DecodeUtteranceLatticeFasterClass
*
task
=
new
DecodeUtteranceLatticeFasterClass
(
decoder
,
nnet_decodable
,
// takes ownership of these two.
trans_model
,
word_syms
,
utt
,
decodable_opts
.
acoustic_scale
,
determinize
,
allow_partial
,
&
alignment_writer
,
&
words_writer
,
&
compact_lattice_writer
,
&
lattice_writer
,
&
tot_like
,
&
frame_count
,
&
num_success
,
&
num_fail
,
NULL
);
sequencer
.
Run
(
task
);
// takes ownership of "task",
// and will delete it when done.
}
}
sequencer
.
Wait
();
// Waits for all tasks to be done.
delete
decode_fst
;
}
else
{
// We have different FSTs for different utterances.
SequentialTableReader
<
fst
::
VectorFstHolder
>
fst_reader
(
fst_in_str
);
RandomAccessBaseFloatMatrixReader
feature_reader
(
feature_rspecifier
);
for
(;
!
fst_reader
.
Done
();
fst_reader
.
Next
())
{
std
::
string
utt
=
fst_reader
.
Key
();
if
(
!
feature_reader
.
HasKey
(
utt
))
{
KALDI_WARN
<<
"Not decoding utterance "
<<
utt
<<
" because no features available."
;
num_fail
++
;
continue
;
}
const
Matrix
<
BaseFloat
>
&
features
=
feature_reader
.
Value
(
utt
);
if
(
features
.
NumRows
()
==
0
)
{
KALDI_WARN
<<
"Zero-length utterance: "
<<
utt
;
num_fail
++
;
continue
;
}
const
Matrix
<
BaseFloat
>
*
online_ivectors
=
NULL
;
const
Vector
<
BaseFloat
>
*
ivector
=
NULL
;
if
(
!
ivector_rspecifier
.
empty
())
{
if
(
!
ivector_reader
.
HasKey
(
utt
))
{
KALDI_WARN
<<
"No iVector available for utterance "
<<
utt
;
num_fail
++
;
continue
;
}
else
{
ivector
=
&
ivector_reader
.
Value
(
utt
);
}
}
if
(
!
online_ivector_rspecifier
.
empty
())
{
if
(
!
online_ivector_reader
.
HasKey
(
utt
))
{
KALDI_WARN
<<
"No online iVector available for utterance "
<<
utt
;
num_fail
++
;
continue
;
}
else
{
online_ivectors
=
&
online_ivector_reader
.
Value
(
utt
);
}
}
LatticeFasterDecoder
*
decoder
=
new
LatticeFasterDecoder
(
fst_reader
.
Value
(),
config
);
DecodableInterface
*
nnet_decodable
=
new
DecodableAmNnetSimpleParallel
(
decodable_opts
,
trans_model
,
am_nnet
,
features
,
ivector
,
online_ivectors
,
online_ivector_period
);
DecodeUtteranceLatticeFasterClass
*
task
=
new
DecodeUtteranceLatticeFasterClass
(
decoder
,
nnet_decodable
,
// takes ownership of these two.
trans_model
,
word_syms
,
utt
,
decodable_opts
.
acoustic_scale
,
determinize
,
allow_partial
,
&
alignment_writer
,
&
words_writer
,
&
compact_lattice_writer
,
&
lattice_writer
,
&
tot_like
,
&
frame_count
,
&
num_success
,
&
num_fail
,
NULL
);
sequencer
.
Run
(
task
);
// takes ownership of "task",
// and will delete it when done.
}
sequencer
.
Wait
();
// Waits for all tasks to be done.
}
double
elapsed
=
timer
.
Elapsed
();
KALDI_LOG
<<
"Time taken "
<<
elapsed
<<
"s: real-time factor assuming 100 frames/sec is "
<<
(
elapsed
*
100.0
/
frame_count
);
KALDI_LOG
<<
"Done "
<<
num_success
<<
" utterances, failed for "
<<
num_fail
;
KALDI_LOG
<<
"Overall log-likelihood per frame is "
<<
(
tot_like
/
frame_count
)
<<
" over "
<<
frame_count
<<
" frames."
;
delete
word_syms
;
if
(
num_success
!=
0
)
return
0
;
else
return
1
;
}
catch
(
const
std
::
exception
&
e
)
{
std
::
cerr
<<
e
.
what
();
return
-
1
;
}
}
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment