오늘은 24년 4월에 나온 메타의 Imagine Flash 논문을 리뷰한다. 메타의 diffusion 모델인 Emu를 베이스라인으로 잡고, Emu의 성능을 끌어올리는 후속 논문이다. Emu 외 다른 모델에도 범용적으로 쓸 수 있다고 한다.
참고로 Emu는 latent diffusion model로 latent embedding은 오토인코더(AE. Auto Encoder) 방식으로 2.8B 크기의 U-Net을 사용했다.
1. Introduction
이미지 생성 분야인 Denoising Diffusion Models(DMs)에서 문제점은 시간과 cost가 많이 드는 과정이란 거다. '시간' 측면에서, denoising 스텝 수와 스텝별 latency가 시간을 많이 걸리게 하는 두 주범이다. 그렇기에 sampling process, 스텝 수를 줄이는 연구들이 있었고 가장 많이 줄인 건 5 step까지 진전을 보였다. 하지만 이미지 퀄리티를 유지하면서 스텝 수를 줄이는 건 쉽지 않다.
main contribution은 다음과 같으며, 본 논문에서는 스텝 수를 3까지 줄이면서 Emu 모델보다 이미지 생성 능력에 있어 유사하거나 더 좋아지는 결과를 보았다.
- Backward Distillation 도입. student 모델의 학습과 inference 때 간극을 줄이기 위함이다. (Knowledge Distillation 쪽이구나.) 일전에 diffusion에서는 noise가 정말 랜덤에서 시작하느냐에 따른 의문에서 시작해 여러 시도들이 나왔었는데 그 대응책 중 하나로 논문에서 제안했다.
- Shifted Reconstruction Loss (SRL) 도입. Experiments를 보면 이부분의 도입으로 많은 개선을 보였다.
- Noise Correction 도입. 추론 시에 약간의 수정을 통해 퀄리티를 높이는 방법으로, 학습과는 무관하다. 이를 통해 생성된 이미지를 더 선명하고 밝거나 어둡게 만들 수 있게 된다.
2. Related Work
다음은 Diffusion Model에서 성능은 유지하면서 추론 속도는 향상하는 태스크들에 대한 소개다.
solvers and curvature rectification
추론 속도 향상을 위한 초기 아이디어다.
이 분야에서는 추론 path를 선형화하고 큰 스텝 사이즈를 허용해서, 추론 step 수를 줄이는 연구들이 있다. 다만 추론 스텝 사이즈가 커진다는 것은 생성 이미지 퀄리티를 어느 정도 버리는 것과 같다.
Reducing model size
모델 백본 사이즈를 아예 줄여버려서 추론 속도를 향상시켰다.
모델이 작으니 per-step latency도 줄어들고 추론 속도도 빨라진다. 다만 정말 인퍼런스 스케일을 줄이려면 결국 추론 step 수도 줄여야 한다.
Reducing sampling steps
본 논문에서 채택한, 추론 시 step 수를 줄이는 방법이다.
물론 step 수를 줄이면 이미지 생성 능력이 떨어지기 때문에, 이를 억제하기 위한 constraint를 추가해 주는 쪽의 연구들이 있다. quality와 speed를 어느 정도 타협하는 관점이 타당해 보이지만, 본 논문에서는 student's backward path를 증류하는 방법을 통해 퀄리티도 놓치지 않았다고 한다. (원래는 forward path를 distillation 한다.)
3. Background on Diffusion Models
diffusion 모델을 크게 두 부분으로 나누면 forward(encoder)와 backward(decoder)가 될 것이다.
forward에서는 멀쩡한 input 이미지 데이터(x_0)에 T번동안 계속해서 가우시안 noise를 주입해서 이미지를 아예 노이즈 데이터(x_t)로 만드는 것이 목표다. 아래 식의 입실론 ϵ가 noise다.
분홍색 계수는 signal-to-noise ratio(SNR)로, noise를 얼마나 주입할 건지를 나타낸다. (variance preserving을 썼으나 생략.) 이 forward process의 식을 미분 가능한 확률 식으로 바꾸면(SDE. stochastic differential equation) 아래와 같이 미분식을 쓸 수 있다. 이때의 f는 vector-valued drift 계수, g는 diffusion 계수다. 이와 관련한 matlab 함수를 첨부한다.
backward에서는 노이즈덩어리로 만든 데이터를 다시 복원해서 denoising, generation 한다. Denoising Diffusion Implicit Models이라는 모델을 참고해서 x_0과 x_t의 관계를 다음과 같이 정의한다. iterative numerical solver f를 통해 x_T에서 원본 데이터 x_0을 추정하는 식을 아래와 같이 f를 활용해 풀어낼 수 있다.
x_0이 x_t와 linear 하게 표현된다는 것은, '1) 추정한 원본 데이터 ^x_0와 2) 추정되는 noise를 3) 맨 처음 소개했던 1번 forward 식에 넣었더니 3) 다시 x_t-1가 되더라'로 표현할 수 있다. 이를 식으로 표현하면 아래와 같다.
Eq 4. 에서 ^x_0만 정리하고 바깥 t-1을 t로 치환하면, t 까지 Noise를 주입했던 상태에서 시작해 첫 원본 샘플 데이터를 추정할 수 있음을 보인다. 이를 통해 t 번째 step에서 원본데이터를 유추할 수 있음을 보였다.
4. Method
본 논문에서 쓰이는 용어부터 정리해 보자.
- Φ : teacher model. pretrained diffusion model이다.
- ^ϵ : 측정된 noise
- Θ: student model.
즉 Imagine Flash는 Emu와 같은 teacher model Φ에서 knowledge distillation(KD)를 통해 sampling step을 줄일 수 있는 student model Θ을 얻는 것이 목적이다.
Backward Distillation
diffusion background 설명의 forward에서 나왔던 SNR(α와 ϵ. 참고로 ϵ = (1- α))은 스텝이 끝까지 도달했을 때 (t=T) 0에 도달하지 못한다는 문제점이 있다. α가 0이 아니니 원본데이터 x_0에 대한 정보가 녹아있게 되고(low-frequency information), 그러면 x_T는 학습에서 도출한 순수한 noise가 아니게 된다. 즉 학습과 인퍼런스 시의 차이가 생기는 것이다. 만약 x_T를 노이즈 ϵ 자리에 집어넣어도 t < T 인 상황의 x_t에선 여전히 원 이미지 x_0에 대한 정보가 섞여있을 수밖에 없다. t가 작을수록 원본 시그널은 더욱 보존될 것이다.
여기서 기존의 방법대로 distillation을 진행한다면 student model이 어떻게 학습하는지를 보자. xΦ_(t->0)은 k step teacher prediction을, ^x_0은 추정한 원본 샘플 데이터이다. ^x_0을 추정하는 건 backward의 마지막 수식에서 다뤘다.
위 수식에서 보다시피 추론 시 xΦ_0 이 ^x_0과 유사해지게 되는데, ^x_0은 x0의 데이터가 섞인 순수하지 않은 noise ^x_T에서 도출되게 학습되므로 편파적인 error를 갖는 것이다.
이를 해결하려면 학습과 추론 모두 일관된 시그널을 보장할 수 있어야 한다. 그래서 본 논문에서 제시한 방법이 backward distillation이다. 이는 기존의 forward distillation 방법과 다르게 x_t (t번째로 노이즈 주입한 결과)에서 샘플링을 진행하지 않고 아래와 같이 학습이 진행된다.
1. 학습시키고자 하는 student model이 backward로 x_T에서 x_t를 추정한다. 이를 xΘ_T→t, 또는 latent code로 표현하며 아래 수식과 같다. f는 위에서 3.background 설명의 Eq 4에서 나온, 데이터를 추정하는 식이다.
2. 1번에서 구한 xΘ_T→t을 학습 과정에서 student, teacher model에 각각 input으로 준다.
3. 아래 수식은 backward distillation에서의 학습 gradient이다. 하늘색, 보라색 모두 input으로 xΘ_T→t을 받고 있는 걸 알 수 있다. 지금이 t step 시기일 때, 하늘색 항은 teacher model이 latent code에서 시작하여 k step 만큼 denoising 한 target이다. 보라색 항은 student model이 latent code에서 시작하여 x_0을 바로 추정한 결과다. 이 둘의 차이가 작을수록 student 모델의 distillation이 잘 되었다고 볼 수 있다.
요약
forward distillation은 x_t 에서 x_0을 추정하는 것을 student 모델이 모사한다. x_t는 원본 이미지 x_0에 대한 정보가 섞여있는 노이즈이고, 원본 이미지에 대한 정보는 teacher model이 학습 시 주입한 결과다. inferecne를 생각하면 student 모델에게는 사전정보가 없는 왜곡된 noise인 것이다.
backward distillation에서는 x_T->t에서 x_0을 추정하는 것을 학습하게 된다. x_T->t는 x_T에서 x_t를 추정한 결과이기 때문에 training과 inferecne에서 간극이 없다.
SRL: Shifted Reconstruction Loss
앞서 backward diffusion에서는 T -> t -> 0의 순서로 image generation이 진행된다. 초기 t가 T에 가까울 때는 전체적인 image composition(= coarse grained)에 대해 학습한다. 반면 t가 0에 가까운 후기 단계에선 이미지의 high level detail(=fine grained)을 학습한다. student model이 image의 fine grained, coarse grained feature를 둘 다 학습할 수 있게 한다면 더 좋지 않을까?
이 컨셉을 달성하기 위해 teacher 모델이 denoising를 시작하는 스텝을 student의 denoising step과 분리시킨다. Eq 7. backward distillation에선 student, teacher가 같은 noisy latent code xΘ_T→t에서 각각 x_0을 추정했었다. SRL에선 teacher는 그대로 두고, student의 추론 시점을 T->t에서 tΦ 로 치환했다. Eq 7과 비교해 보면 다음 수식의 하늘색 부분만이 변했다. 여기서 shifting function r(t)가 tΦ이다.
r(t)가 정확히 어떻게 업데이트되는지는 실험 디테일에 따라 다르지만, t와 동일한 값은 갖지 않게 설정된다. 그림으로 도식화해서 보다시피 간단한 컨셉이지만 이를 통해 성능 개선을 가장 많이 이뤄냈다고 한다.
r(t)는 t>900일 때 990, 500 <t < 900일 때 950, t<=500일때 200을 사용했다. 참고로 3 step 실험에선 t = {999, 750, 500}을 사용했고, 2 step에선 t = {999, 666}로 진행했다.
Noise Correction
Noise correction은 학습과 무관하게 zero-SNR을 달성하기 위한 트릭 중 하나다. SNR은 Eq1. 의 분홍색 계수들을 의미했었다. 일반적으로는 t가 T까지 가도 SNR이 0을 달성하지 못하기 때문에 이를 고치기 위한 방법론들이 여럿 있다. 그 중 하나로 본 논문에선 t = T일 때의 함수를 그냥 수정해준다.
f는 Eq 1.의 변형형태로 Eq 4. 에서도 나왔듯이 x_t에서 x_t-1을 추정하는 함수였다.
이 f가 T인 시점에 SNR이 0이 안되니까, 임의로 t = T인 시점에 강제로 0이 될 수 있도록 한다.
5. Experiments
재밌거나 핵심 실험 결과들만 정리한다. 실험 백본은 EMU을 사용했고, 비교군 ADD실험 땐 StyleGAN-T discriminator를 사용했다.
정량적 결과는 FID와 CLIP, CompBench score를 사용했다. FID, CLIP score는 이미지 퀄리티와 prompt 간의 유사성을 평가한다. CompBench는 생성된 이미지의 색, 모양, 생성 오브젝트들 간의 퀄리티를 평가한다. 아래 표를 보면 Imagine Flash가 적은 step으로도 유의미한 결과를 내는 것을 확인했다. 참고로 backward distillation이 빠지면 Imagine Flash의 CLIP score가 하락한다.
Emu를 baseline로 두고 현 SOTA 방법론들과 비교한 결과다. 제법 선명하고, 깨끗하고, 색이 더 밝고 어두움이 분명한 걸 확인할 수 있다.
다음은 step 1, 2, 3으로 갈수록 distillation target이 더 선명해지고 high frequency detail이 나타나는 이미지가 되는 것을 볼 수 있다.
다음은 Imagine Flash에서 각각의 방법론을 빼가며 실험한 결과다.
- SRL이 없을 때 상당히 망가지는 걸 볼 수 있다.
- without discriminator를 보면 생각보다 멀쩡한 생성결과를 볼 수 있는데, backward distillation이 있기 때문에 영향을 덜 미치는 것으로 해석할 수 있다. (experiments에서 Imagine Flash에 discriminator를 추가해 비교하는 실험을 했는데, 이미지를 더 sharpen 하게 만드는 결과를 보였다.)
- noise correction보다 backward distillation이 생성 이미지의 선명함, 블러 등에 큰 영향을 주는 것을 알 수 있다.
diffusion 쪽 팔로우업을 요즘 너무 안 한 것 같아 읽어본 논문이다. 그간 이미지 생성 쪽에서 어떤 고민들이 있었으며 한계를 극복하기 위해 어느 지점까지 왔는지 알 수 있어 좋았다.
'논문리뷰' 카테고리의 다른 글
[Paper] DPO 논문 리뷰 (1) | 2023.11.13 |
---|