<aside> 📢 <Table of Contents> train.py

디버깅

실험 예시

정리

</aside>

안녕하세요 여러분!

이번 포스팅은 Hydra 시리즈의 마지막으로서, 간단한 ML 실험 실습 시간입니다 🦾

이전 포스팅에서 설명 드린 Hydra 설정을 기반으로 실습이 진행됩니다🙂

train.py

이미지 분류에 대한 Task, 즉 모델 학습은 train.py 을 통해 진행됩니다.

따라서, 해당 파일의 코드를 한 번 살펴보도록 하겠습니다.

# ./train.py

import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from config_schemas.config_schema import setup_config

@hydra.main(config_path="configs", config_name="config", version_base="1.2")
def train(config: DictConfig) -> None:
    data_module = instantiate(config.data_module)
    task = instantiate(config.task)

    # Create the logger
    tb_logger = TensorBoardLogger("tb_logs", name="cifar10")

    # Create the checkpoint callback
    checkpoint_callback = ModelCheckpoint(
        monitor="validation_accuracy",
        dirpath="checkpoints/",
        filename="cifar10-{epoch:02d}-{val_loss:.2f}",
        save_top_k=3,
        mode="max",
    )

    # Create the trainer
    trainer: Trainer = instantiate(config.trainer, logger=tb_logger, callbacks=[checkpoint_callback])

    # Fit the model
    trainer.fit(model=task, datamodule=data_module)

    # Test the model
    trainer.test(model=task, datamodule=data_module)

# Run the train function
if __name__ == "__main__":
    setup_config()
    train()

본 코드는 크게 두 가지 메인 로직이 존재합니다.

setup_config 메소드

이전 포스팅에서 설명드린 Total task 컴포넌트의 설정을 config store에 저장하는 역할을 수행합니다.

해당 컴포넌트는 Task, Data Module, 그리고 Trainer 컴포넌트로 구성되어 있다고 설명 드렸습니다.

따라서, 해당 컴포넌트의 설정을 연쇄적으로 저장합니다.

마찬가지로, Task의 경우도 여러 컴포넌트로 구성되므로, 하위 계층의 컴포넌트 설정도 똑같이 저장합니다.

결과적으로, 전체 컴포넌트의 설정 정보를 config store에 저장하게 됩니다.