Well, yes, I deserved that reply! And yes the code is printing True. It's not that I disbelieved you... but something is wrong here. Investigation below, thanks to Claude.ai for walking me through it!
In [10]: o1[0, :, :3]
Out[10]:
tensor([[ 0.0053, 0.0017, -0.0012],
[ 0.0053, 0.0017, -0.0012],
[ 0.0053, 0.0017, -0.0012],
[ 0.0053, 0.0017, -0.0012],
[ 0.0053, 0.0017, -0.0012],
[ 0.0053, 0.0017, -0.0012]], grad_fn=<SliceBackward0>)
Every token has the same attention values. I expect attention(cat, everything) to differ from attention(dog, everything), even without positional encoding.
Further, the attention weights are uniform and identical for both sentences:
In [46]: o1, aw1 = mha(W_q(e1), W_k(e1), W_v(e1))
In [47]: o2, aw2 = mha(W_q(e2), W_k(e2), W_v(e2))
In [48]: aw1.shape
Out[48]: torch.Size([1, 6, 6])
In [49]: aw2.shape
Out[49]: torch.Size([1, 6, 6])
In [50]: aw1
Out[50]:
tensor([[[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]]],
grad_fn=<MeanBackward1>)
In [51]: aw2
Out[51]:
tensor([[[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]]],
grad_fn=<MeanBackward1>)
That is not expected. It's because the Linear layers are initialised with such small values. And the softmax causes a collapse.
Trying random weights on a larger scale:
In [52]: W_q.weight.data *= 100
W_k.weight.data *= 100
W_v.weight.data *= 100
In [55]: o1, aw1 = mha(W_q(e1), W_k(e1), W_v(e1))
In [56]: o2, aw2 = mha(W_q(e2), W_k(e2), W_v(e2))
In [57]: aw1
Out[57]:
tensor([[[0.2049, 0.1606, 0.1256, 0.1095, 0.1723, 0.2270],
[0.0883, 0.2047, 0.1544, 0.2776, 0.1405, 0.1345],
[0.1196, 0.1719, 0.1831, 0.1541, 0.1374, 0.2339],
[0.1413, 0.2399, 0.1617, 0.2056, 0.1634, 0.0880],
[0.1455, 0.1432, 0.2432, 0.1239, 0.1494, 0.1948],
[0.1897, 0.1817, 0.1920, 0.1478, 0.1618, 0.1270]]],
grad_fn=<MeanBackward1>)
In [58]: aw2
Out[58]:
tensor([[[0.2049, 0.1606, 0.2270, 0.1095, 0.1723, 0.1256],
[0.0883, 0.2047, 0.1345, 0.2776, 0.1405, 0.1544],
[0.1897, 0.1817, 0.1270, 0.1478, 0.1618, 0.1920],
[0.1413, 0.2399, 0.0880, 0.2056, 0.1634, 0.1617],
[0.1455, 0.1432, 0.1948, 0.1239, 0.1494, 0.2432],
[0.1196, 0.1719, 0.2339, 0.1541, 0.1374, 0.1831]]],
grad_fn=<MeanBackward1>)
In [60]: o1[:, :, :5]
Out[60]:
tensor([[[ 0.0145, 0.3128, -0.3659, -0.1884, 0.1724],
[-0.2319, 0.1407, -0.6010, -0.4064, 0.4259],
[-0.3231, 0.1622, -0.6351, -0.1711, 0.4014],
[-0.0596, 0.2610, -0.7388, -0.2987, 0.3214],
[-0.2750, 0.0676, -0.4140, -0.2024, 0.3383],
[-0.1434, 0.0871, -0.3154, -0.0755, 0.3314]]],
grad_fn=<SliceBackward0>)
In [61]: o2[:, :, :5]
Out[61]:
tensor([[[ 0.0145, 0.3128, -0.3659, -0.1884, 0.1724],
[-0.2319, 0.1407, -0.6010, -0.4064, 0.4259],
[-0.1434, 0.0871, -0.3154, -0.0755, 0.3314],
[-0.0596, 0.2610, -0.7388, -0.2987, 0.3214],
[-0.2750, 0.0676, -0.4140, -0.2024, 0.3383],
[-0.3231, 0.1622, -0.6351, -0.1711, 0.4014]]],
grad_fn=<SliceBackward0>)
In [62]: print("Matches: ", torch.allclose(o1, o2, atol=1e-6))
Matches: False