How to filter numpy array based on list of indices

In this post, we are going to understand How to filter a numpy array based on a list of indices by using the numpy function take() filter the numpy array based on indices.

1. filter numpy array based on a list of indices row-wise


The np. take() function returns the elements of the numpy array as per the given indices and axis. The default value for the axis is None, If none the flatten array is used.

In this example, we filter the numpy array by a list of indexes by using the np. take() function and passed the axis=0 to filtering the numpy array row-wise.

import numpy as np
 
indices = [1,2,3]

newArr = np.array([[12,14,70,80],[12,75,60,50],[3,6,9,12],[4,8,12,16]])


axis = 0
print('original array:\n',newArr)

fltrArr = np.take(newArr, indices, axis)
print('\n FIlter array:\n',fltrArr)

Output

original array:
 [[12 14 70 80]
 [12 75 60 50]
 [ 3  6  9 12]
 [ 4  8 12 16]]

 FIlter array:
 [[12 75 60 50]
 [ 3  6  9 12]
 [ 4  8 12 16]]

2. filter numpy array based on a list of indices column-wise


In this example, we will filter the numpy array by a list of indexes by using the np.take() function passed the axis=1 to filter the numpy array column-wise.

import numpy as np
 
indices = [1,2]

newArr = np.array([[12,14,70,80],[12,75,60,50],[3,6,9,12],[4,8,12,16]])


axis = 1
print('original array:\n',newArr)

fltrArr = np.take(newArr, indices, axis)
print('\n FIlter array:\n',fltrArr)

Output

original array:
 [[12 14 70 80]
 [12 75 60 50]
 [ 3  6  9 12]
 [ 4  8 12 16]]

 FIlter array:
 [[14 70]
 [75 60]
 [ 6  9]
 [ 8 12]]

3. Filter multidimesional NumPy array based list of indices


In this example, we are filtering the multidimensional array by a list of indices. First, we have created an array of sizes that we have distributed in 5 rows 8 columns

  • We have to create a numpy array from a list of indices.
  • Finally filtering the numpy array and printing the result.
import numpy as np

origanlArr = np.arange(40).reshape(5,8)

print('Original Array:\n',origanlArr)

index = np.array([[0,1], [0,2],[0,4], [3,3]])

print('\n filter NumpY array element:\n',origanlArr[index[:,0],index[:,1]])

Output

Original Array:
 [[ 0  1  2  3  4  5  6  7]
 [ 8  9 10 11 12 13 14 15]
 [16 17 18 19 20 21 22 23]
 [24 25 26 27 28 29 30 31]
 [32 33 34 35 36 37 38 39]]

 filter NumpY array element:
 [ 1  2  4 27]