Transfer learning 이란?

2012년 Alex Krizhevsky가 AlexNet[1]이라는 convolutional neural network (CNN)으로 ImageNet 데이터셋에서 기존 비전 알고리즘을 모두 압도하는 성능을 보고한 이후로 CNN은 비전문제에 매우 광범위하게 적용되고 있습니다. 하지만 기존 비전 알고리즘에 비해 CNN이 갖는 치명적인 단점이 있었는데 그것은 바로 깊은 CNN을 학습시키기 위해 매우 많은 양의 데이터가 필요하다는 것이었습니다. 이러한 문제점 때문에 초창기 CNN은 주로 ImageNet과 같이 매우 많은 이미지들로 구성된 데이터셋에만 적용되어 왔습니다.

많은 연구자들이 어떻게 하면 데이터가 적은 도메인에 CNN을 적용할 수 있을 지를 고민하고 연구한 결과 여러 방법들을 생각해냅니다. 그 방법들 중 하나가 본 글의 주제인 transfer learning입니다. Transfer learning이란 데이터가 아주 많은 도메인(source 도메인)에서 학습된 지식을 데이터가 적은 도메인(target 도메인)으로 전이(transfer)시키는 학습 방법론입니다. CNN에서 transfer learning은 주로 데이터가 적은 target 도메인 학습을 위한 CNN의 weight 값을 source 도메인에서 학습된 CNN의 weight 값으로 초기화하는 방식으로 적용됩니다. Transfer learning에 대한 아이디어는 2012~2013년에 활발히 논의가 되어 2014년에 이와 관련된 많은 논문들이 publish 되었습니다[2-7]. 이 중에서 특히 [7]에서 보고한 transfer learning에 관한 다양한 실험 결과들은 연구자들로 하여금 CNN에 transfer learning을 적용하는 방식에 대한 많은 영감을 주었습니다. [2-7] 논문들은 모두 CNN 연구 초창기에 transfer learning이라는 방법론이 정착할 수 있도록 하는 데에 많은 기여를 한 논문들이므로 한 번쯤 읽어보는 것을 추천드립니다.

Transfer learning 연구들에서 사용되는 source domain과 target domain의 유사성에 대한 고찰

2014년 이후로는 거의 모든 CNN 연구에 transfer learning이 적용되었습니다. 그런데 대부분의 연구들에서 transfer learning 적용 시 source와 target 도메인들을 서로 동일한 특성을 가지고 있는 데이터셋으로 세팅하여 사용하였습니다. 대표적인 예로는 source 도메인에 ImageNet 데이터셋을, target 도메인에 비교적 작은 데이터셋인 Pascal VOC나 Caltech-101 데이터셋을 사용하는 경우가 있습니다. 이러한 source와 target 도메인의 데이터셋들은 전체적으로 비슷한 유형의 이미지들로 구성되어 있으며 심지어 서로 동일한 클래스도 존재합니다(그림 1).

이러한 source와 target 도메인은 단지 사람이 보기에 유사하다고 생각되는 정도가 아니라 CNN에서 추출하는 feature들도 유사한 특성을 보입니다. 그림 2는 동일한 CNN에 source와 target 도메인의 이미지들을 통과시켜 feature vector를 추출한 후 PCA, t-SNE로 2차원 공간에 embedding 시킨 결과입니다. 그림에서 볼 수 있듯이 CNN은 target 도메인인 Caltech-101과 Pascal VOC 데이터셋의 feature가 ImageNet 데이터셋의 feature와 매우 유사하다고 판단합니다. 따라서 이와 같이 유사한 도메인 사이에서 transfer learning이 잘 작동하는 것은 어떻게 보면 당연한 결과라고 할 수도 있습니다.

하지만 세이지리서치에서 주로 타겟하고 있는 도메인은 제조업 도메인입니다. 제조업 도메인의 이미지들이 학계에서 주로 사용되는 ImageNet 데이터셋의 이미지들과 얼마나 다른 지를 DAGM 데이터셋[8]을 통해 살펴보겠습니다. DAGM 데이터셋은 섬유 형상 배경에 가상의 결함들을 만들어 놓은 데이터셋으로 그림 3과 같이 이미지의 외관이 ImageNet 데이터셋과 매우 다릅니다. 이미지의 외관뿐만 아니라 CNN 역시 이 두 데이터셋의 feature가 서로 다르다고 인식합니다. 위에서 설명한 방법과 동일한 방식으로 ImageNet과 DAGM 데이터셋에서 추출한 feature들을 시각화하면 그림 4와 같습니다. 그림 4의 결과를 통해 CNN은 학계에서 흔히 사용되는 Caltech-101이나 Pascal VOC와는 달리 DAGM과 같은 제조업 데이터의 feature는 ImageNet 데이터의 feature와 매우 다르다고 인식하는 것을 확인할 수 있습니다.

따라서 다른 연구들에서와 같이 transfer learning을 잘 적용하기 위해서는 우리가 타겟으로 하려는 제조업 도메인의 데이터와 유사한 특성을 가지고 있는 source 도메인의 데이터셋을 이용하는 것이 필요해 보이기도 합니다. 하지만 각 제조업체들에서 생산하는 제품에 대한 데이터는 매우 엄격한 보안 아래에서 관리되며 따라서 안타깝게도 제조업 도메인에서 ImageNet과 같이 방대한 데이터셋을 구축하는 것은 거의 불가능에 가깝습니다. 그렇다면 이러한 제조업 도메인에서는 transfer learning을 어떻게 적용해야 할까요?

그림 1. Transfer learning을 적용하는 많은 연구들에서 사용되는 source와 target 도메인 사이의 유사성

그림 1. Transfer learning을 적용하는 많은 연구들에서 사용되는 source와 target 도메인 사이의 유사성

https://s3-us-west-2.amazonaws.com/secure.notion-static.com/c993a2a4-ca68-47b3-8006-6ae2caecf63c/Untitled.png

그림 2. CNN에서 추출한 각 데이터셋의 feature를 2차원 공간에 embedding 시킨 결과(위: PCA, 아래: t-SNE)

그림 2. CNN에서 추출한 각 데이터셋의 feature를 2차원 공간에 embedding 시킨 결과(위: PCA, 아래: t-SNE)

그림 3. ImageNet 데이터셋과 DAGM 데이터셋

그림 3. ImageNet 데이터셋과 DAGM 데이터셋

그림 4. ImageNet 데이터와 DAGM 데이터의 feature embedding 결과(왼쪽: PCA, 오른쪽: t-SNE)

그림 4. ImageNet 데이터와 DAGM 데이터의 feature embedding 결과(왼쪽: PCA, 오른쪽: t-SNE)

Transfer learning for industrial visual inspection

결론부터 말씀드리면 transfer learning은 source와 target 도메인의 특성이 매우 다를 때에도 잘 작동합니다. 즉, 제조업 도메인의 검사를 위한 CNN을 학습시킬 때에도 ImageNet 데이터셋을 source 도메인으로 선택하여 transfer learning을 적용하면 됩니다. 그렇다면 지금부터는 서로 전혀 다른 source와 target 도메인 사이에 transfer learning을 적용했을 때 CNN의 성능이 어떻게 변하는지, CNN에서 학습되는 feature에는 어떠한 영향을 미치는 지에 대해 살펴보겠습니다. 본격적인 논의에 앞서 다음 몇 가지 사항들을 정리하고 가겠습니다.

그림 5. DAGM 데이터셋

그림 5. DAGM 데이터셋

제조업 도메인에서 transfer learning이 CNN의 성능에 미치는 영향

제조업 도메인에서 transfer learning이 CNN 성능에 미치는 영향을 확인하기 위해 다음과 같은 세 가지 세팅으로 CNN을 학습하였습니다.

  1. Data augmentation 없이 target 데이터셋에 CNN을 처음부터 학습시킴: Scratch (no aug)
  2. Target 데이터셋에 data augmentation (rotate, flip)을 적용한 후 CNN을 처음부터 학습시킴: Scratch (aug)
  3. Transfer learning 적용하여 target 네트워크를 학습시킴: Transfer Learning

실험 결과는 그림 6과 같습니다. 먼저 *Scratch (no aug)*는 낮은 분류 정확도(70.87%)에 수렴하는 것을 확인할 수 있습니다. 이를 통해 DAGM 데이터셋과 같이 적은 수의 데이터로 깊은 네트워크를 학습시킬 경우 over-fitting이 발생하는 것을 확인할 수 있습니다. 이러한 over-fitting 문제는 data augmentation으로 해결할 수 있습니다 (Scratch (aug) 결과). 하지만 그래프에서 볼 수 있듯이 최종 정확도(99.55%)에 수렴하기까지 매우 오랜 시간(320k iteration)을 학습해야 하는 것을 확인할 수 있습니다. 반면에 transfer learning을 적용할 경우 어떠한 data augmentation 없이도 최종 정확도(99.90%)까지 매우 적은 학습(2,242 iteration)만에 도달하는 것을 확인할 수 있습니다.

이 실험 결과를 통해 제조업 도메인과 같이 방대한 source 도메인을 구축하기 힘든 도메인에는 target과 전혀 다른 특성의 source 도메인으로부터 transfer learning을 적용하는 것이 큰 도움이 된다는 것을 알 수 있습니다.

 그림 6. 각 실험의 classification 성능 비교

그림 6. 각 실험의 classification 성능 비교

Transfer learning으로 학습된 CNN의 특징

Transfer learning을 사용하면 저희의 타겟인 제조업 도메인에 적은 데이터만으로 CNN을 성공적으로 적용할 수 있는 것을 확인하였습니다. 그렇다면 transfer learning을 이용하여 학습한 CNN (이후 Transfer learning이라고 기술)은 어떤 특징을 가지길래 처음부터 학습한 CNN (이후 Scratch라고 기술)과 비교하여 훨씬 높은 학습 효율을 낼 수 있는 지에 대해 알아보겠습니다.

Transfer learningScratch에 비해 다음과 같은 세 가지 특징을 보입니다.

  1. 동일한 수준의 정확도에 도달하기까지 Transfer learning에서는 Scratch보다 적은 양의 weight update가 일어납니다. 또한 네트워크의 아랫단보다 윗단의 weight update가 많이 됩니다.
  2. Transfer learningScratch보다 더 sparse한 feature를 학습합니다.
  3. Transfer learningScratch보다 더 disentangle된 feature를 학습합니다.

Transfer learning의 각 특징들을 살펴보기 위해 DAGM 실험에서 높은 분류 정확도에 도달한 두 네트워크 Transfer learning과 *Scratch (aug)*의 특징들을 비교 분석해보겠습니다.

1. Transfer learningScratch보다 적은 양의 weight를 update 하며 네트워크의 윗단을 주로 update 합니다.

CNN의 각 $l$ 번째 layer의 weight 변화율을 아래 식과 같이 weight의 초기값 $w_{0}$과 학습이 수렴한 후의 weight 값 $w_{\text{converged}}$의 차이로 정의합니다.

$$ \gamma^{l} \triangleq \frac{\parallel w_{\text{converged}}^{l} -w_{0}^{l} \parallel_{2}}{\parallel w_{0}^{l} \parallel_{2}} $$

그림 7은 학습이 수렴한 후 Transfer learning과 *Scratch (aug)*의 각 layer에서의 weight 변화율 $\gamma^{l}$을 비교한 결과입니다. 흥미롭게도 두 네트워크에서 weight 변화율은 전혀 반대의 경향을 보입니다. Transfer learning에서는 *Scratch (aug)*에 비해 전체적으로 훨씬 적은 양의 weight 변화가 일어나는 것을 확인할 수 있습니다. 특히 Transfer learning은 네트워크의 아랫단(1~7번째 layer)의 weight를 거의 update하지 않는 것을 확인할 수 있습니다. 즉, Transfer learning의 아랫단에서는 ImageNet 데이터셋에서 학습된 source 네트워크와 거의 동일한 low-level feature를 추출합니다 (neural network의 아랫단에서는 이미지의 edge와 같은 low-level feature들이 학습된다고 알려져있습니다 [3]). Transfer learning에서 주로 학습 되는 것은 네트워크의 윗단이며 이를 통해 source 도메인과는 전혀 다른 target 도메인의 task가 학습이 된다는 것을 알 수 있습니다.

이와는 반대로 *Scratch (aug)*는 윗단보다는 아랫단이 훨씬 더 많이 학습됩니다. 즉, CNN을 새로운 target 도메인에 처음부터 학습시킬 때는 해당 도메인에 맞는 low-level feature들을 학습하는 데에 학습의 많은 부분이 투자된다는 것도 확인할 수 있습니다.

그림 7. Transfer learning과 *Scratch (aug)*의 각 layer의 weight 변화율

그림 7. Transfer learning과 *Scratch (aug)*의 각 layer의 weight 변화율

2. Transfer learningScratch보다 더 sparse한 feature를 학습합니다.