Skip to content
Commit 265f42b7 authored by James Cross's avatar James Cross Committed by Facebook Github Bot
Browse files

multihead_attention: pre-transpose incremental state (#232)

Summary:
Pull Request resolved: https://github.com/pytorch/translate/pull/232

Though transpose operations are essentially free during PyTorch execution, they can result in costly operations when exported to Caffe2 inference nets via ONNX tracing, especially when applied repeatedly to large tensors.

For this reason, we update `MultiheadAttention` to store its incremental state with shape (bsz, num_heads, seq_len, head_dim), that is after transposing the projected input. This should result in non-trivially faster exported models without changing the semantics or speed of PyTorch execution.

Reviewed By: myleott

Differential Revision: D10186506

fbshipit-source-id: 8a42712423ee767ea49ed88d2a4653f900d14fba
parent b9e29a47
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment