why Vision Transformers key and query linar layer do not combine into one matrix

I study some vision transformers code (e.g. vit-pytorch) and found in attention module:

#x is input
key=nn.Linear(...,bias=False)(x)
query=nn.Linear(...,bias=False)(x)
similar_matrix=torch.matmul(query,key.transpose(...))

because Linear can be considered as a matrix, I get

key=K^T @ x
query=Q^T @ x
similar_matrix = query^T @ key = x^T @ (Q @ K^T) @ x
(K,Q means learnable matrix, @ means matmul, ^T means transpose)

here Q @ K^T , I think they can be combined into a matrix in order to reduce the amount of parameters and calculation

why not do this? is it because the training effect is not good?

1 answer

  • answered 2022-05-06 06:00 user2736738

    Let's clear few things up.

    Since bias=False it is indeed the correct idea. That means there would be only the weight but the bias would be None. Quickly let's check it out

    import torch
    import torch.nn as nn
    
    m = nn.Linear(3, 4,bias=False)
    n = nn.Linear(3,4,bias=False)
    
    if hasattr(m, 'weight'):
        print(m.weight) # prints the weight
    else:
        print("doesn't exist")
    if hasattr(m, 'bias'):
        print(m.bias) # None
    else:
        print("doesn't exist")
    

    Now let's run this snippet, it is following your idea,

    x = torch.randint(1,5,(2,3)).float()
    m = nn.Linear(3, 4,bias=False)
    n = nn.Linear(3,4,bias=False)
    
    K = m.weight
    Q = n.weight
    
    k = m(x) # x@K.T
    q = n(x) # x@Q.T
    print(q@k.T)
    print( x@Q.T @  K @x.T) 
    
    
    #output (it would be same in your case as per the initialized input)
    tensor([[-4.6655, -5.2234],
            [-6.8665, -7.6535]], grad_fn=<MmBackward0>)
    tensor([[-4.6655, -5.2234],
            [-6.8665, -7.6535]], grad_fn=<MmBackward0>)
    

    So both terms are equivalent. (confirmed). But the number of parameters would be exactly the same as other other approach. Since we use the same weight matrices and stuff.

    Which is faster?

    timeit("m(x) @  n(x).T", globals())
    timeit("x@Q.T @  K @x.T", globals())
    
     #Output
    1000 loops, best of 3: 50.1 usec per loop
    1000 loops, best of 3: 31.1 usec per loop
    

    So the first one takes slightly more time than the second one. If we increase the size then also

    x = torch.randint(1,5,(2000,3000)).float()
    m = nn.Linear(3000, 4000,bias=False)
    n = nn.Linear(3000,4000,bias=False)
    
    K = m.weight
    Q = n.weight
    timeit("m(x) @  n(x).T", globals())
    timeit("x@Q.T @  K @x.T", globals())
    # output
    1 loops, best of 3: 2.93 sec per loop
    1 loops, best of 3: 2.73 sec per loop
    

    So here as well we see a significant performance increase in the second case. So ideally one could use second one to gain some performance increase like you guessed.

    But the major thing, why they didn't follow can be attributed to the attempt of keeping the code open towards future changes where the bias is being used and in that case the first code would be much more readable and conforming to the existing standards. In this small case it might be suitable but for larger models this can be quite ugly. For better readability and future scopes I believe they selected the first one.

How many English words
do you know?
Test your English vocabulary size, and measure
how many words do you know
Online Test
Powered by Examplum