←back to thread

213 points Philpax | 7 comments | | HN request time: 2.277s | source | bottom
1. imjonse ◴[] No.42170648[source]
I don't think the first code example should work (it indeed says false here).

When given a permuted sequence, the attention output will also be permuted, not identical. The need for positional encodings is due to two tokens resulting in the same value in the final attention matrix regardless of the tokens' absolute and relative position; that is enough to miss a lot of meaning.

replies(2): >>42170863 #>>42172280 #
2. FL33TW00D ◴[] No.42170863[source]
The first code example says False because of high precision, I've updated the example.
replies(1): >>42170879 #
3. jmmcd ◴[] No.42170879[source]
But u/imjonse's reasoning seems right. I haven't run either version of the code, but when reading it I expected that to be False. The output is still a list with an order.

the dog chased the cat: position 1 in the output is attention(dog, everything)

the cat chased the dog: position 1 in the output is attention(cat, everything)

replies(1): >>42170914 #
4. FL33TW00D ◴[] No.42170914{3}[source]
Run the code and look at the values!
replies(1): >>42171962 #
5. jmmcd ◴[] No.42171962{4}[source]
Well, yes, I deserved that reply! And yes the code is printing True. It's not that I disbelieved you... but something is wrong here. Investigation below, thanks to Claude.ai for walking me through it!

    In [10]: o1[0, :, :3]
    Out[10]:
    tensor([[ 0.0053,  0.0017, -0.0012],
        [ 0.0053,  0.0017, -0.0012],
        [ 0.0053,  0.0017, -0.0012],
        [ 0.0053,  0.0017, -0.0012],
        [ 0.0053,  0.0017, -0.0012],
        [ 0.0053,  0.0017, -0.0012]],       grad_fn=<SliceBackward0>)
Every token has the same attention values. I expect attention(cat, everything) to differ from attention(dog, everything), even without positional encoding.

Further, the attention weights are uniform and identical for both sentences:

    In [46]: o1, aw1 = mha(W_q(e1), W_k(e1), W_v(e1))
    In [47]: o2, aw2 = mha(W_q(e2), W_k(e2), W_v(e2))
    In [48]: aw1.shape
    Out[48]: torch.Size([1, 6, 6])
    In [49]: aw2.shape
    Out[49]: torch.Size([1, 6, 6])
    In [50]: aw1
    Out[50]:
    tensor([[[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]]],
       grad_fn=<MeanBackward1>)

    In [51]: aw2
    Out[51]:
    tensor([[[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]]],
       grad_fn=<MeanBackward1>)
That is not expected. It's because the Linear layers are initialised with such small values. And the softmax causes a collapse.

Trying random weights on a larger scale:

    In [52]: W_q.weight.data *= 100
         W_k.weight.data *= 100
         W_v.weight.data *= 100

    In [55]: o1, aw1 = mha(W_q(e1), W_k(e1), W_v(e1))
    In [56]: o2, aw2 = mha(W_q(e2), W_k(e2), W_v(e2))
    In [57]: aw1
    Out[57]:
    tensor([[[0.2049, 0.1606, 0.1256, 0.1095, 0.1723, 0.2270],
         [0.0883, 0.2047, 0.1544, 0.2776, 0.1405, 0.1345],
         [0.1196, 0.1719, 0.1831, 0.1541, 0.1374, 0.2339],
         [0.1413, 0.2399, 0.1617, 0.2056, 0.1634, 0.0880],
         [0.1455, 0.1432, 0.2432, 0.1239, 0.1494, 0.1948],
         [0.1897, 0.1817, 0.1920, 0.1478, 0.1618, 0.1270]]],
       grad_fn=<MeanBackward1>)

    In [58]: aw2
    Out[58]:
    tensor([[[0.2049, 0.1606, 0.2270, 0.1095, 0.1723, 0.1256],
         [0.0883, 0.2047, 0.1345, 0.2776, 0.1405, 0.1544],
         [0.1897, 0.1817, 0.1270, 0.1478, 0.1618, 0.1920],
         [0.1413, 0.2399, 0.0880, 0.2056, 0.1634, 0.1617],
         [0.1455, 0.1432, 0.1948, 0.1239, 0.1494, 0.2432],
         [0.1196, 0.1719, 0.2339, 0.1541, 0.1374, 0.1831]]],
       grad_fn=<MeanBackward1>)

    In [60]: o1[:, :, :5]
    Out[60]:
    tensor([[[ 0.0145,  0.3128, -0.3659, -0.1884,  0.1724],
         [-0.2319,  0.1407, -0.6010, -0.4064,  0.4259],
         [-0.3231,  0.1622, -0.6351, -0.1711,  0.4014],
         [-0.0596,  0.2610, -0.7388, -0.2987,  0.3214],
         [-0.2750,  0.0676, -0.4140, -0.2024,  0.3383],
         [-0.1434,  0.0871, -0.3154, -0.0755,  0.3314]]],
       grad_fn=<SliceBackward0>)

    In [61]: o2[:, :, :5]
    Out[61]:
    tensor([[[ 0.0145,  0.3128, -0.3659, -0.1884,  0.1724],
         [-0.2319,  0.1407, -0.6010, -0.4064,  0.4259],
         [-0.1434,  0.0871, -0.3154, -0.0755,  0.3314],
         [-0.0596,  0.2610, -0.7388, -0.2987,  0.3214],
         [-0.2750,  0.0676, -0.4140, -0.2024,  0.3383],
         [-0.3231,  0.1622, -0.6351, -0.1711,  0.4014]]],
       grad_fn=<SliceBackward0>)

    In [62]: print("Matches: ", torch.allclose(o1, o2, atol=1e-6))
    Matches:  False
replies(1): >>42172097 #
6. FL33TW00D ◴[] No.42172097{5}[source]
Hm! Very interesting! Thank you for taking the time to debug that.

I'm going to have to think hard about how to rewrite the motivating example to explain this best.

Edit: updated the post, thanks for pointing out the pernicious init values!

7. aconz2 ◴[] No.42172280[source]
To add on since this took me a while to understand: for a single token, self attention is permutation invariant because we take the qK (one query dot all the other keys) weighted sum of all the values; that sum is what gives the invariance because + is commutative. But for all the tokens, the mha output matrix will not be invariant, but rather equivariant, where you apply the same permutation to the output matrix as you did to the input tokens. What might be a more useful example is to take one position, like the last one, and compute its mha for every permutation of the previous tokens; those will/should all be the same.