Skip to content
Snippets Groups Projects
Commit 7c7a1767 authored by Justin Luitjens's avatar Justin Luitjens Committed by Daniel Povey
Browse files

[src] Implemented CUDA acclerated online cmvn. (#3370)

This patch is part of a larger effort to implement the entire online feature pipeline in CUDA so that wav data is transfered to the device and never copied back to the host.
This patch includes a new binary cudafeatbin/apply-cmvn-online.cc which for the most part matches online2bin/apply-cmvn-online.
This binary is primarily for correctness testing and debugging as it makes no effort to compute multiple features in parallel on the device.
The CUDA performance is dominiated by the cost of copying the feature to and from the device. While there is a small speedup I do not expect this binary to be used in production.
Instead users will use the upcomming online-pipeline which will take features directly from the mfcc computation on the device and pass results to the next part of the pipeline.

Summary of changes:

Makefile:
   Added online2 dependencies to cudafeat, cudafeatbin, cudadecoder, and cudadecoderbin.
cudafeat\:
   Makefile:  added online2 dependency, added new .cu/.h files
   feature-online-cmvn-cuda.cu/h:  implements online-cmvn in cuda.
cudafeatbin\:
   Makefile:  added new binary, added online2 dependency
   apply-cmvn-online-cuda.cc:  binary which mimics online2bin/apply-cmvn-online

Correctness testing:

The correctness was tested by generating set of 20000 features and then running the CPU binary and GPU binary and comparing results using featbin/compare-feats.

../online2bin/apply-cmvn-online /workspace/models/LibriSpeech/ivector_extractor/global_cmvn.stats "scp:mfcc.scp" "ark,scp:cmvn.ark,cmvn.scp"
./apply-cmvn-online-cuda /workspace/models/LibriSpeech/ivector_extractor/global_cmvn.stats "scp:mfcc.scp" "ark,scp:cmvn-cuda.ark,cmvn-cuda.scp"

../featbin/compare-feats ark:cmvn-cuda.ark ark:cmvn.ark
LOG (compare-feats[5.5.1301~3-17818]:main():compare-feats.cc:105) self-product of 1st features for each column dimension:  [ 5.52221e+09 9.1134e+09 5.92818e+09 7.42173e+09 7.48633e+09 7.21316e+09 6.9515e+09 7.03883e+09 6.40267e+09 5.83088e+09 5.01438e+09 5.1575e+09 4.28688e+09 3.529e+09 3.12182e+09 2.28721e+09 1.76343e+09 1.35117e+09 8.72517e+08 5.31836e+08 2.65112e+08 9.20308e+07 1.24084e+07 3.56008e+06 4.25283e+07 1.09786e+08 1.88937e+08 2.60207e+08 3.23115e+08 3.56371e+08 3.69035e+08 3.65216e+08 3.89125e+08 4.07064e+08 3.40407e+08 2.65444e+08 2.50244e+08 2.05726e+08 1.60606e+08 1.07217e+08 ]

LOG (compare-feats[5.5.1301~3-17818]:main():compare-feats.cc:106) self-product of 2nd features for each column dimension:  [ 5.5223e+09 9.11355e+09 5.92812e+09 7.4218e+09 7.48666e+09 7.21338e+09 6.95174e+09 7.03895e+09 6.40254e+09 5.83113e+09 5.01411e+09 5.15774e+09 4.28692e+09 3.52918e+09 3.122e+09 2.28693e+09 1.76326e+09 1.3513e+09 8.72521e+08 5.31802e+08 2.65137e+08 9.20296e+07 1.2408e+07 3.5604e+06 4.25301e+07 1.09793e+08 1.88933e+08 2.60217e+08 3.23124e+08 3.56371e+08 3.69007e+08 3.65176e+08 3.89104e+08 4.07067e+08 3.40416e+08 2.65498e+08 2.50196e+08 2.057e+08 1.60612e+08 1.07192e+08 ]

LOG (compare-feats[5.5.1301~3-17818]:main():compare-feats.cc:107) cross-product for each column dimension:  [ 5.52209e+09 9.11229e+09 5.92538e+09 7.41665e+09 7.47877e+09 7.20269e+09 6.93785e+09 7.02284e+09 6.38411e+09 5.81143e+09 4.99389e+09 5.13753e+09 4.26792e+09 3.51154e+09 3.10676e+09 2.27436e+09 1.75322e+09 1.34367e+09 8.67367e+08 5.28672e+08 2.63516e+08 9.14194e+07 1.23215e+07 3.53409e+06 4.21905e+07 1.08872e+08 1.87238e+08 2.57779e+08 3.19827e+08 3.5252e+08 3.64691e+08 3.60529e+08 3.84482e+08 4.02396e+08 3.36136e+08 2.61631e+08 2.46931e+08 2.03079e+08 1.5856e+08 1.05738e+08 ]

LOG (compare-feats[5.5.1301~3-17818]:main():compare-feats.cc:111) Similarity metric for each dimension  [ 0.99997 0.999871 0.999532 0.999311 0.998968 0.998533 0.998019 0.997719 0.997111 0.996644 0.995941 0.996104 0.995572 0.995028 0.995147 0.994445 0.994258 0.994402 0.994095 0.994084 0.993934 0.993363 0.993015 0.992655 0.992037 0.991645 0.991017 0.990649 0.98981 0.989195 0.988267 0.987222 0.988093 0.98853 0.987442 0.985534 0.986858 0.987196 0.987242 0.986318 ]
 (1.0 means identical, the smaller the more different)
LOG (compare-feats[5.5.1301~3-17818]:main():compare-feats.cc:116) Overall similarity for the two feats is:0.993119 (1.0 means identical, the smaller the more different)
LOG (compare-feats[5.5.1301~3-17818]:main():compare-feats.cc:119) Processed 20960 feature files, 0 had errors.
LOG (compare-feats[5.5.1301~3-17818]:main():compare-feats.cc:126) Features are considered similar since 0.993119 >= 0.99
parent 63c54e2d
No related branches found
No related tags found
No related merge requests found
......@@ -171,9 +171,10 @@ ivector: base util matrix transform tree gmm
#3)Dependencies for optional parts of Kaldi
onlinebin: base matrix util feat tree gmm transform sgmm2 fstext hmm lm decoder lat cudamatrix nnet nnet2 online
# python-kaldi-decoding: base matrix util feat tree gmm transform sgmm2 fstext hmm decoder lat online
cudafeat: base matrix util gmm transform tree feat cudamatrix
cudafeat: base matrix util gmm transform tree feat cudamatrix online2
cudafeatbin: base matrix util gmm transform tree feat cudamatrix cudafeat online2
online: decoder gmm transform feat matrix util base lat hmm tree
online2: decoder gmm transform feat matrix util base lat hmm tree ivector cudamatrix nnet2 nnet3 chain
kws: base util hmm tree matrix lat
cudadecoder: cudamatrix online2 nnet3 ivector feat fstext lat chain transform
cudadecoderbin: cudadecoder cudamatrix online2 nnet3 ivector feat fstext lat chain transform
cudadecoder: cudamatrix cudafeat online2 nnet3 ivector feat fstext lat chain transform
cudadecoderbin: cudadecoder cudafeat cudamatrix online2 nnet3 ivector feat fstext lat chain transform
......@@ -7,15 +7,15 @@ ifeq ($(CUDA), true)
TESTFILES =
OBJFILES += feature-window-cuda.o feature-mfcc-cuda.o
ifeq ($(CUDA), true)
OBJFILES += feature-window-cuda.o feature-mfcc-cuda.o feature-online-cmvn-cuda.o
endif
LIBNAME = kaldi-cudafeat
ADDLIBS = ../feat/kaldi-feat.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \
../base/kaldi-base.a ../cudamatrix/kaldi-cudamatrix.a \
../gmm/kaldi-gmm.a
../gmm/kaldi-gmm.a ../online2/kaldi-online2.a
LDFLAGS += $(CUDA_LDFLAGS)
LDLIBS += $(CUDA_LDLIBS)
......
// cudafeat/feature-online-cmvn-cuda.cu
//
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
// Justin Luitjens
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cub/cub.cuh>
#include "cudafeat/feature-online-cmvn-cuda.h"
#include "cudamatrix/cu-matrix.h"
#include "cudamatrix/cu-vector.h"
__device__ inline float2 operator-(const float2 &a, const float2 &b) {
float2 retval;
retval.x = a.x - b.x;
retval.y = a.y - b.y;
return retval;
}
__device__ inline float2 operator+(const float2 &a, const float2 &b) {
float2 retval;
retval.x = a.x + b.x;
retval.y = a.y + b.y;
return retval;
}
#if __CUDA_ARCH__ == 750
__launch_bounds__ (1024, 1)
#else
__launch_bounds__ (1024, 2)
#endif
__global__ void compute_cmvn_stats_kernel(const float *data, int32_t ldd,
int32_t num_frames, int32_t feat_dim,
float *stats, int32_t lds) {
typedef cub::BlockScan<float2, 1024> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;
int32_t feat = blockIdx.x;
float2 running_sum = {0.0f, 0.0f};
// for each frame, keep threads alive for cub
for (int32_t r = 0; r < num_frames; r += blockDim.x) {
int32_t rid = r + threadIdx.x;
float val = 0.0f;
if (rid < num_frames) {
// uncoalesced, could transpose data or do some shared memory swizzling...
val = data[rid * ldd + feat];
}
float2 sum = {val, val * val}; // this elements value and value squared
float2 psum; // row prefix sum
float2 total; // total count
BlockScan(temp_storage).InclusiveSum(sum, psum, total);
// offset by running sum
psum = psum + running_sum;
// increase running sum by new total
running_sum = running_sum + total;
// un-coalesced
if (rid < num_frames) {
reinterpret_cast<float2 *>(&stats[rid * lds])[feat] = psum;
}
}
}
__global__ void apply_cmvn_kernel(
int32_t cmvn_window, bool var_norm, bool mean_norm, const float *feat_in,
int32_t ldi, int32_t num_rows, int32_t num_cols,
const float *__restrict__ stats, int32_t lds,
const float *__restrict__ global_stats, int32_t ldg, int32_t global_frames,
const float *__restrict__ speaker_stats, int32_t ldss,
int32_t speaker_frames, float *feat_out, int32_t ldo) {
int32_t r = blockIdx.x;
for (int c = threadIdx.x; c < num_cols; c += blockDim.x) {
float2 frame_stats =
reinterpret_cast<const float2 __restrict__ *>(&stats[r * lds])[c];
float val = feat_in[r * ldi + c];
float window_length = min(r + 1, cmvn_window);
// we have to subtract row r-cmvn_window stats
if (r >= cmvn_window) {
// window starting row
int32_t o = r - cmvn_window;
// stats at the start row of the window that must be removed
float2 ostats =
reinterpret_cast<const float2 __restrict__ *>(&stats[o * lds])[c];
// remove start of the window stats
frame_stats = frame_stats - ostats;
}
// Smooth stats by speaker frames if necessary
float smooth_frames = cmvn_window - window_length;
if (smooth_frames > 0 && speaker_frames > 0) {
float count_from_speaker = min(smooth_frames, (float)speaker_frames);
float speaker_count = speaker_stats[num_cols];
if (count_from_speaker > 0.0) {
float alpha = count_from_speaker / speaker_count;
frame_stats.x += alpha * speaker_stats[c]; // update mean
frame_stats.y += alpha * speaker_stats[ldss + c]; // update variance
window_length += alpha * speaker_count; // update window length
// recompute smooth frames now that we have speaker stats
smooth_frames = cmvn_window - window_length;
}
}
// Smooth stats by global frames if necessary
if (smooth_frames > 0 && global_frames > 0) {
float count_from_global = min(smooth_frames, (float)global_frames);
float global_count = global_stats[num_cols];
if (count_from_global > 0.0) {
float alpha = count_from_global / global_count;
frame_stats.x += alpha * global_stats[c]; // update mean
frame_stats.y += alpha * global_stats[ldg + c]; // update variance
window_length += alpha * global_count; // update window length
}
}
float mean = frame_stats.x / window_length;
float var = frame_stats.y / window_length - mean * mean;
float floor = 1e-20;
if (var < floor) // avoid dividing by zero
var = floor;
if (!var_norm) {
// skip variance normalization
var = 1.0f;
}
if (!mean_norm) {
assert(false);
// skip mean normalization
mean = 0.0f;
}
// shift by mean and scale by variance
feat_out[r * ldo + c] = (val - mean) / sqrtf(var);
}
}
namespace kaldi {
void CudaOnlineCmvn::ComputeFeatures(const CuMatrixBase<BaseFloat> &feats_in,
CuMatrix<BaseFloat> *feats_out) {
int32_t num_frames = feats_in.NumRows();
int32_t feat_dim = feats_in.NumCols();
feats_out->Resize(num_frames, feat_dim, kUndefined);
CuMatrix<float> stats(num_frames, feat_dim * 2, kUndefined);
int threads = 1024;
int blocks = feat_dim;
// compute windowed sum/sum2 prefix sum along column of feats
compute_cmvn_stats_kernel<<<blocks, threads>>>(
feats_in.Data(), feats_in.Stride(), num_frames, feat_dim, stats.Data(),
stats.Stride());
CU_SAFE_CALL(cudaGetLastError());
threads = (feat_dim + 31) / 32 * 32; // round up to 32 threads
if (threads > 1024) threads = 1024;
const CuMatrix<float> &gstats = cmvn_state_.global_cmvn_stats;
const CuMatrix<float> &sstats = cmvn_state_.speaker_cmvn_stats;
int global_frames = opts_.global_frames;
int speaker_frames = opts_.speaker_frames;
if (gstats.NumRows() == 0) global_frames = 0;
if (sstats.NumRows() == 0) speaker_frames = 0;
// apply cmvn
apply_cmvn_kernel<<<num_frames, threads>>>(
opts_.cmn_window, opts_.normalize_variance, opts_.normalize_mean,
feats_in.Data(), feats_in.Stride(), num_frames, feat_dim, stats.Data(),
stats.Stride(), gstats.Data(), gstats.Stride(), global_frames,
sstats.Data(), sstats.Stride(), speaker_frames, feats_out->Data(),
feats_out->Stride());
CU_SAFE_CALL(cudaGetLastError());
}
}
// cudafeat/feature-online-cmvn-cuda.h
//
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
// Justin Luitjens
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_CUDAFEAT_FEATURE_ONLINE_CMVN_CUDA_H_
#define KALDI_CUDAFEAT_FEATURE_ONLINE_CMVN_CUDA_H_
#include "cudamatrix/cu-matrix.h"
#include "cudamatrix/cu-vector.h"
#include "feat/online-feature.h"
namespace kaldi {
struct CudaOnlineCmvnState {
// The following is the global CMVN stats, in the usual
// format, of dimension 2 x (dim+1), as [ sum-stats count
// sum-sqared-stats 0 ]
CuMatrix<float> global_cmvn_stats;
CuMatrix<float> speaker_cmvn_stats;
CudaOnlineCmvnState(){};
CudaOnlineCmvnState(const OnlineCmvnState &cmvn_state)
: global_cmvn_stats(cmvn_state.global_cmvn_stats),
speaker_cmvn_stats(cmvn_state.speaker_cmvn_stats) {}
CudaOnlineCmvnState(const CudaOnlineCmvnState &cmvn_state)
: global_cmvn_stats(cmvn_state.global_cmvn_stats),
speaker_cmvn_stats(cmvn_state.speaker_cmvn_stats) {}
};
class CudaOnlineCmvn {
public:
CudaOnlineCmvn(const OnlineCmvnOptions &opts, const CudaOnlineCmvnState &cmvn_state)
: opts_(opts), cmvn_state_(cmvn_state){};
~CudaOnlineCmvn(){};
void ComputeFeatures(const CuMatrixBase<BaseFloat> &feats_in,
CuMatrix<BaseFloat> *feats_out);
private:
const OnlineCmvnOptions &opts_;
const CudaOnlineCmvnState &cmvn_state_;
};
}
#endif
......@@ -9,7 +9,7 @@ LDLIBS += $(CUDA_LDLIBS)
BINFILES =
ifeq ($(CUDA), true)
BINFILES += compute-mfcc-feats-cuda
BINFILES += compute-mfcc-feats-cuda apply-cmvn-online-cuda
endif
OBJFILES =
......@@ -20,6 +20,6 @@ ADDLIBS = ../cudafeat/kaldi-cudafeat.a ../cudamatrix/kaldi-cudamatrix.a \
../hmm/kaldi-hmm.a ../feat/kaldi-feat.a \
../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \
../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \
../base/kaldi-base.a
../base/kaldi-base.a ../online2/kaldi-online2.a
include ../makefiles/default_rules.mk
// online2bin/apply-cmvn-online.cc
// Copyright 2014 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "feat/online-feature.h"
#include "cudafeat/feature-online-cmvn-cuda.h"
int main(int argc, char *argv[]) {
try {
typedef kaldi::int32 int32;
using namespace kaldi;
const char *usage =
"Apply online cepstral mean (and possibly variance) computation online,\n"
"using the same code as used for online decoding in the 'new' setup in\n"
"online2/ and online2bin/.'\n"
"The computation is done on the device in serial. "
"spk2utt is not supported.\n"
"\n"
"Usage: apply-cmvn-online-cuda [options] <global-cmvn-stats> <feature-rspecifier> "
"<feature-wspecifier>\n"
"e.g. apply-cmvn-online-cuda 'matrix-sum scp:data/train/cmvn.scp -|' data/train/split8/1/feats.scp ark:-\n";
ParseOptions po(usage);
OnlineCmvnOptions cmvn_opts;
std::string spk2utt_rspecifier;
cmvn_opts.Register(&po);
po.Read(argc, argv);
if (po.NumArgs() != 3) {
po.PrintUsage();
exit(1);
}
g_cuda_allocator.SetOptions(g_allocator_options);
CuDevice::Instantiate().SelectGpuId("yes");
CuDevice::Instantiate().AllowMultithreading();
std::string global_stats_rxfilename = po.GetArg(1),
feature_rspecifier = po.GetArg(2),
feature_wspecifier = po.GetArg(3);
// global_cmvn_stats helps us initialize to online CMVN to
// reasonable values at the beginning of the utterance.
Matrix<double> global_cmvn_stats;
ReadKaldiObject(global_stats_rxfilename, &global_cmvn_stats);
BaseFloatMatrixWriter feature_writer(feature_wspecifier);
int32 num_done = 0;
int64 tot_t = 0;
OnlineCmvnState cmvn_state(global_cmvn_stats);
CudaOnlineCmvnState cu_cmvn_state(cmvn_state);
CudaOnlineCmvn cuda_cmvn(cmvn_opts, cu_cmvn_state);
SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
for (; !feature_reader.Done(); feature_reader.Next()) {
std::string utt = feature_reader.Key();
const Matrix<BaseFloat> &feats = feature_reader.Value();
int32_t numRows = feats.NumRows();
int32_t numCols = feats.NumCols();
CuMatrix<BaseFloat> cu_feats_in(feats);
CuMatrix<BaseFloat> cu_feats_out(numRows, numCols, kUndefined);
Matrix<BaseFloat> normalized_feats(numRows, numCols, kUndefined);
cuda_cmvn.ComputeFeatures(cu_feats_in, &cu_feats_out);
normalized_feats.CopyFromMat(cu_feats_out);
num_done++;
tot_t += feats.NumRows();
feature_writer.Write(utt, normalized_feats);
num_done++;
}
KALDI_LOG << "Applied online CMVN to " << num_done << " files, or "
<< tot_t << " frames.";
return (num_done != 0 ? 0 : 1);
} catch(const std::exception &e) {
std::cerr << e.what();
return -1;
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment