Python

Numpy axis 이해하기

체봄 2021. 9. 5. 14:30

numpy로 연산을 할 때 axis를 지정하는 함수가 많은데, 헷갈리는 경우가 많아 기록해둔다.

 

우선 말로 정리하자면 axis=0이면 가장 높은 차원을 의미하고, 1, 2, ...와 같이 증가할수록 한 차원씩 낮은 차원을 의미한다.

axis=-1이면 가장 낮은 차원을 의미하고, -2, -3과 같이 감소할수록 한 차원씩 높은 차원을 의미한다.

 

 

3차원 배열의 예시로 보면 쉽다.

import numpy as np

arr1 = np.array([[[1,2,3],
                  [4,5,6]],
                 [[7,8,9],
                  [10,11,12]]]) # shape: (2,2,3)
arr2 = np.array([[[1,1,1],
                  [1,1,1]],
                 [[1,1,1],
                  [1,1,1]]]) # shape: (2,2,3)

shape이 (2, 2, 3)인 두 3차원 배열이 있다.

두 배열을 이어 붙이는 np.concatenate() 함수를 통해서 axis 값이 어떤 차원의 값을 의미하는지 알아볼 것이다.

 

concat1 = np.concatenate((arr1, arr2), axis=0)  # axis=-3과 동일
concat1.shape # shape: (4,2,3)

axis=0이므로 가장 높은 차원인 3차원에 해당한다. 3차원에 대해 concatenate 했으니 3차원에 해당하는 shape가 변경되었다.

axis를 음수로 나타내면 현재 3차원 배열이므로 세번째로 낮은 차원이기 때문에 axis=-3로 나타낼 수 있다.

 

concat2 = np.concatenate((arr1, arr2), axis=1)  # axis=-2와 동일
concat2.shape # shape: (2,4,3)

axis=1이므로 두번째로 높은 차원인 2차원에 대해 concat 연산을 진행하므로, shape를 보면 2차원에 해당하는 부분이 변경되었다.

axis를 음수로 나타내면 두번째로 낮은 차원이므로 axis=-2로 나타낼 수 있다.

 

concat3 = np.concatenate((arr1, arr2), axis=2)  # axis=-1과 동일
concat3.shape # shape: (2,2,6)

axis=2이므로 세번째로 높은 차원(가장 낮은 차원)인 1차원에 대해 연산을 진행하였고, 1차원에 해당하는 shape이 변경되었다.

axis를 음수로 나타내면 1차원은 가장 낮은 차원이므로 axis=-1로 나타낼 수 있다.

반응형