pyhaya’s diary

機械学習系の記事をメインで書きます

BERTによる日本語文書分類

自分の備忘録的な感じで雑にまとめます。やることは単純に自然言語からその文書のカテゴリーを分類するタスクです。

テーマ的には新しいものでもなく、すでにたくさんの良記事がある分野なので、詳しく説明するというよりやったことと結果を淡々と書いていくという感じにしてます。

実験環境


pyproject.toml

パッケージ管理にはRyeを使用。

[project]
name = "project"
version = "0.1.0"
description = "Add a short description here"
dependencies = [
    "transformers~=4.29.2",
    "pandas~=2.0.2",  # 無くても多分動く
    "numpy~=1.24.3",
    "torch~=2.0.1",
    "fugashi~=1.2.1",
    "ipadic~=1.0.0",
    "tqdm~=4.65.0",
    "scikit-learn~=1.2.2",
    "unidic-lite~=1.0.8",
    "sentencepiece~=0.1.99",
    "datasets~=2.12.0",
]
readme = "README.md"
requires-python = ">= 3.8"

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.rye]
managed = true
dev-dependencies = ["flake8~=5.0.4", "black~=23.3.0", "ipython~=8.12.2"]

[tool.hatch.metadata]
allow-direct-references = true


データセットの確認

全部で7,367本のニュース記事が入ったデータセット。カテゴリは全9種類。使えるfeatureは以下の5つ

  • url
  • title
  • date
  • content
  • category

今回は軽く試したいだけなのでtitleからcategoryを当てに行くというタスク設定にしている。どのカテゴリもおおよそ690記事ほどが収録されている(category ID = 4 だけは400記事ほどと少ない)。データの偏りはあまり気にせずaccuracyで評価したら良さそう。

実験コード

from datasets import load_dataset
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.modeling_utils import PreTrainedModel
from transformers import (
    AutoTokenizer,
    BertTokenizer,
    BertForSequenceClassification,
    DebertaV2Tokenizer,
    DebertaForSequenceClassification,
    RobertaTokenizer,
    RobertaForSequenceClassification,
    DistilBertForSequenceClassification,
    AlbertForSequenceClassification,
)
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# PRETRAINED = "cl-tohoku/bert-large-japanese-v2"
# PRETRAINED = "ku-nlp/deberta-v2-large-japanese"
# PRETRAINED = "rinna/japanese-roberta-base"
# PRETRAINED = "line-corporation/line-distilbert-base-japanese"
PRETRAINED = "ken11/albert-base-japanese-v1"
DATASET = "shunk031/livedoor-news-corpus"
CLASS_NUM = 9
NUM_EPOCHS = 10
BATCH_SIZE = 128
LEARNING_RATE = 1e-5


def train(
    dataset: dict[str, Dataset],
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    optimizer: torch.optim.Optimizer,
    device: str = "cpu",
):
    model.to(device)

    train_loader = DataLoader(dataset["train"], batch_size=BATCH_SIZE, shuffle=True)
    valid_loader = DataLoader(
        dataset["validation"], batch_size=BATCH_SIZE, shuffle=True
    )

    best_loss = 1e9
    for epoch in range(NUM_EPOCHS):
        model.train()

        for data in tqdm(train_loader):
            inputs = tokenizer(data["title"], padding=True, return_tensors="pt").to(
                device
            )
            category = data["category"].to(device)

            optimizer.zero_grad()

            outputs = model(**inputs, labels=category)
            loss = outputs.loss
            loss.backward()
            optimizer.step()

        print("Evaluate train data:")
        _ = eval_loss_acc(train_loader, model, tokenizer, device)
        print("Evaluate valid data:")
        val_loss, _ = eval_loss_acc(valid_loader, model, tokenizer, device)

        if val_loss < best_loss:
            print(f"Loss updated {best_loss:.4f} → {val_loss:.4f}. Save best model")
            best_loss = val_loss
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                f"./outputs/bert_checkpoint_best.pt",
            )
        else:
            print(f"Loss did not improve from {best_loss:.4f}")


def eval_loss_acc(
    loader: DataLoader,
    model: BertForSequenceClassification,
    tokenizer: BertTokenizer,
    device: str = "cpu",
) -> tuple[float, float]:
    model.to(device)
    model.eval()

    loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for data in tqdm(loader):
            inputs = tokenizer(data["title"], padding=True, return_tensors="pt").to(
                device
            )
            category = data["category"].to(device)

            outputs = model(**inputs, labels=category)
            loss += outputs.loss
            correct += (outputs.logits.argmax(dim=-1) == category).sum().item()
            total += len(category)

    print(f"loss: {loss / len(loader):.4f}, accuracy: {correct / total:.4f}")

    return loss / len(loader), correct / total


if __name__ == "__main__":
    dataset = load_dataset(
        DATASET,
        train_ratio=0.8,
        val_ratio=0.1,
        test_ratio=0.1,
        random_state=42,
        shuffle=True,
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    tokenizer = AutoTokenizer.from_pretrained(
        PRETRAINED, use_fast=False, trust_remote_code=True
    )
    # ここを色々変えていく
    model = AlbertForSequenceClassification.from_pretrained(
        PRETRAINED, num_labels=CLASS_NUM, ignore_mismatched_sizes=True
    )
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

    print("Start training...")
    train(dataset, model, tokenizer, optimizer, device)

    print("Evaluate test data:")
    model.load_state_dict(
        torch.load("./outputs/bert_checkpoint_best.pt")["model_state_dict"]
    )

    test_loader = DataLoader(dataset["test"], batch_size=BATCH_SIZE, shuffle=True)
    eval_loss_acc(test_loader, model, tokenizer, device)

結果

いくつかのモデルを試して、タイトルだけでどれくらいの性能が出るのかぱっと確認する。すべての実験で共通している部分は、

  • epoch数:10
  • learning rate:1e-5

結果は以下の表の通り。ここでlossと言っているのは "CrossEntropyLoss" を指している。
github.com

model batch size loss accuracy time per epoch
cl-tohoku/bert-base-japanese-v2 128 0.5129 0.8410 1 min
cl-tohoku/bert-large-japanese-v2 32 0.4356 0.8628 3 min
rinna/japanese-roberta-base 128 0.4787 0.8655 45 sec
line-corporation/line-distilbert-base-japanese 128 0.3512 0.8872 25 sec
ken11/albert-base-japanese-v1 128 0.4384 0.8668 40 sec
ku-nlp/deberta-v2-base-japanese 32 0.6503 0.7976 1 min
ku-nlp/deberta-v2-large-japanese 16 0.7890 0.7405 3.5 min

感想

transformersライブラリが便利&使ったデータセットが小さかったのでパッと性能を比較して感覚を掴めたのは良かったと思ってます。あと、DeBERTaV2であまり性能が出なかったのは意外で、2epoch目くらいでサチったのでデータセットの大きさがモデルの大きさに比べてあまりに小さすぎたのかななどと想像してます。(ハイパラチューンもろくにしてないのでそこらへんの可能性も十分ありますが)

もしよかったらTwitterのフォローもお願いします!
twitter.com

新卒データサイエンティストが 1 年目にやったこと

私は 2022 年 4 月にウォンテッドリーに新卒として入社しました。私は大学時代は実は物理学を専攻していたため、仕事は最初は慣れないことばかりで苦労しました。そこで、新卒 2 年目に入って少し落ち着いたタイミングで 1 年間の業務内容を振り返り、この記事を執筆することにしました。

この記事の内容は技術書典14で頒布したWANTEDLY TECHBOOK 12に掲載されている記事になります。TECHBOOK 12には他にも良い記事が沢山ありますので気になった方はぜひ!
techbookfest.org

1 年間で何をやってきたか

私はこの 1 年間、メインの仕事である推薦システムの改善はもちろん、推薦システムの改善を行うための知識を学ぶために様々なことに取り組んできました。メインタスクの説明の前にそれらについて軽く触れておくと、社内の機械学習勉強会に毎週参加したり、国内外の学会(JSAI、RecSys、DEIM)に参加してブログを執筆したりしていました(興味のある方は「RecSys2022 Wantedly」と検索してみてください。)。これらの取り組みは進歩の早い機械学習という領域で最新の知見を得てキャッチアップしていくことを1つの目的としていて、大学で情報科学を専攻していなかった私にとっては知識を増やす上で非常に有意義でした。

社内勉強会の話はまた後で触れるとして、私がやってきたメインの仕事に話を戻します。私はこの1年でやってきた仕事は、ざっくりとまとめると以下のようになります。

  • 仮説検証によるユーザ課題の定義
  • 実装
  • テスト(オフラインテスト・オンラインテスト)

これらは独立の作業ではなく、以下のように1つの施策の中でグルグルと繰り返していきます。

ユーザが持っている課題感についてなんとなくわかっている状態からはじめて、まずそのざっくりしたものを仮説検証によって解像度を上げていき、具体的な問題に落とし込みます。このフェーズでユーザがどんな問題になぜ困っているのかを理解できると、課題に対するソリューションをある程度絞ることができます。それを実際にコード上に実装するのが次のフェーズです。この段階では検討したソリューションが本当にユーザ課題を解決できるか確信度が十分には高くないので、ログデータを使って課題の解決が可能かどうか定量的・定性的に評価します(オフラインテスト)。ここで確信度が上がればユーザに実際に出してみて検証を行います(オンラインテスト)。テストがうまく行っても行かなくても、結果からは何かしら新しいユーザに対する知見が得られるのでそこから問題に取り組み直したり新しい課題を定義したりして施策を回していきます。

データサイエンティストが課題の分析・ソリューション検討・実装・テスト・テスト分析の一連のタスクをこなすのは、世間一般のデータサイエンティストとはもしかしたら乖離があるかもしれませんが、ウォンテッドリーでは一貫してデータサイエンティストが担当することによってプロダクトを使ってくれるユーザに対する理解が進み、推薦システムの改善にもとても良い効果があると感じています。

この 1 年間で苦労したこと

1 年間様々な施策に携わってきた中で、当たり前ですが様々な困難に直面しました。その困難は新卒1年目なら誰でも直面しそうなものから、私の場合には大学時代に専攻していた分野の違いによるものまで様々ありました。ここではそれらのいくつかを書いてみたいと思います。

仮説検証でハマったこと

仮説検証はデータサイエンティストのタスクとして最も頻度が高い部類に入るものだと思います。データを使ってユーザ課題を見つけて解決したいというときに、最後まで仮説を持たずに分析を進めることはありません。というより無理です。データを見てそこから仮説を立てて徐々にユーザの持っている課題や解決方法を明らかにしていきます。具体的には私達のチームでは多くの場合以下のような手順でユーザ課題を定義するための仮説検証を進めていきました。

1. ユーザの行動をサンプリングして見たり、軽くデータを見たりしてユーザがなぜどんな問題に困っているかについて仮説を挙げていく
2. 仮説をストーリとして整理する
3. ストーリに沿ってデータを使って検証する

最初のうちは 1 番の仮説だしの部分で苦労しました。ここは会社に入って間もない人はみなそうだと思うのですが、その会社特有の知識が圧倒的に足りない(&私の場合にはドメイン知識もほぼ皆無だった)ので、仮説の質が良くなかったためでした。ここを克服する方法は、今になって振り返ってみても大して近道は無いような気がしていますが、基本的な数字や生データをしっかり見ることが重要だと思っています。ここで言っている「基本的な数字」と言っているのはプロダクトを使っている月間ユーザ数などを指していて、これを把握するだけでぐっとユーザに対する解像度が上がります。

生データを見ることの重要性は、言わずもがなといった感じなのですが最初できていなかったところでした。私は大学時代に物理の実験系の研究室に所属していて普段から生データをチェックしてデータに問題がないか確かめる癖はついていたと思っていたのですが、業務で触れる生データは研究室の実験で得られるデータと比べて桁違いに多いため、ついつい生データではなく統計量をみるという方向に逃げてしまっていました。今になっては理解していますが、分析を行うときには、様々な統計的手法を使ってなめされたデータを見るよりも、生データを見てどのセグメントにどんな特徴がありそうか定性的にでも感覚を掴むほうがその後の分析のスピードが早くなります。

また、良い仮説を出すのに苦労していたときに注目していたもう1つの観点が仮説の網羅性でした。仮説検証はなにか課題があってそれを解決するために用いる手段です。なので出した仮説によって目的が果たせそうかどうかというのはとても重要な要素です。私はこの課題を克服するために仮説を木構造に分解して捉えるということをやってみました。


実装でハマったこと

実装で一番苦労したのが、コードの変更容易性や可読性といったコード品質を気にかけながら実装を進める部分でした。情報系のバックグラウンドを持っている方ならばこのあたりは当たり前なことだと思うのですが、入社するまで私はコードの品質にあまり留意してこなかったので入社してから苦労しました。研究室時代には実験の分析用にコードを書くことがほとんどで、使い捨てだったり、使い回すにしても自分しか使わないか後輩がたまに使うかなぐらいでした。研究室のほかメンバーもバリバリコードを書くような人種ではなかったので、動けば OK、動かなかったらエクセル使えばみたいな感じでした。

ここまでで十分想像がついたと思いますが、こういうコード品質とは無縁のコードを書くのに慣れていると、本番に投入できるような品質のコードを書くとなったときにとても苦労します。私の場合にはここでハマったときに最初、なるべくコードの変更が無いように実装をすればいいのではと安直に考えて取り組んでいたのですがすぐにこの方法は良くないと感じるようになりました。この方法だとコードの品質は上がりませんが複雑さは上がることが多いためです。例として、データに対してなにか処理をするメソッド process_data があったとします。そして今、新しく別の処理をするメソッドを作りたいがその中身は process_data と似ている部分が多かったとします。このとき、コードの変更量を最小限にするというポリシーで動くと以下のようなコードが出来上がるかもしれません。

def process_data(data):
    # ...

    if isinstance(data, str):
        data = json.loads(data)

    elif isinstance(data, dict):  # ← 追加!
        # ...
        if "name" in data:
            process_name(data["name"])
    elif isinstance(data, list):  # ← 追加!
        # ...
    elif isinstance(data, str):  # ← 追加!
        # ...

明らかにコード品質を悪化させていることがわかります。では最善の実装は何かと言ったら時と場合によるというのが正しいでしょう。既存のメソッドと新しく実装するメソッドで共通処理に名前がつけられるなら切り出してもいですし、上の例だと引数の型で分岐させているので新しく実装する部分で入力の型を揃えるという方法もあると思います。process_code というメソッドの名前の抽象度が高すぎるのが原因なのでより具体的な名前をつけ直して、新しいロジック用に新しいメソッドを定義するという方法もあるでしょう。結局いちばん大事なのは思考を機械的にすることなく、コードの品質を保持する・上げるためにどうしたらいいか考え続けることで、それからは本を読んで勉強したりレビューしてもらって考えたりすることを意識的にするようしています。

テストでハマったこと

エンジニアリングの文脈でテストというと指すものが人によって様々だと思いますが、ここではオフラインテスト・オンラインテストについて書きます。オフラインテストというのは、考えた施策が実際にユーザに良い効果をもたらしそうかどうかを過去のログデータなどを使って検証するフェーズです。ここで良さそうということになったら実際にユーザに出して効果があるか検証します(オンラインテスト)。見たい効果を正しく得るにはテスト設計がとても重要ですし、テスト結果が得られたときそれを正しく解釈するにも様々な知識が必要になります。テストをする、となったときに考える必要のあることはぱっと挙げただけでも、

  • テストの対象となるユーザはどんなユーザ?
  • テスト方法はどうする?
  • テスト期間はどのくらいに設定すべき?
  • どの指標がどのような条件を満たしていたら施策の効果があったと判断する?

など多くあります。

特に苦労したのが、「テスト対象ユーザ」の決定部分でした。施策は明確なターゲットユーザのもとに行われるものですが、介入がそのターゲットユーザだけに閉じるとは限りません。例として Wantedly Visit を考えてみます。Wantedly Visit では学生も中途ユーザも同一のプラットフォームで気になる募集を探して話を聞きに行くことができるのですが、ここで学生をターゲットとして体験を良くしたい場合を考えたとします。手段は色々考えられますが、機械学習モデルに新しい特徴量を追加して学生が話しを聞きに行きたいと思える募集をもっと出すようにしようとすると、このモデルの挙動は学生に対して変化するのはもちろん、中途ユーザに対しても変化します。このような状況でテストを学生ユーザだけに限定して進めてしまうと、極端な話、学生の利用は激増して中途ユーザは誰も利用しなくなるという状態になっていても気づけず、学生の結果だけ見て本番環境全体に適用させてしまうという大事故を引き起こしかねません。

そのため、ターゲットユーザはどんなユーザか、介入の影響を受けるユーザはどんなユーザか、の両方を把握しテストの通過条件をそれぞれのユーザセグメントに対して適切に設定する必要があります。このあたりの知見は(少なくとも自分は)大学での研究活動では縁があまりなかったのと、自分の設計があっているのか間違っているのか自己判断が難しいので苦労しました。なのでテスト設計のレビューを積極的にしてもらって、フィードバックから学んでいくというのを中心に、足りない知識を本で補うというような学び方をしていました。一番参考になっていたのは、よくカバ本と呼ばれている、「A/B テスト実践ガイド」でした。A/B テスト実践ガイドでは「テスト時のランダム化単位をしっかり考えて設計することの重要性」や「テストの信頼性を上げるためのガードレールメトリクスの重要性」など、基本的ですが大事なことを多く学べました。

やってみて良かったと感じたこと

ここまでは割と課題にぶつかってそれに対してどうしてきたかという課題ドリブンな話をしてきたのですが、ここでは自分の成長のために色々試してみたことのうち、やってよかったと思っていることをいくつか紹介します。やってみてよかったことで共通しているのは何かしらの「アウトプットをすること」です。普段勉強しているとなにかとインプットが多くなると思いますが、インプットしただけだと理解したつもりでいたけど実は自分の中でちゃんと使える形で整理できていないということがよく起こります。アウトプットして第三者に見せるときには人は必ず情報を整理したり足りない情報を調べたりします。この過程でわかったつもりでいたことに気づけたりして効率的に自身の成長に繋げられたと思っています。

社内勉強会への参加

ウォンテッドリー社内には有志による様々な勉強会が開かれており、私はデータサイエンティストで開催している機械学習輪講会に毎週参加していました。この輪講会では機械学習関連のブログを読んで他のメンバーに紹介したり、論文を読んで議論したりしていました。この勉強会に参加することは、自分で勉強したり情報収集するよりもはるかに効率的に自身の知識を広げられるので参加してみてよかったと思っています。(この機械学習勉強会はオープンで興味のある方は参加もできますのでぜひ https://github.com/wantedly/machine-learning-round-table

毎週勉強会のために論文やブログを読んで発表できるようにまとめたりする作業や、他の人の発表を聞いて新しい知識を得ることはもちろんですが、私の場合にはこの勉強会をきっかけにアクションにつなげることができたりしていて参加してよかったと思っています。ある回で機械学習エンジニアの OSS コミットに関するスライドが紹介されたのですが、それをきっかけに私自身が OSS にコミットしてみたいというモチベーションがわきました。また、別の機会で自分が発表するトピックを探している際に、偶然 SPyQL という OSS を見つけました。これは Python で書かれた SQL のようなもので、CLI 上で CSV ファイルを分析するのに使うことができるツールです。コードを読んでいたら SQL をパースする部分に改善の余地があることを見つけたので勢いでコミットすることにしました。結果、私の提案はメンテナーに受け入れられて無事マージされるところまで行くことができました。更に派生して、この経験は社内 LT での発表やブログの執筆にもつながったので、これ1つで様々な経験・知識が得られたと思っています。

機械学習論文の追試

こちらも元をたどれば上に書いた社内勉強会から派生したものなのですが、最近、機械学習モデルを実際に実装してみることで読むだけではわからない新たな知見を得られるのではないかと考えて、自分で論文で提案されているモデルの実装、追試をしてみました。リポジトリはオープンにしているので興味のある方は見てみてください。

In SIGIR. 165–174.

やってみてわかったことですが、実際に自分で実装してみることで論文の中身の理解を深めることができたと感じています。論文ではよくわからなかった式がコードを読むことで意図を理解できたり、逆に論文で展開されていたロジックはコードにするとこんな感じになるのか、など相互に不足していた知識を補い合うような効果があリました。また、自分で実験環境を持てるので、論文には書いていなかったけど気になる事があったときに実際に回してみて確認することができるのもとても良かったです。特に、なかなか論文の中には書いていないが実務で使うとなったときに必要な情報である、学習時間やリソースの使用量を把握することもできたのは大きな収穫だったと思います。

次にやりたいこと

最後に、私が 1 年間データサイエンティストとしてやってきて、これから何に挑戦したいと思っているかについて少しだけ書きたいと思います。上に書いたように、この1年間でアウトプットの重要性について理解できたのですが、そのアウトプット先は社内での発表やブログに限られており、得られるフィードバックも限定的でした。なのでこれからは外部発表にも参加してどんどん自分の知見を伝えていき、そこで生まれるコミュニケーションを通じてまた新たな知見を得たり、人とのつながりを作ったりということをしていきたいと考えています。

また、実務に直接関連する部分では、自分のできる仮説検証の範囲を広げたいということを考えています。この 1 年間で私が行ってきた仮説検証は、上の例にも出したような「なぜ」に答えるためのものが大半を占めていました。これはつまり、ある程度はっきりした課題があってその解像度を上げることで解決につなげていくという仕事を任されていたことが多かったためです。今後はさらに前段の「ユーザが何に困っているのか」を明らかにするための仮説検証をどんどんやってきたいと思っています。また、これに関連して自分で考えて進めていくという点でもこれからもっとできるようにしていきたいなと思っています。

1 年目に読んで良かった本
  • リーダブルコード
  • コンサル 1 年目に学ぶこと
  • イシューからはじめよ
  • A/B テスト実践ガイド
  • 効果検証入門
  • 因果推論の科学

よかったらTwitterのフォローもお願いします!
twitter.com

論文の再現実験: Neural Graph Collaborative Filteringを実装してみた

本記事では「Neural Graph Collaborative Filtering (NGCF)」と呼ばれる、グラフ構造を使った推薦システムの手法を自分で実装して論文の再現実験をしてみたことについて書きます。近年の推薦システムの研究では、ユーザーとアイテムの特徴量だけでなく、ユーザーとアイテム間の関係性を考慮した手法が注目されています。NGCFもその一つで、ユーザーとアイテムのグラフ構造を学習することで、より高い精度で推薦を行うことができます。本記事では、まずNGCFの手法や特徴について解説し、その後に実装と再現実験の方法を紹介します。

Neural Graph Collaborative Filteringとは?

NGCFは、SIGIR'19 に採択された論文「Neural Graph Collaborative Filtering」で提案されました。

NGCFは、ユーザとアイテムの相互作用を表すグラフ構造を利用して推薦を行う機械学習モデルです。具体的に言うと、グラフ畳み込みネットワーク (Graph Convolutional Network, GCN) を用いて、ノード間の情報を交換することで、高度な特徴表現を学習することができます。このようにして学習された特徴表現を用いて、推薦が行われます。

NGCFの特徴

NGCFは、以下のような特徴を持っています。

  • ユーザとアイテムの関係をより詳細にモデル化できる
  • 非線形な相互作用を捉えることができる
  • ユーザとアイテムの特徴表現を同時に学習することができる

ただ、欠点として単一のユーザーとアイテム間インタラクションのみしか扱えないということがよく言われていて、最近では複数の行動を取り入れたGNNベースのモデルが発展してきています(MBGCN etc.)。

実装

github.com

実装にはPyTorchを使い、実験はNvidia Tesla T4を用いて行いました。ハイパラやデータセットは論文と同一のものを用いました。Poetryを使っている方であれば

poetry install
poetry run python main.py

だけで実行できますのでぜひ試してみてください。モデルの肝の部分は↓のように実装できます。

class NGCF(nn.Module):
    # ...
    def forward(...):
        # ...
        ego_embeddings = torch.cat(
            [self.embedding_dict["user_emb"], self.embedding_dict["item_emb"]], 0
        )
        all_embeddings = [ego_embeddings]

        for k in range(len(self.layers)):
            # Eq. (7) in the paper
            side_embeddings = torch.sparse.mm(A_hat, ego_embeddings)
            sum_embeddings = (
                torch.matmul(
                    side_embeddings + ego_embeddings, self.weight_dict[f"W_gc_{k}"]
                )
                + self.weight_dict[f"b_gc_{k}"]
            )
            bi_embeddings = torch.mul(ego_embeddings, side_embeddings)
            bi_embeddings = (
                torch.matmul(bi_embeddings, self.weight_dict[f"W_bi_{k}"])
                + self.weight_dict[f"b_bi_{k}"]
            )

            ego_embeddings = nn.LeakyReLU(negative_slope=0.2)(
                sum_embeddings + bi_embeddings
            )
            ego_embeddings = nn.Dropout(self.mess_dropout[k])(ego_embeddings)
            norm_embeddings = F.normalize(ego_embeddings, p=2, dim=1)

            all_embeddings += [norm_embeddings]

        # Eq. (9) in the paper
        all_embeddings = torch.cat(all_embeddings, 1)

ego_embedding がグラフ上を伝搬させるもので、全ユーザ・アイテムのembeddingをひとまとめにした行列になっています。A_hat はインタラクション行列から作成したラプラシアン行列 (詳しくは論文の(8)式を見てください)です。これらを使って for文の中で各層での畳込み操作をしています。最終出力で得られるユーザ・アイテムのembeddingは各層のそのユーザ・アイテムのembeddingをくっつけたものになっています。

結果

400 epochs学習させたときのRecall変化は以下のようになりました。

400 epochs後の各データセットでの指標は、

Dataset Recall@20 NDCG@20 Precision@20
Gowalla 0.1538 0.1295 0.0472
Amazon 0.0311 0.0239 0.0132

論文にはRecallとNDCGの結果しか載っていないのですが、Precisionは少し他と比べると低いようです。

また、モデルの性能を比較するためにMatrix Factorization + BPR Loss に対しても実験してみました。データセットはGowallaでbatch sizeや学習率はNGCFと同じにしました。結果 Recall@20 は 0.1400 となりNGCFよりかは9 %ほど評価値が低いという結果になりました。MFでの論文の値は Recall@20 = 0.1291 なので論文よりかは高い値が出ています。何回か実験してみるとEarly stopのかかるところが実験によって結構ばらつき、結果として評価値がそこそこブレることがわかりました。なのでEarly stopの基準を論文で設定されているものより多くすればMFの性能は論文で示されている値よりもよいものになりそうです。しかし、MFとNGCFの評価値の差は10 %近くあるのでこのブレを考慮してもNGCFの方が優れたモデルであるということはできるでしょう(ちゃんとは検証できてないですが)。

ただ、学習が進む速度は当たり前ですが遥かにMFのほうが早く、 Gowalla だと

という感じで、思ったよりNGCF遅いなといった感じでした。なので例えば実務でMFを使っていたとして、10 %精度が良くなります、でも学習時間は15倍に伸びます、CPUだともっと遅くなるのでGPUも使いたいです、はちょっとう〜んとなると思います。

感想

今回はじめてグラフ系のモデルを実装・実験してみたのですが、学習に思ったより時間がかかるという印象でした。Amazon Bookのデータセットでは400 epoch回すのに4日ほどかかりました。NGCFをより効率化したモデルであるLightGCNについても今後実装をしてみたいと思いました。

あと余談ですがNGCFはGPUメモリ思ったより使わないな、というのは結構驚きでした。↓はGowallaデータセットで訓練している最中の の結果ですが、1 GBしか使っていなかったので最初は三度見くらいしました。

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.103.01   Driver Version: 470.103.01   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            On   | 00000000:00:04.0 Off |                    0 |
| N/A   38C    P0    26W /  70W |   1090MiB / 15109MiB |      9%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A     27394      C   ...OdQNlrNe-py3.8/bin/python     1087MiB |
+-----------------------------------------------------------------------------+

おまけ


計算メモ1



計算メモ2


データサイエンティストがOSSにコミットしてみた話

OSSにコミットしてみたいと考えて、見つけたデータ分析ツールにコミットすることができたのでその体験談みたいなことを書きます。

  • SPyQLというツールの紹介
  • どうやって見つけたか
  • どうやって作業を進めていったか

について書こうと思います。

イントロ

OSSへのコミットは、自分のコードを書く能力を上げることができるだけでなく、ツールを使っている他のユーザーにも喜ばれるという経験ができるので興味のあるエンジニアは多いのではないでしょうか。一方でなんとなく敷居が高く、一歩を踏み出すことができない領域だと感じている人も多いと思います。そういう自分もその一人でした。自分はデータサイエンティストとして働いているのですが、なんとなくデータサイエンティストはいわゆるソフトウェアエンジニアとは距離があるように感じていて、データサイエンティストがOSSにコミットなんて。。。などと思っていた部分もあったのだと思います。

そのようなときに、社内の勉強会でばんくしさんのスライドが紹介されていて、OSSコミットにソフトウェアエンジニアもデータサイエンティストも関係ない、むしろデータサイエンティストにしかできない貢献もあると気付かされました。
speakerdeck.com

このあたりから自分の中でOSSになんでもいいからコミットしたいという思いが芽生え、コミットできるリポジトリを探していました。そして、社内の機械学習勉強会で紹介するブログ記事を探しているときにSPyQLというツールの存在を知りました。

github.com

この記事では、SPyQLがどんなツールで、自分がこのツールのどこに貢献したのかについて書きたいと思います。

SPyQLとは

SPyQLは一言でいうとCLI上でSQLっぽい言語を使ってデータ分析ができるツールです。

github.com

spyql "SELECT * FROM csv('dammy.csv') LIMIT 1 TO json(indent=2)"

のようにCSVを分析して結果をJSONに保存するといった使い方ができます。また、matplotcliと組み合わせることによって分析結果を即座にプロットすることもできます。

何をしたか

自分が今回変更を加えたのは、パーサの最初の部分です。例えば、

SELECT
    user_id
FROM
    log
WHERE
    created_at >= '2023-01-01'

{
    "SELECT": "user_id",
    "FROM": "log",
    "WHERE": "created_at >= '2023-01-01'",
    "GROUP BY": None,
    "HAVING": None,
    "ORDER BY": None,
    # ...
}

のように変換する部分です。

自分が変更を加える前は、この部分は正規表現を使ったマッチングを使って行われており、可読性もパフォーマンスも向上の余地が大きいのでは無いかと思いました。Issueを作ってリポジトリオーナーに提案してみたところ、よさそうという返事をもらえたので早速コードを書いてPull Requestを作りました。


(英語は自分で書きつつ、怪しいところはDeepL先生に聞いて書きました)

やり取りをする中で気づいた学び

今回のPull Requestを完成させる過程では、リポジトリオーナーと様々なディスカッションがありました。例えば、コードのこの部分はこうしたほうがもっとパフォーマンスが出るんじゃない?ここまでやってしまうと可読性が落ちるよね、といったことです。このあたりは単純にコードの書き方で学びがあった他、どうやったらすれ違いなくコミュニケーションを取れるかといった点でも学びが大きかったように思っています。というのも、普段の業務でコードを書いてレビューしてもらうときには、レビュー相手は当然互いをよく知っている間柄なので暗黙的にこういうことだよね、といった感じで少し雑なコミュニケーションでも齟齬無く伝わりますが、今回のようなOSSへのコミットの場面では丁寧にコミュニケーションを取らないと、すぐに認識がずれてしまうからです。それはコミュニケーション相手があまり良く知らない人であることがほとんどであり、コミュニケーションの地盤を共有していないためです。

そのため、ディスカッションを進めるときには少し冗長になってもいいので「あなたが言っているのはこういうことか?」といった確認をするようにしていました。その結果、自分もリポジトリオーナーも納得する形でコードを完成させられたと感じています。

まとめ

  • OSSコミットは難しくない(のもある)
  • 勉強会に参加すると自分で集めるよりも効率的に情報を集められる
  • 互いをよく知らない場合のコミュニケーションでは思い込みをなるべく排除することがより大切

flutter build apkしたら~/.gradle/caches/transforms-3/... (そのようなファイルやディレクトリはありません)と怒られた話

タイトルのとおりですが、flutter buildができなくなりました。きっかけはbuild時に 「~/.gradle/caches のバージョンが〜」というwarningが出ていたのでcacheディレクトリだから消せばいいかと考えて消したことでした。

ネットで情報を探していると

  • ~/.gradle/caches を消せばいい
  • ~/.AndroidStudio4.0/system/caches を消せばいい
  • Android Studioを開いて、「File → Invalidate cache & restart」をすればいい

など色々あったのですが、全部うまく行きませんでした。結局、

rm -rf ~/.gradle

で .gradle ディレクトリごと消したらビルドできるようになりました。.gradleディレクトリ内のcaches以外のものがcachesに影響していたのでしょうか?

flutter何もわからない。。。

実験タスクでのデータ管理をちょっと改善した話

TL;DR

実験した結果をクラウドストレージから手元に落として来て分析するときに、ファイルが多いといろいろ辛みがあった。スクリプトを使って使うファイルの管理をしたらデータの分析の効率が少し改善した。

背景

自分は業務でよく、機械学習ジョブを回してその結果を分析するということをやります。その際、ジョブ自体はクラウド上で行い、結果の詳細な分析はデータを手元に落として行うことが多いです。このように、やることによって場所を変えるのは、クラウド上の実行環境はあくまで「実行環境」なのでデバッグや分析をするために毎回ツールをセットアップする必要があったり、分析した結果を可視化するのが難しいという理由からです。

しかし、データを手元に落としてきたはきたで別のツラミがあります。例として、ジョブの実行結果がクラウドストレージに吐き出されていて、それを分析のためにダウンロードするといった状況を考えると、パッと思いつくだけでも以下のようなツラミが生じる可能性があります。

  • 複数のジョブを回す場合に、雑にデータをホイホイダウンロードしているとファイルが多くなったときにどのファイルがどのジョブのものかわからなくなる
  • 1つの実験から複数ファイルが出力として得られるときに、どのファイルが同じ実験から得られたファイルなのか対応がわからなくなる
  • 最初気にしていなかった実験条件が知りたくなったときにストレージのどのファイルを見たら調べられるかわからなくなる

タスクが明確でやりたいことが全部はっきりしているプロジェクトであったら上に挙げたツラミはもしかしたら生じる余地はないのかもしれません。しかし、多くのプロジェクトは多かれ少なかれ「不確かさ」をはらんでいるはずで、そのようなとき、上のツラミが生じることは十分あり得ると思います。この記事では、何らかの方法で少しでもこれを軽減できないかと考えて自分が考え出した1つの解決策を紹介します。

ナイーブな解決策と問題点

上のような問題点を解決する方法として真っ先に思いつくのは、「ファイルをダウンロードして、その後にリンクをどこかにメモしておく」という方法かと思います。しかし、この方法の場合には「リンクをどこかにメモしておく」というステップはスキップ可能で、ダウンロード→分析という過程が簡単に取れてしまいます。したがってこの方法ではツラミが生じる余地を十分減らせていないでしょう。

自分が考えた解決策

そこで私は、「ファイルをダウンロードする」と「リンクをどこかにメモしておく」の順番を逆にすることで問題を軽減できないかと考えました。つまり、必要なデータのリンクをどこかで一元管理しておいて、データのダウンロードはそれを使って行うという方法にすればデータの管理がしやすいのではないかと考えました。

以下が現状で私が使っているデータ管理の形式です。データ分析はPythonでやることが多いので、データ管理にもPythonを使って書いています。

resources = [
	{
		"bucket": "bucket1",
		"blob": "project1",
		"subblob": "experiment1-yyyymmd'd'",
		"files": [
			"parent_dir/file1",
			"file2",
			"file3",
		],
		"destination": "path/to/save/experiment1/yyyymmd'd'",
		"skil_if_exist": True,
		"description": "〇〇を✗✗にして実験した結果",
	},
	{
		"bucket": "bucket1",
		"blob": "project1",
		"subblob": "experiment1-yyyymmdd",
		"files": [
			"parent_dir/file1",
			"file2",
			"file3",
		],
		"destination": "path/to/save/experiment1/yyyymmdd",
		"skil_if_exist": True,
		"description": "〇〇を△△にして実験した結果",
	}
]

この形式にして実際に分析をしてみて感じた利点は、

  • 同じ実験のファイルが一目でわかりやすい
  • ファイル名とパス、ファイルの説明が一箇所にまとまっている
  • 特定の実験で追加のファイルが必要になったときに変更が容易
  • この形式と決めておけば実際のダウンロードに使うスクリプトは使い回せるので楽

だと思っています。

また、まだ試してはいませんが、ダウンロードしたファイルを分析するために読み込むときにも resources を利用できれば、例えば似た名前のファイルを間違って読み込んで分析した、といったミスも減らせる可能性もあります。

この手法のテンプレートはGitHubで公開していますので興味のある人は見てみてください。

github.com

他の手法との比較

Pythonで何かしらのジョブ実行をサポートするものはいくつもあります。その代表例とも言えるのがAirflowです。

airflow.apache.org

Airflowでは上のリンクにあるように、GCSへの多様な操作をできるようになっています。Airflowと比較すると上で紹介した方法というのは機能性という点では劣ってしまいますが、あえて(半分無理やり)Airflowに対する利点を挙げると以下のようなものがあると思います。

  • 実験データ管理だけしたいときには紹介した方法は必要最低限の機能が備わっている
    • Airflowはワークフローエンジンなので実験データ管理以外にも様々な機能がありすぎる
  • データを分析する部分は自由度が大きい
    • 分析部分についてはDAGを作って実行というより試行錯誤しながらやりたいことが多い

補足

実際にダウンロードする際に使うコード
def _download_resource(
    files: list[str],
    destination_dir: str,
    gcs_base: str,
) -> None:
    targets = " ".join([gcs_base + f for f in files])

    result = subprocess.call(
        [
            "gsutil",
            "-m",
            "cp",
            targets,
            destination_dir,
        ]
    )
    if result != 0:
        raise RuntimeError("Script Failed")

def _build_source_and_destination(
    project: str, resource: dict[str, Any]
) -> tuple[str, str]:
    source = f"gs://{resource['bucket']}/{resource['blob']}"
    if resource["subblob"] != "":
        source += f"/{resource['subblob']}/"

    destination = f"{config.REPO_ROOT}/data/{project}/{resource['destination']}"

    return source, destination


def download(project: str, resources: list[dict[str, Any]]):
    for rs in resources:
        source, destination = _build_source_and_destination(project, rs)

        _download_resource(
            rs["files"],
            destination,
            source,
        )

例えばGoogle Cloud StorageからCLI経由でダウンロードする際には、以下のようなコードを使うことができます。(権限があればCloud SDKを使う方法もあります)

「因果推論の科学」を読んだ

最近話題になっていた「因果推論の科学」という本を読んだので感想みたいなものをつらつら書いてみました。


この本はジューディア・パールが2018年に書いた "The Book of Why: The New Science of Cause and Effect " の邦訳版です。英語圏ではすでに非常に評判が良いらしくAmazonでも1,500近くのレビューがついて平均4.4の評価となっています。

著者のジューディア・パールはベイジアンネットワークの研究で有名な方で、チューリング賞も受賞されています(ソース: Wikipedia)。

簡単にまとめると

この本がベースとしているのが「因果のはしご」という考え方です。因果のはしごは3段からなり、下から順に以下のように説明されています。

  • 関連付け:現実を観察してその中に規則性を見つけ、予測に用いる
  • 介入:ある行動をしたときに結果がどうなるか予測する
  • 反事実:現実とは異なる状況を仮定したら結果はどのようになるか予測する

最初にこのはしごに沿ってどう統計学が発展してきたかという歴史的な話がされます。その中で、現状の深層学習を使ったAIはまだはしごの1段目にいるので「強いAI」を作るには因果関係を組み込んではしごを登る必要があるということが書かれています。では因果関係はどうやって定式化されるのか?という話が続いて因果ダイヤグラムという道具が出てきます。そして因果ダイヤグラムを使って介入や反事実がどう表現されるか、具体例を多く交えながら(かつ少ない数式で)説明されています。

データは何も教えてくれない

ちょっと過激なこのセクションタイトルは第一章に出てくる言葉です。

データを見れば、たとえば、ある薬を服用した人が服用しなかった人よりも早く回復したことだけはわかるかもしれない。しかし、「なぜそうなったか」という理由はわからない。もしかすると、その薬を服用した人は、そうする金銭的余裕があったからそうしたまでで、服用しなかったとしても、結局は同じくらい早く回復したかもしれない。

わたしたちがデータを使って知りたいことは、ある行動が結果に対してどのような影響をもたらしたか、つまり因果関係であることが多いです。しかしデータ単体では因果関係を明らかにすることはできません(因果のはしごの1段目の問には答えられるが2段目以上の問には答えられない)。上の例でいうと、薬を服用した人と服用していない人のそれぞれに対する回復までにかかる時間のデータがあって、そこから「薬に効果があったか」という問に答えを出したいとき、分析者の頭の中には下のような因果ダイヤグラムがあります(本当はもっと様々な要因がダイヤグラムには現れるはずですがここでは上の引用に挙げられている要因だけを考えます)。

データに加えて上のような「モデル」を組み合わせることで「薬に効果があったか」という問に対する1つの答えを得ることができます。このモデルでは、経済力が交絡因子であると考えているので経済力を固定してバッグドアを閉じという分析をすることになります。

このように、データから因果のはしごの2段目の問に答えるにはデータの背後にある因果関係をどのようにモデリングするかが重要になります。

データの生成過程を知ることの重要性

データそのものよりもその生成過程が重要であるという事実を、本ではモンティー・ホールのパラドックスを使ってわかりやすく説明しています。モンティー・ホールのパラドックスは一種のくじ引き的な状況で生じるパラドックスです。3つの扉がありそのうち1つの後ろには新車が置かれていて、参加者はその扉を選べば新車がもらえるという状況です。参加者が開ける扉を選択したあとにくじを作った主催者は参加者が選ばなかった扉の中から新車の扉以外を開き、参加者に選択した扉を変更できることを告げます。このとき参加者は選択した扉を変えるべきかというのが問題です。

直感的に考えると選択を変更しても変更しなくても確率は変わらない気がしますが、実際には変更しない場合に新車がもらえる確率は1/3、変更した場合には2/3となります。本の中ではこの状況を少し変えたバージョンも挙げて比較しています。参加者が最初に扉を選んだあとに主催者が扉を開けるとき、開ける扉は完全にランダムで新車がある扉も開ける可能性があるというバージョンです。実はこちらのバージョンでは参加者が選択を変更してもしなくても新車がもらえる確率は1/3です。

この2つの状況からデータを取ろうとすると両方とも「参加者の選択した扉」、「新車の位置」、「主催者の開いたドア」に関するデータが取れます。データを比較すればアレンジバージョンの方では参加者が選択を変えても新車を得られる確率が変わらないという事実はわかりますが、なぜそうなっているのかはわかりません。両者の状況では実は、データの生成過程が異なります。両者で因果ダイヤグラムを書くと前者では新車の位置と主催者の扉選択の間に因果関係がありますが後者ではありません。つまり、本の中の一節を引用するならば

情報をどのようにして得たかは、情報そのものと同じくらい重要

ということです。

モデルを作る難しさ

一方で「モデルを正しく作る」という作業も簡単なものではありません。本の中で

因果関係の存在は、因果関係が存在するという前提で状況を見ていないと発見できない

と書いてあるように、正しい分析のためにはデータが生成される過程を理解し、不確かな部分には仮説を立てることが重要です。余談ですが、本の中では因果推論の研究が長い間統計学で敬遠されてきた理由を述べていて下のような一節が出てきます。

クローは、無視された理由をこう推測している。パス解析は、「あらかじめ用意された手順にただ従えばいい、というものではなかった。パス解析を行う者は、まず自分で仮説を立て、複数の因果関係をまとめた適切なダイヤグラムを作成する必要があった」。クローの指摘は本質を突いている。あらゆる因果推論がそうであるように、パス解析にもまた科学的思考が不可欠だ。ところが統計学では、科学的思考は敬遠され、むしろ決められた手順に従うことをよしとする場合が多い。自らの科学的知識が試されるような手法は敬遠され、データの数値を使って決まった手順で計算すればいいという手法が好まれるのである。

「科学的知識」はデータサイエンスの文脈では「ドメイン知識」と言われることが多いかもしれません。やはり適切なモデルを作るための手順書のようなものは存在しないので粘り強く仮説検証をすることが重要ということがわかりますね。

感想

自分は因果推論に関しては初心者で、入門書としてこの本を読んでみました。中身は数学書のように数式がいっぱい出て来るというわけでもなく(そもそも縦書きですし)、様々な実例を挙げて説明がなされているのでとてもわかり易かったです。

実は過去にPythonを使った因果推論の参考書をちょこっと読んだことがあったのですが、その中ではバックドア基準はこういうもんだから受け入れろという感じの書き方だったのでいまいち理解できていませんでした。しかしこの本では文章でですが重要な部分がちゃんと説明されていたので、消化不良を起こさずに理解できた気がしています。