1️⃣ 정답 데이터셋 만드는 방법 (추천 방식)

가장 관리하기 좋은 방식은 CSV 또는 Excel 입니다.

dataset.csv 예시

image_path A B C D E F G
img1.jpg 1 0 1 0 0 1 0
img2.jpg 0 1 0 1 0 0 0
img3.jpg 1 1 0 0 0 0 1

이미지경로 + 7개의 라벨 벡터

형태입니다.


2️⃣ 폴더 구조

dataset/
   images/
      img1.jpg
      img2.jpg
      img3.jpg
   labels.csv

3️⃣ PyTorch Dataset 코드

import torch
from torch.utils.data import Dataset
import pandas as pd
from PIL import Image
import os

class MultiLabelDataset(Dataset):
 
    def __init__(self, csv_file, img_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform

        self.labels = self.data.iloc[:,1:].values

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):

        img_path = os.path.join(self.img_dir, self.data.iloc[idx,0])
        image = Image.open(img_path).convert("RGB")

        label = torch.tensor(self.labels[idx], dtype=torch.float32)

        if self.transform:
            image = self.transform(image)

        return image, label

4️⃣ VGG16 멀티레이블 모델

VGG16 마지막 layer만 수정합니다.

출력 = 7
import torch
import torch.nn as nn
import torchvision.models as models

num_classes = 7

model = models.vgg16(pretrained=True)

model.classifier[6] = nn.Linear(4096, num_classes)

model = model.cuda()

5️⃣ Loss 함수