Python
Numpy axis 이해하기
체봄
2021. 9. 5. 14:30
numpy로 연산을 할 때 axis를 지정하는 함수가 많은데, 헷갈리는 경우가 많아 기록해둔다.
우선 말로 정리하자면 axis=0이면 가장 높은 차원을 의미하고, 1, 2, ...와 같이 증가할수록 한 차원씩 낮은 차원을 의미한다.
axis=-1이면 가장 낮은 차원을 의미하고, -2, -3과 같이 감소할수록 한 차원씩 높은 차원을 의미한다.
3차원 배열의 예시로 보면 쉽다.
import numpy as np
arr1 = np.array([[[1,2,3],
[4,5,6]],
[[7,8,9],
[10,11,12]]]) # shape: (2,2,3)
arr2 = np.array([[[1,1,1],
[1,1,1]],
[[1,1,1],
[1,1,1]]]) # shape: (2,2,3)
shape이 (2, 2, 3)인 두 3차원 배열이 있다.
두 배열을 이어 붙이는 np.concatenate() 함수를 통해서 axis 값이 어떤 차원의 값을 의미하는지 알아볼 것이다.
concat1 = np.concatenate((arr1, arr2), axis=0) # axis=-3과 동일
concat1.shape # shape: (4,2,3)
axis=0이므로 가장 높은 차원인 3차원에 해당한다. 3차원에 대해 concatenate 했으니 3차원에 해당하는 shape가 변경되었다.
axis를 음수로 나타내면 현재 3차원 배열이므로 세번째로 낮은 차원이기 때문에 axis=-3로 나타낼 수 있다.
concat2 = np.concatenate((arr1, arr2), axis=1) # axis=-2와 동일
concat2.shape # shape: (2,4,3)
axis=1이므로 두번째로 높은 차원인 2차원에 대해 concat 연산을 진행하므로, shape를 보면 2차원에 해당하는 부분이 변경되었다.
axis를 음수로 나타내면 두번째로 낮은 차원이므로 axis=-2로 나타낼 수 있다.
concat3 = np.concatenate((arr1, arr2), axis=2) # axis=-1과 동일
concat3.shape # shape: (2,2,6)
axis=2이므로 세번째로 높은 차원(가장 낮은 차원)인 1차원에 대해 연산을 진행하였고, 1차원에 해당하는 shape이 변경되었다.
axis를 음수로 나타내면 1차원은 가장 낮은 차원이므로 axis=-1로 나타낼 수 있다.
반응형