简单讨论一下argmax函数及其用法,由于其在numpyPyTorch中都有出现,所以先在numpy中讨论,然后补充介绍在PyTorch中的用法。

同理我们可以理解argmin

API

numpy.argmax(a, axis=None, out=None)[source]
Returns the indices of the maximum values along an axis.  

Returns:
index_array : ndarray of ints
Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed.

如API中描述的,它返回某一轴上最大值所在位置的下标。它有一个非常重要的参数axis

axis

axis : int, optional
By default, the index is into the flattened array, otherwise along the specified axis.

比如:

# 一维
b = np.asarray([0, 5, 2, 3, 4, 5])
np.argmax(b)  # 1
# 二维  
a = np.asarray([[0, 2, 5], [4, 1, 3]])
np.argmax(a) # 2
np.argmax(a, axis=1) # array([2, 0])

对于一维的情况,最大值是5,但是注意它只返回第一个最大值元素的下标,所以这里是1。
对于二维的情况,需要指定其在哪一个轴上做,比如这里a.shape(2,3),如果axis不指定,那就默认把它平坦化,变成array([0, 2, 5, 4, 1, 3],同一维情况一样,返回2;如果指定axis=1,表明在[0, 2, 5][4, 1, 3]上分别执行,前者返回2,后者返回0,所以结果返回array([2, 0])

问:如果指定axis=0呢?
这就需要我们对axis有更清晰的理解了。以数组a为例,它元素可以用$a_ij$来表示,其中i表示第0维,j表示第1维,这里的维就可以理解为轴,i表示纵轴,j表示横轴。
所以axis=1时,就表示以横轴为单位进行运算,因此返回数组与纵轴元素的数目相同。
同理,axis=0时,以纵轴为单位运算,即a[0][0] vs a[1][0], a[0][1] vs a[1][1], a[0][2] vs a[1][2],返回数组与纵轴元素的数目相同,结果为array([1, 0, 0])

更高维的情况呢?
总结为一句话,a_i...k...n,设axis=k,那么沿着第k个下标的位置进行操作。
例如三维的情况:

check = np.asarray([[[1, 2, 3],[10, 11, 12],[4 , 5 , 6 ],[7, 8, 9]],
                    [[4, 5, 6],[7 , 8 , 9 ],[1 , 2 , 3 ],[0, 1, 2]],
                    [[1, 2, 3],[4,  5,  6], [10, 11, 12],[3, 4, 5]]])
#  check.shape  (3, 4, 3)
np.argmax(check, axis=0)  # eg: 1, 4, 1 -> 1
# array([[1 1 1] [0 0 0] [2 2 2] [0 0 0]])
np.argmax(check, axis=1)  # eg: 1, 10, 4, 7 -> 1
# array([[1 1 1] [1 1 1] [2 2 2]])
np.argmax(check, axis=2)  # eg: 1, 2, 3
# array([[2 2 2 2] [2 2 2 2] [2 2 2 2]])

PyTorch

PyTorch里的对应axis参数为dim,更为直观,表示在第几维做操作。