Taking rows of a matrix given a batch of indices - Python

How can we extract the rows of a matrix given a batch of indices (in Python)?

i = [[0,1],[1,2],[2,3]]
a = jnp.array([[1,2,3,4],[2,3,4,5]])

def extract(A,idx):
    A = A[:,idx]
    return A

B = extract(a,i)

I expect to get this result (where the matrices are stacked):

B = [[[1,2],



And NOT:

  B_ = [[1, 2],
     [2, 3],
     [3, 4]],

     [[2, 3],
     [3 ,4],
     [4, 5]]]

In this case, the rows are stacked, but I want to stack the different matrices.

I tried using


but this gives me an error since a and i don't have the same dimension.... Is there an alternative, without using loops?

2 answers

  • answered 2022-05-04 16:40 jakevdp

    You can do this with vmap if you specify in_axes in the right way, and convert your index list into an index array:

    vmap(extract, in_axes=(None, 0))(a, jnp.array(i))
    # DeviceArray([[[1, 2],
    #               [2, 3]],
    #              [[2, 3],
    #               [3, 4]],
    #              [[3, 4],
    #               [4, 5]]], dtype=int32)

    When you say in_axes=(None, 0), it specifies that you want the first argument to be unmapped, and you want the second argument to be mapped along its leading axis.

    The reason you need to convert i from a list to an array is because JAX will only map over array arguments: if vmap encounters a collection like a list, tuple, dict, or a general pytree, it attempts to map over each array-like value within the collection.

  • answered 2022-05-04 16:43 lemon

    You can use indexing right away on the matrix a transposed:


    Does it work for you?

How many English words
do you know?
Test your English vocabulary size, and measure
how many words do you know
Online Test
Powered by Examplum