Potato
์•ˆ๋…•ํ•˜์„ธ์š”, ๊ฐ์žก๋‹ˆ๋‹ค?๐Ÿฅ” ^___^ ๐Ÿ˜บ github ๋ฐ”๋กœ๊ฐ€๊ธฐ ๐Ÿ‘‰๐Ÿป

AI study/์ž์—ฐ์–ด ์ฒ˜๋ฆฌ (NLP)

[NLP] ๋ ˆ์Šคํ† ๋ž‘ ๋ฆฌ๋ทฐ ๊ฐ์„ฑ ๋ถ„๋ฅ˜ํ•˜๊ธฐ (3) (feat.ํŒŒ์ดํ† ์น˜๋กœ ๋ฐฐ์šฐ๋Š” ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ) - ํ›ˆ๋ จ ๋ฐ ํ‰๊ฐ€, ์ถ”๋ก , ๋ถ„์„

๊ฐ์ž ๐Ÿฅ” 2021. 7. 22. 19:58
๋ฐ˜์‘ํ˜•

-- ๋ณธ ํฌ์ŠคํŒ…์€ ํŒŒ์ดํ† ์น˜๋กœ ๋ฐฐ์šฐ๋Š” ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ (ํ•œ๋น›๋ฏธ๋””์–ด) ์ฑ…์„ ์ฐธ๊ณ ํ•ด์„œ ์ž‘์„ฑ๋œ ๊ธ€์ž…๋‹ˆ๋‹ค.
-- ์†Œ์Šค์ฝ”๋“œ ) https://github.com/rickiepark/nlp-with-pytorch

 

GitHub - rickiepark/nlp-with-pytorch: <ํŒŒ์ดํ† ์น˜๋กœ ๋ฐฐ์šฐ๋Š” ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ>(ํ•œ๋น›๋ฏธ๋””์–ด, 2021)์˜ ์†Œ์Šค ์ฝ”๋“œ๋ฅผ

<ํŒŒ์ดํ† ์น˜๋กœ ๋ฐฐ์šฐ๋Š” ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ>(ํ•œ๋น›๋ฏธ๋””์–ด, 2021)์˜ ์†Œ์Šค ์ฝ”๋“œ๋ฅผ ์œ„ํ•œ ์ €์žฅ์†Œ์ž…๋‹ˆ๋‹ค. - GitHub - rickiepark/nlp-with-pytorch: <ํŒŒ์ดํ† ์น˜๋กœ ๋ฐฐ์šฐ๋Š” ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ>(ํ•œ๋น›๋ฏธ๋””์–ด, 2021)์˜ ์†Œ์Šค ์ฝ”๋“œ๋ฅผ ์œ„ํ•œ ์ €์žฅ

github.com

 

<PREVIOUS> 

https://didu-story.tistory.com/83?category=952805 

 

[NLP] ๋ ˆ์Šคํ† ๋ž‘ ๋ฆฌ๋ทฐ ๊ฐ์„ฑ ๋ถ„๋ฅ˜ํ•˜๊ธฐ (1) (feat.ํŒŒ์ดํ† ์น˜๋กœ ๋ฐฐ์šฐ๋Š” ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ)

-- ๋ณธ ํฌ์ŠคํŒ…์€ ํŒŒ์ดํ† ์น˜๋กœ ๋ฐฐ์šฐ๋Š” ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ (ํ•œ๋น›๋ฏธ๋””์–ด) ์ฑ…์„ ์ฐธ๊ณ ํ•ด์„œ ์ž‘์„ฑ๋œ ๊ธ€์ž…๋‹ˆ๋‹ค. -- ์†Œ์Šค์ฝ”๋“œ ) https://github.com/rickiepark/nlp-with-pytorch (ํ•œ๋น›๋ฏธ๋””์–ด, 2021)์˜ ์†Œ์Šค ์ฝ”๋“œ๋ฅผ ์œ„ํ•œ ์ €์žฅ์†Œ์ž…..

didu-story.tistory.com

https://didu-story.tistory.com/86?category=952805 

 

 

โ–ถ ๋ ˆ์Šคํ† ๋ž‘ ๋ฆฌ๋ทฐ ๊ฐ์„ฑ ๋ถ„๋ฅ˜ํ•˜๊ธฐ

์•ž์˜ (1) , (2) ํฌ์ŠคํŒ…์—์„œ ๋ฐ์ดํ„ฐ๋ฅผ ์ „์ฒ˜๋ฆฌํ•˜๊ณ , ๋ฐ์ดํ„ฐ๋ฅผ ํŒŒ์ดํ† ์น˜์—์„œ ํ™œ์šฉ ๊ฐ€๋Šฅํ•˜๊ฒŒ ๋งŒ๋“ค์–ด์ฃผ๋Š” ์—ฌ๋Ÿฌ๊ฐ€์ง€ ํด๋ž˜์Šค์— ๋Œ€ํ•ด์„œ ์‚ดํŽด๋ณด์•˜๋‹ค. (์—ฌ๊ธฐ ์ดํ•ดํ•˜๋Š”๋ฐ ๊ฐœ์˜ค๋ž˜๊ฑธ๋ฆผ ;;...)

์ด์ œ ๊ฐ„๋‹จํ•œ ํผ์…‰ํŠธ๋ก  ๋ชจ๋ธ์„ ํ™œ์šฉํ•ด์„œ ๋ณธ๊ฒฉ์ ์ธ ๊ฐ์„ฑ๋ถ„๋ฅ˜๋ฅผ ์ง„ํ–‰ํ•ด๋ณด์ž.

 

1. ํผ์…‰ํŠธ๋ก  ๋ถ„๋ฅ˜๊ธฐ ์ •์˜ํ•˜๊ธฐ

  • ReviewClassifier ํด๋ž˜์Šค๋Š” ํŒŒ์ดํ† ์น˜์˜ Module ํด๋ž˜์Šค๋ฅผ ์ƒ์†ํ•˜๊ณ  ๋‹จ์ผ ์ถœ๋ ฅ์„ ๋งŒ๋“œ๋Š” Linear์ธต ํ•˜๋‚˜๋ฅผ ์ƒ์„ฑํ•˜๋„๋ก ํ•  ๊ฒƒ์ด๋‹ค.
  • ๋งˆ์ง€๋ง‰์—๋Š” ๋น„์„ ํ˜• ํ™œ์„ฑํ™” ํ•จ์ˆ˜๋กœ ์‹œ๊ทธ๋ชจ์ด๋“œ ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•  ๊ฒƒ์ด๋‹ค.
  • forward() ๋ฉ”์„œ๋“œ
    • ์„ ํƒ์ ์œผ๋กœ ์‹œ๊ทธ๋ชจ์ด๋“œ ํ•จ์ˆ˜๋ฅผ ์ ์šฉํ•˜๋Š” ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ๋งŒ๋“ฆ
    • ์ด์ง„๋ถ„๋ฅ˜ ๋ฌธ์ œ์—์„œ๋Š” ์ด์ง„ ํฌ๋กœ์Šคํ”ผ ์—”ํŠธ๋กœํ”ผ ์†์‹ค(BCELoss)๊ฐ€ ๊ฐ€์žฅ ์ ์ ˆํ•˜์ง€๋งŒ, ์‹œ๊ทธ๋ชจ์ด๋“œ์™€ ์†์‹คํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•  ๊ฒฝ์šฐ์— ์ˆ˜์น˜ ์•ˆ์ •์„ฑ์˜ ์ด์Šˆ๊ฐ€ ๋ฐœ์ƒํ•œ๋‹ค๊ณ  ํ•œ๋‹ค.
    • ํŒŒ์ดํ† ์น˜๋Š” ์‹œ๊ทธ๋ชจ์ด๋“œ ์—†์ด ๊ฐ„ํŽธํ•˜๊ฒŒ ์‚ฌ์šฉํ•  ์ˆ˜์žˆ๊ณ , ์ˆ˜์น˜์ ์œผ๋กœ ์•ˆ์ •๋œ ๊ณ„์‚ฐ์„ ์œ„ํ•œ BSEWithLogitsLoss()๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜์žˆ๋‹ค.
class ReviewClassifier(nn.Module):
    """ ๊ฐ„๋‹จํ•œ ํผ์…‰ํŠธ๋ก  ๊ธฐ๋ฐ˜ ๋ถ„๋ฅ˜๊ธฐ """
    def __init__(self, num_features):
        """
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            num_features (int): ์ž…๋ ฅ ํŠน์„ฑ ๋ฒกํŠธ์˜ ํฌ๊ธฐ
        """
        super(ReviewClassifier, self).__init__()
        self.fc1 = nn.Linear(in_features=num_features, 
                             out_features=1)

    def forward(self, x_in, apply_sigmoid=False):
        """ ๋ถ„๋ฅ˜๊ธฐ์˜ ์ •๋ฐฉํ–ฅ ๊ณ„์‚ฐ
        
        ๋งค๊ฐœ๋ณ€์ˆ˜:
            x_in (torch.Tensor): ์ž…๋ ฅ ๋ฐ์ดํ„ฐ ํ…์„œ 
                x_in.shape๋Š” (batch, num_features)์ž…๋‹ˆ๋‹ค.
            apply_sigmoid (bool): ์‹œ๊ทธ๋ชจ์ด๋“œ ํ™œ์„ฑํ™” ํ•จ์ˆ˜๋ฅผ ์œ„ํ•œ ํ”Œ๋ž˜๊ทธ
                ํฌ๋กœ์Šค-์—”ํŠธ๋กœํ”ผ ์†์‹ค์„ ์‚ฌ์šฉํ•˜๋ ค๋ฉด False๋กœ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค
        ๋ฐ˜ํ™˜๊ฐ’:
            ๊ฒฐ๊ณผ ํ…์„œ. tensor.shape์€ (batch,)์ž…๋‹ˆ๋‹ค.
        """
        y_out = self.fc1(x_in).squeeze()
        ## ์„ ํƒ์ ์œผ๋กœ ์‹œ๊ทธ๋ชจ์ด๋“œ ํ•จ์ˆ˜๋ฅผ ์ ์šฉํ•˜๊ธฐ ์œ„ํ•œ ๊ณผ์ •
        if apply_sigmoid:
            y_out = torch.sigmoid(y_out)
        return y_out

 

2. ๋ชจ๋ธ ํ›ˆ๋ จ

2.1 ํผ์…‰ํŠธ๋ก  ๋ถ„๋ฅ˜๊ธฐ๋ฅผ ์œ„ํ•œ ํ•˜์ดํผ ํŒŒ๋ผ๋ฏธํ„ฐ์™€ ํ”„๋กœ๊ทธ๋žจ ์˜ต์…˜์„ ์„ค์ •

from argparse import Namespace

args = Namespace(
    # ๋‚ ์งœ์™€ ๊ฒฝ๋กœ ์ •๋ณด
    frequency_cutoff=25,
    model_state_file='model.pth',
    review_csv='data/yelp/reviews_with_splits_lite.csv',
    # review_csv='data/yelp/reviews_with_splits_full.csv',
    save_dir='model_storage/ch3/yelp/',
    vectorizer_file='vectorizer.json',
    # ๋ชจ๋ธ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ์—†์Œ
    # ํ›ˆ๋ จ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ
    batch_size=128,
    early_stopping_criteria=5,
    learning_rate=0.001,
    num_epochs=100,
    seed=1337,
    # ์‹คํ–‰ ์˜ต์…˜
    catch_keyboard_interrupt=True,
    cuda=True,
    expand_filepaths_to_save_dir=True,
    reload_from_files=False,
)

 

2.2 ๋ฐ์ดํ„ฐ์…‹, ๋ชจ๋ธ, ์†์‹ค, ์˜ตํ‹ฐ๋งˆ์ด์ €, ํ›ˆ๋ จ์ƒํƒœ ๋”•์…”๋„ˆ๋ฆฌ ์ƒ์„ฑ

import torch.optim as optim

def make_train_state(args):
  return {'epoch_index': 0,
          'train_loss': [],
          'train_acc': [],
          'val_loss': [],
          'val_acc': [],
          'test_loss': -1,
          'test_acc': -1}
train_state = make_train_state(args)

if not torch.cuda_is_available():
  args.cuda = False
args.device = torch.device("cuda" if args.cuda else "cpu")

# ๋ฐ์ดํ„ฐ์…‹๊ณผ Vectorizer
dataset = ReviewDataset.load_dataset_and_make_vectorizer(args.review_csv)
vectorizer = dataset.get_vecgtorizer()

# ๋ชจ๋ธ
classifier = ReviewClassifier(num_features=len(vectorizer.review_vocab))
classifier = classifier.to(args.device)

# ์†์‹คํ•จ์ˆ˜์™€ ์˜ตํ‹ฐ๋งˆ์ด์ €
loss_func = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(classifier.parameters(), lr = args.learning_rage)
  • args๊ฐ์ฒด๋ฅผ ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ ๋ฐ›์•„์„œ ํ›ˆ๋ จ ์ƒํƒœ๋ฅผ ์ดˆ๊ธฐํ™”ํ•˜๋Š” ํ•จ์ˆ˜ ์ž‘์„ฑ (make_train_state)
  • ๋ฐ์ดํ„ฐ์…‹๊ณผ ๋ชจ๋ธ์ƒ์„ฑ
    • ReviewDataset ํด๋ž˜์Šค์—์„œ vectorizer๋ฅผ ์ƒ์„ฑํ•ด์ค€๋‹ค.
  • ์†์‹คํ•จ์ˆ˜๋Š” BCEWithLogistitsLoss() ์‚ฌ์šฉ
  • ์˜ตํ‹ฐ๋งˆ์ด์ €๋Š” Adam

 

2.3 ํ›ˆ๋ จ ๋ฐ˜๋ณต

    # ์—ํฌํฌ ํšŸ์ˆ˜๋งŒํผ for๋ฌธ์„ ๋ฐ˜๋ณตํ•  ๊ฒƒ์ด๋‹ค. (args์—์„œ ์ •์˜)
    for epoch_index in range(args.num_epochs):
        train_state['epoch_index'] = epoch_index

        # ํ›ˆ๋ จ ์„ธํŠธ์— ๋Œ€ํ•œ ์ˆœํšŒ
        # ํ›ˆ๋ จ ์„ธํŠธ์™€ ๋ฐฐ์น˜ ์ œ๋„ˆ๋ ˆ์ดํ„ฐ ์ค€๋น„, ์†์‹ค๊ณผ ์ •ํ™•๋„๋ฅผ 0์œผ๋กœ ์„ค์ •
        dataset.set_split('train')
        batch_generator = generate_batches(dataset, 
                                           batch_size=args.batch_size, 
                                           device=args.device)
        running_loss = 0.0
        running_acc = 0.0
        classifier.train()

        for batch_index, batch_dict in enumerate(batch_generator):
            # ํ›ˆ๋ จ ๊ณผ์ •์€ 5๋‹จ๊ณ„๋กœ ์ด๋ฃจ์–ด์ง‘๋‹ˆ๋‹ค

            # --------------------------------------
            # ๋‹จ๊ณ„ 1. ๊ทธ๋ ˆ์ด๋””์–ธํŠธ๋ฅผ 0์œผ๋กœ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค
            optimizer.zero_grad()

            # ๋‹จ๊ณ„ 2. ์ถœ๋ ฅ์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค
            y_pred = classifier(x_in=batch_dict['x_data'].float())

            # ๋‹จ๊ณ„ 3. ์†์‹ค์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค
            loss = loss_func(y_pred, batch_dict['y_target'].float())
            loss_t = loss.item()
            running_loss += (loss_t - running_loss) / (batch_index + 1)

            # ๋‹จ๊ณ„ 4. ์†์‹ค์„ ์‚ฌ์šฉํ•ด ๊ทธ๋ ˆ์ด๋””์–ธํŠธ๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค
            loss.backward()

            # ๋‹จ๊ณ„ 5. ์˜ตํ‹ฐ๋งˆ์ด์ €๋กœ ๊ฐ€์ค‘์น˜๋ฅผ ์—…๋ฐ์ดํŠธํ•ฉ๋‹ˆ๋‹ค
            optimizer.step()
            # -----------------------------------------
            
            # ์ •ํ™•๋„๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค
            acc_t = compute_accuracy(y_pred, batch_dict['y_target'])
            running_acc += (acc_t - running_acc) / (batch_index + 1)

            # ์ง„ํ–‰ ๋ฐ” ์—…๋ฐ์ดํŠธ
            train_bar.set_postfix(loss=running_loss, 
                                  acc=running_acc, 
                                  epoch=epoch_index)
            train_bar.update()

        train_state['train_loss'].append(running_loss)
        train_state['train_acc'].append(running_acc)

        # ๊ฒ€์ฆ ์„ธํŠธ์— ๋Œ€ํ•œ ์ˆœํšŒ

        # ๊ฒ€์ฆ ์„ธํŠธ์™€ ๋ฐฐ์น˜ ์ œ๋„ˆ๋ ˆ์ดํ„ฐ ์ค€๋น„, ์†์‹ค๊ณผ ์ •ํ™•๋„๋ฅผ 0์œผ๋กœ ์„ค์ •
        dataset.set_split('val')
        batch_generator = generate_batches(dataset, 
                                           batch_size=args.batch_size, 
                                           device=args.device)
        running_loss = 0.
        running_acc = 0.
        classifier.eval()

        for batch_index, batch_dict in enumerate(batch_generator):

            # ๋‹จ๊ณ„ 1. ์ถœ๋ ฅ์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค
            y_pred = classifier(x_in=batch_dict['x_data'].float())

            # ๋‹จ๊ณ„ 2. ์†์‹ค์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค
            loss = loss_func(y_pred, batch_dict['y_target'].float())
            loss_t = loss.item()
            running_loss += (loss_t - running_loss) / (batch_index + 1)

            # ๋‹จ๊ณ„ 3. ์ •ํ™•๋„๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค
            acc_t = compute_accuracy(y_pred, batch_dict['y_target'])
            running_acc += (acc_t - running_acc) / (batch_index + 1)
            
            val_bar.set_postfix(loss=running_loss, 
                                acc=running_acc, 
                                epoch=epoch_index)
            val_bar.update()

        train_state['val_loss'].append(running_loss)
        train_state['val_acc'].append(running_acc)
  • ๋‚ด๋ถ€ for loop : ๋ฏธ๋‹ˆ ๋ฐฐ์น˜์— ๋Œ€ํ•ด์„œ ๋ฐ˜๋ณต ์ˆ˜ํ–‰
    • ๋ฏธ๋‹ˆ๋ฐฐ์น˜ : ์˜ˆ์ธก - ์†์‹ค๊ณ„์‚ฐ - ์ •ํ™•๋„ ๊ณ„์‚ฐ 
  • ์™ธ๋ถ€ for loop : ๋‚ด๋ถ€๋ฐ˜๋ณต๋ฌธ์„ ์—ฌ๋Ÿฌ๋ฒˆ ๋ฐ˜๋ณตํ•œ๋‹ค. ๋‚ด๋ถ€ ๋ฐ˜๋ณต๋ฌธ์—์„œ ๋ฏธ๋‹ˆ๋ฐฐ์น˜๋งˆ๋‹ค ์†์‹ค์„ ๊ณ„์‚ฐํ•˜๊ณ  ์˜ตํ‹ฐ๋งˆ์ด์ €๊ฐ€ ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์—…๋ฐ์ดํŠธ ํ•ด์ค€๋‹ค. 

 

 

2.4 ํ‰๊ฐ€, ์ถ”๋ก , ๋ถ„์„

2.4.1 ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ๋กœ ํ‰๊ฐ€ํ•˜๊ธฐ

  • ์‚ฌ์šฉ ๋ฐ์ดํ„ฐ๋ฅผ val ๋Œ€์‹  test๋กœ ์ง€์ •
dataset.set_split('test')
batch_generator = generate_batches(dataset, 
                                    batch_size=args.batch_size, 
                                    device=args.device)
running_loss = 0.
running_acc = 0.
classifier.eval()

for batch_index, batch_dict in enumerate(batch_generator):

    # ๋‹จ๊ณ„ 1. ์ถœ๋ ฅ์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค
    y_pred = classifier(x_in=batch_dict['x_data'].float())

    # ๋‹จ๊ณ„ 2. ์†์‹ค์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค
    loss = loss_func(y_pred, batch_dict['y_target'].float())
    loss_batch = loss.item()
    running_loss += (loss_batch - running_loss) / (batch_index + 1)

    # ๋‹จ๊ณ„ 3. ์ •ํ™•๋„๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค
    acc_batch = compute_accuracy(y_pred, batch_dict['y_target'])
    running_acc += (acc_batch - running_acc) / (batch_index + 1)
    
train_state['test_loss'].append(running_loss)
train_state['test_acc'].append(running_acc)

print("Test loss: {:.3f}".format(train_state['test_loss']))
print("Test Accuracy: {:.2f}".format(train_state['test_acc']))

 

2.4.2 ์ƒˆ๋กœ์šด ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ ์ถ”๋ก ํ•˜์—ฌ ๋ถ„๋ฅ˜ํ•˜๊ธฐ

# ์ •๊ทœ์‹์„ ์‚ฌ์šฉํ•˜์—ฌ text๋ฅผ ํ† ํฐํ™”
def preprocess_text(text):
    text = text.lower()
    text = re.sub(r"([.,!?])", r" \1 ", text)
    text = re.sub(r"[^a-zA-Z.,!?]+", r" ", text)
    return text
def predict_rating(review, classifier, vectorizer, decision_threshold=0.5):
    """ ๋ฆฌ๋ทฐ ์ ์ˆ˜ ์˜ˆ์ธกํ•˜๊ธฐ
    
    ๋งค๊ฐœ๋ณ€์ˆ˜:
        review (str): ๋ฆฌ๋ทฐ ํ…์ŠคํŠธ
        classifier (ReviewClassifier): ํ›ˆ๋ จ๋œ ๋ชจ๋ธ
        vectorizer (ReviewVectorizer): Vectorizer ๊ฐ์ฒด
        decision_threshold (float): ํด๋ž˜์Šค๋ฅผ ๋‚˜๋ˆŒ ๊ฒฐ์ • ๊ฒฝ๊ณ„
    """
    review = preprocess_text(review)
    
    vectorized_review = torch.tensor(vectorizer.vectorize(review))
    result = classifier(vectorized_review.view(1, -1))
    
    probability_value = torch.sigmoid(result).item()
    index = 1
    if probability_value < decision_threshold:
        index = 0

    return vectorizer.rating_vocab.lookup_index(index)
test_review = "this is a pretty awesome book"

#์œ„์˜ ๋ฌธ์žฅ์„ ์˜ˆ์ธกํ•ด์„œ ๋ถ„๋ฅ˜ํ•ด๋ณด๊ธฐ
classifier = classifier.cpu()
prediction = predict_rating(test_review, classifier, vectorizer, decision_threshold=0.5)
print("{} -> {}".format(test_review, prediction))

 

2.4.3 ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ๋ถ„์„

ํ›ˆ๋ จ์ด ๋๋‚œ ๋’ค ๋ชจ๋ธ์ด ์ž˜ ์ž‘๋™ํ•˜๋Š”์ง€ ์•Œ์•„๋ณด๊ธฐ ์œ„ํ•ด ๊ฐ€์ค‘์น˜๋ฅผ ๋ถ„์„ํ•ด๋ณด์ž.

# ๊ฐ€์ค‘์น˜ ์ •๋ ฌ
fc1_weights = classifier.fc1.weight.detach()[0]
_, indices = torch.sort(fc1_weights, dim=0, descending=True)
indices = indices.numpy().tolist()

# ๊ธ์ •์ ์ธ ์ƒ์œ„ 20๊ฐœ ๋‹จ์–ด
print("๊ธ์ • ๋ฆฌ๋ทฐ์— ์˜ํ–ฅ์„ ๋ฏธ์น˜๋Š” ๋‹จ์–ด:")
print("--------------------------------------")
for i in range(20):
    print(vectorizer.review_vocab.lookup_index(indices[i]))
    
print("====\n\n\n")

# ๋ถ€์ •์ ์ธ ์ƒ์œ„ 20๊ฐœ ๋‹จ์–ด
print("๋ถ€์ • ๋ฆฌ๋ทฐ์— ์˜ํ–ฅ์„ ๋ฏธ์น˜๋Š” ๋‹จ์–ด:")
print("--------------------------------------")
indices.reverse()
for i in range(20):
    print(vectorizer.review_vocab.lookup_index(indices[i]))

 

 

๋ฐ˜์‘ํ˜•