hiro_5656's blog

機械学習やクラウド技術について勉強したことを発信していきます!

numpy配列を効率的に連結させる!


まえがき

データ分析を行なっていると、複数のnumpy配列を特定の軸で連結させたいことはないだろうか。
私もこれまでは、空のnumpy配列を用意してそれにnp.append()などで連結するなどの方法をとっていたのだがとても面倒でした。

今回はより効率的に複数のnumpy配列を連結する方法を紹介していきます!

① リストで連結する

要素の追加や削除が特に大きな制約なく行えるリストはPythonの長所ですよね。
リストの追加には、appendを使います。

list0 = []
list1 = [0, 1, 2]
list2 = [3, 4, 5]

list0.append(list1)
list0.append(list2)
list0
### -----------------------------------
[[0, 1, 2], [3, 4, 5]]

リストの連結には、extendを使います。

list0 = []
list1 = [0, 1, 2]
list2 = [3, 4, 5]

list0.extend(list1)
list0.extend(list2)
list0
### -----------------------------------
[0, 1, 2, 3, 4, 5]

では連結させる要素がリストではなく、numpy.array()型だった場合はどうなるであろうか。

list0 = []
arr1 = np.arange(0, 2*2*3).reshape(2, 2, 3)
arr2 = np.arange(100, 100+4*2*3).reshape(4, 2, 3)

list0.append(arr1)
list0.append(arr2)

list0
### -----------------------------------
[array([[[ 0,  1,  2],
         [ 3,  4,  5]],
 
        [[ 6,  7,  8],
         [ 9, 10, 11]]]),
 array([[[100, 101, 102],
         [103, 104, 105]],
 
        [[106, 107, 108],
         [109, 110, 111]],
 
        [[112, 113, 114],
         [115, 116, 117]],
 
        [[118, 119, 120],
         [121, 122, 123]]])]

numpy.array()型の要素がリスト化されるだけである。

問題はextendした場合である。

list0 = []
arr1 = np.arange(0, 2*2*3).reshape(2, 2, 3)
arr2 = np.arange(100, 100+4*2*3).reshape(4, 2, 3)

list0.extend(arr1)
list0.extend(arr2)

list0
### -----------------------------------
[array([[0, 1, 2],
        [3, 4, 5]]),
 array([[ 6,  7,  8],
        [ 9, 10, 11]]),
 array([[100, 101, 102],
        [103, 104, 105]]),
 array([[106, 107, 108],
        [109, 110, 111]]),
 array([[112, 113, 114],
        [115, 116, 117]]),
 array([[118, 119, 120],
        [121, 122, 123]])]

なんとnumpy.array()型のaxis=0の方向に連結されるのである...!

print(arr1.shape)
print(arr2.shape)
print(np.array(list0).shape)
### -----------------------------------
(2, 2, 3)
(4, 2, 3)
(6, 2, 3)

このやり方の嬉しい点は、形状を定義した空の配列を用意する必要がなく、リストを拡張するだけなので高速に処理を行えるところである。

② 任意の軸でnumpy配列を連結する方法

リストのextendを使ったやり方は、axis=0の軸でしか連結できないがこれを任意の軸で連結させたい場合には、numpy.concatenate()を使用する。

list0 = []
arr1 = np.arange(0, 2*2*3).reshape(2, 2, 3)
arr2 = np.arange(100, 100+2*4*3).reshape(2, 4, 3)
arr3 = np.arange(200, 200+2*3*3).reshape(2, 3, 3)

list0.append(arr1)
list0.append(arr2)
list0.append(arr3)

list0
### -----------------------------------
[array([[[ 0,  1,  2],
         [ 3,  4,  5]],
 
        [[ 6,  7,  8],
         [ 9, 10, 11]]]),
 array([[[100, 101, 102],
         [103, 104, 105],
         [106, 107, 108],
         [109, 110, 111]],
 
        [[112, 113, 114],
         [115, 116, 117],
         [118, 119, 120],
         [121, 122, 123]]]),
 array([[[200, 201, 202],
         [203, 204, 205],
         [206, 207, 208]],
 
        [[209, 210, 211],
         [212, 213, 214],
         [215, 216, 217]]])]

arr1, arr2, arr3 の3つのnumpy配列をaxis=1の軸で連結させたい場合は以下のようにする。

np.concatenate([arr for arr in list0], axis=1)
### -----------------------------------
array([[[  0,   1,   2],
        [  3,   4,   5],
        [100, 101, 102],
        [103, 104, 105],
        [106, 107, 108],
        [109, 110, 111],
        [200, 201, 202],
        [203, 204, 205],
        [206, 207, 208]],

       [[  6,   7,   8],
        [  9,  10,  11],
        [112, 113, 114],
        [115, 116, 117],
        [118, 119, 120],
        [121, 122, 123],
        [209, 210, 211],
        [212, 213, 214],
        [215, 216, 217]]])

pythonの内包表記を利用して、連結する要素のリストをnumpy.concatenate()の引数に渡すことで任意の軸での連結を行なっている。

np.concatenate([arr for arr in list0], axis=1).shape
### -----------------------------------
(2, 9, 3)

③ 新規の軸で連結する

同じ形状のnumpy配列を新たな軸で連結したい場合もあるだろう。その場合は、numpy.stack()を用いる。

list0 = []
arr1 = np.arange(0, 2*2*3).reshape(2, 2, 3)
arr2 = np.arange(100, 100+2*2*3).reshape(2, 2, 3)
arr3 = np.arange(200, 200+2*2*3).reshape(2, 2, 3)

list0.append(arr1)
list0.append(arr2)
list0.append(arr3)

list0
### -----------------------------------
[array([[[ 0,  1,  2],
         [ 3,  4,  5]],
 
        [[ 6,  7,  8],
         [ 9, 10, 11]]]),
 array([[[100, 101, 102],
         [103, 104, 105]],
 
        [[106, 107, 108],
         [109, 110, 111]]]),
 array([[[200, 201, 202],
         [203, 204, 205]],
 
        [[206, 207, 208],
         [209, 210, 211]]])]
np.stack([arr for arr in list0], axis=0)
### -----------------------------------
array([[[[  0,   1,   2],
         [  3,   4,   5]],

        [[  6,   7,   8],
         [  9,  10,  11]]],


       [[[100, 101, 102],
         [103, 104, 105]],

        [[106, 107, 108],
         [109, 110, 111]]],


       [[[200, 201, 202],
         [203, 204, 205]],

        [[206, 207, 208],
         [209, 210, 211]]]])
np.stack([arr for arr in list0], axis=0).shape
### -----------------------------------
(3, 2, 2, 3)