Skip to content
Snippets Groups Projects
Commit 6e6d28b8 authored by Karel Vesely's avatar Karel Vesely
Browse files

add the expand copy nnet trasforms, option randomize to nnet trainer


git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@1197 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
parent c10fb02b
No related branches found
No related tags found
No related merge requests found
......@@ -21,6 +21,7 @@
#include "nnet/nnet-activation.h"
#include "nnet/nnet-biasedlinearity.h"
#include "nnet/nnet-rbm.h"
#include "nnet/nnet-various.h"
namespace kaldi {
......@@ -83,6 +84,12 @@ Component* Component::Read(std::istream &is, bool binary, Nnet *nnet) {
case Component::kRbm :
p_comp = new Rbm(dim_in, dim_out, nnet);
break;
case Component::kExpand :
p_comp = new Expand(dim_in, dim_out, nnet);
break;
case Component::kCopy :
p_comp = new Copy(dim_in, dim_out, nnet);
break;
case Component::kUnknown :
default :
KALDI_ERR << "Missing type: " << token;
......
......@@ -57,7 +57,7 @@ class Expand : public Component {
frame_offsets_.CopyFromVec(vec_i);
}
void WriteData(std::ostream &os, bool binary) {
void WriteData(std::ostream &os, bool binary) const {
std::vector<int32> vec_i;
frame_offsets_.CopyToVec(&vec_i);
Vector<double> vec_d(vec_i.size());
......@@ -111,7 +111,7 @@ class Copy : public Component {
copy_from_indices_.CopyFromVec(vec_i);
}
void WriteData(std::ostream &os, bool binary) {
void WriteData(std::ostream &os, bool binary) const {
std::vector<int32> vec_i;
copy_from_indices_.CopyToVec(&vec_i);
Vector<double> vec_d(vec_i.size());
......
......@@ -35,9 +35,11 @@ int main(int argc, char *argv[]) {
ParseOptions po(usage);
bool binary = false,
crossvalidate = false;
crossvalidate = false,
randomize = true;
po.Register("binary", &binary, "Write output in binary mode");
po.Register("cross-validate", &crossvalidate, "Perform cross-validation (don't backpropagate)");
po.Register("randomize", &randomize, "Perform the frame-level shuffling within the Cache::");
BaseFloat learn_rate = 0.008,
momentum = 0.0,
......@@ -139,7 +141,7 @@ int main(int argc, char *argv[]) {
time_next += t_features.Elapsed();
}
// randomize
if (!crossvalidate) {
if (!crossvalidate && randomize) {
cache.Randomize();
}
// report
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment