1. ํ์ดํ ์น์ Custom dataset / DataLoader
1.1 Custom Dataset ์ ์ฌ์ฉํ๋ ์ด์
- ๋ฐฉ๋ํ ๋ฐ์ดํฐ์ ์ --> ๋ฐ์ดํฐ๋ฅผ ํ ๋ฒ์ ๋ถ๋ฌ์ค๊ธฐ ์ฝ์ง ์์
- ๋ฐ์ดํฐ๋ฅผ ํ ๋ฒ์ ๋ถ๋ฅด์ง ์๊ณ ํ๋์ฉ๋ง ๋ถ๋ฌ์ ์ฐ๋ ๋ฐฉ์์ ํํด์ผ ํจ
- ๋ฐ๋ผ์ ๋ชจ๋ ๋ฐ์ดํฐ๋ฅผ ๋ถ๋ฌ๋๊ณ ์ฌ์ฉํ๋ ๊ธฐ์กด์ Dataset ๋ง๊ณ Custom Dataset ์ด ํ์ํ ๊ฒ
1.2 Dataset ์ด๋?
- from torch.utils.data import Dataset, DataLoader
- ํ์ดํ ์น์์ ์ง์ํ๋ ๊ธฐ๋ฅ์ด๋ค. dataset๊ณผ dataloader ๊ธฐ๋ฅ์ ๊ธฐ๋ฐ์ผ๋ก ๋ฏธ๋๋ฐฐ์น ํ์ต, ๋ฐ์ดํฐ ์ ํ, ๋ณ๋ ฌ ์ฒ๋ฆฌ๊น์ง ๊ฐ๋จํ๊ฒ ์ํํ ์ ์๋ค.
- ํ๋ํ ๋ ์ฌ์ฉ (?)
- Dataset ํด๋์ค๋ ์ ์ฒด Dataset์ ๊ตฌ์ฑํ๋ ๋จ๊ณ
- input์ผ๋ก ์ ์ฒด x (input feature)์ y (label)์ tensor๋ก ๋ฃ์ด์ฃผ๋ฉด ๋จ
- __len__ ๋ฉ์๋์ __getitem__ ๋ฉ์๋๋ฅผ ์ง๋ ์ด๋ค ๊ฒ๋ ๋ค Dataset์ด ๋ ์ ์์
โ Dataset ์ ๊ตฌ์ฑ / Custom Dataset ๋ ๋์ผํ ๊ตฌ์ฑ โ
- __init__(self) : ํ์ํ ๋ณ์๋ค์ ์ ์ธํ๋ ๋ฉ์๋. input์ผ๋ก ์ค๋ x์ y๋ฅผ load ํ๊ฑฐ๋, ํ์ผ๋ชฉ๋ก์ loadํ๋ค.
- __len__(self) : x๋ y ๋ ๊ธธ์ด๋ฅผ ๋๊ฒจ์ฃผ๋ ๋ฉ์๋.
- __getitem__(self, index) : index๋ฒ์งธ ๋ฐ์ดํฐ๋ฅผ return ํ๋ ๋ฉ์๋.
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self):
@@@@
def __getitem__(self, index):
@@@@@
def __len__(self):
@@@@@
๊ฐ์ธ ๋ฐ์ดํฐ (custom dataset)์ ๊ตฌ์ถํ๊ธฐ ์ํด์๋ ์์ ํํ๋ฅผ ๊ผญ ๊ธฐ์ตํ์!
1.2 DataLoader ๋?
- ๋ฐฐ์น์ฌ์ดํธ ํํ๋ก ๋ง๋ค์ด์ ์ฐ๋ฆฌ๊ฐ ์ค์ ๋ก ํ์ตํ ๋ ์ด์ฉํ ์ ์๊ฒ ํํ๋ฅผ ๋ง๋ค์ด์ฃผ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ
- ๋ฐ์ดํฐ๊ฐ ์์ฑ๋๋ฉด, ๋ฐฐ์นํํ๋ก ๋ง๋ค์ด์ค์ผํ๋๊น DataLoader ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ํ์ฉํ์ฌ ํํ๋ฅผ ๋ง๋ค์ด์ค๋ค.
- ์๋ DataLoader์ ์ ์ (๋งค๊ฐ๋ณ์)๋ฅผ ํ์ธํด๋ณผ ์ ์๋ค.
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
2. Custom Dataset ๊ฐ๋จํ๊ฒ ๋ง๋ค์ด๋ณด๊ธฐ
2.1 import
import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np
2.2 ์ฌ์ฉํ ๋ฐ์ดํฐ ๋ถ๋ฌ์ค๊ธฐ
ํด๋น ์ฝ๋์์๋ ๋ถ์ํ๋ ค๋ ๋ฐ์ดํฐ๊ฐ ์ด๋ฏธ์ง๋ผ๋ ๊ฐ์ ํ์, ๋ฐ์ดํฐ๋ฅผ numpyํํ๋ก ๋ง๋ค์ด ์ฃผ์๋ค.
train_images = np.random.randint(256, size=(20,32,32,3))
train_labels = np.random.randint(2, size=(20,1))
print(train_images.shape, train_labels.shape)
- numpy ํํ๋ก ๊ฐ์ง ๋ฐ์ดํฐ๋ฅผ ๋ฃ์ด์คฌ๋ค. (20, 32, 32, 3)
- ์ผ๋ฐ์ ์ธ ์ด๋ฏธ์ง ๋ฐ์ดํฐ์ ํํ๋ (๊ฐฏ์, ์ด๋ฏธ์ง ํฌ๊ธฐ, ์ด๋ฏธ์ง์ ํฌ๊ธฐ, ์ฑ๋์) ๋ก ๊ตฌ์ฑ๋๋ค.
- 20๊ฐ, ํฌ๊ธฐ๊ฐ 32*32์ด๊ณ ์ฑ๋์ ์๊ฐ 3์ธ ์ด๋ฏธ์ง ๋ฐ์ดํฐ๊ฐ ์๋ค๊ณ ๊ฐ์ ํ๋ค.
- ๊ทธ์ ๋ฐ๋ฅธ ๋ ์ด๋ธ์ด 20๊ฐ๋ผ๊ณ train_labels์ ์ ์ฅํ๋ค.
2.3 Custom Dataset ๋ง๋ค์ด์ฃผ๊ธฐ
from torch.utils.data import Dataset, DataLoader ์์ ๋ถ๋ฌ์๋ Dataset์ ํด๋์ค๋ฅผ ์์๋ฐ์ ๋๋ง์ dataset ํด๋์ค๋ฅผ ๋ง๋ค์ด ์ฃผ์ด์ผ ํ๋ค. Custom Dataset์ ํํ๋ ๊ฑฐ์ ์ ํํ ๋์ด์๋ค๊ณ ์๊ฐํด์ผ ํ๋ค. ์์์ ์ธ๊ธํ Dataset์ ๊ตฌ์ฑ ์์๋ฅผ ํ์ธํ๊ณ ์๋ ์ฝ๋๋ฅผ ํ์ธํ์.
class TensorData(Dataset):
# ์ธ๋ถ์ ์๋ ๋ฐ์ดํฐ๋ฅผ ๊ฐ์ ธ์ค๊ธฐ ์ํด ์ธ๋ถ์์ ๋ฐ์ดํฐ๊ฐ ๋ค์ด์ฌ ์ ์๋๋ก, x_data, y_data ๋ณ์๋ฅผ ์ง์
def __init__(self, x_data, y_data):
#๋ค์ด์จ x๋ tensorํํ๋ก ๋ณํ
self.x_data = torch.FloatTensor(x_data)
# tensor data์ ํํ๋ (๋ฐฐ์น์ฌ์ด์ฆ, ์ฑ๋์ฌ์ด์ฆ, ์ด๋ฏธ์ง ๋๋น, ๋์ด)์ ํํ์
# ๋ฐ๋ผ์ ๋ค์ด์จ ๋ฐ์ดํฐ์ ํ์์ permuteํจ์๋ฅผ ํ์ฉํ์ฌ ๋ฐ๊พธ์ด์ฃผ์ด์ผํจ.
self.x_data = self.x_data.permute(0,3,1,2) # ์ธ๋ฑ์ค ๋ฒํธ๋ก ๋ฐ๊พธ์ด์ฃผ๋ ๊ฒ # ์ด๋ฏธ์ง ๊ฐ์, ์ฑ๋ ์, ์ด๋ฏธ์ง ๋๋น, ๋์ด
self.y_data = torch.LongTensor(y_data) # float tensor / long tensor ๋ก ์ซ์ ์์ฑ์ ์ ํด์ค ์ ์์
self.len = self.y_data.shape[0]
# x,y๋ฅผ ํํํํ๋ก ๋ฐ๊นฅ์ผ๋ก ๋ด๋ณด๋ด๊ธฐ
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
- ์์๋ณด๊ธฐ ํ๋ ๋ฌธ๋ฒ์ฌํญ์ ์ฃผ์์ผ๋ก ๋ฌ์๋์๋ค.
2.4 ๋ฐ์ดํฐ ์์ฑ ๋ฐ ๋ฐฐ์นํํ๋ก ๋ณํ
# ์ธ์คํฐ์ค(๋ฐ์ดํฐ) ์์ฑ
train_data = TensorData(train_images, train_labels)
# ๋ง๋ค์ด์ง ๋ฐ์ดํฐ๊ฐ ๋ฐฐ์นํํ๋ก ๋ง๋ค์ด์ค์ผํ๋๊น DataLoader์๋ค๊ฐ ๋ฃ์ด์ค
train_loader = DataLoader(train_data, batch_size=10, shuffle=True)
- ์์์ DataLoader ๊ฐ ๋ฐฐ์นํํ๋ก ๋ณํํด์ฃผ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ผ๊ณ ํ ๊ฒ์ ๋ํด ์ดํด๊ฐ ๋ ๊ฒ์ด๋ค.
- ์์ฑ๋ ๋ฐ์ดํฐ๋ฅผ DataLoader ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ด์ฉํ์ฌ ์ ํ์ ์ ๋ฌด, ๋ฐฐ์น์ฌ์ด์ฆ ๋ฑ์ ์ง์ ํด์ค๋ค.
2.5 ๋ฐ์ดํฐ์ ์ฌ์ด์ฆ์ ์ค์ ๋ฐ์ดํฐ ํ์ธํด๋ณด๊ธฐ
์์์ tensor๋ฐ์ดํฐ ํํ์ ๋ง๊ฒ permute ํจ์๋ฅผ ์ฌ์ฉํ์ฌ ๋ฐ์ดํฐ์ ํํ๋ฅผ ๋ณํํด์ฃผ์๋ค. ํํ๊ฐ ์ ๋ณํ๋์๋์ง ํ์ธํด๋ณด์.
# ๋ฐ์ดํฐ ์ฌ์ด์ฆ ํ์ธ
# ์ฑ๋์, ์ฌ์ด์ฆ, ์ฌ์ด์ฆ๋ก ์์๊ฐ ๋ฐ๋
# ์ด๊ธฐ ๋ฐ์ดํฐ๋ (50, 32,32,3) ์ด์์
train_data[0][0].size()
- ํ ์ ๋ฐ์ดํฐ์ ํํ์ ๋ง๊ฒ (์ฑ๋์, ํฌ๊ธฐ, ํฌ๊ธฐ) ๋ก ๋ฐ๋ ๊ฒ์ ํ์ธํ ์ ์๋ค.
์ด์ iter์ next ํจ์๋ฅผ ์ฌ์ฉํ์ฌ ์ค์ ๋ฐ์ดํฐ๋ฅผ ํ์ธํด๋ณด์.
# train_loader์์ ์ค์ ๊ฐ ํ๊ฐ ๋ฝ์์ ํ์ธ
dataiter = iter(train_loader)
images, labels = dataiter.next()
images.size()
- 10๊ฐ์ ๋ฐฐ์น, 3์ฑ๋, 32*32 ํฌ๊ธฐ๋ฅผ ๊ฐ์ง ์ด๋ฏธ์ง ๋ฐ์ดํฐ์์ ํ์ธ ํ ์ ์๋ค.
์ด์ ์ด๋ ๊ฒ dataLoader ๊ณผ์ ์ ๊ฑฐ์น ๋ฐ์ดํฐ๋ฅผ ํ์ฉํ์ฌ ๋ชจ๋ธ์ ๋ฃ๊ณ ์ ๊ฒฝ๋ง์ ๊ตฌํํ ์ ์๋ค. ์๋ ๋ธ๋ก๊ทธ ์์๋ DataLoader๊ณผ Dataset์ ๋ํ ๊ฐ๋ ์ ์ค๋ช ํ๊ณ , ๋ชจ๋ธ์ ๋๋ฆฌ๋ ๊ฐ๋จํ ์ฝ๋๊ฐ ๋์์๋ค. ์ฐธ๊ณ ํด์ ๊ณต๋ถํ๊ธธ ๋ฐ๋๋ค!
https://wingnim.tistory.com/33
์ฐธ๊ณ ์๋ฃ
https://www.youtube.com/watch?v=8PnxJ3s3Cwo&t=436s