numpy模块中axis的理解——以np.argmax为例

2023-11-19

numpy模块中axis的理解——以np.argmax为例

np.argmax参数数量及其作用

np.argmax是用于取得数组中每一行或者每一列的的最大值。常用于机器学习中获取分类结果、计算精确度等。
函数如下:

np.argmax(
	a, 
	axis=None, 
	out=None)

部分参数解释:
a:输入矩阵;
axis:对于二维向量而言,0代表对行进行最大值选取,此时对每一列进行操作;1代表对列进行最大值选取,此时对每一行进行操作。三维向量的情况更为复杂,需要结合例子说明。实际上axis的大小代表着进入到第axis+1个[ ]内,对其剩余的部分进行对比;
out:可以指定输出矩阵的变量

axis不同情况的示例

代码较长,许多都是注释,请大家耐心观看

import numpy as np
# 一维向量测试
# 取出x中元素最大值所对应的索引
# 此时最大值为11,其对应的位置索引值为11
x = np.arange(12)
index = np.argmax(x)
print("1 dimension test:",index)

# 二维向量测试
# 0代表对行进行最大值选取,此时对每一列进行操作
x = np.arange(12).reshape(3,4)
index = np.argmax(x,axis = 0)
# 结果为[2 2 2 2]
print("2 dimension test, axis = 0:",index)

# 二维向量测试
# 1代表对列进行最大值选取,此时对每一行进行操作
x = np.arange(12).reshape(3,4)
index = np.argmax(x,axis = 1)
# 结果为[3 3 3]
print("2 dimension test, axis = 1:",index)

# 三维向量测试
# 0代表进入第一个[]内进行对比
x = np.arange(24).reshape(2,3,4)
x[1,0,3] = 1
# x = 
# [[[ 0  1  2  3]
#   [ 4  5  6  7]
#   [ 8  9 10 11]]

#  [[12 13 14  1]
#   [16 17 18 19]
#   [20 21 22 23]]]
index = np.argmax(x,axis = 0)
print("3 dimension test, axis = 0:",index)
# 当axis=0时,进入第一个[]内进行对比,此时x剩下两部分。
#  [[ 0  1  2  3]
#   [ 4  5  6  7]
#   [ 8  9 10 11]]

#  [[12 13 14  1]
#   [16 17 18 19]
#   [20 21 22 23]]
# 两部分格式相同,将剩下的两部分每一个单位进行对比,对比结果为
#  [[1  1  1  0]
#   [1  1  1  1]
#   [1  1  1  1]]
# 除去我设置的特殊位置外,其他位置均为第二部分大。

# 三维向量测试
# 1代表进入第二个[]内进行对比
# x = 
# [[[ 0  1  2  3]
#   [ 4  5  6  7]
#   [ 8  9 10 11]]

#  [[12 13 14  1]
#   [16 17 18 19]
#   [20 21 22 23]]]
index = np.argmax(x,axis = 1)
print("3 dimension test, axis = 1:",index)
# 当axis=1时,进入第二个[]内进行对比。
# [ [ 0  1  2  3]
#   [ 4  5  6  7]
#   [ 8  9 10 11]

#   [12 13 14  1]
#   [16 17 18 19]
#   [20 21 22 23] ]
# 对于第二个[]内的内容而言,均剩下三部分,我特意将两个第二个[]内的内容分开更容易辨认
# 第一个是
#   [ 0  1  2  3]
#   [ 4  5  6  7]
#   [ 8  9 10 11]
# 第二个是
#   [12 13 14  1]
#   [16 17 18 19]
#   [20 21 22 23]
# 都是第三行的值最大,所以输出结果为
#  [[ 2  2  2  2]
#   [ 2  2  2  2]]

# 三维向量测试
# 2代表进入第三个[]内进行对比
x = np.arange(24).reshape(2,3,4)
x[1,0,3] = 1
# x = 
# [[[ 0  1  2  3]
#   [ 4  5  6  7]
#   [ 8  9 10 11]]

#  [[12 13 14  1]
#   [16 17 18 19]
#   [20 21 22 23]]]
index = np.argmax(x,axis = 2)
print("3 dimension test, axis = 2:",index)
# 当axis=2时,进入第三个[]内进行对比。
# [[  0  1  2  3 
#     4  5  6  7 
#     8  9 10 11 ]
#  [ 12 13 14  1 
#    16 17 18 19 
#    20 21 22 23 ]]
# 对于第三个[]内的内容而言,均剩下四部分,我特意将六个第三个[]内的内容分开更容易辨认
# 第一个是
# 0  1  2  3 
# 第二个是
# 4  5  6  7
# ……
# 最后对比结果为
#  [[ 3  3  3 ]
#   [ 2  3  3 ]]

实际上axis的大小代表着进入到第axis+1个[ ]内,对其剩余的部分进行对比。
运行结果为:

1 dimension test: 11
2 dimension test, axis = 0: [2 2 2 2]
2 dimension test, axis = 1: [3 3 3]
3 dimension test, axis = 0: [[1 1 1 0]
 							 [1 1 1 1]
 							 [1 1 1 1]]
3 dimension test, axis = 1: [[2 2 2 2]
							 [2 2 2 2]]
3 dimension test, axis = 2: [[3 3 3]
							 [2 3 3]]
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

numpy模块中axis的理解——以np.argmax为例 的相关文章