import numpy as np
x = np.random.random((5,3,3)) * 10 // 1
array([[[8., 8., 1.],
[8., 1., 6.],
[0., 7., 7.]],
[[2., 7., 2.],
[4., 5., 1.],
[6., 8., 5.]],
[[5., 1., 1.],
[0., 1., 8.],
[5., 7., 1.]],
[[0., 5., 9.],
[6., 7., 3.],
[5., 0., 5.]],
[[0., 1., 9.],
[1., 2., 3.],
[0., 4., 7.]]])
y = x.reshape(5,9).argmax(axis=1)
# array([0, 7, 5, 2, 2])
f = lambda i: (i // 3, i % 3)
# или вернее так:
f = lambda i: (i // x.shape[1], i % x.shape[2])
(a, b) = f(y)
# array([0, 2, 1, 0, 0]), array([0, 1, 2, 2, 2])
result = np.empty((a.size + b.size,), dtype=a.dtype)
result[0::2] = a
result[1::2] = b
result.reshape(5,2)
array([[0, 0],
[2, 1],
[1, 2],
[0, 2],
[0, 2]])