pyhaya’s diary

機械学習系の記事をメインで書きます

BERTによる日本語文書分類

自分の備忘録的な感じで雑にまとめます。やることは単純に自然言語からその文書のカテゴリーを分類するタスクです。

テーマ的には新しいものでもなく、すでにたくさんの良記事がある分野なので、詳しく説明するというよりやったことと結果を淡々と書いていくという感じにしてます。

実験環境


pyproject.toml

パッケージ管理にはRyeを使用。

[project]
name = "project"
version = "0.1.0"
description = "Add a short description here"
dependencies = [
    "transformers~=4.29.2",
    "pandas~=2.0.2",  # 無くても多分動く
    "numpy~=1.24.3",
    "torch~=2.0.1",
    "fugashi~=1.2.1",
    "ipadic~=1.0.0",
    "tqdm~=4.65.0",
    "scikit-learn~=1.2.2",
    "unidic-lite~=1.0.8",
    "sentencepiece~=0.1.99",
    "datasets~=2.12.0",
]
readme = "README.md"
requires-python = ">= 3.8"

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.rye]
managed = true
dev-dependencies = ["flake8~=5.0.4", "black~=23.3.0", "ipython~=8.12.2"]

[tool.hatch.metadata]
allow-direct-references = true


データセットの確認

全部で7,367本のニュース記事が入ったデータセット。カテゴリは全9種類。使えるfeatureは以下の5つ

  • url
  • title
  • date
  • content
  • category

今回は軽く試したいだけなのでtitleからcategoryを当てに行くというタスク設定にしている。どのカテゴリもおおよそ690記事ほどが収録されている(category ID = 4 だけは400記事ほどと少ない)。データの偏りはあまり気にせずaccuracyで評価したら良さそう。

実験コード

from datasets import load_dataset
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.modeling_utils import PreTrainedModel
from transformers import (
    AutoTokenizer,
    BertTokenizer,
    BertForSequenceClassification,
    DebertaV2Tokenizer,
    DebertaForSequenceClassification,
    RobertaTokenizer,
    RobertaForSequenceClassification,
    DistilBertForSequenceClassification,
    AlbertForSequenceClassification,
)
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# PRETRAINED = "cl-tohoku/bert-large-japanese-v2"
# PRETRAINED = "ku-nlp/deberta-v2-large-japanese"
# PRETRAINED = "rinna/japanese-roberta-base"
# PRETRAINED = "line-corporation/line-distilbert-base-japanese"
PRETRAINED = "ken11/albert-base-japanese-v1"
DATASET = "shunk031/livedoor-news-corpus"
CLASS_NUM = 9
NUM_EPOCHS = 10
BATCH_SIZE = 128
LEARNING_RATE = 1e-5


def train(
    dataset: dict[str, Dataset],
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    optimizer: torch.optim.Optimizer,
    device: str = "cpu",
):
    model.to(device)

    train_loader = DataLoader(dataset["train"], batch_size=BATCH_SIZE, shuffle=True)
    valid_loader = DataLoader(
        dataset["validation"], batch_size=BATCH_SIZE, shuffle=True
    )

    best_loss = 1e9
    for epoch in range(NUM_EPOCHS):
        model.train()

        for data in tqdm(train_loader):
            inputs = tokenizer(data["title"], padding=True, return_tensors="pt").to(
                device
            )
            category = data["category"].to(device)

            optimizer.zero_grad()

            outputs = model(**inputs, labels=category)
            loss = outputs.loss
            loss.backward()
            optimizer.step()

        print("Evaluate train data:")
        _ = eval_loss_acc(train_loader, model, tokenizer, device)
        print("Evaluate valid data:")
        val_loss, _ = eval_loss_acc(valid_loader, model, tokenizer, device)

        if val_loss < best_loss:
            print(f"Loss updated {best_loss:.4f} → {val_loss:.4f}. Save best model")
            best_loss = val_loss
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                f"./outputs/bert_checkpoint_best.pt",
            )
        else:
            print(f"Loss did not improve from {best_loss:.4f}")


def eval_loss_acc(
    loader: DataLoader,
    model: BertForSequenceClassification,
    tokenizer: BertTokenizer,
    device: str = "cpu",
) -> tuple[float, float]:
    model.to(device)
    model.eval()

    loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for data in tqdm(loader):
            inputs = tokenizer(data["title"], padding=True, return_tensors="pt").to(
                device
            )
            category = data["category"].to(device)

            outputs = model(**inputs, labels=category)
            loss += outputs.loss
            correct += (outputs.logits.argmax(dim=-1) == category).sum().item()
            total += len(category)

    print(f"loss: {loss / len(loader):.4f}, accuracy: {correct / total:.4f}")

    return loss / len(loader), correct / total


if __name__ == "__main__":
    dataset = load_dataset(
        DATASET,
        train_ratio=0.8,
        val_ratio=0.1,
        test_ratio=0.1,
        random_state=42,
        shuffle=True,
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    tokenizer = AutoTokenizer.from_pretrained(
        PRETRAINED, use_fast=False, trust_remote_code=True
    )
    # ここを色々変えていく
    model = AlbertForSequenceClassification.from_pretrained(
        PRETRAINED, num_labels=CLASS_NUM, ignore_mismatched_sizes=True
    )
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

    print("Start training...")
    train(dataset, model, tokenizer, optimizer, device)

    print("Evaluate test data:")
    model.load_state_dict(
        torch.load("./outputs/bert_checkpoint_best.pt")["model_state_dict"]
    )

    test_loader = DataLoader(dataset["test"], batch_size=BATCH_SIZE, shuffle=True)
    eval_loss_acc(test_loader, model, tokenizer, device)

結果

いくつかのモデルを試して、タイトルだけでどれくらいの性能が出るのかぱっと確認する。すべての実験で共通している部分は、

  • epoch数:10
  • learning rate:1e-5

結果は以下の表の通り。ここでlossと言っているのは "CrossEntropyLoss" を指している。
github.com

model batch size loss accuracy time per epoch
cl-tohoku/bert-base-japanese-v2 128 0.5129 0.8410 1 min
cl-tohoku/bert-large-japanese-v2 32 0.4356 0.8628 3 min
rinna/japanese-roberta-base 128 0.4787 0.8655 45 sec
line-corporation/line-distilbert-base-japanese 128 0.3512 0.8872 25 sec
ken11/albert-base-japanese-v1 128 0.4384 0.8668 40 sec
ku-nlp/deberta-v2-base-japanese 32 0.6503 0.7976 1 min
ku-nlp/deberta-v2-large-japanese 16 0.7890 0.7405 3.5 min

感想

transformersライブラリが便利&使ったデータセットが小さかったのでパッと性能を比較して感覚を掴めたのは良かったと思ってます。あと、DeBERTaV2であまり性能が出なかったのは意外で、2epoch目くらいでサチったのでデータセットの大きさがモデルの大きさに比べてあまりに小さすぎたのかななどと想像してます。(ハイパラチューンもろくにしてないのでそこらへんの可能性も十分ありますが)

もしよかったらTwitterのフォローもお願いします!
twitter.com