티스토리 뷰

 

[방법1]

 

PyTorch 공식 문서 및 Wikidocs에서 더 자세한 설명을 볼 수 있다.

 

import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file, names=['file_name', 'label'])
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

위 코드는 img_dir 폴더에 존재하는 이미지들과 annotations_file을 이용해 custom 데이터셋을 만드는 클래스를 선언한다.

이 클래스는 Dataset 클래스를 상속받기 때문에, 반드시 __init__(), __len__(), __getitem__() 함수를 오버라이딩해야 한다.

 

그리고 아래 코드와 같은 식으로 Dataset 객체를 선언해준다.

train_dataset = CustomDataset(annotations_file, img_dir, transform, target_transform)

 

이제 DataLoader를 이용해서 iterator 형식으로 데이터에 접근할 수 있도록 만든다.

from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    # 이외에도 많은 파라미터들을 설정할 수 있음
)

shuffle 파라미터는 epoch마다 데이터셋을 섞을지 여부를 의미하며, 학습 시에는 True로 설정하는 것이 권장된다.

이외에도 sampler, num_workers, collate_fn 등의 파라미터들이 존재한다.

 

for idx, (input, label) in enumerate(train_dataloader):
    logits = model(input)
    ...

이런 식으로 데이터를 가져와 사용할 수 있다. 

 


 

[방법2]

 

Dataset 클래스를 선언해주는 과정이 귀찮으므로, ImageFolder를 통해 더 간단하게 사용할 수 있다.

 

각 파일에 해당하는 label을 적어둔 annotation 파일이 따로 필요 없이,

dir_path 경로의 하위에 각 클래스별 폴더를 만들고, 그 안에 해당하는 이미지들을 위치시키면 된다.

dir_path
  ㄴ apple
       ㄴ a.jpg
       ㄴ b.jpg
       ㄴ ...
  ㄴ banana
       ㄴ 1.jpg
       ㄴ 2.jpg
       ㄴ ...
  ㄴ ...

이런 식으로 구성되어 있다면,

다음 코드를 통해서 apple 폴더 하위의 이미지들은 class 0에 해당하게 되고, banana 폴더 하위의 이미지들은 class 1에 해당하게 된다.

from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader


train_transform = transforms.Compose([
    transforms.ToTensor(),
    # Resize(), CenterCrop() 등 다양한 함수 활용 가능
])

train_dataset = ImageFolder(dir_path, transform=train_transform)
# 클래스 개수 확인 => len(train_dataset.classes)

클래스 개수를 출력해보면 dir_path 하위의 폴더 개수와 동일할 것이다. (빈 폴더가 없는 한)

아주 편리하다 :>

 

그리고 원래 사용하던 방식대로 DataLoader를 이용해 iterator 형식으로 데이터를 가져온다.

from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
)

for idx, (input, label) in enumerate(train_dataloader):
    logits = model(input)
    ...

 

반응형

댓글