최근에 distillation 학습을 하면서 오랜만에 OOM(Out Of Memory)를 맛봤다. 중간 hidden state를 끌어와서 추가 loss를 계산하는 것 때문이기도 했고, output의 모든 텐서를 끌어와서 계산하는 항도 있었기 때문이다. 이참에 블로그에 기초 모델 메모리를 정리해볼까 한다. 

 


0. 모델 크기에 따른 메모리 계산 

간단하다. pretrain시 FP32로 모델을 올려 학습시킨다고 가정해보자. 그러면 fp32 == 4byte다. 만약 1.3B 모델을 GPU에 올려서 학습하고 싶다. 그렇다면, 

  • 1.3B (billion이다) x 4byte = 5.2GB 이다. 

모델을 단순히 올리기 위해서 5.2GB의 메모리가 필요하다. 

 

그렇다면 FP16으로 모델을 올린다면? (실제 학습에선 FP16 -> FP32로 재변환 후 matrix 연산을 한다. 일단 모델을 올리는 것만 생각하고 가정해본다.)

  • 16bit == 2byte
  • 1.3B x 2byte = 2.6GB 다. 

 

만약 quantization을 진행해서 int8, int4로 모델을 올린다면? 이때는 선형적으로 메모리가 줄어들지 않는다. 이정도의 precision을 가져갈 땐, embedding layer에 적용되지는 않기 때문이다. 예를 들어서 token 수가 42,000 인 tokenizer에 hidden dim 768라고 한다면 다음과 같이 임베딩 레이어는 최대 fp16으로 볼 수 있다. (positional embedding은 뺐다.) 

  • 42000 x 768 ~= 0.03B
  • 0.03B x 2byte == 0.06 GB

그럼 1.3B에서 임베딩을 뺀 나머지 파라미터만이 quantization을 할 수 있다. int8 qunatization을 했다고 하면 1byte다. 

  • 남은 파라미터 수: 1.3B - 0.03B == 1.27B 
  • total) 0.03B x 2byte + 1.27B x 1byte 

1.3B에서 메모리 파라미터는 작은 비중을 차지하는 것을 볼 수 있다. 

모델 크기가 줄어들수록 모델에서 임베딩이 차지하는 비중은 커진다. 300M (모델 크기는 파라미터의 갯수로 나타낸다. MB 아니다. Million이다.) 모델인데 위 예시로 든 1.3B모델과 같은 임베딩이라면 quantization해서 줄어드는 모델 메모리의 비중이 줄어들 것이다. 

 

 

1. 학습에서의 모델 메모리 

학습에서 추가적으로 올라가는 메모리는 뭐가 있을까? 1) 데이터를 학습시키니까 데이터도 올라갈테고, 2) 학습 중 forward, backward에서 행렬 연산이 사용되기도 한다. 

 

데이터 학습과 연관된 하이퍼파라미터엔 batch_size와, LM의 경우 tokenizer의 size도 있다. token_size가 42,000으로 고정일 때 batch size가 늘어날수록 어떻게 메모리를 추가 점유하는지는 단순 input만을 계산할 수 없다. 학습 연산 과정 메모리에도 당연하지만 batch size가 관여한다. 

 

forward 과정에서는 backward 시 사용해야 하기 때문에 각 layer의 output이 메모리를 점유한다. hidden state, attention layer의 output이 여기에 해당한다. 

 

backward를 하면 gradient를 계산해야 한다. 이때도 행렬연산이기 때문에 GPU메모리가 필요하다. 그렇기에 optimizer도 GPU에 올라가는 것이다. 익숙한 Adam optimizer를 예로 하면 미분값과 이계도 미분값을 기준으로 계산하기 때문에, 메모리 두배 이벤트다. 

 

여기에 연산용으로 미리 잡아놓는 메모리까지 하면 많은 GPU 메모리가 필요함을 알 수 있다. 

 

 

 

 

우선 layer의 중간 output값들이 점유하는 메모리를 계산해보자. 어텐션은 Q, K, V 3가지 요소로 구성되어 있으니 3배 이벤트다.

  • Q(또는 K, V)의 matrix shape == batch size x max_len x hidden state

여기서 조금 더 헉스러운건 attention score도 저장한다는 거다. attention을 생각하면, 각 Q와 K의 점수를 계산한 행렬이 나오는걸 알 수 있다. 

  • attention score shape == batch_size x max_len x max_len x hidden_state 

 

 

새삼 max_len이 긴 요즘 LLM을 생각하면 얼마나 많은 GPU가 필요할지 체감이 된다. 

 

 

2. 실제 계산 

계산상의 편의를 위해 precision은 fp32로 고정하고, BERT-small의 하이퍼파라미터를 기준으로 계산해보자. 

  • H=512
  • L=4 

학습시 batch size 는 B, max_len은 S라고 표기하겠다. 

 

1) 우선 모델 사이즈는 27M 정도다. fp32이면 약 108MB 메모리다. 

 

2) layer의 Output은 몇 메모리를 점유할까? 

  • query, key value) batch size  B x max_len  S x hidden state H x 3(키 쿼리 벨류 3개니까) ~= B x S x 512 x 3 x 4byte
  • attention score) B x S x S x H ~= B x S**2 x 512 x 4byte

3) optimizer 메모리는?

  • 모든 레이어에 대해서 이전과 어떻게 변했는지를 계산해야 하기 때문에 모델의 전체 파라미터의 gradient를 위한 optimizer pararmeter를 저장한다. 
    • Adamw 계열은 위에서 언급했듯이 2배이벤트다. 
  • ~= 27M x 4Byte x 2(배) ~= 216MB 

 

4) backpropagation을 하면?

  • 모델의 전체 파라미터를 backprop하기 때문에, 그만큼 든다. 즉 거의 모델 크기만큼이다.
  • ~= 108MB 

 

여기까지 더해보면 108 + 1536BS + (1024BS x 4byte) + 216 + 108 == 432 MB + 2560*B*S *4byte 이다.

만약 S = 512, B = 64라면? 거의 32GB가 필요하다. 

 

참 많이도 든다! 

 


 

bert-small의 작은 모델로 대충 어느정도 메모리가 학습에 필요한지를 기반으로 플로우를 그려보았다. 

 

+

계산 자체에서 뭔가 잘못 곱셈을 했을 수도 있다. (역산 안해봤다.) 

 

 

 

 

 

728x90

+ Recent posts