Potato
์•ˆ๋…•ํ•˜์„ธ์š”, ๊ฐ์žก๋‹ˆ๋‹ค?๐Ÿฅ” ^___^ ๐Ÿ˜บ github ๋ฐ”๋กœ๊ฐ€๊ธฐ ๐Ÿ‘‰๐Ÿป

AI study/์ž์—ฐ์–ด ์ฒ˜๋ฆฌ (NLP)

[NLP] MLP๋กœ ์„ฑ์”จ ๋ถ„๋ฅ˜ํ•˜๊ธฐ (1) (feat.ํŒŒ์ดํ† ์น˜๋กœ ๋ฐฐ์šฐ๋Š” ์ž์—ฐ์–ด์ฒ˜๋ฆฌ)

๊ฐ์ž ๐Ÿฅ” 2021. 7. 23. 15:54
๋ฐ˜์‘ํ˜•

-- ๋ณธ ํฌ์ŠคํŒ…์€ ํŒŒ์ดํ† ์น˜๋กœ ๋ฐฐ์šฐ๋Š” ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ (ํ•œ๋น›๋ฏธ๋””์–ด) ์ฑ…์„ ์ฐธ๊ณ ํ•ด์„œ ์ž‘์„ฑ๋œ ๊ธ€์ž…๋‹ˆ๋‹ค.
-- ์†Œ์Šค์ฝ”๋“œ ) https://github.com/rickiepark/nlp-with-pytorch

 

GitHub - rickiepark/nlp-with-pytorch: <ํŒŒ์ดํ† ์น˜๋กœ ๋ฐฐ์šฐ๋Š” ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ>(ํ•œ๋น›๋ฏธ๋””์–ด, 2021)์˜ ์†Œ์Šค ์ฝ”๋“œ๋ฅผ

<ํŒŒ์ดํ† ์น˜๋กœ ๋ฐฐ์šฐ๋Š” ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ>(ํ•œ๋น›๋ฏธ๋””์–ด, 2021)์˜ ์†Œ์Šค ์ฝ”๋“œ๋ฅผ ์œ„ํ•œ ์ €์žฅ์†Œ์ž…๋‹ˆ๋‹ค. - GitHub - rickiepark/nlp-with-pytorch: <ํŒŒ์ดํ† ์น˜๋กœ ๋ฐฐ์šฐ๋Š” ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ>(ํ•œ๋น›๋ฏธ๋””์–ด, 2021)์˜ ์†Œ์Šค ์ฝ”๋“œ๋ฅผ ์œ„ํ•œ ์ €์žฅ

github.com

 

1. ์„ฑ์”จ ๋ฐ์ดํ„ฐ์…‹

  • 18๊ฐœ ๊ตญ์ ์˜ ์„ฑ์”จ 10,000๊ฐœ๋ฅผ ๋ชจ์€ ์„ฑ์”จ ๋ฐ์ดํ„ฐ์…‹
  • ๋งค์šธ ๋ถˆ๊ท ํ˜•ํ•จ
    • ์ตœ์ƒ์œ„ ํด๋ž˜์Šค 3๊ฐœ๊ฐ€ ๋ฐ์ดํ„ฐ์˜ 60%๋ฅผ ์ฐจ์ง€
    • 27%๊ฐ€ ์˜์–ด, 21%๊ฐ€ ๋Ÿฌ์‹œ์•„์–ด, 14%๊ฐ€ ์•„๋ž์–ด
    • ๋‚˜๋จธ์ง€ 15๊ฐœ๊ตญ์˜ ๋นˆ๋„๋Š” ๊ณ„์† ๊ฐ์†Œ
    • ์–ธ์–ด ์ž์ฒด์œผ ์†์„ฑ์ด๊ธฐ๋„ ํ•จ (๋งŽ์ด ์‚ฌ์šฉํ•˜๋Š” ์–ธ์–ด์ผ์ˆ˜๋ก ๋งŽ์„ ์ˆ˜ ๋ฐ–์—)
  • ์ถœ์‹ ๊ตญ๊ฐ€์™€ ์„ฑ์”จ ๋งž์ถค๋ฒ• ์‚ฌ์ด์— ์˜๋ฏธ๊ฐ€ ์žˆ๊ณ  ์ง๊ด€์ ์ธ ๊ด€๊ณ„๊ฐ€ ์žˆ์Œ
    • ์ฆ‰, ๊ตญ์ ๊ณผ ๊ด€๊ณ„๊ฐ€ ์žˆ๋Š” ์„ฑ์”จ๊ฐ€ ์กด์žฌ

 

2. ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ

์•„๋ž˜์™€ ๊ฐ™์ด ์ฒ˜๋ฆฌํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ–ˆ๋‹ค. ํ•ด๋‹น ์ฝ”๋“œ์—์„œ๋Š” ์•„๋ž˜ ๋ฐฉ์‹์„ ๊ฑฐ์ณ์ ธ์„œ ๋ฏธ๋ฆฌ ์ฒ˜๋ฆฌ๋œ ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค. ๋ฐ์ดํ„ฐ๋ฅผ ๋‚˜๋ˆ„๋Š” ์ฝ”๋“œ๋Š” ์•„๋ž˜

  1. ๋ถˆ๊ท ํ˜• ์ค„์ด๊ธฐ 
    • ์›๋ณธ ๋ฐ์ดํ„ฐ์…‹์€ 70% ์ด์ƒ์ด ๋Ÿฌ์‹œ์•„ ์ด๋ฆ„
    • ์ƒ˜ํ”Œ๋ง์ด ํŽธํ–ฅ๋˜์—ˆ๊ฑฐ๋‚˜ ๋Ÿฌ์‹œ์•„์— ๊ณ ์œ ํ•œ ์„ฑ์”จ๊ฐ€ ๋งŽ๊ธฐ ๋•Œ๋ฌธ์œผ๋กœ ์ถ”์ •
    • ๋Ÿฌ์‹œ์•„ ์„ฑ์”จ์˜ ๋ถ€๋ถ„์ง‘ํ•ฉ์„ ๋žœ๋คํ•˜๊ฒŒ ์„ ํƒํ•˜์—ฌ ํŽธ์ค‘๋œ ํด๋ž˜์Šค๋ฅผ "์„œ๋ธŒ์ƒ˜ํ”Œ๋ง" ํ•ด์คŒ
  2. ๊ตญ์ ์„ ๊ธฐ๋ฐ˜์œผ๋กœ ๋ชจ์•„ 3์„ธํŠธ๋กœ split
    • train = 70%
    • validation = 15%
    • test = 15%
  3. ์„ธํŠธ๊ฐ„ ๋ ˆ์ด๋ธ” ๋ถ„ํฌ๋ฅผ ๊ณ ๋ฅด๊ฒŒ ์œ ์ง€ ์‹œํ‚ด

โ–ถ ๊ฐ„๋‹จํ•˜๊ฒŒ ๋ฐ์ดํ„ฐ ์‚ดํŽด๋ณด๊ธฐ

# Read raw data
surnames = pd.read_csv(args.raw_dataset_csv, header=0)
surnames.head()

  • surnames ๋ฐ์ดํ„ฐ์…‹์˜ ํ˜•ํƒœ๋Š” ์ด๋ ‡๊ฒŒ ์ƒ๊ฒผ๋‹ค.
# Unique classes
set(surnames.nationality)

 

  • 18๊ฐœ์˜ ๊ตญ๊ฐ€๋Š” ์œ„์™€ ๊ฐ™๋‹ค.
# Write split data to file
final_surnames = pd.DataFrame(final_list)
final_surnames.split.value_counts()

  • ๋ฐ์ดํ„ฐ์˜ ๋ถ„ํฌ๋Š” 70: 15: 15 ๋กœ ๋‚˜๋ˆ„์–ด์ ธ ์žˆ๋Š” ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.

 

3. import

ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ๋ชจ๋‘ ์ž„ํฌํŠธ ํ•ด์ฃผ์—ˆ๋‹ค.

from argparse import Namespace
from collections import Counter
import json
import os
import string

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import tqdm

 

4. Custom Dataset ๋งŒ๋“ค๊ธฐ (SurnameDataset)

  • ํŒŒ์ดํ† ์น˜์˜ Dataset ์ƒ์†
  • Dataset์ด ๊ตฌ์กฐ๋ฅผ ๋”ฐ๋ฆ„
    • __len__ ๋ฉ”์„œ๋“œ : ๋ฐ์ดํ„ฐ์…‹์˜ ๊ธธ์ด๋ฅผ ๋ฐ˜ํ™˜ํ•œ๋‹ค.
    • __getitem__ ๋ฉ”์„œ๋“œ : ๋ฒกํ„ฐ๋กœ ๋ฐ”๊พผ ์„ฑ์”จ์™€ ๊ตญ์ ์— ํ•ด๋‹นํ•˜๋Š” ์ธ๋ฑ์Šค๋ฅผ ๋ฐ˜ํ™˜ํ•œ๋‹ค.

โ–ถ SurnameDataset ์˜ class ์„ค๋ช…

class ์ฝ”๋“œ๊ฐ€ ๋„ˆ๋ฌด ๊ธธ์–ด์„œ ์ดํ•ดํ•˜๊ธฐ ์œ„ํ•ด ํ•˜๋‚˜์”ฉ ๋œฏ์–ด์„œ ์„ค๋ช…ํ•  ์˜ˆ์ •์ด๋‹ค. ๋งˆ์ง€๋ง‰์— <๋”๋ณด๊ธฐ>๋ฅผ ํด๋ฆญํ•˜๋ฉด class๋กœ ์ด์–ด์ง„ ์ „์ฒด ์ฝ”๋“œ๋ฅผ ๋ณผ ์ˆ˜ ์žˆ๋‹ค. 

1. __init__ ๋ฉ”์„œ๋“œ

    def __init__(self, surname_df, vectorizer):
        """
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            surname_df (pandas.DataFrame): ๋ฐ์ดํ„ฐ์…‹
            vectorizer (SurnameVectorizer): SurnameVectorizer ๊ฐ์ฒด
        """
        self.surname_df = surname_df
        self._vectorizer = vectorizer

        self.train_df = self.surname_df[self.surname_df.split=='train']
        self.train_size = len(self.train_df)

        self.val_df = self.surname_df[self.surname_df.split=='val']
        self.validation_size = len(self.val_df)

        self.test_df = self.surname_df[self.surname_df.split=='test']
        self.test_size = len(self.test_df)

        self._lookup_dict = {'train': (self.train_df, self.train_size),
                             'val': (self.val_df, self.validation_size),
                             'test': (self.test_df, self.test_size)}

        self.set_split('train')
        
        # ํด๋ž˜์Šค ๊ฐ€์ค‘์น˜ ๊ตฌํ•˜๊ธฐ
        # surnaem_df์˜ nationality ๊ฐฏ์ˆ˜๋ฅผ dict ํ˜•ํƒœ๋กœ ๋„ฃ์–ด์คŒ
        class_counts = surname_df.nationality.value_counts().to_dict()
        def sort_key(item):
            # _vectorizer = SurnameVectorizer ๊ฐ์ฒด๋ฅผ ์˜๋ฏธํ•จ
            # nationality_vocab = ๊ตญ์ ์„ ์ •์ˆ˜์— ๋งคํ•‘ํ•˜๋Š” Vocabulary ๊ฐ์ฒด (class SurnameVectorizer)
            # look_up_token = ํ† ํฐ์— ๋Œ€์‘ํ•˜๋Š” ์ธ๋ฑ์Šค๋ฅผ ์ถ”์ถœํ•˜๋Š” ๋ฉ”์„œ๋“œ (class vocabulary)
            return self._vectorizer.nationality_vocab.lookup_token(item[0])
        # ์ธ๋ฑ์Šค๋ฅผ key๋กœ ์ง€์ •ํ•ด์„œ dic์—์„œ ์Œ์œผ๋กœ ๊บผ๋‚ด์™€์„œ ์ •๋ ฌ
        sorted_counts = sorted(class_counts.items(), key=sort_key)
        # ๋นˆ๋„์ˆ˜๋ฅผ ๊ธฐ์ค€์œผ๋กœ ํด๋ž˜์Šค์˜ ๊ฐ€์ค‘์น˜ ๊ตฌํ•˜๊ธฐ
        frequencies = [count for _, count in sorted_counts]
        self.class_weights = 1.0 / torch.tensor(frequencies, dtype=torch.float32)
  • ์ƒ์„ฑ์ž ๋ฉ”์„œ๋“œ
  • ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ surname_df (๋ฐ์ดํ„ฐ์…‹)๊ณผ vectorizer(SurnameVectorizer ๊ฐ์ฒด) ๋ฅผ ๋ฐ›์Œ
  • train_df / train_size , val_df / val_size , test_df / test_size ์„ ์ •์˜
  • ๋‚ด๋ถ€์— def sort_key(item) ๋ฉ”์„œ๋“œ ์กด์žฌ
    • ํด๋ž˜์Šค ๊ฐ€์ค‘์น˜๋ฅผ ๊ตฌํ•˜๊ธฐ ์œ„ํ•ด์„œ ์‚ฌ์šฉ

2. load_dataset_and_make_vectorizer ๋ฉ”์„œ๋“œ 

    @classmethod
    def load_dataset_and_make_vectorizer(cls, surname_csv):
        # cls = SurnameDataset class ๋ฅผ ๋ฐ›์Œ
        """ ๋ฐ์ดํ„ฐ์…‹์„ ๋กœ๋“œํ•˜๊ณ  ์ƒˆ๋กœ์šด SurnameVectorizer ๊ฐ์ฒด๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            review_csv (str): ๋ฐ์ดํ„ฐ์…‹์˜ ์œ„์น˜
        ๋ฐ˜ํ™˜๊ฐ’:
            SurnameDataset์˜ ์ธ์Šคํ„ด์Šค
        """
        surname_df = pd.read_csv(surname_csv)
        train_surname_df = surname_df[surname_df.split=='train']
        # SurnameVectorizer.from_data ์˜ ๋ฐ˜ํ™˜๊ฐ’ : surname_vocab, nationality_vocab
        return cls(surname_df, SurnameVectorizer.from_dataframe(train_surname_df))
  • ๋ฐ˜ํ™˜๊ฐ’์ด SurnameDataset์˜ ์ธ์Šคํ„ด์Šค
    • ์•„๋ž˜์—์„œ ๋ณด๋ฉด ๋ฐ์ดํ„ฐ์…‹์„ ์ƒˆ๋กœ ๋งŒ๋“ค ๋•Œ ํ•ด๋‹น ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค.
    • dataset = SurnameDataset.load_dataset_and_make_vectorizer(train_df)
    • ํ•ด๋‹น ๋ฐ˜ํ™˜๊ฐ’์€ surname_vocab / nationality_vocab ์ด๋ฏ€๋กœ ๋ฐ์ดํ„ฐ์…‹์ด ๋งŒ๋“ค์–ด์ง€๋Š” ๊ฒƒ์ด๋‹ค.
    • ๊ทธ๋ž˜์„œ SurnameDataset์˜ ์ธ์Šคํ„ด์Šค๋ผ๊ณ  ํ•œ ๊ฒƒ์ด๋‹ค.

3. load_dataset_and_load_vectorizer ๋ฉ”์„œ๋“œ

    @classmethod
    def load_dataset_and_load_vectorizer(cls, surname_csv, vectorizer_filepath):
        """๋ฐ์ดํ„ฐ์…‹์„ ๋กœ๋“œํ•˜๊ณ  ์ƒˆ๋กœ์šด SurnameVectorizer ๊ฐ์ฒด๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
        ์บ์‹œ๋œ SurnameVectorizer ๊ฐ์ฒด๋ฅผ ์žฌ์‚ฌ์šฉํ•  ๋•Œ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            surname_csv (str): ๋ฐ์ดํ„ฐ์…‹์˜ ์œ„์น˜
            vectorizer_filepath (str): SurnameVectorizer ๊ฐ์ฒด์˜ ์ €์žฅ ์œ„์น˜
        ๋ฐ˜ํ™˜๊ฐ’:
            SurnameDataset์˜ ์ธ์Šคํ„ด์Šค
        """
        surname_df = pd.read_csv(surname_csv)
        # load_vectorizer_only ๋ฉ”์„œ๋“œ : ํŒŒ์ผ์—์„œ surnamevectorizer ๊ฐ์ฒด๋ฅผ ๋กœ๋“œํ•˜๋Š” ๋ฉ”์„œ๋“œ
        # ๋ฐ˜ํ™˜๊ฐ’์ด SurnameVectorizer.from_serializable(json.load(fp))
        # SurnameVectorizer.from_serializable : surnameVectorizer ์˜ ์ธ์Šคํ„ด์Šค๋ฅผ ๋ฐ˜ํ™˜
        vectorizer = cls.load_vectorizer_only(vectorizer_filepath)
        # ๊ฒฐ๊ตญ vectorizer ๊ฐ€ Dataset์˜ ์ธ์Šคํ„ด์Šค๊ฐ€ ๋˜์–ด ๋ฐ˜ํ™˜๋จ.
        return cls(surname_df, vectorizer)
  • vectorizer_filepath๋ฅผ ๋ฐ›์•„์„œ SurnameVectorizer ์—์„œ surname_vocab / nationality_vocab ์ด ๋˜์–ด SurnameVectorizer ์˜ ์ธ์Šคํ„ด์Šค๊ฐ€ ๋˜๊ณ  ๋งˆ์ง€๋ง‰ return ๋‹จ๊ณ„์—์„œ ๊ฒฐ๊ตญ SurnameDataset ์˜ ์ธ์Šคํ„ด์Šค๋กœ ๋ฐ˜ํ™˜๋œ๋‹ค. ( cls๊ฐ€ SurnameDataset์„ ์˜๋ฏธ)

4. load_vectorizer_only / save_vectoriezer ๋ฉ”์„œ๋“œ

  • load_vectorizer_only ๋Š” ํŒŒ์ผ์—์„œ vectoriezer ๊ฐ์ฒด๋ฅผ ๋กœ๋“œํ•˜๋Š” ๋ฉ”์„œ๋“œ๋กœ, surname_vocab / nationality_vocab์„ ์ •์˜ํ•˜๋Š” ๊ฐ์ฒด๋ผ๊ณ  ์ƒ๊ฐ
  • save_vectorier ์€ ํŒŒ์ผ์—์„œ ์ง๋ ฌํ™”๋œ dict๋กœ vocab์„ ๋งŒ๋“ค์–ด์ฃผ๋Š” ๋ฉ”์„œ๋“œ๋ผ๊ณ  ์ดํ•ดํ•˜๋ฉด ๋จ
    @staticmethod
    def load_vectorizer_only(vectorizer_filepath):
        """ํŒŒ์ผ์—์„œ SurnameVectorizer ๊ฐ์ฒด๋ฅผ ๋กœ๋“œํ•˜๋Š” ์ •์  ๋ฉ”์„œ๋“œ
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            vectorizer_filepath (str): ์ง๋ ฌํ™”๋œ SurnameVectorizer ๊ฐ์ฒด์˜ ์œ„์น˜
        ๋ฐ˜ํ™˜๊ฐ’:
            SurnameVectorizer์˜ ์ธ์Šคํ„ด์Šค
        """
        with open(vectorizer_filepath) as fp:
            # from_serializable : filepath์—์„œ vocab์„ ๋ฐ›์•„์™€์„œ vocab์œผ๋กœ ๋งŒ๋“ค์–ด์ฃผ๋ฉด์„œ vectorizer ์˜ ์ธ์Šคํ„ด์Šค๋กœ ๋ฐ˜ํ™˜
            # from_serializable๋Š” ์ง๋ ฌํ™”๋œ ๊ฒƒ๋“ค์„ vocab์œผ๋กœ ๋ฐ˜ํ™˜ํ•˜๋Š” ๋ฉ”์„œ๋“œ๋ผ๊ณ  ์ดํ•ดํ•˜๋ฉด ๋จ
            # ๋ฐ˜ํ™˜ : surname_vocab / nationality_vocab
            return SurnameVectorizer.from_serializable(json.load(fp))

    def save_vectorizer(self, vectorizer_filepath):
        """ SurnameVectorizer ๊ฐ์ฒด๋ฅผ json ํ˜•ํƒœ๋กœ ๋””์Šคํฌ์— ์ €์žฅํ•ฉ๋‹ˆ๋‹ค
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            vectorizer_filepath (str): SurnameVectorizer ๊ฐ์ฒด์˜ ์ €์žฅ ์œ„์น˜
        """
        with open(vectorizer_filepath, "w") as fp:
            # to_serializable : ์ง๋ ฌํ™”๋œ dic์œผ๋กœ vocab์„ ๋งŒ๋“ค์–ด์ฃผ๋Š” ๋ฉ”์„œ๋“œ๋ผ๊ณ  ์ดํ•ดํ•˜๋ฉด ๋จ
            json.dump(self._vectorizer.to_serializable(), fp)

 

5. get_vectorizer / set_split

    def get_vectorizer(self):
        """ ๋ฒกํ„ฐ ๋ณ€ํ™˜ ๊ฐ์ฒด๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค """
        return self._vectorizer

    def set_split(self, split="train"):
        """ ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„์— ์žˆ๋Š” ์—ด์„ ์‚ฌ์šฉํ•ด ๋ถ„ํ•  ์„ธํŠธ๋ฅผ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค 
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            split (str): "train", "val", "test" ์ค‘ ํ•˜๋‚˜
        """
        self._target_split = split
        self._target_df, self._target_size = self._lookup_dict[split]

 

6. __len__ / __getitem__

  • Dataset์˜ ๊ธฐ๋ณธ ๊ตฌ์„ฑ โ˜…
  • __getitem__ : ๋ฒกํ„ฐ๋กœ ๋ฐ”๊พผ ์„ฑ์”จ์™€ ๊ตญ์ ์— ํ•ด๋‹นํ•˜๋Š” ์ธ๋ฑ์Šค๋ฅด ๋ฐ˜ํ™˜ํ•ด์ค€๋‹ค.
    def __len__(self):
        return self._target_size

    def __getitem__(self, index):
        """ ํŒŒ์ดํ† ์น˜ ๋ฐ์ดํ„ฐ์…‹์˜ ์ฃผ์š” ์ง„์ž… ๋ฉ”์„œ๋“œ
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            index (int): ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ์˜ ์ธ๋ฑ์Šค
        ๋ฐ˜ํ™˜๊ฐ’:
            ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ์˜ ํŠน์„ฑ(x_surname)๊ณผ ๋ ˆ์ด๋ธ”(y_nationality)๋กœ ์ด๋ฃจ์–ด์ง„ ๋”•์…”๋„ˆ๋ฆฌ
        """
        row = self._target_df.iloc[index]

        surname_vector = \
            self._vectorizer.vectorize(row.surname)

        nationality_index = \
            self._vectorizer.nationality_vocab.lookup_token(row.nationality)

        return {'x_surname': surname_vector,
                'y_nationality': nationality_index}

 

7. get_num_batches ๋ฉ”์„œ๋“œ

    def get_num_batches(self, batch_size):
        """ ๋ฐฐ์น˜ ํฌ๊ธฐ๊ฐ€ ์ฃผ์–ด์ง€๋ฉด ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ๋งŒ๋“ค ์ˆ˜ ์žˆ๋Š” ๋ฐฐ์น˜ ๊ฐœ์ˆ˜๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            batch_size (int)
        ๋ฐ˜ํ™˜๊ฐ’:
            ๋ฐฐ์น˜ ๊ฐœ์ˆ˜
        """
        return len(self) // batch_size

 

โ–ผ  SurnameDataset ์ „์ฒด ์ฝ”๋“œ ๋ณด๊ธฐ

๋”๋ณด๊ธฐ
class SurnameDataset(Dataset):
    def __init__(self, surname_df, vectorizer):
        """
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            surname_df (pandas.DataFrame): ๋ฐ์ดํ„ฐ์…‹
            vectorizer (SurnameVectorizer): SurnameVectorizer ๊ฐ์ฒด
        """
        self.surname_df = surname_df
        self._vectorizer = vectorizer

        self.train_df = self.surname_df[self.surname_df.split=='train']
        self.train_size = len(self.train_df)

        self.val_df = self.surname_df[self.surname_df.split=='val']
        self.validation_size = len(self.val_df)

        self.test_df = self.surname_df[self.surname_df.split=='test']
        self.test_size = len(self.test_df)

        self._lookup_dict = {'train': (self.train_df, self.train_size),
                             'val': (self.val_df, self.validation_size),
                             'test': (self.test_df, self.test_size)}

        self.set_split('train')
        
        # ํด๋ž˜์Šค ๊ฐ€์ค‘์น˜ ๊ตฌํ•˜๊ธฐ
        # surnaem_df์˜ nationality ๊ฐฏ์ˆ˜๋ฅผ dict ํ˜•ํƒœ๋กœ ๋„ฃ์–ด์คŒ
        class_counts = surname_df.nationality.value_counts().to_dict()
        def sort_key(item):
            # _vectorizer = SurnameVectorizer ๊ฐ์ฒด๋ฅผ ์˜๋ฏธํ•จ
            # nationality_vocab = ๊ตญ์ ์„ ์ •์ˆ˜์— ๋งคํ•‘ํ•˜๋Š” Vocabulary ๊ฐ์ฒด (class SurnameVectorizer)
            # look_up_token = ํ† ํฐ์— ๋Œ€์‘ํ•˜๋Š” ์ธ๋ฑ์Šค๋ฅผ ์ถ”์ถœํ•˜๋Š” ๋ฉ”์„œ๋“œ (class vocabulary)
            return self._vectorizer.nationality_vocab.lookup_token(item[0])
        # ์ธ๋ฑ์Šค๋ฅผ key๋กœ ์ง€์ •ํ•ด์„œ dic์—์„œ ์Œ์œผ๋กœ ๊บผ๋‚ด์™€์„œ ์ •๋ ฌ
        sorted_counts = sorted(class_counts.items(), key=sort_key)
        # ๋นˆ๋„์ˆ˜๋ฅผ ๊ธฐ์ค€์œผ๋กœ ํด๋ž˜์Šค์˜ ๊ฐ€์ค‘์น˜ ๊ตฌํ•˜๊ธฐ
        frequencies = [count for _, count in sorted_counts]
        self.class_weights = 1.0 / torch.tensor(frequencies, dtype=torch.float32)

    @classmethod
    def load_dataset_and_make_vectorizer(cls, surname_csv):
        # cls = SurnameDataset class ๋ฅผ ๋ฐ›์Œ
        """ ๋ฐ์ดํ„ฐ์…‹์„ ๋กœ๋“œํ•˜๊ณ  ์ƒˆ๋กœ์šด SurnameVectorizer ๊ฐ์ฒด๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            review_csv (str): ๋ฐ์ดํ„ฐ์…‹์˜ ์œ„์น˜
        ๋ฐ˜ํ™˜๊ฐ’:
            SurnameDataset์˜ ์ธ์Šคํ„ด์Šค
        """
        surname_df = pd.read_csv(surname_csv)
        train_surname_df = surname_df[surname_df.split=='train']
        # SurnameVectorizer.from_data ์˜ ๋ฐ˜ํ™˜๊ฐ’ : surname_vocab, nationality_vocab
        # ์•„๋ž˜์—์„œ ๋ณด๋ฉด ๋ฐ์ดํ„ฐ์…‹์„ ์ƒˆ๋กœ ๋งŒ๋“ค๋•Œ ํ•ด๋‹น ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜๋Š”๋ฐ, 
        # dataset = surnameDataset.load_dataset_and_make_vectorizer(train) ์ด๋ ‡๊ฒŒ ํ•จ์œผ๋กœ์จ
        # surname_vocab, nationality vocab์ด ์ƒ์„ฑ๋˜๋Š” ๊ฑฐ๋‹ˆ๊นŒ ๋ฐ์ดํ„ฐ์…‹์ด ๋งŒ๋“ค์–ด์ง€๋Š” ๊ฑฐ์ž„. ์ฆ‰ dataset์˜ ์ธ์Šคํ„ด์Šค
        return cls(surname_df, SurnameVectorizer.from_dataframe(train_surname_df))

    @classmethod
    def load_dataset_and_load_vectorizer(cls, surname_csv, vectorizer_filepath):
        """๋ฐ์ดํ„ฐ์…‹์„ ๋กœ๋“œํ•˜๊ณ  ์ƒˆ๋กœ์šด SurnameVectorizer ๊ฐ์ฒด๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
        ์บ์‹œ๋œ SurnameVectorizer ๊ฐ์ฒด๋ฅผ ์žฌ์‚ฌ์šฉํ•  ๋•Œ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            surname_csv (str): ๋ฐ์ดํ„ฐ์…‹์˜ ์œ„์น˜
            vectorizer_filepath (str): SurnameVectorizer ๊ฐ์ฒด์˜ ์ €์žฅ ์œ„์น˜
        ๋ฐ˜ํ™˜๊ฐ’:
            SurnameDataset์˜ ์ธ์Šคํ„ด์Šค
        """
        surname_df = pd.read_csv(surname_csv)
        # load_vectorizer_only ๋ฉ”์„œ๋“œ : ํŒŒ์ผ์—์„œ surnamevectorizer ๊ฐ์ฒด๋ฅผ ๋กœ๋“œํ•˜๋Š” ๋ฉ”์„œ๋“œ
        # ๋ฐ˜ํ™˜๊ฐ’์ด SurnameVectorizer.from_serializable(json.load(fp))
        # SurnameVectorizer.from_serializable : surnameVectorizer ์˜ ์ธ์Šคํ„ด์Šค๋ฅผ ๋ฐ˜ํ™˜
        vectorizer = cls.load_vectorizer_only(vectorizer_filepath)
        # ๊ฒฐ๊ตญ vectorizer ๊ฐ€ Dataset์˜ ์ธ์Šคํ„ด์Šค๊ฐ€ ๋˜์–ด ๋ฐ˜ํ™˜๋จ.
        return cls(surname_df, vectorizer)

    @staticmethod
    def load_vectorizer_only(vectorizer_filepath):
        """ํŒŒ์ผ์—์„œ SurnameVectorizer ๊ฐ์ฒด๋ฅผ ๋กœ๋“œํ•˜๋Š” ์ •์  ๋ฉ”์„œ๋“œ
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            vectorizer_filepath (str): ์ง๋ ฌํ™”๋œ SurnameVectorizer ๊ฐ์ฒด์˜ ์œ„์น˜
        ๋ฐ˜ํ™˜๊ฐ’:
            SurnameVectorizer์˜ ์ธ์Šคํ„ด์Šค
        """
        with open(vectorizer_filepath) as fp:
            # from_serializable : filepath์—์„œ vocab์„ ๋ฐ›์•„์™€์„œ vocab์œผ๋กœ ๋งŒ๋“ค์–ด์ฃผ๋ฉด์„œ vectorizer ์˜ ์ธ์Šคํ„ด์Šค๋กœ ๋ฐ˜ํ™˜
            # from_serializable๋Š” ์ง๋ ฌํ™”๋œ ๊ฒƒ๋“ค์„ vocab์œผ๋กœ ๋ฐ˜ํ™˜ํ•˜๋Š” ๋ฉ”์„œ๋“œ๋ผ๊ณ  ์ดํ•ดํ•˜๋ฉด ๋จ
            # ๋ฐ˜ํ™˜ : surname_vocab / nationality_vocab
            return SurnameVectorizer.from_serializable(json.load(fp))

    def save_vectorizer(self, vectorizer_filepath):
        """ SurnameVectorizer ๊ฐ์ฒด๋ฅผ json ํ˜•ํƒœ๋กœ ๋””์Šคํฌ์— ์ €์žฅํ•ฉ๋‹ˆ๋‹ค
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            vectorizer_filepath (str): SurnameVectorizer ๊ฐ์ฒด์˜ ์ €์žฅ ์œ„์น˜
        """
        with open(vectorizer_filepath, "w") as fp:
            # to_serializable : ์ง๋ ฌํ™”๋œ dic์œผ๋กœ vocab์„ ๋งŒ๋“ค์–ด์ฃผ๋Š” ๋ฉ”์„œ๋“œ๋ผ๊ณ  ์ดํ•ดํ•˜๋ฉด ๋จ
            json.dump(self._vectorizer.to_serializable(), fp)

    def get_vectorizer(self):
        """ ๋ฒกํ„ฐ ๋ณ€ํ™˜ ๊ฐ์ฒด๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค """
        return self._vectorizer

    def set_split(self, split="train"):
        """ ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„์— ์žˆ๋Š” ์—ด์„ ์‚ฌ์šฉํ•ด ๋ถ„ํ•  ์„ธํŠธ๋ฅผ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค 
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            split (str): "train", "val", "test" ์ค‘ ํ•˜๋‚˜
        """
        self._target_split = split
        self._target_df, self._target_size = self._lookup_dict[split]

    def __len__(self):
        return self._target_size

    def __getitem__(self, index):
        """ ํŒŒ์ดํ† ์น˜ ๋ฐ์ดํ„ฐ์…‹์˜ ์ฃผ์š” ์ง„์ž… ๋ฉ”์„œ๋“œ
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            index (int): ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ์˜ ์ธ๋ฑ์Šค
        ๋ฐ˜ํ™˜๊ฐ’:
            ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ์˜ ํŠน์„ฑ(x_surname)๊ณผ ๋ ˆ์ด๋ธ”(y_nationality)๋กœ ์ด๋ฃจ์–ด์ง„ ๋”•์…”๋„ˆ๋ฆฌ
        """
        row = self._target_df.iloc[index]

        surname_vector = \
            self._vectorizer.vectorize(row.surname)

        nationality_index = \
            self._vectorizer.nationality_vocab.lookup_token(row.nationality)

        return {'x_surname': surname_vector,
                'y_nationality': nationality_index}

    def get_num_batches(self, batch_size):
        """ ๋ฐฐ์น˜ ํฌ๊ธฐ๊ฐ€ ์ฃผ์–ด์ง€๋ฉด ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ๋งŒ๋“ค ์ˆ˜ ์žˆ๋Š” ๋ฐฐ์น˜ ๊ฐœ์ˆ˜๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            batch_size (int)
        ๋ฐ˜ํ™˜๊ฐ’:
            ๋ฐฐ์น˜ ๊ฐœ์ˆ˜
        """
        return len(self) // batch_size

    
def generate_batches(dataset, batch_size, shuffle=True,
                     drop_last=True, device="cpu"): 
    """
    ํŒŒ์ดํ† ์น˜ DataLoader๋ฅผ ๊ฐ์‹ธ๊ณ  ์žˆ๋Š” ์ œ๋„ˆ๋ ˆ์ดํ„ฐ ํ•จ์ˆ˜.
    ๊ฑฑ ํ…์„œ๋ฅผ ์ง€์ •๋œ ์žฅ์น˜๋กœ ์ด๋™ํ•ฉ๋‹ˆ๋‹ค.
    """
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size,
                            shuffle=shuffle, drop_last=drop_last)

    for data_dict in dataloader:
        out_data_dict = {}
        for name, tensor in data_dict.items():
            out_data_dict[name] = data_dict[name].to(device)
        yield out_data_dict

 

<NEXT>

์ด๋ฒˆ ํฌ์ŠคํŒ…์—์„œ๋Š” ์„ฑ์”จ ๋ฐ์ดํ„ฐ์…‹์„ ํŒŒ์ดํ† ์น˜๋กœ ์‚ฌ์šฉํ•˜๊ธฐ์œ„ํ•ด custom Dataset ํด๋ž˜์Šค๋ฅผ ๋งŒ๋“  ์ฝ”๋“œ๋ฅผ ์‚ดํŽด๋ณด์•˜๋‹ค. ๋‹ค์Œ ํฌ์ŠคํŒ…์—์„œ๋Š” ํ•ด๋‹น custom Dataset์—์„œ ํ•„์š”๋กœ ํ•˜๋Š” ์•„๋ž˜ ๋‘ ๊ฐœ์˜ ํด๋ž˜์Šค๋ฅผ ๋‹ค๋ค„๋ณด๊ฒ ๋‹ค.

  • Vocabulary class : ๋‹จ์–ด๋ฅผ ํ•ด๋‹น ์ •์ˆ˜๋กœ ๋งคํ•‘ํ•˜๋Š”๋ฐ ์‚ฌ์šฉํ•˜๋Š” ํด๋ž˜์Šค
  • Vectorizer class : vocabulary๋ฅผ ์ ์šฉํ•˜์—ฌ ์„ฑ์”จ ๋ฌธ์ž์—ด์„ ๋ฒกํ„ฐ๋กœ ๋ฐ”๊พธ๋Š” ํด๋ž˜์Šค
๋ฐ˜์‘ํ˜•