towardsdatascience.com/everything-you-need-to-know-about-saving-weights-in-pytorch-572651f3f8detowardsdatascience.com/everything-you-need-to-know-about-saving-weights-in-pytorch-572651f3f8de
towards data science의 위 글을 번역한 글입니다!
오역한 부분이나 자연스러운 표현을 위해 의역한 부분이 있을 수 있습니다.
별(*)로 시작하는 문장, 문단은 제가 추가한 해설입니다.
잘못된 내용에 대한 댓글로 부탁드립니다. :)
딥러닝 실무자들은 모델 학습을 마친 뒤 어떤 행동을 할까요? 우리는 휴식을 취합니다! 이건 농담이고, 우리는 학습된 웨이트를 저장하거나 혹은 모델을 추가적으로 학습시키기 위해 전체 모델을 저장하거나 혹은 학습된 모델을 사용합니다.
다음으로 여러분이 알고 싶은 것은 언제 학습된 웨이트만을 저장하고, 언제 전체 모델을 저장하는지입니다.
이 글에서 우리는 이 질문에 대한 답을 찾아볼까 합니다.
모델의 아키텍처와 PyTorch의 웨이트를 저장하는 방법에 대해 매우 간단하게 설명해보겠습니다.
우리는 또한 주어진 어떤 PyTorch 모델의 다른 모듈, 정확히는 nn.Modules에 접근하는 방법을 배웁니다.
그리고 이 캐글 커널을 포크해 코드를 가지고 놀아보세요! 시작해봅시다!
우리는 PyTorch를 사용해 코딩의 필수 요소들을 import해옵니다.
import torch
import torch.nn as nn
다음으로, 우리는 CNN 기반의 모델을 정의합니다.
class NeuralNet(nn.Module):
def __init__(self):
super(NeuralNet, self).__init__()
self.sequential = nn.Sequential(nn.Conv2d(1, 32, 5),
nn.Conv2d(32, 64, 5),
nn.Dropout(0.3))
self.layer1 = nn.Conv2d(64, 128, 5)
self.layer2 = nn.Conv2d(128, 256, 5)
self.fc = nn.Linear(256*34*34, 128)
def forward(self, x):
output = self.sequential(x)
output = self.layer1(output)
output = self.layer2(output)
output = output.view(output.size()[0], -1)
output = self.fc(output)
return output
모델을 초기화하고 모델을 print 해 안에 무엇이 있는지 확인해봅시다.
model = NeuralNet()
print(model)
Model을 print 하면 Model의 구조를 볼 수 있습니다. 하지만 우리는 딥러닝 실무자들이기 때문에 더 깊이 알아보도록 하겠습니다.
우리는 우리의 모델 안에 정확히 무엇이 있는지 확실히 이해할 필요가 있습니다.
모델의 모든 학습 가능한 파라미터에 이름과 함께 접근할 수 있는 방법이 있습니다. 한편, torch.nn.Parameter는 Tensor의 하위 클래스로 torch.nn.Module과 함께 사용하면 자동으로 parameters() 혹은 named_parameters() iterator와 같은 파라미터 목록에 추가됩니다. 반면 torch.nn.Tensor에 추가하는 것은 아무런 효과를 갖지 않습니다. 자세한 건 뒤에서 알아보죠!
다시 모델의 모든 파라미터를 출력하는 것으로 돌아가 봅시다.
for name, param in model.named_parameters():
print(f'name:{name}')
print(type(param))
print(f'param.shape:{param.shape}')
print(f'param.requries_grad:{param.requires_grad}')
print('=====')
자 여기서 무슨 일이 일어났나요? named_parameters() 함수 안으로 들어가 봅시다.
model.named_parameters()는 그 자체로 generator입니다. 이것은 이름과 매개 변수 자체인 이름(name)과 파라미터(param)를 반환합니다. 여기서 반환되는 파라미터는 일종의 텐서로, torch.nn.Parameter 클래스입니다. Param은 tensor의 한 종류이기 때문에 shape과 requires_grad 속성을 갖습니다. param.shape는 단순히 tensor의 크기를, param.requires_grad는 파라미터를 학습할 수 있는지 없는지를 말해주는 boolean 자료형의 속성입니다. 모델의 모든 파라미터의 requires_grad가 True이기 때문에, 그것은 모든 파라미터가 학습 가능하고 학습 동안 업데이트될 것임을 의미합니다. 특정 파라미터에 대해 False로 설정한 경우 해당 파라미터의 가중치는 모델 학습 시 업데이트되지 않습니다.
즉, requires_grad는 모델의 특정 레이어 세트를 학습 혹은 고정하려고 할 때 변경할 수 있는 플래그입니다.
이제 모델의 마지막 레이어를 제외한 모든 레이어를 고정해봅시다. 모델의 모든 파라미터의 이름을 확인하다 보면, 우리는 마지막 레이어의 이름이 'fully connected'의 약자인 'fc'인 것을 확인할 수 있습니다.
따라서 'fc.weight' 또는 'fc.bias'라는 이름을 갖는 파라미터를 제외하고는 모두 얼려봅시다.
for name, param in model.named_parameters():
if name in ['fc.weight', 'fc.bias']:
param.requires_grad = True
else:
param.requires_grad = False
for name, param in model.named_parameters():
print(name, ':', param.requires_grad)
우리가 원하는 대로 성공적으로 바뀌었습니다!
따라서 우리는 모델의 원하는 파라미터의 requires_grad 플래그를 바꾸는 방법을 배웠습니다. 또한 우리는 특정 파라미터, 레이어의 가중치를 학습 또는 고정시키려는 상황에서 위와 같은 방법을 사용하는 것이 매우 유용함을 배웠습니다.
이제부터 우리는 모델의 웨이트, 파라미터를 저장하는 널리 알려진 방법 2가지를 배워보겠습니다.
1. torch.save(model.state_dict(), 'weights_path_name.pth')
: 이 방법은 오직 모델의 웨이트만 저장합니다.
2. torch.save(model, 'model_path_name.pth')
: 이 방법은 모델의 웨이트는 물론 모델의 구조까지 전체 모델을 저장합니다.
What Is state_dict() And Where To Use It?
가장 먼저 우리는 state_dict 구문을 작성하는 방법을 알아보겠습니다. 아주 간단하고 쉽습니다.
model.state_dict()
이건 파이썬의 Ordered dictionary입니다.
하지만, 이것을 print 하면 아마 카오스가 발생할 것입니다. 따라서 여기서는 전체 모델에 대한 state_dict는 출력하지 않겠습니다. 하지만 여러분들은 한 번 프린트해서 화면에 띄워보세요!
이제 주제에서 살짝 벗어나 보겠습니다.
help(model)을 출력해보면 모델이 nn.Module의 인스턴스임을 알 수 있습니다.
help(model)
이것은 파이썬의 isinstance 함수를 사용해서도 확인할 수 있습니다.
isinstance(model, nn.Module)
model.fc 또한 nn.Module의 인스턴스일까요?
isinstance(model.fc, nn.Module)
맞습니다!그렇다면 fc는 정확히 무엇이며, 어디에서 온 것일까요?우리는 다음과 같은 코드를 통해 모델 내의 모든 nn.Module 객체를 확인할 수 있습니다.
for name, child in model.named_children():
print('name: ', name)
print(f'isinstance({name}, nn.Module):', isinstance(child, nn.Module))
print('=====')
어떤 nn.Module 객체에 적용된 named_children() 함수는 nn.Module 객체를 포함한 그것의 직계 자식들을 반환합니다. 위의 코드의 결과를 살펴보면, 우리는 'sequential', 'layer1', 'layer2', , 'fc'가 모델의 모든 자식들이며 모두 nn.Module 객체라는 것을 알 수 있습니다. 이제 우리는 'fc'가 어디에서 온 것인지 알 수 있습니다.
그리고 state.dict()는 nn.Module 객체에 대해 동작하며 그것의 모든 직계 자식을 반환합니다.
그럼 모델의 'fc' 레이어에 대해 state_dict 함수를 작동시켜보죠!
for key in model.fc.state_dict():
print('key: ', key)
param = model.fc.state_dict()[key]
print('param.shape: ', param.shape)
print('param.requires_grad: ', param.requires_grad)
print('param.shape, param.requires_grad: ', param.shape, param.requires_grad)
print('isinstance(param, nn.Module) ', isinstance(param, nn.Module))
print('isinstance(param, nn.Parameter) ', isinstance(param, nn.Parameter))
print('isinstance(param, torch.Tensor): ', isinstance(param, torch.Tensor))
print('=====')
model.fc.state_dict() 혹은 어떠한 nn.Module.state_dict()는 ordered dictionary라는 것을 기억하세요. 따라서 그것을 반복하면(iterating) nn.Module 객체가 아닌, shape과 requires_grad 속성을 지닌 파라미터 torch.Tensor에 접근할 수 있는 키(*딕셔너리의 키)를 얻을 수 있습니다.
따라서 우리가 모델과 같은 nn.Module 객체의 state_dict를 저장하면 torch.Tensor가 저장되는 것임을 기억해야 합니다!
아래의 코드가 전체 모델에 대한 state_dict를 저장하는 방법입니다.
torch.save(model.state_dict(), 'weights_only.pth')
이 방법은 작업 중인 디렉토리에 'weights_only.pth' 파일을 생성합니다. 이 파일은 ordered dictionary로 모델의 모든 레이어에 대한 torch.Tensor 객체를 보관합니다.
우리는 이제 저장된 웨이트를 불러와보겠습니다. 하지만 그러기 전에 우리는 먼저 모델의 아키텍처를 정의해야 합니다. 저장된 정보는 모델이 아니라 웨이트일 뿐이기 때문에 먼저 모델을 정의한 다음 웨이트를 불러오는 것이 좋습니다.
model_new = NeuralNet()
model_new.load_state_dict(torch.load('weights_only.pth'))
정의된 모델에 웨이트를 로드한 뒤, model_new의 모든 레이어에 대한 requires_grad 속성을 확인해봅시다.
for name, param in model.named_parameters():
print(name, ':', param.requires_grad)
잠깐! 모든 레이어에 대한 설정한 requires_grad 플래그가 어떻게 바뀌었나요? 모든 requires_grad 플래그는 True로 리셋된 것으로 보입니다.
사실 우리는 애초에 파라미터의 required_grad 플래그를 저장하지 않았습니다. state_dict는 각각의 레이어에 대응하는 파라미터 텐서라는 것을 기억해야 합니다. 그것은 requires_grad 속성을 저장하지 않습니다.
따라서 우리는 추가적인 학습을 진행하기 전에 모든 파라미터의 requires_grad 속성을 필요한 대로 바꾸어주어야 합니다.
How to Save The Entire Model And When To Do It?
우리에겐 전체 모델을 저장할 수 있는 두 번째 방법도 있습니다. 전체 모델이란 웨이트를 포함한 모델의 구조를 의미합니다.
우리는 마지막 레이어를 제외하고 모두 고정한 시점부터 다시 시작하고, 전체 모델을 저장합니다.
torch.save(model,'entire_model.pth')
이 코드는 'entire_model.pth' 파일을 현재 작업 중인 디렉토리에 생성하고 이것은 웨이트를 포함한 모델 아키텍처를 보관합니다.
이제 우리는 저장된 모델을 불러와보겠습니다. 이 경우에는 저장된 파일에 모델 구조가 저장되어 있기 때문에 모델 구조를 정의해줄 필요가 없습니다.
model_new = torch.load('entire_model.pth')
모델이 로드되면, model_new의 모든 레이어에 대해 requires_grad 속성을 확인해봅시다.
for name, param in model.named_parameters():
print(name, ':', param.requires_grad)
정확히 우리가 기대하던 결과 아닌가요? :D
즉, 우리가 전체 모델을 저장한 경우에는 nn.Module 객체와 모든 파라미터의 requires_grad flag 역시 저장됩니다.
요약
우리는 이 글에서 많은 것들을 공부했습니다.
1. 모델 혹은 model.layer2, model.fc와 같은 nn.Module 객체에 named_parameters()를 적용하면 각각의 파라미터와 그들의 이름을 반환합니다. 이 파라미터들은 nn.Parameter(torch.Tensor의 하위 클래스) 객체이고 따라서 그들은 shape과 requires_grad 속성을 가지고 있습니다.
2. nn.Parameter 객체의 requires_grad 속성은 특정 파라미터를 학습시킬지 혹은 고정시킬지를 결정합니다. 예를 들어 우리가 모델의 layer1의 파라미터를 고정시키고자 한다면, 다음과 같은 코드를 사용할 수 있습니다.
for param in model.layer1.parameters():
param.requires_grad = False
3. nn.Module 객체에 named_children() 함수를 적용하면 그것의 모든 직계 자식들을 반환합니다.
4. 모델 혹은 model.layer2, model.fc와 같은 nn.Module 객체의 state_dict()는 각각의 파라미터 tensor에 대응하는 파라미터 값을 가진 ordered dictionary입니다. 이 ordered dictionary의 key는 각각의 파라미터 tensor에 접근하는 데 쓰이며, 파라미터의 이름입니다.
5. nn.Module 객체의 state_dict를 저장하는 것은 모델 구조가 아닌 객체의 다양한 파라미터의 웨이트를 저장하는 것일 뿐입니다. 또한 웨이트의 require_grad 속성도 포함되지 않습니다. 따라서 state_dict를 로드하기 전에 반드시 모델을 먼저 정의해야 합니다.
6. 웨이트는 물론 모델 구조를 포함한 전체 모델을 저장하는 것도 가능합니다. nn.Module 객체를 저장하는 것이기 때문에 requires_grad 속성 역시 저장됩니다. 저장된 파일이 모델 구조를 가지고 있기 때문에 파일을 로드하기 전에 모델 구조를 정의할 필요가 없습니다.
7. state_dict를 저장하는 것은 오직 모델의 웨이트만을 저장하고 싶을 때 사용되는 방법입니다. 이 방법은 requires_grad 플래그를 저장하지 않는 반면, 전체 모델을 저장하면 웨이트와 모든 파라미터에 대한 requires_grad 속성을 포함한 모델 구조가 저장됩니다.
8. state_dict와 전체 모델은 모두 추후에 활용(inference) 하기 위해 저장됩니다.
저는 다른 사람의 블로그를 읽고 많은 것을 배웠기 때문에 가능한 많이 저의 지식을 글로 써 나누고자 이 글을 작성하고 있습니다. 따라서 아래 댓글란에 피드백을 남겨주세요. 또한 저는 블로그에 글을 쓰는 것이 처음이기 때문에 글을 발전시킬 수 있는 방법을 알려주시면 감사하겠습니다. :D
* 얼핏 보면 다 아는 내용 같은데 사실은 정확하게 알지 못했던 내용을 짚어 설명해준 글이었습니다. 저 역시 state_dict()를 저장하는 것과 모델 전체를 저장하는 것의 차이를 잘 몰랐었는데 이번 기회를 통해 확실하게 그 차이를 알게 되었네요. 이번 포스팅은 여기서 마무리하도록 하겠습니다. 잘못된 부분에 대한 지적은 댓글로 부탁드립니다.😋
'컴린이 탈출기 > Machine Learning' 카테고리의 다른 글
Object Detection에서는 Train/Valid set을 어떻게 나눌까? -Stratified Group KFold (feat. sklearn) (0) | 2022.03.23 |
---|