multihead_attention: pre-transpose incremental state (#232)
Summary: Pull Request resolved: https://github.com/pytorch/translate/pull/232 Though transpose operations are essentially free during PyTorch execution, they can result in costly operations when exported to Caffe2 inference nets via ONNX tracing, especially when applied repeatedly to large tensors. For this reason, we update `MultiheadAttention` to store its incremental state with shape (bsz, num_heads, seq_len, head_dim), that is after transposing the projected input. This should result in non-trivially faster exported models without changing the semantics or speed of PyTorch execution. Reviewed By: myleott Differential Revision: D10186506 fbshipit-source-id: 8a42712423ee767ea49ed88d2a4653f900d14fba
Loading
Please register or sign in to comment