# Taking rows of a matrix given a batch of indices - Python

How can we extract the rows of a matrix given a batch of indices (in Python)?

``````i = [[0,1],[1,2],[2,3]]
a = jnp.array([[1,2,3,4],[2,3,4,5]])

def extract(A,idx):
A = A[:,idx]
return A

B = extract(a,i)
``````

I expect to get this result (where the matrices are stacked):

``````B = [[[1,2],
[2,3]],

[[2,3],
[3,4]],

[3,4],
[4,5]]]
``````

And NOT:

``````  B_ = [[1, 2],
[2, 3],
[3, 4]],

[[2, 3],
[3 ,4],
[4, 5]]]
``````

In this case, the rows are stacked, but I want to stack the different matrices.

I tried using

jax.vmap(extract)(a,i),

but this gives me an error since a and i don't have the same dimension.... Is there an alternative, without using loops?

You can do this with `vmap` if you specify `in_axes` in the right way, and convert your index list into an index array:

``````vmap(extract, in_axes=(None, 0))(a, jnp.array(i))
# DeviceArray([[[1, 2],
#               [2, 3]],
#
#              [[2, 3],
#               [3, 4]],
#
#              [[3, 4],
#               [4, 5]]], dtype=int32)
``````

When you say `in_axes=(None, 0)`, it specifies that you want the first argument to be unmapped, and you want the second argument to be mapped along its leading axis.

The reason you need to convert `i` from a list to an array is because JAX will only map over array arguments: if `vmap` encounters a collection like a list, tuple, dict, or a general pytree, it attempts to map over each array-like value within the collection.

You can use indexing right away on the matrix `a` transposed:
``````a.T[i,:]