This is another potential improvement to the transformer architecture from Facebook (the other one that comes to mind is this one from same authors: https://arxiv.org/abs/2405.18719), but note that it comes with a major problem that might not be obvious at first glance: it's just not usable in practice without a ton of work. It modifies the innards of the attention mechanism, so it is incompatible with Flash Attention (or any other optimized attention library), and you do not want to train anything beyond toy models without Flash Attention (the performance hit is just way too big).
There's pytorch's FlexAttention which could maybe make this practical, but currently it's just way too buggy.
replies(2):