가장 관리하기 좋은 방식은 CSV 또는 Excel 입니다.
| image_path | A | B | C | D | E | F | G |
|---|---|---|---|---|---|---|---|
| img1.jpg | 1 | 0 | 1 | 0 | 0 | 1 | 0 |
| img2.jpg | 0 | 1 | 0 | 1 | 0 | 0 | 0 |
| img3.jpg | 1 | 1 | 0 | 0 | 0 | 0 | 1 |
즉
이미지경로 + 7개의 라벨 벡터
형태입니다.
dataset/
images/
img1.jpg
img2.jpg
img3.jpg
labels.csv
import torch
from torch.utils.data import Dataset
import pandas as pd
from PIL import Image
import os
class MultiLabelDataset(Dataset):
def __init__(self, csv_file, img_dir, transform=None):
self.data = pd.read_csv(csv_file)
self.img_dir = img_dir
self.transform = transform
self.labels = self.data.iloc[:,1:].values
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.data.iloc[idx,0])
image = Image.open(img_path).convert("RGB")
label = torch.tensor(self.labels[idx], dtype=torch.float32)
if self.transform:
image = self.transform(image)
return image, label
VGG16 마지막 layer만 수정합니다.
출력 = 7
import torch
import torch.nn as nn
import torchvision.models as models
num_classes = 7
model = models.vgg16(pretrained=True)
model.classifier[6] = nn.Linear(4096, num_classes)
model = model.cuda()