[NLP] BERT code

2 minute read

BERT code - python

이번 포스팅에서는 본격적인 Pre-trained models 설명으로 넘어가기 전에 제가 작성해보았던 BERT 임베딩 코드를 공유해볼까 합니다. 영어 같은 경우는 example이 많아서 초보자도 사용하기 어렵지 않은데, 한국어 바탕 임베딩은 생각보다 코드 관련 자료가 많지 않아서 개인적으로 어려움을 좀 겪었습니다. KoBERT와 BERT multilingual version pre-trained model을 사용해 임베딩하는 코드이며, 튜토리얼 등을 참고하여 만들긴했지만 실습용으로 가볍게 짠 코드라 다소 지저분(ㅜㅜ)할 수는 있습니다. 감안하고 봐주세요!

## Import packages
import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import torch
from transformers import BertModel, DistilBertModel, BertTokenizer
from tokenization_kobert import KoBertTokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model

## KoBERT (처음 써보는 분은 KoBERT git clone 필수!!)
class Embed_Kor:
    def __init__(self, pretrain_ver='monologg/kobert'):
        self.ver = pretrain_ver
        self.tokenizer = KoBertTokenizer.from_pretrained(self.ver)
        self.model, self.vocab = get_pytorch_kobert_model()
        
    ## Tokenization
    def tokenization_kor(self, sent):
        marked_text = '[CLS]' + sent + '[SEP]'
        tokenized_text = self.tokenizer.tokenize(marked_text)
        input_ids = self.tokenizer.convert_tokens_to_ids(tokenized_text)
        att_mask = self.tokenizer.get_special_tokens_mask(input_ids[1:-1])
        type_ids = self.tokenizer.create_token_type_ids_from_sequences(input_ids[1:-1])
        
        ## 512가 max_token 수라서, 그걸 넘어가면 앞 512개만 자르기
        ## 여기서는 편의상 앞을 잘랐지만, 보통 중간 부분을 잘라서 많이 쓴다고 합니다
        if len(input_ids) > 512 :
          input_ids = input_ids[len(input_ids)-512:]
          att_mask = att_mask[len(att_mask)-512:]
          type_ids = type_ids[len(type_ids)-512:]

        return input_ids, att_mask, type_ids
            
    ## Embedding
    def _transformer_kor(self, sent):
        input_ids, att_mask, type_ids = self.tokenization_kor(sent)
        input_ids = torch.LongTensor([input_ids])
        att_mask = torch.LongTensor([att_mask])
        type_ids = torch.LongTensor([type_ids])
        sequence_output, pooled_output = self.model(input_ids, att_mask, type_ids)
        final_embed = sequence_output[0]
        return final_embed

    if __name__=="__main__":
        print('Transformer-Korean ver. ready')

## BERT Multilingual ver.
class Embed_multi:
    def __init__(self, pretrain_ver='bert-base-multilingual-uncased', unit='sentence'):
        self.ver = pretrain_ver
        self.unit = unit
        self.tokenizer = BertTokenizer.from_pretrained(self.ver)
        self.model = BertModel.from_pretrained(self.ver, output_hidden_states = True)

    def tokenization_multi(self, sent):
        marked_text = '[CLS]' + sent + '[SEP]'
        tokenized_text = self.tokenizer.tokenize(marked_text)
        indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
        segment_ids = [1] * len(tokenized_text)
        
        ## 여기서는 중간 부분을 잘라서 써보았네용
        if len(indexed_tokens) > 512 :
          cut_end = (len(indexed_tokens)-512)//2
          cut_start = len(indexed_tokens) - 512 - cut_end
          indexed_tokens = indexed_tokens[cut_start:(len(indexed_tokens)-cut_end)]
          segment_ids = segment_ids[cut_start:(len(segment_ids)-cut_end)]

        return indexed_tokens, segment_ids

    def _transformer_multi(self, sent):
        indexed_tokens, segment_ids = self.tokenization_multi(sent)
        tokens_tensor = torch.tensor([indexed_tokens])
        segments_tensors = torch.tensor([segment_ids])
        with torch.no_grad():
            outputs = self.model(tokens_tensor, segments_tensors)
        hidden_states = outputs[2]
        
        ## Last 4 layers SUM 방식으로 Word embedding
        ## 여러 방식 종류에 대해서는 이전 How does BERT work? 포스팅 뒷부분 참조
        if self.unit == 'word':
            token_embeddings = torch.stack(hidden_states, dim=0)
            token_embeddings = torch.squeeze(token_embeddings, dim=1)
            token_embeddings = token_embeddings.permute(1,0,2)
            token_vecs_sum = []
            for token in token_embeddings:
                sum_vec = torch.sum((token[-4:]), dim=0)
                token_vecs_sum.append(sum_vec)
            return token_vecs_sum
                
        ## Sentence embedding
        elif self.unit == 'sentence':
            token_vecs = hidden_states[-2][0]
            sent_embed = torch.mean(token_vecs, dim=0)
            return sent_embed

    if __name__=="__main__":
        print('Transformer-Multilingual ver. ready')

위 코드는 한 개 단위 문장 임베딩을 기본으로 두고 짠 코드입니다. 그래서 일반화 하기에는 무리가 좀 있기는 합니다만 (직접 활용하려면 코드 일부를 고쳐야할 수 있습니당), KoBERT 및 BERT multilingual version과 같은 Pre-trained model을 어떻게 사용하여 Embedding을 뽑아내는 것인지 그 대강의 흐름을 이해하는 데 도움이 되었으면 좋겠습니다.

더 자세한 내용과 기타 환경 세팅은 이 곳 이나 이 곳 을 참고해주세요!