
I’ve seen a lot of convoluted tutorials on attention but nothing really made it click for me more as understanding as mixing a projected version of your tokens across a sequence weighted by their "softmaxed" unnormalized correlations in a different space.
Let's strip that back a bit and assess each part
Projected version of your tokens
We're not mixing the original copy of the tokens, but the Value projected versions.
Unnormalized
correlation is measured by dot product, with is unbounded as opposed to being in between [-1,1]
Correlations
the NxN matrix of the dot products between every vector combination. These are the mixing weights, for new tokens to be created as weighted sums.
In a different space
tokens are first projected into another space by Query/Key projections when calculating the correlations
Softmaxed
As opposed to using the values of the unnormalized correlations directly as is, we apply the softmax kernel which normalizes and "sparsifies" the weights.
Let's try mixing your tokens across a sequence weighted by their unnormalized correlations. Sans the "different space", "projected version of your tokens", "softmaxed" parts. We'll then build up to attention from there.
If you'd like to follow along, all experiments are performed in this repo.
1) Boneless Attention. Just correlation
Before drawing out this minimal attention with matrices and vectors, let's look at it in a symbolic way.

Each one of these arrows/vectors is a token. The vector is in 2d space and can be described with an x coordinate and a y coordinate. The full set of vectors is our sequence.
Now, we'd like to create a new set of arrows aka vectors aka tokens, by mixing our existing ones together, but the ratios we mix them at should correspond to how similar or related they are. How can we measure how correlated two tokens are? By how closely they point in the same direction.

Here, we've plotted the similarities of the first token with the others. The opacity of the red arrows represents how similar the first token is with the one it is pointing to. Note, we are also having it point to itself.
Now with our similarities in hand, we can mix the tokens to obtain a new token corresponding to the first in this sequence.

We can illustrate that a given token should contribute less to the output by making it smaller. When we add all the vectors up tip to tail we'll see it has less of an effect.
We then repeat the process for all the other vectors in our sequence.
Now we're ready for a bit more mathy notation.
First, we need our correlation matrix.

Next we will "mix" our tokens, as a weighted sum with weights determined by the values in the correlation matrix.
Because we have 3 vectors, each new vector will be a mix of those 3. Looking at the first one (on the left cell of the blue row), we have:
Now doing them all simultaneously, as a matrix multiplication.

Notice that the vectors didn't change a ton, aside from changing length a bit and slight shifts to direction, a product of the fact that they highly interacted with themselves but not as much their peers.
If we repeat this process over and over again, vectors that are less than 90 degrees apart from each other will increasingly move closer to each other.
This stripped back form of attention, but also attention somewhat in general, I like to describe as a "consensus" algorithm. It encourages the vectors to all agree on a direction, which is a smoothed, mixed result of all the initial vectors. This is a parameterless (boneless) attention. There are no learned weights.
However, this explosion of length makes for not such an exciting operation.
Next up, enter softmax.
2) Boneless Attention. Softmaxed correlation,
Same as before, except now we'll apply softmax normalization to the correlation matrix. This makes the weights sum to one, avoids negative values, and also has the capability of "sparsifying" the operation, often driving many values close to 0.

Now, despite having no actual parameters, I'd argue this confers some practical value.
One example is where I used parameterless attention as a clustering algorithm for colors in an image.

Another, is that you can actually train transformer networks like language models using this kind of layer. The performance is not nearly as satisfying, but its also very far from terrible, raw correlation alone is at least doing something. (this is 1000 steps on GPT2 just as a proof of concept, so take with a massive grain of salt)

Now, let's add in the Value projected version of our tokens.
3) Attention. Softmaxed correlation, weighted sum over Value projected versions
Here, as opposed to using the same versions of our original vectors for both determining correlation weights, and what tokens get mixed together, we'll use a separate set for the vectors to be mixed.

This can give us a bit more of an uplift. Interestingly, loss plummets faster initially, perhaps because correlation in native space as opposed to one created by random projections may be a stronger starting place off the bat. But as Query and Key weights are given time to adjust, baseline quickly surpasses this modded version.

4) Bonafide Attention. Correlation can be taken in another space.
Here now, we introduce the Query and Key learnable projection matrices. When we measure our correlations between tokens now, we'll do them between two different versions of the original set.
Why would we want to do this?
One of our original problems with attention is that it is guided towards consensus, but not too much beyond that. We also found that the self-interaction was also always very high, sort of loosely explaining part of what the outcome will be before we even proceed.
By projecting to another space, we can more finely decide which tokens should correlate with each other, increasing the expressive power of attention.
This operation is WAY more flexible and expressive in higher dimensions, but we can imagine this simple 2d case where we have 4 vectors. We could apply a stretching transformation as such, increasing the similarity between nearby vectors and further decreasing it between ones farther away, letting us modify correlation weights and thus the output of attention.


Now for fun, let's see what happens if we double the initial values of the correlations prior to feeding them to softmax. This should help give a feel of the advantages of softmax and its sparsifying properties. For some more interesting tidbits see here.

This then becomes the weights for the value tokens (green in previous diagrams) to be mixed by.
At this point, the tutorial is complete, this is attention! The remainder of this post is further experiments on attention leveraging the fact that it is somewhat similar to the idea of calculating raw correlations.
5) Bonafide Attention. Initialize Query/Key projections to identity function
Seeing the initial fast loss drop of the previous method, it seemed worthwhile to try initializing the Query and Key projection matrices to the identity function, BUT unlike pure "just correlation" attention, they are free to adapt from there. Perhaps this is a good place to start and adapt from?
This might be our strongest contender so far, though still baseline surpasses it.

I Imagine a counter argument here is that this kind of thing simply doesn't optimize well, and we do our random initialization for a reason.
There's still more we can try though. What is a common way we can include the identity function? Residuals!
6) Bonafide Attention. Make Query/Key projections residual layers
Here now we change the query and key calculation from
Q = wQ(x) K = wK(x)
to
Q = wQ(x) + x K = wK(x) + x
This is motivated by correlation making for a useful starting bias, but also wanting to give adequate freedom to the optimization process.
Here the two were now run for about 4k steps. The two are approximately neck and neck with residual attention seemingly having a slight edge.

A Discussion on Low Rank Attention Weights
A point I've seen a few papers contend on is the rank we observe of the weights used in attention layers. This is approximated by taking the singular value decomposition and seeing how % of singular values are "properly large".
Low rank layers are generally seen as a bad thing. It can suggest dimensional collapse and an underutilization of the network's capacity. But is this always true? When we begin training, sampling our weights from normal distribution, our matrices begin as 100% full rank, so it can really only go down from there regardless.
An instance here is the the knowledge that CLIP embeddings are low rank, a few axes explain the majority of the variation in the output, making them feasibly much more compressible and lower dimensional than they reveal. Here we have the singular values extracted from the covariance matrix of 500,000 OpenCLIP VITG embeddings
It very rapidly approaches 0

We can also see how this affects the output by decoding the embeddings back into images, after pruning out smaller singular vectors. This is an example from the DALLE2 paper.

Starting from the top 320 dimensions, we can prune out the last 120 with not a majorly substantial change to the output. Though once we get down to 40, 80 range we can see these larger dimensions are much more vital.
The Modality Gap paper looks at CLIP in depth and the factors around contrastive learning that induce this low-rankness
This paper also then finds a way to address it, recovering full use of dimensions and slightly increasing performance.
So at least for embedding representations, low rankness seems undesirable but evidently far from catastrophic.
The paper, Weight decay induces low-rank attention layers, comes off as a scary title. However, the results are a bit surprising.
They measure low-rankness by the observed 95% saturation point. We can see that higher values of weight decay, indicated by lambda, very much reduces the rank of attention layers. I'm tempted to say the the charts on the right are particularly interesting, that the pseudo rank drops and then increases over the course of training, however this measurement doesn't imply that dimensions have fully collapsed and then recovered, it could just as much be our largest dimensions decreasing as small ones increasing. Nevertheless it'd be interesting to see if that chart continued a bit more to the right.

Despite finding this massive drop in rank, performance does not plummet in the same manner. In fact, performance seems to improve a bit in some cases, and this effect extends to MLP layers as well. The benefit could be explained by low-rankness, or it might be some other factor around weight decay shrinking the magnitude of weights offering regularization. It could also be the case that "pseduo-rank" might not be telling us the whole story. Regardless, it doesn't correlate to the expected loss of performance we would imagine with decreased rank.

So perhaps, this is not something we should aim to avoid but something we can lean into? Maybe attention layers, as they are, are over parameterized by dense matrices, but possibly this suggests this is a good place we can use structured sparse matrices like LoRA.
And this is kind of exactly what Deepseek's Multi-Latent Attention does, both reducing computation a bit and the size of KV caches.
This was a similar motivation for the experiments on correlation done in this post. If correlation can already do alright, do we need such a heavily parameterized operation?