# why Vision Transformers key and query linar layer do not combine into one matrix

I study some vision transformers code (e.g. vit-pytorch) and found in attention module:

``````#x is input
key=nn.Linear(...,bias=False)(x)
query=nn.Linear(...,bias=False)(x)
similar_matrix=torch.matmul(query,key.transpose(...))
``````

because Linear can be considered as a matrix, I get

``````key=K^T @ x
query=Q^T @ x
similar_matrix = query^T @ key = x^T @ (Q @ K^T) @ x
(K,Q means learnable matrix, @ means matmul, ^T means transpose)
``````

here Q @ K^T , I think they can be combined into a matrix in order to reduce the amount of parameters and calculation

why not do this? is it because the training effect is not good?

Let's clear few things up.

Since `bias=False` it is indeed the correct idea. That means there would be only the `weight` but the bias would be `None`. Quickly let's check it out

``````import torch
import torch.nn as nn

m = nn.Linear(3, 4,bias=False)
n = nn.Linear(3,4,bias=False)

if hasattr(m, 'weight'):
print(m.weight) # prints the weight
else:
print("doesn't exist")
if hasattr(m, 'bias'):
print(m.bias) # None
else:
print("doesn't exist")
``````

Now let's run this snippet, it is following your idea,

``````x = torch.randint(1,5,(2,3)).float()
m = nn.Linear(3, 4,bias=False)
n = nn.Linear(3,4,bias=False)

K = m.weight
Q = n.weight

k = m(x) # x@K.T
q = n(x) # x@Q.T
print(q@k.T)
print( x@Q.T @  K @x.T)

#output (it would be same in your case as per the initialized input)
tensor([[-4.6655, -5.2234],
tensor([[-4.6655, -5.2234],
``````

So both terms are equivalent. (confirmed). But the number of parameters would be exactly the same as other other approach. Since we use the same weight matrices and stuff.

Which is faster?

``````timeit("m(x) @  n(x).T", globals())
timeit("x@Q.T @  K @x.T", globals())

#Output
1000 loops, best of 3: 50.1 usec per loop
1000 loops, best of 3: 31.1 usec per loop
``````

So the first one takes slightly more time than the second one. If we increase the size then also

``````x = torch.randint(1,5,(2000,3000)).float()
m = nn.Linear(3000, 4000,bias=False)
n = nn.Linear(3000,4000,bias=False)

K = m.weight
Q = n.weight
timeit("m(x) @  n(x).T", globals())
timeit("x@Q.T @  K @x.T", globals())
# output
1 loops, best of 3: 2.93 sec per loop
1 loops, best of 3: 2.73 sec per loop
``````

So here as well we see a significant performance increase in the second case. So ideally one could use second one to gain some performance increase like you guessed.

But the major thing, why they didn't follow can be attributed to the attempt of keeping the code open towards future changes where the `bias` is being used and in that case the first code would be much more readable and conforming to the existing standards. In this small case it might be suitable but for larger models this can be quite ugly. For better readability and future scopes I believe they selected the first one.