-- ๋ณธ ํฌ์คํ
์ ํ์ดํ ์น๋ก ๋ฐฐ์ฐ๋ ์์ฐ์ด ์ฒ๋ฆฌ (ํ๋น๋ฏธ๋์ด) ์ฑ
์ ์ฐธ๊ณ ํด์ ์์ฑ๋ ๊ธ์
๋๋ค.
-- ์์ค์ฝ๋ ) https://github.com/rickiepark/nlp-with-pytorch
1. ์ฑ์จ ๋ฐ์ดํฐ์
- 18๊ฐ ๊ตญ์ ์ ์ฑ์จ 10,000๊ฐ๋ฅผ ๋ชจ์ ์ฑ์จ ๋ฐ์ดํฐ์
- ๋งค์ธ ๋ถ๊ท ํํจ
- ์ต์์ ํด๋์ค 3๊ฐ๊ฐ ๋ฐ์ดํฐ์ 60%๋ฅผ ์ฐจ์ง
- 27%๊ฐ ์์ด, 21%๊ฐ ๋ฌ์์์ด, 14%๊ฐ ์๋์ด
- ๋๋จธ์ง 15๊ฐ๊ตญ์ ๋น๋๋ ๊ณ์ ๊ฐ์
- ์ธ์ด ์์ฒด์ผ ์์ฑ์ด๊ธฐ๋ ํจ (๋ง์ด ์ฌ์ฉํ๋ ์ธ์ด์ผ์๋ก ๋ง์ ์ ๋ฐ์)
- ์ถ์ ๊ตญ๊ฐ์ ์ฑ์จ ๋ง์ถค๋ฒ ์ฌ์ด์ ์๋ฏธ๊ฐ ์๊ณ ์ง๊ด์ ์ธ ๊ด๊ณ๊ฐ ์์
- ์ฆ, ๊ตญ์ ๊ณผ ๊ด๊ณ๊ฐ ์๋ ์ฑ์จ๊ฐ ์กด์ฌ
2. ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ
์๋์ ๊ฐ์ด ์ฒ๋ฆฌํ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ๋ค. ํด๋น ์ฝ๋์์๋ ์๋ ๋ฐฉ์์ ๊ฑฐ์ณ์ ธ์ ๋ฏธ๋ฆฌ ์ฒ๋ฆฌ๋ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ๋ค. ๋ฐ์ดํฐ๋ฅผ ๋๋๋ ์ฝ๋๋ ์๋
- ๋ถ๊ท ํ ์ค์ด๊ธฐ
- ์๋ณธ ๋ฐ์ดํฐ์ ์ 70% ์ด์์ด ๋ฌ์์ ์ด๋ฆ
- ์ํ๋ง์ด ํธํฅ๋์๊ฑฐ๋ ๋ฌ์์์ ๊ณ ์ ํ ์ฑ์จ๊ฐ ๋ง๊ธฐ ๋๋ฌธ์ผ๋ก ์ถ์
- ๋ฌ์์ ์ฑ์จ์ ๋ถ๋ถ์งํฉ์ ๋๋คํ๊ฒ ์ ํํ์ฌ ํธ์ค๋ ํด๋์ค๋ฅผ "์๋ธ์ํ๋ง" ํด์ค
- ๊ตญ์ ์ ๊ธฐ๋ฐ์ผ๋ก ๋ชจ์ 3์ธํธ๋ก split
- train = 70%
- validation = 15%
- test = 15%
- ์ธํธ๊ฐ ๋ ์ด๋ธ ๋ถํฌ๋ฅผ ๊ณ ๋ฅด๊ฒ ์ ์ง ์ํด
โถ ๊ฐ๋จํ๊ฒ ๋ฐ์ดํฐ ์ดํด๋ณด๊ธฐ
# Read raw data
surnames = pd.read_csv(args.raw_dataset_csv, header=0)
surnames.head()
- surnames ๋ฐ์ดํฐ์ ์ ํํ๋ ์ด๋ ๊ฒ ์๊ฒผ๋ค.
# Unique classes
set(surnames.nationality)
- 18๊ฐ์ ๊ตญ๊ฐ๋ ์์ ๊ฐ๋ค.
# Write split data to file
final_surnames = pd.DataFrame(final_list)
final_surnames.split.value_counts()
- ๋ฐ์ดํฐ์ ๋ถํฌ๋ 70: 15: 15 ๋ก ๋๋์ด์ ธ ์๋ ๊ฒ์ ํ์ธํ ์ ์๋ค.
3. import
ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ๋ชจ๋ ์ํฌํธ ํด์ฃผ์๋ค.
from argparse import Namespace
from collections import Counter
import json
import os
import string
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import tqdm
4. Custom Dataset ๋ง๋ค๊ธฐ (SurnameDataset)
- ํ์ดํ ์น์ Dataset ์์
- Dataset์ด ๊ตฌ์กฐ๋ฅผ ๋ฐ๋ฆ
- __len__ ๋ฉ์๋ : ๋ฐ์ดํฐ์ ์ ๊ธธ์ด๋ฅผ ๋ฐํํ๋ค.
- __getitem__ ๋ฉ์๋ : ๋ฒกํฐ๋ก ๋ฐ๊พผ ์ฑ์จ์ ๊ตญ์ ์ ํด๋นํ๋ ์ธ๋ฑ์ค๋ฅผ ๋ฐํํ๋ค.
โถ SurnameDataset ์ class ์ค๋ช
class ์ฝ๋๊ฐ ๋๋ฌด ๊ธธ์ด์ ์ดํดํ๊ธฐ ์ํด ํ๋์ฉ ๋ฏ์ด์ ์ค๋ช ํ ์์ ์ด๋ค. ๋ง์ง๋ง์ <๋๋ณด๊ธฐ>๋ฅผ ํด๋ฆญํ๋ฉด class๋ก ์ด์ด์ง ์ ์ฒด ์ฝ๋๋ฅผ ๋ณผ ์ ์๋ค.
1. __init__ ๋ฉ์๋
def __init__(self, surname_df, vectorizer):
"""
๋งค๊ฐ๋ณ์:
surname_df (pandas.DataFrame): ๋ฐ์ดํฐ์
vectorizer (SurnameVectorizer): SurnameVectorizer ๊ฐ์ฒด
"""
self.surname_df = surname_df
self._vectorizer = vectorizer
self.train_df = self.surname_df[self.surname_df.split=='train']
self.train_size = len(self.train_df)
self.val_df = self.surname_df[self.surname_df.split=='val']
self.validation_size = len(self.val_df)
self.test_df = self.surname_df[self.surname_df.split=='test']
self.test_size = len(self.test_df)
self._lookup_dict = {'train': (self.train_df, self.train_size),
'val': (self.val_df, self.validation_size),
'test': (self.test_df, self.test_size)}
self.set_split('train')
# ํด๋์ค ๊ฐ์ค์น ๊ตฌํ๊ธฐ
# surnaem_df์ nationality ๊ฐฏ์๋ฅผ dict ํํ๋ก ๋ฃ์ด์ค
class_counts = surname_df.nationality.value_counts().to_dict()
def sort_key(item):
# _vectorizer = SurnameVectorizer ๊ฐ์ฒด๋ฅผ ์๋ฏธํจ
# nationality_vocab = ๊ตญ์ ์ ์ ์์ ๋งคํํ๋ Vocabulary ๊ฐ์ฒด (class SurnameVectorizer)
# look_up_token = ํ ํฐ์ ๋์ํ๋ ์ธ๋ฑ์ค๋ฅผ ์ถ์ถํ๋ ๋ฉ์๋ (class vocabulary)
return self._vectorizer.nationality_vocab.lookup_token(item[0])
# ์ธ๋ฑ์ค๋ฅผ key๋ก ์ง์ ํด์ dic์์ ์์ผ๋ก ๊บผ๋ด์์ ์ ๋ ฌ
sorted_counts = sorted(class_counts.items(), key=sort_key)
# ๋น๋์๋ฅผ ๊ธฐ์ค์ผ๋ก ํด๋์ค์ ๊ฐ์ค์น ๊ตฌํ๊ธฐ
frequencies = [count for _, count in sorted_counts]
self.class_weights = 1.0 / torch.tensor(frequencies, dtype=torch.float32)
- ์์ฑ์ ๋ฉ์๋
- ๋งค๊ฐ๋ณ์๋ก surname_df (๋ฐ์ดํฐ์ )๊ณผ vectorizer(SurnameVectorizer ๊ฐ์ฒด) ๋ฅผ ๋ฐ์
- train_df / train_size , val_df / val_size , test_df / test_size ์ ์ ์
- ๋ด๋ถ์ def sort_key(item) ๋ฉ์๋ ์กด์ฌ
- ํด๋์ค ๊ฐ์ค์น๋ฅผ ๊ตฌํ๊ธฐ ์ํด์ ์ฌ์ฉ
2. load_dataset_and_make_vectorizer ๋ฉ์๋
@classmethod
def load_dataset_and_make_vectorizer(cls, surname_csv):
# cls = SurnameDataset class ๋ฅผ ๋ฐ์
""" ๋ฐ์ดํฐ์
์ ๋ก๋ํ๊ณ ์๋ก์ด SurnameVectorizer ๊ฐ์ฒด๋ฅผ ๋ง๋ญ๋๋ค
๋งค๊ฐ๋ณ์:
review_csv (str): ๋ฐ์ดํฐ์
์ ์์น
๋ฐํ๊ฐ:
SurnameDataset์ ์ธ์คํด์ค
"""
surname_df = pd.read_csv(surname_csv)
train_surname_df = surname_df[surname_df.split=='train']
# SurnameVectorizer.from_data ์ ๋ฐํ๊ฐ : surname_vocab, nationality_vocab
return cls(surname_df, SurnameVectorizer.from_dataframe(train_surname_df))
- ๋ฐํ๊ฐ์ด SurnameDataset์ ์ธ์คํด์ค
- ์๋์์ ๋ณด๋ฉด ๋ฐ์ดํฐ์ ์ ์๋ก ๋ง๋ค ๋ ํด๋น ๋ฉ์๋๋ฅผ ์ฌ์ฉํ๋ค.
- dataset = SurnameDataset.load_dataset_and_make_vectorizer(train_df)
- ํด๋น ๋ฐํ๊ฐ์ surname_vocab / nationality_vocab ์ด๋ฏ๋ก ๋ฐ์ดํฐ์ ์ด ๋ง๋ค์ด์ง๋ ๊ฒ์ด๋ค.
- ๊ทธ๋์ SurnameDataset์ ์ธ์คํด์ค๋ผ๊ณ ํ ๊ฒ์ด๋ค.
3. load_dataset_and_load_vectorizer ๋ฉ์๋
@classmethod
def load_dataset_and_load_vectorizer(cls, surname_csv, vectorizer_filepath):
"""๋ฐ์ดํฐ์
์ ๋ก๋ํ๊ณ ์๋ก์ด SurnameVectorizer ๊ฐ์ฒด๋ฅผ ๋ง๋ญ๋๋ค.
์บ์๋ SurnameVectorizer ๊ฐ์ฒด๋ฅผ ์ฌ์ฌ์ฉํ ๋ ์ฌ์ฉํฉ๋๋ค.
๋งค๊ฐ๋ณ์:
surname_csv (str): ๋ฐ์ดํฐ์
์ ์์น
vectorizer_filepath (str): SurnameVectorizer ๊ฐ์ฒด์ ์ ์ฅ ์์น
๋ฐํ๊ฐ:
SurnameDataset์ ์ธ์คํด์ค
"""
surname_df = pd.read_csv(surname_csv)
# load_vectorizer_only ๋ฉ์๋ : ํ์ผ์์ surnamevectorizer ๊ฐ์ฒด๋ฅผ ๋ก๋ํ๋ ๋ฉ์๋
# ๋ฐํ๊ฐ์ด SurnameVectorizer.from_serializable(json.load(fp))
# SurnameVectorizer.from_serializable : surnameVectorizer ์ ์ธ์คํด์ค๋ฅผ ๋ฐํ
vectorizer = cls.load_vectorizer_only(vectorizer_filepath)
# ๊ฒฐ๊ตญ vectorizer ๊ฐ Dataset์ ์ธ์คํด์ค๊ฐ ๋์ด ๋ฐํ๋จ.
return cls(surname_df, vectorizer)
- vectorizer_filepath๋ฅผ ๋ฐ์์ SurnameVectorizer ์์ surname_vocab / nationality_vocab ์ด ๋์ด SurnameVectorizer ์ ์ธ์คํด์ค๊ฐ ๋๊ณ ๋ง์ง๋ง return ๋จ๊ณ์์ ๊ฒฐ๊ตญ SurnameDataset ์ ์ธ์คํด์ค๋ก ๋ฐํ๋๋ค. ( cls๊ฐ SurnameDataset์ ์๋ฏธ)
4. load_vectorizer_only / save_vectoriezer ๋ฉ์๋
- load_vectorizer_only ๋ ํ์ผ์์ vectoriezer ๊ฐ์ฒด๋ฅผ ๋ก๋ํ๋ ๋ฉ์๋๋ก, surname_vocab / nationality_vocab์ ์ ์ํ๋ ๊ฐ์ฒด๋ผ๊ณ ์๊ฐ
- save_vectorier ์ ํ์ผ์์ ์ง๋ ฌํ๋ dict๋ก vocab์ ๋ง๋ค์ด์ฃผ๋ ๋ฉ์๋๋ผ๊ณ ์ดํดํ๋ฉด ๋จ
@staticmethod
def load_vectorizer_only(vectorizer_filepath):
"""ํ์ผ์์ SurnameVectorizer ๊ฐ์ฒด๋ฅผ ๋ก๋ํ๋ ์ ์ ๋ฉ์๋
๋งค๊ฐ๋ณ์:
vectorizer_filepath (str): ์ง๋ ฌํ๋ SurnameVectorizer ๊ฐ์ฒด์ ์์น
๋ฐํ๊ฐ:
SurnameVectorizer์ ์ธ์คํด์ค
"""
with open(vectorizer_filepath) as fp:
# from_serializable : filepath์์ vocab์ ๋ฐ์์์ vocab์ผ๋ก ๋ง๋ค์ด์ฃผ๋ฉด์ vectorizer ์ ์ธ์คํด์ค๋ก ๋ฐํ
# from_serializable๋ ์ง๋ ฌํ๋ ๊ฒ๋ค์ vocab์ผ๋ก ๋ฐํํ๋ ๋ฉ์๋๋ผ๊ณ ์ดํดํ๋ฉด ๋จ
# ๋ฐํ : surname_vocab / nationality_vocab
return SurnameVectorizer.from_serializable(json.load(fp))
def save_vectorizer(self, vectorizer_filepath):
""" SurnameVectorizer ๊ฐ์ฒด๋ฅผ json ํํ๋ก ๋์คํฌ์ ์ ์ฅํฉ๋๋ค
๋งค๊ฐ๋ณ์:
vectorizer_filepath (str): SurnameVectorizer ๊ฐ์ฒด์ ์ ์ฅ ์์น
"""
with open(vectorizer_filepath, "w") as fp:
# to_serializable : ์ง๋ ฌํ๋ dic์ผ๋ก vocab์ ๋ง๋ค์ด์ฃผ๋ ๋ฉ์๋๋ผ๊ณ ์ดํดํ๋ฉด ๋จ
json.dump(self._vectorizer.to_serializable(), fp)
5. get_vectorizer / set_split
def get_vectorizer(self):
""" ๋ฒกํฐ ๋ณํ ๊ฐ์ฒด๋ฅผ ๋ฐํํฉ๋๋ค """
return self._vectorizer
def set_split(self, split="train"):
""" ๋ฐ์ดํฐํ๋ ์์ ์๋ ์ด์ ์ฌ์ฉํด ๋ถํ ์ธํธ๋ฅผ ์ ํํฉ๋๋ค
๋งค๊ฐ๋ณ์:
split (str): "train", "val", "test" ์ค ํ๋
"""
self._target_split = split
self._target_df, self._target_size = self._lookup_dict[split]
6. __len__ / __getitem__
- Dataset์ ๊ธฐ๋ณธ ๊ตฌ์ฑ โ
- __getitem__ : ๋ฒกํฐ๋ก ๋ฐ๊พผ ์ฑ์จ์ ๊ตญ์ ์ ํด๋นํ๋ ์ธ๋ฑ์ค๋ฅด ๋ฐํํด์ค๋ค.
def __len__(self):
return self._target_size
def __getitem__(self, index):
""" ํ์ดํ ์น ๋ฐ์ดํฐ์
์ ์ฃผ์ ์ง์
๋ฉ์๋
๋งค๊ฐ๋ณ์:
index (int): ๋ฐ์ดํฐ ํฌ์ธํธ์ ์ธ๋ฑ์ค
๋ฐํ๊ฐ:
๋ฐ์ดํฐ ํฌ์ธํธ์ ํน์ฑ(x_surname)๊ณผ ๋ ์ด๋ธ(y_nationality)๋ก ์ด๋ฃจ์ด์ง ๋์
๋๋ฆฌ
"""
row = self._target_df.iloc[index]
surname_vector = \
self._vectorizer.vectorize(row.surname)
nationality_index = \
self._vectorizer.nationality_vocab.lookup_token(row.nationality)
return {'x_surname': surname_vector,
'y_nationality': nationality_index}
7. get_num_batches ๋ฉ์๋
def get_num_batches(self, batch_size):
""" ๋ฐฐ์น ํฌ๊ธฐ๊ฐ ์ฃผ์ด์ง๋ฉด ๋ฐ์ดํฐ์
์ผ๋ก ๋ง๋ค ์ ์๋ ๋ฐฐ์น ๊ฐ์๋ฅผ ๋ฐํํฉ๋๋ค
๋งค๊ฐ๋ณ์:
batch_size (int)
๋ฐํ๊ฐ:
๋ฐฐ์น ๊ฐ์
"""
return len(self) // batch_size
โผ SurnameDataset ์ ์ฒด ์ฝ๋ ๋ณด๊ธฐ
class SurnameDataset(Dataset):
def __init__(self, surname_df, vectorizer):
"""
๋งค๊ฐ๋ณ์:
surname_df (pandas.DataFrame): ๋ฐ์ดํฐ์
vectorizer (SurnameVectorizer): SurnameVectorizer ๊ฐ์ฒด
"""
self.surname_df = surname_df
self._vectorizer = vectorizer
self.train_df = self.surname_df[self.surname_df.split=='train']
self.train_size = len(self.train_df)
self.val_df = self.surname_df[self.surname_df.split=='val']
self.validation_size = len(self.val_df)
self.test_df = self.surname_df[self.surname_df.split=='test']
self.test_size = len(self.test_df)
self._lookup_dict = {'train': (self.train_df, self.train_size),
'val': (self.val_df, self.validation_size),
'test': (self.test_df, self.test_size)}
self.set_split('train')
# ํด๋์ค ๊ฐ์ค์น ๊ตฌํ๊ธฐ
# surnaem_df์ nationality ๊ฐฏ์๋ฅผ dict ํํ๋ก ๋ฃ์ด์ค
class_counts = surname_df.nationality.value_counts().to_dict()
def sort_key(item):
# _vectorizer = SurnameVectorizer ๊ฐ์ฒด๋ฅผ ์๋ฏธํจ
# nationality_vocab = ๊ตญ์ ์ ์ ์์ ๋งคํํ๋ Vocabulary ๊ฐ์ฒด (class SurnameVectorizer)
# look_up_token = ํ ํฐ์ ๋์ํ๋ ์ธ๋ฑ์ค๋ฅผ ์ถ์ถํ๋ ๋ฉ์๋ (class vocabulary)
return self._vectorizer.nationality_vocab.lookup_token(item[0])
# ์ธ๋ฑ์ค๋ฅผ key๋ก ์ง์ ํด์ dic์์ ์์ผ๋ก ๊บผ๋ด์์ ์ ๋ ฌ
sorted_counts = sorted(class_counts.items(), key=sort_key)
# ๋น๋์๋ฅผ ๊ธฐ์ค์ผ๋ก ํด๋์ค์ ๊ฐ์ค์น ๊ตฌํ๊ธฐ
frequencies = [count for _, count in sorted_counts]
self.class_weights = 1.0 / torch.tensor(frequencies, dtype=torch.float32)
@classmethod
def load_dataset_and_make_vectorizer(cls, surname_csv):
# cls = SurnameDataset class ๋ฅผ ๋ฐ์
""" ๋ฐ์ดํฐ์
์ ๋ก๋ํ๊ณ ์๋ก์ด SurnameVectorizer ๊ฐ์ฒด๋ฅผ ๋ง๋ญ๋๋ค
๋งค๊ฐ๋ณ์:
review_csv (str): ๋ฐ์ดํฐ์
์ ์์น
๋ฐํ๊ฐ:
SurnameDataset์ ์ธ์คํด์ค
"""
surname_df = pd.read_csv(surname_csv)
train_surname_df = surname_df[surname_df.split=='train']
# SurnameVectorizer.from_data ์ ๋ฐํ๊ฐ : surname_vocab, nationality_vocab
# ์๋์์ ๋ณด๋ฉด ๋ฐ์ดํฐ์
์ ์๋ก ๋ง๋ค๋ ํด๋น ๋ฉ์๋๋ฅผ ์ฌ์ฉํ๋๋ฐ,
# dataset = surnameDataset.load_dataset_and_make_vectorizer(train) ์ด๋ ๊ฒ ํจ์ผ๋ก์จ
# surname_vocab, nationality vocab์ด ์์ฑ๋๋ ๊ฑฐ๋๊น ๋ฐ์ดํฐ์
์ด ๋ง๋ค์ด์ง๋ ๊ฑฐ์. ์ฆ dataset์ ์ธ์คํด์ค
return cls(surname_df, SurnameVectorizer.from_dataframe(train_surname_df))
@classmethod
def load_dataset_and_load_vectorizer(cls, surname_csv, vectorizer_filepath):
"""๋ฐ์ดํฐ์
์ ๋ก๋ํ๊ณ ์๋ก์ด SurnameVectorizer ๊ฐ์ฒด๋ฅผ ๋ง๋ญ๋๋ค.
์บ์๋ SurnameVectorizer ๊ฐ์ฒด๋ฅผ ์ฌ์ฌ์ฉํ ๋ ์ฌ์ฉํฉ๋๋ค.
๋งค๊ฐ๋ณ์:
surname_csv (str): ๋ฐ์ดํฐ์
์ ์์น
vectorizer_filepath (str): SurnameVectorizer ๊ฐ์ฒด์ ์ ์ฅ ์์น
๋ฐํ๊ฐ:
SurnameDataset์ ์ธ์คํด์ค
"""
surname_df = pd.read_csv(surname_csv)
# load_vectorizer_only ๋ฉ์๋ : ํ์ผ์์ surnamevectorizer ๊ฐ์ฒด๋ฅผ ๋ก๋ํ๋ ๋ฉ์๋
# ๋ฐํ๊ฐ์ด SurnameVectorizer.from_serializable(json.load(fp))
# SurnameVectorizer.from_serializable : surnameVectorizer ์ ์ธ์คํด์ค๋ฅผ ๋ฐํ
vectorizer = cls.load_vectorizer_only(vectorizer_filepath)
# ๊ฒฐ๊ตญ vectorizer ๊ฐ Dataset์ ์ธ์คํด์ค๊ฐ ๋์ด ๋ฐํ๋จ.
return cls(surname_df, vectorizer)
@staticmethod
def load_vectorizer_only(vectorizer_filepath):
"""ํ์ผ์์ SurnameVectorizer ๊ฐ์ฒด๋ฅผ ๋ก๋ํ๋ ์ ์ ๋ฉ์๋
๋งค๊ฐ๋ณ์:
vectorizer_filepath (str): ์ง๋ ฌํ๋ SurnameVectorizer ๊ฐ์ฒด์ ์์น
๋ฐํ๊ฐ:
SurnameVectorizer์ ์ธ์คํด์ค
"""
with open(vectorizer_filepath) as fp:
# from_serializable : filepath์์ vocab์ ๋ฐ์์์ vocab์ผ๋ก ๋ง๋ค์ด์ฃผ๋ฉด์ vectorizer ์ ์ธ์คํด์ค๋ก ๋ฐํ
# from_serializable๋ ์ง๋ ฌํ๋ ๊ฒ๋ค์ vocab์ผ๋ก ๋ฐํํ๋ ๋ฉ์๋๋ผ๊ณ ์ดํดํ๋ฉด ๋จ
# ๋ฐํ : surname_vocab / nationality_vocab
return SurnameVectorizer.from_serializable(json.load(fp))
def save_vectorizer(self, vectorizer_filepath):
""" SurnameVectorizer ๊ฐ์ฒด๋ฅผ json ํํ๋ก ๋์คํฌ์ ์ ์ฅํฉ๋๋ค
๋งค๊ฐ๋ณ์:
vectorizer_filepath (str): SurnameVectorizer ๊ฐ์ฒด์ ์ ์ฅ ์์น
"""
with open(vectorizer_filepath, "w") as fp:
# to_serializable : ์ง๋ ฌํ๋ dic์ผ๋ก vocab์ ๋ง๋ค์ด์ฃผ๋ ๋ฉ์๋๋ผ๊ณ ์ดํดํ๋ฉด ๋จ
json.dump(self._vectorizer.to_serializable(), fp)
def get_vectorizer(self):
""" ๋ฒกํฐ ๋ณํ ๊ฐ์ฒด๋ฅผ ๋ฐํํฉ๋๋ค """
return self._vectorizer
def set_split(self, split="train"):
""" ๋ฐ์ดํฐํ๋ ์์ ์๋ ์ด์ ์ฌ์ฉํด ๋ถํ ์ธํธ๋ฅผ ์ ํํฉ๋๋ค
๋งค๊ฐ๋ณ์:
split (str): "train", "val", "test" ์ค ํ๋
"""
self._target_split = split
self._target_df, self._target_size = self._lookup_dict[split]
def __len__(self):
return self._target_size
def __getitem__(self, index):
""" ํ์ดํ ์น ๋ฐ์ดํฐ์
์ ์ฃผ์ ์ง์
๋ฉ์๋
๋งค๊ฐ๋ณ์:
index (int): ๋ฐ์ดํฐ ํฌ์ธํธ์ ์ธ๋ฑ์ค
๋ฐํ๊ฐ:
๋ฐ์ดํฐ ํฌ์ธํธ์ ํน์ฑ(x_surname)๊ณผ ๋ ์ด๋ธ(y_nationality)๋ก ์ด๋ฃจ์ด์ง ๋์
๋๋ฆฌ
"""
row = self._target_df.iloc[index]
surname_vector = \
self._vectorizer.vectorize(row.surname)
nationality_index = \
self._vectorizer.nationality_vocab.lookup_token(row.nationality)
return {'x_surname': surname_vector,
'y_nationality': nationality_index}
def get_num_batches(self, batch_size):
""" ๋ฐฐ์น ํฌ๊ธฐ๊ฐ ์ฃผ์ด์ง๋ฉด ๋ฐ์ดํฐ์
์ผ๋ก ๋ง๋ค ์ ์๋ ๋ฐฐ์น ๊ฐ์๋ฅผ ๋ฐํํฉ๋๋ค
๋งค๊ฐ๋ณ์:
batch_size (int)
๋ฐํ๊ฐ:
๋ฐฐ์น ๊ฐ์
"""
return len(self) // batch_size
def generate_batches(dataset, batch_size, shuffle=True,
drop_last=True, device="cpu"):
"""
ํ์ดํ ์น DataLoader๋ฅผ ๊ฐ์ธ๊ณ ์๋ ์ ๋๋ ์ดํฐ ํจ์.
๊ฑฑ ํ
์๋ฅผ ์ง์ ๋ ์ฅ์น๋ก ์ด๋ํฉ๋๋ค.
"""
dataloader = DataLoader(dataset=dataset, batch_size=batch_size,
shuffle=shuffle, drop_last=drop_last)
for data_dict in dataloader:
out_data_dict = {}
for name, tensor in data_dict.items():
out_data_dict[name] = data_dict[name].to(device)
yield out_data_dict
<NEXT>
์ด๋ฒ ํฌ์คํ ์์๋ ์ฑ์จ ๋ฐ์ดํฐ์ ์ ํ์ดํ ์น๋ก ์ฌ์ฉํ๊ธฐ์ํด custom Dataset ํด๋์ค๋ฅผ ๋ง๋ ์ฝ๋๋ฅผ ์ดํด๋ณด์๋ค. ๋ค์ ํฌ์คํ ์์๋ ํด๋น custom Dataset์์ ํ์๋ก ํ๋ ์๋ ๋ ๊ฐ์ ํด๋์ค๋ฅผ ๋ค๋ค๋ณด๊ฒ ๋ค.
- Vocabulary class : ๋จ์ด๋ฅผ ํด๋น ์ ์๋ก ๋งคํํ๋๋ฐ ์ฌ์ฉํ๋ ํด๋์ค
- Vectorizer class : vocabulary๋ฅผ ์ ์ฉํ์ฌ ์ฑ์จ ๋ฌธ์์ด์ ๋ฒกํฐ๋ก ๋ฐ๊พธ๋ ํด๋์ค