Potato
์•ˆ๋…•ํ•˜์„ธ์š”, ๊ฐ์žก๋‹ˆ๋‹ค?๐Ÿฅ” ^___^ ๐Ÿ˜บ github ๋ฐ”๋กœ๊ฐ€๊ธฐ ๐Ÿ‘‰๐Ÿป

AI study/๋”ฅ๋Ÿฌ๋‹ ํ”„๋ ˆ์ž„์›Œํฌ

[Pytorch] ํŒŒ์ดํ† ์น˜์˜ Custom dataset๊ณผ DataLoader ์ดํ•ดํ•˜๊ธฐ

๊ฐ์ž ๐Ÿฅ” 2021. 7. 22. 17:20
๋ฐ˜์‘ํ˜•

1. ํŒŒ์ดํ† ์น˜์˜ Custom dataset / DataLoader 

1.1 Custom Dataset ์„ ์‚ฌ์šฉํ•˜๋Š” ์ด์œ 

  • ๋ฐฉ๋Œ€ํ•œ ๋ฐ์ดํ„ฐ์˜ ์–‘ --> ๋ฐ์ดํ„ฐ๋ฅผ ํ•œ ๋ฒˆ์— ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ์‰ฝ์ง€ ์•Š์Œ
  • ๋ฐ์ดํ„ฐ๋ฅผ ํ•œ ๋ฒˆ์— ๋ถ€๋ฅด์ง€ ์•Š๊ณ  ํ•˜๋‚˜์”ฉ๋งŒ ๋ถˆ๋Ÿฌ์„œ ์“ฐ๋Š” ๋ฐฉ์‹์„ ํƒํ•ด์•ผ ํ•จ
  • ๋”ฐ๋ผ์„œ ๋ชจ๋“  ๋ฐ์ดํ„ฐ๋ฅผ ๋ถˆ๋Ÿฌ๋†“๊ณ  ์‚ฌ์šฉํ•˜๋Š” ๊ธฐ์กด์˜ Dataset ๋ง๊ณ  Custom Dataset ์ด ํ•„์š”ํ•œ ๊ฒƒ

1.2 Dataset ์ด๋ž€?

  • from torch.utils.data import Dataset, DataLoader 
    • ํŒŒ์ดํ† ์น˜์—์„œ ์ง€์›ํ•˜๋Š” ๊ธฐ๋Šฅ์ด๋‹ค. dataset๊ณผ dataloader ๊ธฐ๋Šฅ์„ ๊ธฐ๋ฐ˜์œผ๋กœ ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ํ•™์Šต, ๋ฐ์ดํ„ฐ ์…”ํ”Œ, ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๊นŒ์ง€ ๊ฐ„๋‹จํ•˜๊ฒŒ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ๋‹ค.
  • ํŠœ๋‹ํ•  ๋•Œ ์‚ฌ์šฉ (?)
  • Dataset ํด๋ž˜์Šค๋Š” ์ „์ฒด Dataset์„ ๊ตฌ์„ฑํ•˜๋Š” ๋‹จ๊ณ„
  • input์œผ๋กœ ์ „์ฒด x (input feature)์™€ y (label)์„ tensor๋กœ ๋„ฃ์–ด์ฃผ๋ฉด ๋จ
  • __len__ ๋ฉ”์„œ๋“œ์™€ __getitem__ ๋ฉ”์„œ๋“œ๋ฅผ ์ง€๋‹Œ ์–ด๋–ค ๊ฒƒ๋„ ๋‹ค Dataset์ด ๋  ์ˆ˜ ์žˆ์Œ

โ˜… Dataset ์˜ ๊ตฌ์„ฑ / Custom Dataset ๋„ ๋™์ผํ•œ ๊ตฌ์„ฑ โ˜…

  • __init__(self) : ํ•„์š”ํ•œ ๋ณ€์ˆ˜๋“ค์„ ์„ ์–ธํ•˜๋Š” ๋ฉ”์„œ๋“œ. input์œผ๋กœ ์˜ค๋Š” x์™€ y๋ฅผ load ํ•˜๊ฑฐ๋‚˜, ํŒŒ์ผ๋ชฉ๋ก์„ loadํ•œ๋‹ค.
  • __len__(self) : x๋‚˜ y ๋Š” ๊ธธ์ด๋ฅผ ๋„˜๊ฒจ์ฃผ๋Š” ๋ฉ”์„œ๋“œ.
  • __getitem__(self, index) : index๋ฒˆ์งธ ๋ฐ์ดํ„ฐ๋ฅผ return ํ•˜๋Š” ๋ฉ”์„œ๋“œ.
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self):
        @@@@
    def __getitem__(self, index):
        @@@@@
    def __len__(self):
        @@@@@

๊ฐœ์ธ ๋ฐ์ดํ„ฐ (custom dataset)์„ ๊ตฌ์ถ•ํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ์œ„์˜ ํ˜•ํƒœ๋ฅผ ๊ผญ ๊ธฐ์–ตํ•˜์ž!

1.2 DataLoader ๋ž€?

  • ๋ฐฐ์น˜์‚ฌ์ดํŠธ ํ˜•ํƒœ๋กœ ๋งŒ๋“ค์–ด์„œ ์šฐ๋ฆฌ๊ฐ€ ์‹ค์ œ๋กœ ํ•™์Šตํ•  ๋•Œ ์ด์šฉํ•  ์ˆ˜ ์žˆ๊ฒŒ ํ˜•ํƒœ๋ฅผ ๋งŒ๋“ค์–ด์ฃผ๋Š” ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ
  • ๋ฐ์ดํ„ฐ๊ฐ€ ์ƒ์„ฑ๋˜๋ฉด, ๋ฐฐ์น˜ํ˜•ํƒœ๋กœ ๋งŒ๋“ค์–ด์ค˜์•ผํ•˜๋‹ˆ๊นŒ DataLoader ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ํ™œ์šฉํ•˜์—ฌ ํ˜•ํƒœ๋ฅผ ๋งŒ๋“ค์–ด์ค€๋‹ค. 
  • ์•„๋ž˜ DataLoader์˜ ์ •์˜ (๋งค๊ฐœ๋ณ€์ˆ˜)๋ฅผ ํ™•์ธํ•ด๋ณผ ์ˆ˜ ์žˆ๋‹ค.
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

2. Custom Dataset ๊ฐ„๋‹จํ•˜๊ฒŒ ๋งŒ๋“ค์–ด๋ณด๊ธฐ

2.1 import

import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np

 

2.2 ์‚ฌ์šฉํ•  ๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

ํ•ด๋‹น ์ฝ”๋“œ์—์„œ๋Š” ๋ถ„์„ํ•˜๋ ค๋Š” ๋ฐ์ดํ„ฐ๊ฐ€ ์ด๋ฏธ์ง€๋ผ๋Š” ๊ฐ€์ •ํ•˜์—, ๋ฐ์ดํ„ฐ๋ฅผ numpyํ˜•ํƒœ๋กœ ๋งŒ๋“ค์–ด ์ฃผ์—ˆ๋‹ค.

train_images = np.random.randint(256, size=(20,32,32,3))
train_labels = np.random.randint(2, size=(20,1))
print(train_images.shape, train_labels.shape)

  • numpy ํ˜•ํƒœ๋กœ ๊ฐ€์งœ ๋ฐ์ดํ„ฐ๋ฅผ ๋„ฃ์–ด์คฌ๋‹ค. (20, 32, 32, 3)
    • ์ผ๋ฐ˜์ ์ธ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ์˜ ํ˜•ํƒœ๋Š” (๊ฐฏ์ˆ˜, ์ด๋ฏธ์ง€ ํฌ๊ธฐ, ์ด๋ฏธ์ง€์˜ ํฌ๊ธฐ, ์ฑ„๋„์ˆ˜) ๋กœ ๊ตฌ์„ฑ๋œ๋‹ค.
  • 20๊ฐœ,  ํฌ๊ธฐ๊ฐ€ 32*32์ด๊ณ  ์ฑ„๋„์˜ ์ˆ˜๊ฐ€ 3์ธ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ๊ฐ€ ์žˆ๋‹ค๊ณ  ๊ฐ€์ •ํ–ˆ๋‹ค.
  • ๊ทธ์— ๋”ฐ๋ฅธ ๋ ˆ์ด๋ธ”์ด 20๊ฐœ๋ผ๊ณ  train_labels์— ์ €์žฅํ–ˆ๋‹ค.

 

2.3 Custom Dataset ๋งŒ๋“ค์–ด์ฃผ๊ธฐ

from torch.utils.data import Dataset, DataLoader ์—์„œ ๋ถˆ๋Ÿฌ์™”๋˜ Dataset์˜ ํด๋ž˜์Šค๋ฅผ ์ƒ์†๋ฐ›์„ ๋‚˜๋งŒ์˜ dataset ํด๋ž˜์Šค๋ฅผ ๋งŒ๋“ค์–ด ์ฃผ์–ด์•ผ ํ•œ๋‹ค. Custom Dataset์˜ ํ˜•ํƒœ๋Š” ๊ฑฐ์˜ ์ •ํ˜•ํ™” ๋˜์–ด์žˆ๋‹ค๊ณ  ์ƒ๊ฐํ•ด์•ผ ํ•œ๋‹ค. ์œ„์—์„œ ์–ธ๊ธ‰ํ•œ Dataset์˜ ๊ตฌ์„ฑ ์š”์†Œ๋ฅผ ํ™•์ธํ•˜๊ณ  ์•„๋ž˜ ์ฝ”๋“œ๋ฅผ ํ™•์ธํ•˜์ž.

class TensorData(Dataset):
    # ์™ธ๋ถ€์— ์žˆ๋Š” ๋ฐ์ดํ„ฐ๋ฅผ ๊ฐ€์ ธ์˜ค๊ธฐ ์œ„ํ•ด ์™ธ๋ถ€์—์„œ ๋ฐ์ดํ„ฐ๊ฐ€ ๋“ค์–ด์˜ฌ ์ˆ˜ ์žˆ๋„๋ก, x_data, y_data ๋ณ€์ˆ˜๋ฅผ ์ง€์ •
    def __init__(self, x_data, y_data):
        #๋“ค์–ด์˜จ x๋Š” tensorํ˜•ํƒœ๋กœ ๋ณ€ํ™˜
        self.x_data = torch.FloatTensor(x_data)
        # tensor data์˜ ํ˜•ํƒœ๋Š” (๋ฐฐ์น˜์‚ฌ์ด์ฆˆ, ์ฑ„๋„์‚ฌ์ด์ฆˆ, ์ด๋ฏธ์ง€ ๋„ˆ๋น„, ๋†’์ด)์˜ ํ˜•ํƒœ์ž„
        # ๋”ฐ๋ผ์„œ ๋“ค์–ด์˜จ ๋ฐ์ดํ„ฐ์˜ ํ˜•์‹์„ permuteํ•จ์ˆ˜๋ฅผ ํ™œ์šฉํ•˜์—ฌ ๋ฐ”๊พธ์–ด์ฃผ์–ด์•ผํ•จ.
        self.x_data = self.x_data.permute(0,3,1,2)  # ์ธ๋ฑ์Šค ๋ฒˆํ˜ธ๋กœ ๋ฐ”๊พธ์–ด์ฃผ๋Š” ๊ฒƒ # ์ด๋ฏธ์ง€ ๊ฐœ์ˆ˜, ์ฑ„๋„ ์ˆ˜, ์ด๋ฏธ์ง€ ๋„ˆ๋น„, ๋†’์ด
        self.y_data = torch.LongTensor(y_data) # float tensor / long tensor ๋กœ ์ˆซ์ž ์†์„ฑ์„ ์ •ํ•ด์ค„ ์ˆ˜ ์žˆ์Œ
        self.len = self.y_data.shape[0]

    # x,y๋ฅผ ํŠœํ”Œํ˜•ํƒœ๋กœ ๋ฐ”๊นฅ์œผ๋กœ ๋‚ด๋ณด๋‚ด๊ธฐ
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len
  • ์•Œ์•„๋ณด๊ธฐ ํž˜๋“  ๋ฌธ๋ฒ•์‚ฌํ•ญ์€ ์ฃผ์„์œผ๋กœ ๋‹ฌ์•„๋†“์•˜๋‹ค. 

 

2.4 ๋ฐ์ดํ„ฐ ์ƒ์„ฑ ๋ฐ ๋ฐฐ์น˜ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜

# ์ธ์Šคํ„ฐ์Šค(๋ฐ์ดํ„ฐ) ์ƒ์„ฑ
train_data = TensorData(train_images, train_labels)

# ๋งŒ๋“ค์–ด์ง„ ๋ฐ์ดํ„ฐ๊ฐ€ ๋ฐฐ์น˜ํ˜•ํƒœ๋กœ ๋งŒ๋“ค์–ด์ค˜์•ผํ•˜๋‹ˆ๊นŒ DataLoader์—๋‹ค๊ฐ€ ๋„ฃ์–ด์คŒ
train_loader = DataLoader(train_data, batch_size=10, shuffle=True)
  • ์•ž์—์„œ DataLoader ๊ฐ€ ๋ฐฐ์น˜ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜ํ•ด์ฃผ๋Š” ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ผ๊ณ  ํ•œ ๊ฒƒ์— ๋Œ€ํ•ด ์ดํ•ด๊ฐ€ ๋  ๊ฒƒ์ด๋‹ค.
  • ์ƒ์„ฑ๋œ ๋ฐ์ดํ„ฐ๋ฅผ DataLoader ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์ด์šฉํ•˜์—ฌ ์…”ํ”Œ์˜ ์œ ๋ฌด, ๋ฐฐ์น˜์‚ฌ์ด์ฆˆ ๋“ฑ์„ ์ง€์ •ํ•ด์ค€๋‹ค. 

 

2.5 ๋ฐ์ดํ„ฐ์˜ ์‚ฌ์ด์ฆˆ์™€ ์‹ค์ œ ๋ฐ์ดํ„ฐ ํ™•์ธํ•ด๋ณด๊ธฐ

์•ž์—์„œ tensor๋ฐ์ดํ„ฐ ํ˜•ํƒœ์— ๋งž๊ฒŒ permute ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ฐ์ดํ„ฐ์˜ ํ˜•ํƒœ๋ฅผ ๋ณ€ํ™˜ํ•ด์ฃผ์—ˆ๋‹ค. ํ˜•ํƒœ๊ฐ€ ์ž˜ ๋ณ€ํ™˜๋˜์—ˆ๋Š”์ง€ ํ™•์ธํ•ด๋ณด์ž.

# ๋ฐ์ดํ„ฐ ์‚ฌ์ด์ฆˆ ํ™•์ธ
# ์ฑ„๋„์ˆ˜, ์‚ฌ์ด์ฆˆ, ์‚ฌ์ด์ฆˆ๋กœ ์ˆœ์„œ๊ฐ€ ๋ฐ”๋€œ
# ์ดˆ๊ธฐ ๋ฐ์ดํ„ฐ๋Š” (50, 32,32,3) ์ด์—ˆ์Œ
train_data[0][0].size()

  • ํ…์„œ ๋ฐ์ดํ„ฐ์˜ ํ˜•ํƒœ์— ๋งž๊ฒŒ (์ฑ„๋„์ˆ˜, ํฌ๊ธฐ, ํฌ๊ธฐ) ๋กœ ๋ฐ”๋€ ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.

์ด์ œ iter์™€ next ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์‹ค์ œ ๋ฐ์ดํ„ฐ๋ฅผ ํ™•์ธํ•ด๋ณด์ž.

# train_loader์•ˆ์˜ ์‹ค์ œ๊ฐ’ ํ•œ๊ฐœ ๋ฝ‘์•„์„œ ํ™•์ธ 
dataiter = iter(train_loader)
images, labels = dataiter.next()
images.size()

  • 10๊ฐœ์˜ ๋ฐฐ์น˜, 3์ฑ„๋„, 32*32 ํฌ๊ธฐ๋ฅผ ๊ฐ€์ง„ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ์ž„์„ ํ™•์ธ ํ•  ์ˆ˜ ์žˆ๋‹ค.

 

์ด์ œ ์ด๋ ‡๊ฒŒ dataLoader ๊ณผ์ •์„ ๊ฑฐ์นœ ๋ฐ์ดํ„ฐ๋ฅผ ํ™œ์šฉํ•˜์—ฌ ๋ชจ๋ธ์— ๋„ฃ๊ณ  ์‹ ๊ฒฝ๋ง์„ ๊ตฌํ˜„ํ•  ์ˆ˜ ์žˆ๋‹ค. ์•„๋ž˜ ๋ธ”๋กœ๊ทธ ์—์„œ๋Š” DataLoader๊ณผ Dataset์— ๋Œ€ํ•œ ๊ฐœ๋…์„ ์„ค๋ช…ํ•˜๊ณ , ๋ชจ๋ธ์„ ๋Œ๋ฆฌ๋Š” ๊ฐ„๋‹จํ•œ ์ฝ”๋“œ๊ฐ€ ๋‚˜์™€์žˆ๋‹ค. ์ฐธ๊ณ ํ•ด์„œ ๊ณต๋ถ€ํ•˜๊ธธ ๋ฐ”๋ž€๋‹ค! 

https://wikidocs.net/57165

 

์œ„ํ‚ค๋…์Šค

์˜จ๋ผ์ธ ์ฑ…์„ ์ œ์ž‘ ๊ณต์œ ํ•˜๋Š” ํ”Œ๋žซํผ ์„œ๋น„์Šค

wikidocs.net

https://wingnim.tistory.com/33

 

Pytorch ๋จธ์‹ ๋Ÿฌ๋‹ ํŠœํ† ๋ฆฌ์–ผ ๊ฐ•์˜ 8 (PyTorch DataLoader)

2018/07/02 - [Programming Project/Pytorch Tutorials] - Pytorch ๋จธ์‹ ๋Ÿฌ๋‹ ํŠœํ† ๋ฆฌ์–ผ ๊ฐ•์˜ 1 (Overview) 2018/07/02 - [Programming Project/Pytorch Tutorials] - Pytorch ๋จธ์‹ ๋Ÿฌ๋‹ ํŠœํ† ๋ฆฌ์–ผ ๊ฐ•์˜ 2 (Linear Mod..

wingnim.tistory.com

 

 

์ฐธ๊ณ ์ž๋ฃŒ

https://www.youtube.com/watch?v=8PnxJ3s3Cwo&t=436s 

 

๋ฐ˜์‘ํ˜•