如何使用Python numpy.where()方法

在Python中,我們可以使用numpy.where()函數來根據條件從numpy數組中選擇元素。

不僅如此,如果條件滿足,我們還可以對這些元素進行一些操作。

讓我們看看如何使用這個函數,通過一些說明性的例子!


Python numpy.where()的語法

此函數接受一個類似numpy的數組(例如整數/布爾值的NumPy數組)。

它返回一個新的numpy數組,根據條件(即一個類似numpy的布爾值數組)進行篩選。

例如,條件可以取值array([[True, True, True]]),這是一個類似numpy的布爾值數組。(默認情況下,NumPy只支持數值,但我們也可以將它們轉換為bool

例如,如果條件array([[True, True, False]]),且我們的數組是a = ndarray([[1, 2, 3]]),對數組應用條件(a[:, condition]),我們將得到數組ndarray([[1 2]])

import numpy as np

a = np.arange(10)
print(a[a <= 2]) # 只會捕獲小於等於 2 的元素,忽略其他元素

輸出

array([0 1 2])

注意:相同的條件也可以表示為 a <= 2。這是建議的條件陣列格式,因為將其寫為布林陣列非常繁瑣。

但是如果我們想保留結果的維度,而不會喪失原始陣列的元素怎麼辦呢?我們可以使用 numpy.where() 來實現這一點。

numpy.where(condition [, x, y])

我們還有兩個參數 xy。那些是什麼?

基本上,這意味著如果對於陣列中的某個元素,condition 為真,則新陣列將從 x 中選擇元素。

否則,如果為假,則將從 y 中選擇元素。

有了這個,我們的最終輸出陣列將是一個陣列,其中包含從 x 選擇的元素,無論 condition = True 還是從 y 選擇的元素,無論 condition = False

請注意,雖然 xy 是可選的,但如果您指定了 x,您 必須 也要指定 y。這是因為在這種情況下,輸出陣列的形狀必須與輸入陣列相同。

注意:相同的邏輯也適用於單維和多維陣列。在這兩種情況下,我們根據條件進行過濾。還請記住,xycondition 的形狀會一起廣播。

現在,讓我們看一些例子,以正確理解這個函數。


使用Python numpy.where()

假設我們想從一個numpy數組中提取正元素並將所有負元素設置為0,讓我們使用numpy.where()來編寫代碼。

1. 使用numpy.where()替換元素

我們將在這裡使用一個二維隨機數組,並僅輸出正元素。

import numpy as np

# 隨機初始化一個(2D數組)
a = np.random.randn(2, 3)
print(a)

# 每當條件成立時(即僅正元素),b將是a的所有元素
# 否則,將其設置為0
b = np.where(a > 0, a, 0)

print(b)

可能的輸出

[[-1.06455975  0.94589166 -1.94987123]
 [-1.72083344 -0.69813711  1.05448464]]
[[0.         0.94589166 0.        ]
 [0.         0.         1.05448464]]

正如您所看到的,現在僅保留了正元素!

2.僅使用條件的numpy.where()

可能對上述代碼有些混淆,因為有些人可能認為更直觀的方式是簡單地這樣寫條件:

import random
import numpy as np

a = np.random.randn(2, 3)
b = np.where(a > 0)
print(b)

如果您現在嘗試運行上述代碼,進行此更改,您將獲得以下輸出:

(array([0, 1]), array([2, 1]))

仔細觀察,b 現在是一個numpy數組的元組。而且每個數組都是正元素的位置。這是什麼意思?

每當我們只提供條件時,此函數實際上等效於np.asarray.nonzero()

在我們的示例中,np.asarray(a > 0) 將在應用條件後返回一個類似布爾值的數組,而 np.nonzero(arr_like)將返回arr_like的非零元素的索引。(參見這裡的連結)

因此,我們現在將看一個更簡單的例子,這個例子向我們展示了我們可以如何靈活使用numpy!

import numpy as np

a = np.arange(10)

b = np.where(a < 5, a, a * 10)

print(a)
print(b)

輸出

[0 1 2 3 4 5 6 7 8 9]
[ 0  1  2  3  4 50 60 70 80 90]

這裡,條件是a < 5,這將是numpy類似數組[True True True True True False False False False False]x 是數組 a,而y是數組 a * 10。因此,我們只從 a 中選擇,如果 a < 5,就從 a * 10 中選擇。

因此,這將所有大於等於 5 的元素,通過乘以 10 進行轉換。這確實是我們得到的結果!


使用 numpy.where() 进行广播

如果我们提供了所有的 conditionxy 数组,numpy 将会将它们一起广播。

import numpy as np

a = np.arange(12).reshape(3, 4)

b = np.arange(4).reshape(1, 4)

print(a)
print(b)

# 广播 (a < 5, a 和 b * 10)
# 形状为 (3, 4)、(3, 4) 和 (1, 4)
c = np.where(a < 5, a, b * 10)

print(c)

输出

[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]
[[0 1 2 3]]
[[ 0  1  2  3]
 [ 4 10 20 30]
 [ 0 10 20 30]]

在这里,输出是根据条件选择的,因此所有元素,但是这里,b 被广播到 a 的形状。(它的一个维度只有一个元素,因此在广播过程中不会出错)

所以,b 现在将变为 [[0 1 2 3] [0 1 2 3] [0 1 2 3]],现在,我们甚至可以从这个广播的数组中选择元素。

因此,输出的形状与 a 的形状相同。


结论

在本文中,我们学习了如何使用 Python 的 numpy.where() 函数根据另一个条件数组选择数组。


參考資料


Source:
https://www.digitalocean.com/community/tutorials/python-numpy-where