AI study/์ž์—ฐ์–ด ์ฒ˜๋ฆฌ (NLP)

[NLP] MLP๋กœ ์„ฑ์”จ ๋ถ„๋ฅ˜ํ•˜๊ธฐ (1) (feat.ํŒŒ์ดํ† ์น˜๋กœ ๋ฐฐ์šฐ๋Š” ์ž์—ฐ์–ด์ฒ˜๋ฆฌ)

๊ฐ์ž ๐Ÿฅ” 2021. 7. 23. 15:54
๋ฐ˜์‘ํ˜•

-- ๋ณธ ํฌ์ŠคํŒ…์€ ํŒŒ์ดํ† ์น˜๋กœ ๋ฐฐ์šฐ๋Š” ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ (ํ•œ๋น›๋ฏธ๋””์–ด) ์ฑ…์„ ์ฐธ๊ณ ํ•ด์„œ ์ž‘์„ฑ๋œ ๊ธ€์ž…๋‹ˆ๋‹ค.
-- ์†Œ์Šค์ฝ”๋“œ ) https://github.com/rickiepark/nlp-with-pytorch

 

GitHub - rickiepark/nlp-with-pytorch: <ํŒŒ์ดํ† ์น˜๋กœ ๋ฐฐ์šฐ๋Š” ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ>(ํ•œ๋น›๋ฏธ๋””์–ด, 2021)์˜ ์†Œ์Šค ์ฝ”๋“œ๋ฅผ

<ํŒŒ์ดํ† ์น˜๋กœ ๋ฐฐ์šฐ๋Š” ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ>(ํ•œ๋น›๋ฏธ๋””์–ด, 2021)์˜ ์†Œ์Šค ์ฝ”๋“œ๋ฅผ ์œ„ํ•œ ์ €์žฅ์†Œ์ž…๋‹ˆ๋‹ค. - GitHub - rickiepark/nlp-with-pytorch: <ํŒŒ์ดํ† ์น˜๋กœ ๋ฐฐ์šฐ๋Š” ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ>(ํ•œ๋น›๋ฏธ๋””์–ด, 2021)์˜ ์†Œ์Šค ์ฝ”๋“œ๋ฅผ ์œ„ํ•œ ์ €์žฅ

github.com

 

1. ์„ฑ์”จ ๋ฐ์ดํ„ฐ์…‹

  • 18๊ฐœ ๊ตญ์ ์˜ ์„ฑ์”จ 10,000๊ฐœ๋ฅผ ๋ชจ์€ ์„ฑ์”จ ๋ฐ์ดํ„ฐ์…‹
  • ๋งค์šธ ๋ถˆ๊ท ํ˜•ํ•จ
    • ์ตœ์ƒ์œ„ ํด๋ž˜์Šค 3๊ฐœ๊ฐ€ ๋ฐ์ดํ„ฐ์˜ 60%๋ฅผ ์ฐจ์ง€
    • 27%๊ฐ€ ์˜์–ด, 21%๊ฐ€ ๋Ÿฌ์‹œ์•„์–ด, 14%๊ฐ€ ์•„๋ž์–ด
    • ๋‚˜๋จธ์ง€ 15๊ฐœ๊ตญ์˜ ๋นˆ๋„๋Š” ๊ณ„์† ๊ฐ์†Œ
    • ์–ธ์–ด ์ž์ฒด์œผ ์†์„ฑ์ด๊ธฐ๋„ ํ•จ (๋งŽ์ด ์‚ฌ์šฉํ•˜๋Š” ์–ธ์–ด์ผ์ˆ˜๋ก ๋งŽ์„ ์ˆ˜ ๋ฐ–์—)
  • ์ถœ์‹ ๊ตญ๊ฐ€์™€ ์„ฑ์”จ ๋งž์ถค๋ฒ• ์‚ฌ์ด์— ์˜๋ฏธ๊ฐ€ ์žˆ๊ณ  ์ง๊ด€์ ์ธ ๊ด€๊ณ„๊ฐ€ ์žˆ์Œ
    • ์ฆ‰, ๊ตญ์ ๊ณผ ๊ด€๊ณ„๊ฐ€ ์žˆ๋Š” ์„ฑ์”จ๊ฐ€ ์กด์žฌ

 

2. ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ

์•„๋ž˜์™€ ๊ฐ™์ด ์ฒ˜๋ฆฌํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ–ˆ๋‹ค. ํ•ด๋‹น ์ฝ”๋“œ์—์„œ๋Š” ์•„๋ž˜ ๋ฐฉ์‹์„ ๊ฑฐ์ณ์ ธ์„œ ๋ฏธ๋ฆฌ ์ฒ˜๋ฆฌ๋œ ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค. ๋ฐ์ดํ„ฐ๋ฅผ ๋‚˜๋ˆ„๋Š” ์ฝ”๋“œ๋Š” ์•„๋ž˜

  1. ๋ถˆ๊ท ํ˜• ์ค„์ด๊ธฐ 
    • ์›๋ณธ ๋ฐ์ดํ„ฐ์…‹์€ 70% ์ด์ƒ์ด ๋Ÿฌ์‹œ์•„ ์ด๋ฆ„
    • ์ƒ˜ํ”Œ๋ง์ด ํŽธํ–ฅ๋˜์—ˆ๊ฑฐ๋‚˜ ๋Ÿฌ์‹œ์•„์— ๊ณ ์œ ํ•œ ์„ฑ์”จ๊ฐ€ ๋งŽ๊ธฐ ๋•Œ๋ฌธ์œผ๋กœ ์ถ”์ •
    • ๋Ÿฌ์‹œ์•„ ์„ฑ์”จ์˜ ๋ถ€๋ถ„์ง‘ํ•ฉ์„ ๋žœ๋คํ•˜๊ฒŒ ์„ ํƒํ•˜์—ฌ ํŽธ์ค‘๋œ ํด๋ž˜์Šค๋ฅผ "์„œ๋ธŒ์ƒ˜ํ”Œ๋ง" ํ•ด์คŒ
  2. ๊ตญ์ ์„ ๊ธฐ๋ฐ˜์œผ๋กœ ๋ชจ์•„ 3์„ธํŠธ๋กœ split
    • train = 70%
    • validation = 15%
    • test = 15%
  3. ์„ธํŠธ๊ฐ„ ๋ ˆ์ด๋ธ” ๋ถ„ํฌ๋ฅผ ๊ณ ๋ฅด๊ฒŒ ์œ ์ง€ ์‹œํ‚ด

โ–ถ ๊ฐ„๋‹จํ•˜๊ฒŒ ๋ฐ์ดํ„ฐ ์‚ดํŽด๋ณด๊ธฐ

# Read raw data
surnames = pd.read_csv(args.raw_dataset_csv, header=0)
surnames.head()

  • surnames ๋ฐ์ดํ„ฐ์…‹์˜ ํ˜•ํƒœ๋Š” ์ด๋ ‡๊ฒŒ ์ƒ๊ฒผ๋‹ค.
# Unique classes
set(surnames.nationality)

 

  • 18๊ฐœ์˜ ๊ตญ๊ฐ€๋Š” ์œ„์™€ ๊ฐ™๋‹ค.
# Write split data to file
final_surnames = pd.DataFrame(final_list)
final_surnames.split.value_counts()

  • ๋ฐ์ดํ„ฐ์˜ ๋ถ„ํฌ๋Š” 70: 15: 15 ๋กœ ๋‚˜๋ˆ„์–ด์ ธ ์žˆ๋Š” ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.

 

3. import

ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ๋ชจ๋‘ ์ž„ํฌํŠธ ํ•ด์ฃผ์—ˆ๋‹ค.

from argparse import Namespace
from collections import Counter
import json
import os
import string

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import tqdm

 

4. Custom Dataset ๋งŒ๋“ค๊ธฐ (SurnameDataset)

  • ํŒŒ์ดํ† ์น˜์˜ Dataset ์ƒ์†
  • Dataset์ด ๊ตฌ์กฐ๋ฅผ ๋”ฐ๋ฆ„
    • __len__ ๋ฉ”์„œ๋“œ : ๋ฐ์ดํ„ฐ์…‹์˜ ๊ธธ์ด๋ฅผ ๋ฐ˜ํ™˜ํ•œ๋‹ค.
    • __getitem__ ๋ฉ”์„œ๋“œ : ๋ฒกํ„ฐ๋กœ ๋ฐ”๊พผ ์„ฑ์”จ์™€ ๊ตญ์ ์— ํ•ด๋‹นํ•˜๋Š” ์ธ๋ฑ์Šค๋ฅผ ๋ฐ˜ํ™˜ํ•œ๋‹ค.

โ–ถ SurnameDataset ์˜ class ์„ค๋ช…

class ์ฝ”๋“œ๊ฐ€ ๋„ˆ๋ฌด ๊ธธ์–ด์„œ ์ดํ•ดํ•˜๊ธฐ ์œ„ํ•ด ํ•˜๋‚˜์”ฉ ๋œฏ์–ด์„œ ์„ค๋ช…ํ•  ์˜ˆ์ •์ด๋‹ค. ๋งˆ์ง€๋ง‰์— <๋”๋ณด๊ธฐ>๋ฅผ ํด๋ฆญํ•˜๋ฉด class๋กœ ์ด์–ด์ง„ ์ „์ฒด ์ฝ”๋“œ๋ฅผ ๋ณผ ์ˆ˜ ์žˆ๋‹ค. 

1. __init__ ๋ฉ”์„œ๋“œ

    def __init__(self, surname_df, vectorizer):
        """
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            surname_df (pandas.DataFrame): ๋ฐ์ดํ„ฐ์…‹
            vectorizer (SurnameVectorizer): SurnameVectorizer ๊ฐ์ฒด
        """
        self.surname_df = surname_df
        self._vectorizer = vectorizer

        self.train_df = self.surname_df[self.surname_df.split=='train']
        self.train_size = len(self.train_df)

        self.val_df = self.surname_df[self.surname_df.split=='val']
        self.validation_size = len(self.val_df)

        self.test_df = self.surname_df[self.surname_df.split=='test']
        self.test_size = len(self.test_df)

        self._lookup_dict = {'train': (self.train_df, self.train_size),
                             'val': (self.val_df, self.validation_size),
                             'test': (self.test_df, self.test_size)}

        self.set_split('train')
        
        # ํด๋ž˜์Šค ๊ฐ€์ค‘์น˜ ๊ตฌํ•˜๊ธฐ
        # surnaem_df์˜ nationality ๊ฐฏ์ˆ˜๋ฅผ dict ํ˜•ํƒœ๋กœ ๋„ฃ์–ด์คŒ
        class_counts = surname_df.nationality.value_counts().to_dict()
        def sort_key(item):
            # _vectorizer = SurnameVectorizer ๊ฐ์ฒด๋ฅผ ์˜๋ฏธํ•จ
            # nationality_vocab = ๊ตญ์ ์„ ์ •์ˆ˜์— ๋งคํ•‘ํ•˜๋Š” Vocabulary ๊ฐ์ฒด (class SurnameVectorizer)
            # look_up_token = ํ† ํฐ์— ๋Œ€์‘ํ•˜๋Š” ์ธ๋ฑ์Šค๋ฅผ ์ถ”์ถœํ•˜๋Š” ๋ฉ”์„œ๋“œ (class vocabulary)
            return self._vectorizer.nationality_vocab.lookup_token(item[0])
        # ์ธ๋ฑ์Šค๋ฅผ key๋กœ ์ง€์ •ํ•ด์„œ dic์—์„œ ์Œ์œผ๋กœ ๊บผ๋‚ด์™€์„œ ์ •๋ ฌ
        sorted_counts = sorted(class_counts.items(), key=sort_key)
        # ๋นˆ๋„์ˆ˜๋ฅผ ๊ธฐ์ค€์œผ๋กœ ํด๋ž˜์Šค์˜ ๊ฐ€์ค‘์น˜ ๊ตฌํ•˜๊ธฐ
        frequencies = [count for _, count in sorted_counts]
        self.class_weights = 1.0 / torch.tensor(frequencies, dtype=torch.float32)
  • ์ƒ์„ฑ์ž ๋ฉ”์„œ๋“œ
  • ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ surname_df (๋ฐ์ดํ„ฐ์…‹)๊ณผ vectorizer(SurnameVectorizer ๊ฐ์ฒด) ๋ฅผ ๋ฐ›์Œ
  • train_df / train_size , val_df / val_size , test_df / test_size ์„ ์ •์˜
  • ๋‚ด๋ถ€์— def sort_key(item) ๋ฉ”์„œ๋“œ ์กด์žฌ
    • ํด๋ž˜์Šค ๊ฐ€์ค‘์น˜๋ฅผ ๊ตฌํ•˜๊ธฐ ์œ„ํ•ด์„œ ์‚ฌ์šฉ

2. load_dataset_and_make_vectorizer ๋ฉ”์„œ๋“œ 

    @classmethod
    def load_dataset_and_make_vectorizer(cls, surname_csv):
        # cls = SurnameDataset class ๋ฅผ ๋ฐ›์Œ
        """ ๋ฐ์ดํ„ฐ์…‹์„ ๋กœ๋“œํ•˜๊ณ  ์ƒˆ๋กœ์šด SurnameVectorizer ๊ฐ์ฒด๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            review_csv (str): ๋ฐ์ดํ„ฐ์…‹์˜ ์œ„์น˜
        ๋ฐ˜ํ™˜๊ฐ’:
            SurnameDataset์˜ ์ธ์Šคํ„ด์Šค
        """
        surname_df = pd.read_csv(surname_csv)
        train_surname_df = surname_df[surname_df.split=='train']
        # SurnameVectorizer.from_data ์˜ ๋ฐ˜ํ™˜๊ฐ’ : surname_vocab, nationality_vocab
        return cls(surname_df, SurnameVectorizer.from_dataframe(train_surname_df))
  • ๋ฐ˜ํ™˜๊ฐ’์ด SurnameDataset์˜ ์ธ์Šคํ„ด์Šค
    • ์•„๋ž˜์—์„œ ๋ณด๋ฉด ๋ฐ์ดํ„ฐ์…‹์„ ์ƒˆ๋กœ ๋งŒ๋“ค ๋•Œ ํ•ด๋‹น ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค.
    • dataset = SurnameDataset.load_dataset_and_make_vectorizer(train_df)
    • ํ•ด๋‹น ๋ฐ˜ํ™˜๊ฐ’์€ surname_vocab / nationality_vocab ์ด๋ฏ€๋กœ ๋ฐ์ดํ„ฐ์…‹์ด ๋งŒ๋“ค์–ด์ง€๋Š” ๊ฒƒ์ด๋‹ค.
    • ๊ทธ๋ž˜์„œ SurnameDataset์˜ ์ธ์Šคํ„ด์Šค๋ผ๊ณ  ํ•œ ๊ฒƒ์ด๋‹ค.

3. load_dataset_and_load_vectorizer ๋ฉ”์„œ๋“œ

    @classmethod
    def load_dataset_and_load_vectorizer(cls, surname_csv, vectorizer_filepath):
        """๋ฐ์ดํ„ฐ์…‹์„ ๋กœ๋“œํ•˜๊ณ  ์ƒˆ๋กœ์šด SurnameVectorizer ๊ฐ์ฒด๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
        ์บ์‹œ๋œ SurnameVectorizer ๊ฐ์ฒด๋ฅผ ์žฌ์‚ฌ์šฉํ•  ๋•Œ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            surname_csv (str): ๋ฐ์ดํ„ฐ์…‹์˜ ์œ„์น˜
            vectorizer_filepath (str): SurnameVectorizer ๊ฐ์ฒด์˜ ์ €์žฅ ์œ„์น˜
        ๋ฐ˜ํ™˜๊ฐ’:
            SurnameDataset์˜ ์ธ์Šคํ„ด์Šค
        """
        surname_df = pd.read_csv(surname_csv)
        # load_vectorizer_only ๋ฉ”์„œ๋“œ : ํŒŒ์ผ์—์„œ surnamevectorizer ๊ฐ์ฒด๋ฅผ ๋กœ๋“œํ•˜๋Š” ๋ฉ”์„œ๋“œ
        # ๋ฐ˜ํ™˜๊ฐ’์ด SurnameVectorizer.from_serializable(json.load(fp))
        # SurnameVectorizer.from_serializable : surnameVectorizer ์˜ ์ธ์Šคํ„ด์Šค๋ฅผ ๋ฐ˜ํ™˜
        vectorizer = cls.load_vectorizer_only(vectorizer_filepath)
        # ๊ฒฐ๊ตญ vectorizer ๊ฐ€ Dataset์˜ ์ธ์Šคํ„ด์Šค๊ฐ€ ๋˜์–ด ๋ฐ˜ํ™˜๋จ.
        return cls(surname_df, vectorizer)
  • vectorizer_filepath๋ฅผ ๋ฐ›์•„์„œ SurnameVectorizer ์—์„œ surname_vocab / nationality_vocab ์ด ๋˜์–ด SurnameVectorizer ์˜ ์ธ์Šคํ„ด์Šค๊ฐ€ ๋˜๊ณ  ๋งˆ์ง€๋ง‰ return ๋‹จ๊ณ„์—์„œ ๊ฒฐ๊ตญ SurnameDataset ์˜ ์ธ์Šคํ„ด์Šค๋กœ ๋ฐ˜ํ™˜๋œ๋‹ค. ( cls๊ฐ€ SurnameDataset์„ ์˜๋ฏธ)

4. load_vectorizer_only / save_vectoriezer ๋ฉ”์„œ๋“œ

  • load_vectorizer_only ๋Š” ํŒŒ์ผ์—์„œ vectoriezer ๊ฐ์ฒด๋ฅผ ๋กœ๋“œํ•˜๋Š” ๋ฉ”์„œ๋“œ๋กœ, surname_vocab / nationality_vocab์„ ์ •์˜ํ•˜๋Š” ๊ฐ์ฒด๋ผ๊ณ  ์ƒ๊ฐ
  • save_vectorier ์€ ํŒŒ์ผ์—์„œ ์ง๋ ฌํ™”๋œ dict๋กœ vocab์„ ๋งŒ๋“ค์–ด์ฃผ๋Š” ๋ฉ”์„œ๋“œ๋ผ๊ณ  ์ดํ•ดํ•˜๋ฉด ๋จ
    @staticmethod
    def load_vectorizer_only(vectorizer_filepath):
        """ํŒŒ์ผ์—์„œ SurnameVectorizer ๊ฐ์ฒด๋ฅผ ๋กœ๋“œํ•˜๋Š” ์ •์  ๋ฉ”์„œ๋“œ
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            vectorizer_filepath (str): ์ง๋ ฌํ™”๋œ SurnameVectorizer ๊ฐ์ฒด์˜ ์œ„์น˜
        ๋ฐ˜ํ™˜๊ฐ’:
            SurnameVectorizer์˜ ์ธ์Šคํ„ด์Šค
        """
        with open(vectorizer_filepath) as fp:
            # from_serializable : filepath์—์„œ vocab์„ ๋ฐ›์•„์™€์„œ vocab์œผ๋กœ ๋งŒ๋“ค์–ด์ฃผ๋ฉด์„œ vectorizer ์˜ ์ธ์Šคํ„ด์Šค๋กœ ๋ฐ˜ํ™˜
            # from_serializable๋Š” ์ง๋ ฌํ™”๋œ ๊ฒƒ๋“ค์„ vocab์œผ๋กœ ๋ฐ˜ํ™˜ํ•˜๋Š” ๋ฉ”์„œ๋“œ๋ผ๊ณ  ์ดํ•ดํ•˜๋ฉด ๋จ
            # ๋ฐ˜ํ™˜ : surname_vocab / nationality_vocab
            return SurnameVectorizer.from_serializable(json.load(fp))

    def save_vectorizer(self, vectorizer_filepath):
        """ SurnameVectorizer ๊ฐ์ฒด๋ฅผ json ํ˜•ํƒœ๋กœ ๋””์Šคํฌ์— ์ €์žฅํ•ฉ๋‹ˆ๋‹ค
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            vectorizer_filepath (str): SurnameVectorizer ๊ฐ์ฒด์˜ ์ €์žฅ ์œ„์น˜
        """
        with open(vectorizer_filepath, "w") as fp:
            # to_serializable : ์ง๋ ฌํ™”๋œ dic์œผ๋กœ vocab์„ ๋งŒ๋“ค์–ด์ฃผ๋Š” ๋ฉ”์„œ๋“œ๋ผ๊ณ  ์ดํ•ดํ•˜๋ฉด ๋จ
            json.dump(self._vectorizer.to_serializable(), fp)

 

5. get_vectorizer / set_split

    def get_vectorizer(self):
        """ ๋ฒกํ„ฐ ๋ณ€ํ™˜ ๊ฐ์ฒด๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค """
        return self._vectorizer

    def set_split(self, split="train"):
        """ ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„์— ์žˆ๋Š” ์—ด์„ ์‚ฌ์šฉํ•ด ๋ถ„ํ•  ์„ธํŠธ๋ฅผ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค 
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            split (str): "train", "val", "test" ์ค‘ ํ•˜๋‚˜
        """
        self._target_split = split
        self._target_df, self._target_size = self._lookup_dict[split]

 

6. __len__ / __getitem__

  • Dataset์˜ ๊ธฐ๋ณธ ๊ตฌ์„ฑ โ˜…
  • __getitem__ : ๋ฒกํ„ฐ๋กœ ๋ฐ”๊พผ ์„ฑ์”จ์™€ ๊ตญ์ ์— ํ•ด๋‹นํ•˜๋Š” ์ธ๋ฑ์Šค๋ฅด ๋ฐ˜ํ™˜ํ•ด์ค€๋‹ค.
    def __len__(self):
        return self._target_size

    def __getitem__(self, index):
        """ ํŒŒ์ดํ† ์น˜ ๋ฐ์ดํ„ฐ์…‹์˜ ์ฃผ์š” ์ง„์ž… ๋ฉ”์„œ๋“œ
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            index (int): ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ์˜ ์ธ๋ฑ์Šค
        ๋ฐ˜ํ™˜๊ฐ’:
            ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ์˜ ํŠน์„ฑ(x_surname)๊ณผ ๋ ˆ์ด๋ธ”(y_nationality)๋กœ ์ด๋ฃจ์–ด์ง„ ๋”•์…”๋„ˆ๋ฆฌ
        """
        row = self._target_df.iloc[index]

        surname_vector = \
            self._vectorizer.vectorize(row.surname)

        nationality_index = \
            self._vectorizer.nationality_vocab.lookup_token(row.nationality)

        return {'x_surname': surname_vector,
                'y_nationality': nationality_index}

 

7. get_num_batches ๋ฉ”์„œ๋“œ

    def get_num_batches(self, batch_size):
        """ ๋ฐฐ์น˜ ํฌ๊ธฐ๊ฐ€ ์ฃผ์–ด์ง€๋ฉด ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ๋งŒ๋“ค ์ˆ˜ ์žˆ๋Š” ๋ฐฐ์น˜ ๊ฐœ์ˆ˜๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            batch_size (int)
        ๋ฐ˜ํ™˜๊ฐ’:
            ๋ฐฐ์น˜ ๊ฐœ์ˆ˜
        """
        return len(self) // batch_size

 

โ–ผ  SurnameDataset ์ „์ฒด ์ฝ”๋“œ ๋ณด๊ธฐ

๋”๋ณด๊ธฐ
class SurnameDataset(Dataset):
    def __init__(self, surname_df, vectorizer):
        """
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            surname_df (pandas.DataFrame): ๋ฐ์ดํ„ฐ์…‹
            vectorizer (SurnameVectorizer): SurnameVectorizer ๊ฐ์ฒด
        """
        self.surname_df = surname_df
        self._vectorizer = vectorizer

        self.train_df = self.surname_df[self.surname_df.split=='train']
        self.train_size = len(self.train_df)

        self.val_df = self.surname_df[self.surname_df.split=='val']
        self.validation_size = len(self.val_df)

        self.test_df = self.surname_df[self.surname_df.split=='test']
        self.test_size = len(self.test_df)

        self._lookup_dict = {'train': (self.train_df, self.train_size),
                             'val': (self.val_df, self.validation_size),
                             'test': (self.test_df, self.test_size)}

        self.set_split('train')
        
        # ํด๋ž˜์Šค ๊ฐ€์ค‘์น˜ ๊ตฌํ•˜๊ธฐ
        # surnaem_df์˜ nationality ๊ฐฏ์ˆ˜๋ฅผ dict ํ˜•ํƒœ๋กœ ๋„ฃ์–ด์คŒ
        class_counts = surname_df.nationality.value_counts().to_dict()
        def sort_key(item):
            # _vectorizer = SurnameVectorizer ๊ฐ์ฒด๋ฅผ ์˜๋ฏธํ•จ
            # nationality_vocab = ๊ตญ์ ์„ ์ •์ˆ˜์— ๋งคํ•‘ํ•˜๋Š” Vocabulary ๊ฐ์ฒด (class SurnameVectorizer)
            # look_up_token = ํ† ํฐ์— ๋Œ€์‘ํ•˜๋Š” ์ธ๋ฑ์Šค๋ฅผ ์ถ”์ถœํ•˜๋Š” ๋ฉ”์„œ๋“œ (class vocabulary)
            return self._vectorizer.nationality_vocab.lookup_token(item[0])
        # ์ธ๋ฑ์Šค๋ฅผ key๋กœ ์ง€์ •ํ•ด์„œ dic์—์„œ ์Œ์œผ๋กœ ๊บผ๋‚ด์™€์„œ ์ •๋ ฌ
        sorted_counts = sorted(class_counts.items(), key=sort_key)
        # ๋นˆ๋„์ˆ˜๋ฅผ ๊ธฐ์ค€์œผ๋กœ ํด๋ž˜์Šค์˜ ๊ฐ€์ค‘์น˜ ๊ตฌํ•˜๊ธฐ
        frequencies = [count for _, count in sorted_counts]
        self.class_weights = 1.0 / torch.tensor(frequencies, dtype=torch.float32)

    @classmethod
    def load_dataset_and_make_vectorizer(cls, surname_csv):
        # cls = SurnameDataset class ๋ฅผ ๋ฐ›์Œ
        """ ๋ฐ์ดํ„ฐ์…‹์„ ๋กœ๋“œํ•˜๊ณ  ์ƒˆ๋กœ์šด SurnameVectorizer ๊ฐ์ฒด๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            review_csv (str): ๋ฐ์ดํ„ฐ์…‹์˜ ์œ„์น˜
        ๋ฐ˜ํ™˜๊ฐ’:
            SurnameDataset์˜ ์ธ์Šคํ„ด์Šค
        """
        surname_df = pd.read_csv(surname_csv)
        train_surname_df = surname_df[surname_df.split=='train']
        # SurnameVectorizer.from_data ์˜ ๋ฐ˜ํ™˜๊ฐ’ : surname_vocab, nationality_vocab
        # ์•„๋ž˜์—์„œ ๋ณด๋ฉด ๋ฐ์ดํ„ฐ์…‹์„ ์ƒˆ๋กœ ๋งŒ๋“ค๋•Œ ํ•ด๋‹น ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜๋Š”๋ฐ, 
        # dataset = surnameDataset.load_dataset_and_make_vectorizer(train) ์ด๋ ‡๊ฒŒ ํ•จ์œผ๋กœ์จ
        # surname_vocab, nationality vocab์ด ์ƒ์„ฑ๋˜๋Š” ๊ฑฐ๋‹ˆ๊นŒ ๋ฐ์ดํ„ฐ์…‹์ด ๋งŒ๋“ค์–ด์ง€๋Š” ๊ฑฐ์ž„. ์ฆ‰ dataset์˜ ์ธ์Šคํ„ด์Šค
        return cls(surname_df, SurnameVectorizer.from_dataframe(train_surname_df))

    @classmethod
    def load_dataset_and_load_vectorizer(cls, surname_csv, vectorizer_filepath):
        """๋ฐ์ดํ„ฐ์…‹์„ ๋กœ๋“œํ•˜๊ณ  ์ƒˆ๋กœ์šด SurnameVectorizer ๊ฐ์ฒด๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
        ์บ์‹œ๋œ SurnameVectorizer ๊ฐ์ฒด๋ฅผ ์žฌ์‚ฌ์šฉํ•  ๋•Œ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            surname_csv (str): ๋ฐ์ดํ„ฐ์…‹์˜ ์œ„์น˜
            vectorizer_filepath (str): SurnameVectorizer ๊ฐ์ฒด์˜ ์ €์žฅ ์œ„์น˜
        ๋ฐ˜ํ™˜๊ฐ’:
            SurnameDataset์˜ ์ธ์Šคํ„ด์Šค
        """
        surname_df = pd.read_csv(surname_csv)
        # load_vectorizer_only ๋ฉ”์„œ๋“œ : ํŒŒ์ผ์—์„œ surnamevectorizer ๊ฐ์ฒด๋ฅผ ๋กœ๋“œํ•˜๋Š” ๋ฉ”์„œ๋“œ
        # ๋ฐ˜ํ™˜๊ฐ’์ด SurnameVectorizer.from_serializable(json.load(fp))
        # SurnameVectorizer.from_serializable : surnameVectorizer ์˜ ์ธ์Šคํ„ด์Šค๋ฅผ ๋ฐ˜ํ™˜
        vectorizer = cls.load_vectorizer_only(vectorizer_filepath)
        # ๊ฒฐ๊ตญ vectorizer ๊ฐ€ Dataset์˜ ์ธ์Šคํ„ด์Šค๊ฐ€ ๋˜์–ด ๋ฐ˜ํ™˜๋จ.
        return cls(surname_df, vectorizer)

    @staticmethod
    def load_vectorizer_only(vectorizer_filepath):
        """ํŒŒ์ผ์—์„œ SurnameVectorizer ๊ฐ์ฒด๋ฅผ ๋กœ๋“œํ•˜๋Š” ์ •์  ๋ฉ”์„œ๋“œ
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            vectorizer_filepath (str): ์ง๋ ฌํ™”๋œ SurnameVectorizer ๊ฐ์ฒด์˜ ์œ„์น˜
        ๋ฐ˜ํ™˜๊ฐ’:
            SurnameVectorizer์˜ ์ธ์Šคํ„ด์Šค
        """
        with open(vectorizer_filepath) as fp:
            # from_serializable : filepath์—์„œ vocab์„ ๋ฐ›์•„์™€์„œ vocab์œผ๋กœ ๋งŒ๋“ค์–ด์ฃผ๋ฉด์„œ vectorizer ์˜ ์ธ์Šคํ„ด์Šค๋กœ ๋ฐ˜ํ™˜
            # from_serializable๋Š” ์ง๋ ฌํ™”๋œ ๊ฒƒ๋“ค์„ vocab์œผ๋กœ ๋ฐ˜ํ™˜ํ•˜๋Š” ๋ฉ”์„œ๋“œ๋ผ๊ณ  ์ดํ•ดํ•˜๋ฉด ๋จ
            # ๋ฐ˜ํ™˜ : surname_vocab / nationality_vocab
            return SurnameVectorizer.from_serializable(json.load(fp))

    def save_vectorizer(self, vectorizer_filepath):
        """ SurnameVectorizer ๊ฐ์ฒด๋ฅผ json ํ˜•ํƒœ๋กœ ๋””์Šคํฌ์— ์ €์žฅํ•ฉ๋‹ˆ๋‹ค
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            vectorizer_filepath (str): SurnameVectorizer ๊ฐ์ฒด์˜ ์ €์žฅ ์œ„์น˜
        """
        with open(vectorizer_filepath, "w") as fp:
            # to_serializable : ์ง๋ ฌํ™”๋œ dic์œผ๋กœ vocab์„ ๋งŒ๋“ค์–ด์ฃผ๋Š” ๋ฉ”์„œ๋“œ๋ผ๊ณ  ์ดํ•ดํ•˜๋ฉด ๋จ
            json.dump(self._vectorizer.to_serializable(), fp)

    def get_vectorizer(self):
        """ ๋ฒกํ„ฐ ๋ณ€ํ™˜ ๊ฐ์ฒด๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค """
        return self._vectorizer

    def set_split(self, split="train"):
        """ ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„์— ์žˆ๋Š” ์—ด์„ ์‚ฌ์šฉํ•ด ๋ถ„ํ•  ์„ธํŠธ๋ฅผ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค 
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            split (str): "train", "val", "test" ์ค‘ ํ•˜๋‚˜
        """
        self._target_split = split
        self._target_df, self._target_size = self._lookup_dict[split]

    def __len__(self):
        return self._target_size

    def __getitem__(self, index):
        """ ํŒŒ์ดํ† ์น˜ ๋ฐ์ดํ„ฐ์…‹์˜ ์ฃผ์š” ์ง„์ž… ๋ฉ”์„œ๋“œ
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            index (int): ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ์˜ ์ธ๋ฑ์Šค
        ๋ฐ˜ํ™˜๊ฐ’:
            ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ์˜ ํŠน์„ฑ(x_surname)๊ณผ ๋ ˆ์ด๋ธ”(y_nationality)๋กœ ์ด๋ฃจ์–ด์ง„ ๋”•์…”๋„ˆ๋ฆฌ
        """
        row = self._target_df.iloc[index]

        surname_vector = \
            self._vectorizer.vectorize(row.surname)

        nationality_index = \
            self._vectorizer.nationality_vocab.lookup_token(row.nationality)

        return {'x_surname': surname_vector,
                'y_nationality': nationality_index}

    def get_num_batches(self, batch_size):
        """ ๋ฐฐ์น˜ ํฌ๊ธฐ๊ฐ€ ์ฃผ์–ด์ง€๋ฉด ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ๋งŒ๋“ค ์ˆ˜ ์žˆ๋Š” ๋ฐฐ์น˜ ๊ฐœ์ˆ˜๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            batch_size (int)
        ๋ฐ˜ํ™˜๊ฐ’:
            ๋ฐฐ์น˜ ๊ฐœ์ˆ˜
        """
        return len(self) // batch_size

    
def generate_batches(dataset, batch_size, shuffle=True,
                     drop_last=True, device="cpu"): 
    """
    ํŒŒ์ดํ† ์น˜ DataLoader๋ฅผ ๊ฐ์‹ธ๊ณ  ์žˆ๋Š” ์ œ๋„ˆ๋ ˆ์ดํ„ฐ ํ•จ์ˆ˜.
    ๊ฑฑ ํ…์„œ๋ฅผ ์ง€์ •๋œ ์žฅ์น˜๋กœ ์ด๋™ํ•ฉ๋‹ˆ๋‹ค.
    """
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size,
                            shuffle=shuffle, drop_last=drop_last)

    for data_dict in dataloader:
        out_data_dict = {}
        for name, tensor in data_dict.items():
            out_data_dict[name] = data_dict[name].to(device)
        yield out_data_dict

 

<NEXT>

์ด๋ฒˆ ํฌ์ŠคํŒ…์—์„œ๋Š” ์„ฑ์”จ ๋ฐ์ดํ„ฐ์…‹์„ ํŒŒ์ดํ† ์น˜๋กœ ์‚ฌ์šฉํ•˜๊ธฐ์œ„ํ•ด custom Dataset ํด๋ž˜์Šค๋ฅผ ๋งŒ๋“  ์ฝ”๋“œ๋ฅผ ์‚ดํŽด๋ณด์•˜๋‹ค. ๋‹ค์Œ ํฌ์ŠคํŒ…์—์„œ๋Š” ํ•ด๋‹น custom Dataset์—์„œ ํ•„์š”๋กœ ํ•˜๋Š” ์•„๋ž˜ ๋‘ ๊ฐœ์˜ ํด๋ž˜์Šค๋ฅผ ๋‹ค๋ค„๋ณด๊ฒ ๋‹ค.

  • Vocabulary class : ๋‹จ์–ด๋ฅผ ํ•ด๋‹น ์ •์ˆ˜๋กœ ๋งคํ•‘ํ•˜๋Š”๋ฐ ์‚ฌ์šฉํ•˜๋Š” ํด๋ž˜์Šค
  • Vectorizer class : vocabulary๋ฅผ ์ ์šฉํ•˜์—ฌ ์„ฑ์”จ ๋ฌธ์ž์—ด์„ ๋ฒกํ„ฐ๋กœ ๋ฐ”๊พธ๋Š” ํด๋ž˜์Šค
๋ฐ˜์‘ํ˜•