There's pytorch's FlexAttention which could maybe make this practical, but currently it's just way too buggy.
There's pytorch's FlexAttention which could maybe make this practical, but currently it's just way too buggy.
Also note, depending on your model dimensions and sequence lengths, often the attention computation plays only a minor role (maybe 10% overall or so), and the MLP computation dominates.
Maybe it's better now, but I'd still consider using FlexAttention without a corresponding unit test checking its accuracy against an equivalent eager implementation completely irresponsible.
Nvidia isn't likely to start releasing updated firmware for an obscure architecture for which there is limited evidence of improvement, and even less adoption.
I've been burnt way too many times by fancy new methods that claimed improvement, where I spent a ton of effort to implement them, and they ended up being poop.
Every person working in the field and pushing papers should read this blog post and apply what's written in it: https://kellerjordan.github.io/posts/muon/#discussion-solvin...
Also more integration-like tests where I take an already pretrained model, load it using an established library (e.g. Huggingface Transformers) and I also load the very same checkpoint into my reimplementation (where I vary the implementation, e.g. swap the attention implementation) and compare the outputs. Funnily enough, I recently even found a bug in HF's Transformers this way when I updated to a newer version and my previously matching output was not matching anymore.