작성자 - 백동희(donghee.paek@gmail.com)

본 요약문은 2020년 7월 30일 모두의 연구소 슬로우페이퍼 7기 온라인 풀잎스쿨에서 토의한 내용과 논문을 기반으로 작성되었습니다.

Training GAN with too little data

Generative Model 중 하나인 GAN을 학습시킬 때, 너무 적은 데이터를 학습시키면 Discriminator의 오버피팅 현상이 발생합니다. 여기서 Discriminator는 간단하게 이미지를 진짜(Real Data)인지, 가짜(Generated Data)인지 판단하여 Sigmoid를 통과한 확률을 리턴하는 네트워크라 보시면 됩니다. 진짜면 1에 가까운 값을, 가짜면 0에 가까운 값을 리턴합니다.

저자는 적은 데이터셋을 사용할 때 발생하는 Discriminator의 오버피팅 문제를 아래와 같이 두가지 방식으로 설명합니다.

1. FID Score의 Divergence

Frechet Inception Distance(FID)를 계산하기 위해서 생성된 이미지와 진짜 이미지를 Inception V3 네트워크를 통과시킵니다. 네트워크 중간층에서 각 이미지를 나타내는 feature들을 가져올 수 있는데, 이 feature들의 분포를 가우시안이라 가정합니다. 그리고 실제 데이터의 feature(x)와 생성된 데이터의 feature(g)의 거리를 계산합니다.

여기서 실제 데이터 feature의 평균, 분산은 매번 계산할 필요없이 한번 계산하여, 가지고 있으면 됩니다. 만약 생성된 데이터의 feature의 평균, 분산이 실제 데이터의 feature의 평균, 분산과 가깝다면 FID score는 작은 값을 가지게 됩니다. 따라서 FID score가 작을수록 GAN의 성능이 높다고 볼 수 있습니다. 하지만 그림 1의 (a)를 보게되면 적은 데이터(70k 이하)로 학습시킨 경우, Test 데이터 셋에 대해서 FID score가 수렴하지않고 증가(Diverge)합니다. 반대로 큰 데이터 셋(140k)으로 학습시킨 경우, FID score가 수렴합니다.

2. Discriminator의 판별 능력

저자는 Discriminator가 적은 데이터셋에 대해서 오버피팅하는 것을 "Generator를 학습시킬 때, Discriminator의 피드백이 의미가 없어진다."라고 언급합니다. 이 말의 의미를 그림 1의 (b), (c)를 보고 유추할 수 있었습니다. (b)는 (c)에 대해서 상대적으로 큰 데이터셋에 대한 D(x) 값을 나타냅니다. 아까 D는 주어진 데이터 x를 진짜라고 판단하는 확률(0~1 사이)이라 했으나, 그림 1을 보면 약 -6부터 6까지의 값을 가집니다. 따라서 여기서의 D(x)는 Sigmoid 함수를 통과하기 전에 D가 반환하는 값이라 이해했습니다. 즉, D(x)가 음수에 가까울수록 가짜라고 판단하며, D(x)가 양수에 가까울수록 진짜라고 판단한다 보면 될듯합니다.

논문에선 Validation 데이터셋에 대한 명확한 언급이 없는데, 진짜(Real)와 가짜(Generated)가 섞여있는 데이터 셋이라 가정하고 보았습니다. 큰 데이터셋에 대해 학습된 GAN을 보면, Validation 데이터셋(초록색)을 전부 가짜라 판단하지 않습니다. 하지만 작은 데이터셋에 대해 학습된 GAN을 보면, 진짜 데이터도 가짜, 가짜 데이터도 가짜라고 판단합니다. 즉, Classification을 하는 Discriminator가 Training 데이터셋에만 오버피팅되어, Training 데이터셋 외에는 전부 가짜라 판단하는 상황이라 보았습니다. 이를 2016년도 이안 굿펠로의 GAN 논문의 Generator를 학습시킬 때의 Cost에 빚대어 생각해보았습니다.

현재 상황에서, D(G(z))를 진짜든, 가짜든 0으로 Return하는 상황입니다. 0이 역시그모이드 함수를 통과하면(그림 1에서 (c)의 상황과 동일), 음수값을 리턴하게 됩니다. exponential에서 음수에 해당하는 부분을 생각해보면 기울기인 그래디언트가 없습니다. 이를 "Generator를 학습시킬 때, Discriminator의 피드백이 의미가 없어진다"라는 것과 일맥상통한다고 보았습니다. Discriminator, 일종의 Classification Network의 오버피팅을 극복하는 일반적인 방법이 Data Augmentation 기법입니다.

Data Augmentation

Data Augmentation(DA)은 이미지로 빚대어 생각하면 아래와 같이 원본 이미지에 인위적인 변화를 주는 것입니다. 그리고 이 인위적인 변화(원본 이미지 입장에서의 노이즈)를 준 이미지를 학습 데이터에 사용하는 것입니다. 딥러닝의 고질적인 문제로 트레이닝 데이터셋에 편향되게 학습하는 것을 오버피팅이라 합니다. 이 오버피팅 문제를 해결하기 위해서, 사용되고 있는 방법으로는 Regularization, Normalization 방법이 있습니다. 하지만 이 방식들을 Biased 학습 방향을 경감시키는 정도라고 볼 수 있습니다. 이를 해결하기 위해서, "학습의 방향을 모든 방향으로 넓힐 수 있지않을까?"라는 방법으로 제시된 것이 DA라 볼 수 있습니다. Biased된 학습은 오류를 발생시키지만, DA를 통해 적당한 힘으로 학습 면적을 "아주 조금: 편향된 방향에서 많이 벗어나지 않고", "골고루: 아래와 같은 여러 Augmentation 방식" 넓히자는 의미입니다.

Training GAN with Data Augmentation

기존 Classification Networks에선 오버피팅을 해결하기 위한 하나의 방법으로 Data Augmentation을 사용해왔습니다. 하지만 GAN에서는 다릅니다. Generating이 있기 때문에, Data Augmentation을 적용하여 GAN을 학습시킬시, 그림 3의 (b)에서 E와 같이 Augmentation을 포함한 이미지를 생성해버립니다.