pyhaya’s diary

機械学習系の記事をメインで書きます

Wantedlyでエンジニアインターンした話

この夏、初めてエンジニアとして就業型のインターンに参加したので感想を書いてみます。

Who am I ?

まずはお前誰?という感じだと思うのでざっくりと

  • 都内の理系大学院生(非情報系学科)
  • Python3 チョットできる
  • エンジニア未経験
  • AtCoder

こんな人です。見ての通り、エンジニアとしてはほぼ素人みたいな感じです。

なぜWantedlyでエンジニアインターンをしたのか

f:id:pyhaya:20190831163308j:plain
wantedlyのホームページより引用

最初、Wantedly Visitというサービスを使ってインターン先を探していたところ、当のWantedlyさんから一度話をしてみないかと誘われたのがきっかけです。

www.wantedly.com

(情報系でもないし、企業での開発経験もないのになぜ連絡が来たのかわかりませんでしたが、聞いてみたら「プロフィール見て面白そうな人だと思ったから」と言われました、、、(?))

正式にインターンが決まるまで

声がかかってから、Wantedlyのエンジニアの方と何回かあって話をしました。最初にあったときには、(さすがに私の経験がなかったので)何か作るか勉強した結果を見せてくれと言われました。そこで、私はMLチームを希望していたため、ML系を色々つまみ食いしました。

  • GCP使うといっていたのでCourseraでGoogle CloudのSpecializationを受けた。全部で5コースあって、1ヶ月ですべて修了した(Certificateが有料で、それが月額制だったから)。

www.coursera.org

  • Tensorflowを勉強するために論文を読んで実装して試すということをしていた。

github.com
このリポジトリは今でも結構いじっていて、画像系のモデルを色々実装してるので良かったら見てください。

これらを見せたらなんとかOKが出ました!!

インターン本番

課題

私がインターンで取り組んだ課題は、「ユーザーマッチの改善」です。Wantedly Peopleというアプリを使ったことがある人はわかるかも知れませんが、このアプリでは誰かの名刺を撮影したときに、その相手がWantedlyユーザーであればつながりを作るという機能があります。ただ、現状では本当はマッチすべきなのに読み取りミスなどによってマッチできないという例がいくつか存在していました。

私はこれを「ルールベース」、「画像ベース」の2つの方向から改善しようということに取り組みました。

ルールベースの改善

まず、これまではユーザーマッチシステムの定量的な評価をするためのまとまったデータがなかったので、BigQueryを叩いて必要なデータを集めてデータセットを作るということから始めました。そしてその自作データセットを使いながらシステムの性能、改善点を丁寧に評価していきました。次に、得られた改善点をもとにシステムロジックの変更に取り組みました。ユーザーマッチのシステムはGoで書かれており、私にGoの経験がなかったので、1日はコードとにらめっこしてGoの知識を得るということに費やしましたが、Goがシンプルな言語だったことが幸いして、次の日には少しずつ実装を始めることができました。

画像ベースの改善

これは全く新しい取り組みだったので、モデルの構築からサーバーを書くところまでやりました。これまではモデルを作って訓練して、性能を見て、で満足していたのですが、ここではレイテンシやサーバーが持つデータの容量も非常に重要で、モデルを書き換えたり、インフラチームと議論したりと大変でしたが楽しんで取り組めました。

訓練時にはGPUインスタンスを使って訓練するのですが、自分が個人開発で訓練しているときとは違い、条件を変えて何個も同時に訓練させたり出来たのは快感(?)でした。

日常生活

ランチ

Wantedlyは会社が白金にあるので、ランチが高いです(泣)。店構えからしてすごい、、、。

よく行ってた店
www.chisou-koujiya.com


最初はビビりながら行っていたのですが、日給が1.5万出ていたのでなんとかなりました(昔は日給が8000円だったと聞いて震え)。ランチは色んな人と行きましたが、みんなランチでも技術の話をしていて、ほんとに技術が好きなんだなとか思って聞いてました(レベル高くて入れない)。

話を聞いていて影響を受けることも多々あって、話に出てきた論文を探して読んでみたり、Rustを勉強し始めてみたり、知識の幅が一気に広がった気がします。

ML輪講

毎週水曜の18:00から機械学習系の論文を読む会があり、それに参加していました。

github.com

外部からの参加もWelcomeで、たまに他の企業から人が来て、自分の会社で何をやっているのかとか話を聞けて面白かったです。インターン後半に私のメンターをしてくださっている方が毎回この会で3本くらい論文紹介してて、この人(良い意味で)ヤバイなと思ってましたw。

就業時間

コアタイムはありましたが(忘れた)、みんな好きな時間に来て、好きな時間に帰るという感じだったように思います。9:00くらいから人が来始めて、10:30くらいに大体全員来る。18:30くらいから人が帰り始める、といった感じでした。

感想

初めてのエンジニアインターンだったため比較対象がないためなんとも言えませんが、私的にはWantedlyはすごくいい会社だと思いました。会社全体が自由な雰囲気に溢れていて、組織がフラットなので(誰が偉い人なのか最初2週間位全くわからなかった。。。)、誰かが何かを思いつくと色んな人がすぐに集まって議論が始まるというのは、スピード感もあってすごく魅力的に感じました。

Rustで練習がてら簡単な切符の予約システムを作ってみる

最近、Rustを始めました。しばらくはドキュメントを見ながら勉強していたのですが、飽きてきて、何か作りたいなと思い始めたので、(すごく)簡単な切符の予約システムを作ってみました。まだ初心者なのでGUIで操作できたり、コマンドラインで引数を与えて実行できるというような高尚なものでは無いです。どちらかというと、自分みたいに勉強したはいいけど何していいか全くわからないという人に、初心者でもこんなことができるということを知ってもらいたいというのが目的です。

リファクタリングの過程とかも書くので、結果だけみたい人は目次で飛んでください。

環境

準備

cargo new reserve

でプロジェクトを作ります。

大雑把な外枠を作る

src/main.rs

struct Request {
    start: String,
    destination: String,
    time: String,
    time_kind: String,
}

fn reserve(request: Request) -> bool {
    true
}

fn main() {
    let request = Request {
        start: "NewYork".to_string(),
        destination: "Chicago".to_string(),
        time: "18:00".to_string(),
        time_kind: "start".to_string(),
    };

    if reserve(request) {
        println!("Succeed in reserving that train");
    } else {
        println!("Failed to reserved that train");
    }
}

かなりざっくりしています。まず決めたのは、予約の詳細をRequestという名前の構造体で保持するということです。そして、予約可能かどうかを判定する関数reserveを定義しました。これは現時点ではただ単にtrueを返すだけの関数です。

改善点は山のようにあります。

  • Stringを直接渡しているので、文字列リテラル(&str)を渡すようにする。こうすれば.to_string()もなくせる
  • リクエストが妥当なものか確かめる仕組みを作る
  • reserveRequestのメソッドにする

メソッドを定義してまとめる

src/main.rs

struct Request<'a> {
    start: &'a str,
    destination: &'a str,
    time: &'a str,
    time_kind: &'a str,
}

impl<'a> Request<'a> {
    fn reserve(&self) -> bool {
        true
    }

    fn is_valid(&self) -> bool {
        if self.start == "" {
            println!("You need to determine start point");
            return false;
        } else if self.destination == "" {
            println!("You need to determine destination");
            return false;
        } else if self.time == "" && self.time_kind != "" {
            println!("Invalid time specification");
            return false;
        }

        true
    }
}

fn main() {
    let request = Request {
        start: "Tokyo",
        destination: "Kyoto",
        time: "18:00",
        time_kind: "start",
    };

    if request.is_valid() && request.reserve() {
        println!("Succeed in reserving that train");
    } else {
        println!("Failed to reserved that train");
    }
}

参照を渡すようにすることで、Rust特有の文法であるライフタイムが必要になりました('aとか書いてあるもの)。構造体の要素が参照なので、明示的にそのライフタイムを構造体自身と合わせてやる必要があります。すこしRustっぽくなってきました(?)が、肝心のreserveメソッドがなんの働きもしていません。次にココらへんを改善していきます。

  • 時刻表を作る
  • reserveメソッドを実装する
  • codeが長くなってきたのでRequestをモジュールにまとめる

リクエストをモジュールに分ける

src/reserve_request/mod.rs

use std::collections::HashMap;

pub struct Request<'a> {
    pub start: &'a str,
    pub destination: &'a str,
    pub time: &'a str,
    pub time_kind: &'a str,
}

impl<'a> Request<'a> {
    pub fn reserve(&self, timetable: HashMap<(&str, &str), &[&str]>) -> bool {
        let st = (self.start, self.destination);

        let times = timetable.get(&st);
        let default: &[&str] = &[];
        let result = match times {
            Some(r) => r,
            None => default,
        };

        for t in result {
            if *t == self.time {
                return true;
            }
        }

        false
    }

    pub fn is_valid(&self) -> bool {
        if self.start == "" {
            println!("You need to determine start point");
            return false;
        } else if self.destination == "" {
            println!("You need to determine destination");
            return false;
        } else if self.start == self.destination {
            println!("Invalid. The start point and destication is same");
        } else if self.time == "" && self.time_kind != "" {
            println!("Invalid time specification");
            return false;
        }

        true
    }
}

src以下にreserve_requestディレクトリを作り、その中にmod.rsを作ります。main.rsから使うメソッドや構造体にはpubをつけてパブリックにします。これをモジュールとして認識してもらうために、src以下にlib.rsを作ります。

src/lib.rs

pub mod reserve_request;

main.rsでは、useでreserve_requestをインポートします。

src/main.rs

use reserve::reserve_request;
use std::collections::HashMap;

fn main() {
    //TODO: save time table in the json file
    let time: &[&str] = &["12:00", "14:00", "18:00", "19:00"];
    let mut stations = HashMap::new();
    stations.insert(("Tokyo", "Kyoto"), time);

    let request = reserve_request::Request {
        start: "Tokyo",
        destination: "Kyoto",
        time: "18:00",
        time_kind: "start",
    };

    if request.is_valid() && request.reserve(stations) {
        println!("Succeed in reserving that train");
    } else {
        println!("Failed to reserved that train");
    }
}

main.rsでは、HashMapとして時刻表を保持しています。HashMapのキーは出発駅と到着駅、値が出発時刻です。これはあまりにしょぼいので、将来的にはJSONに書いて、それを読み込む形になるかと思います。

テストを書く

モジュールにも分けたことですし、reserve_requestにテストを追加します。Rustでは実際のコードと同じファイルにテストをかけるので、mod.rsにテストを書き込みます。

src/reserve_request/mod.rs

// 同じため省略

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn struct_type() {
        let request = Request {
            start: "Tokyo",
            destination: "Kyoto",
            time: "18:00",
            time_kind: "start",
        };

        assert!(request.start == "Tokyo");
        assert!(request.destination == "Kyoto");
        assert!(request.time == "18:00");
        assert!(request.time_kind == "start");
    }

    #[test]
    fn invalid_examples() {
        let start_is_lack = Request {
            start: "",
            destination: "Kyoto",
            time: "18:00",
            time_kind: "start",
        };

        assert!(!start_is_lack.is_valid());

        let destination_is_lack = Request {
            start: "Tokyo",
            destination: "",
            time: "18:00",
            time_kind: "start",
        };

        assert!(!destination_is_lack.is_valid());

        let start_is_destination = Request {
            start: "Tokyo",
            destination: "Tokyo",
            time: "18:00",
            time_kind: "start",
        };

        assert!(!start_is_destination.is_valid());

        let invalid_time_info = Request {
            start: "Tokyo",
            destination: "Kyoto",
            time: "",
            time_kind: "start",
        };

        assert!(!invalid_time_info.is_valid());
    }

    #[test]
    fn valid_examples() {
        let valid_example_1 = Request {
            start: "Tokyo",
            destination: "Kyoto",
            time: "18:00",
            time_kind: "start",
        };

        assert!(valid_example_1.is_valid());

        let valid_example_2 = Request {
            start: "Tokyo",
            destination: "Kyoto",
            time: "",
            time_kind: "",
        };

        assert!(valid_example_2.is_valid());
    }
}

Requstの関連関数をつくる

Rustではコンストラクタはなく、似たような役割のメソッドを関連関数と呼ぶみたいです(出典:ドキュメントの日本語訳)。今までは、main.rsで構造体を直接作って、形式が妥当かどうか確かめていましたが、関連関数を作って、そこから構造体を生成して妥当性チェックをするのがよいでしょう。

src/reserve_request/mod.rs

use std::collections::HashMap;

pub struct Request<'a> {
    pub start: &'a str,
    pub destination: &'a str,
    pub time: &'a str,
    pub time_kind: &'a str,
}

impl<'a> Request<'a> {
    pub fn new<'b>(
        start: &'b str,
        destination: &'b str,
        time: &'b str,
        time_kind: &'b str,
    ) -> Result<Request<'b>, &'b str> {
        let r = Request {
            start: start,
            destination: destination,
            time: time,
            time_kind: time_kind,
        };

        if r.is_valid() {
            Ok(r)
        } else {
            Err("Failed to construct request from this information")
        }
    }

    pub fn reserve(&self, timetable: HashMap<(&str, &str), &[&str]>) -> bool {
        let st = (self.start, self.destination);

        let times = timetable.get(&st);
        let default: &[&str] = &[];
        let result = match times {
            Some(r) => r,
            None => default,
        };

        for t in result {
            if *t == self.time {
                return true;
            }
        }

        false
    }

    fn is_valid(&self) -> bool {
        if self.start == "" {
            println!("You need to determine start point");
            return false;
        } else if self.destination == "" {
            println!("You need to determine destination");
            return false;
        } else if self.start == self.destination {
            println!("Invalid. The start point and destication is same");
            return false;
        } else if self.time == "" && self.time_kind != "" {
            println!("Invalid time specification");
            return false;
        }

        true
    }
}

// 以下テスト

このようにすることで、main.rsではRequest::new(...)で構造体を作れ、newの中で妥当性チェックを行えるので、is_validはプライベートにできます。

src/main.rs

use reserve::reserve_request;
use std::collections::HashMap;

fn main() {
    //TODO: save time table in the json file
    let time: &[&str] = &["12:00", "14:00", "18:00", "19:00"];
    let mut stations = HashMap::new();
    stations.insert(("Tokyo", "Kyoto"), time);

    let request = reserve_request::Request::new("Tokyo", "Kyoto", "18:00", "start").unwrap();

    if request.reserve(stations) {
        println!("Succeed in reserving that train");
    } else {
        println!("Failed to reserved that train");
    }
}

まとめ

ここまで来てもやはりかなり大雑把で、直す所だらけですが、これだけでも結構Rustという言語の良い勉強になったなと感じてます。最終的なコードも玄人から見たら「ここはこう書くべきではない」とかあると思うので、気づいたらバシバシ指摘していただけると、勉強になるので嬉しいです。

実践Rust入門[言語仕様から開発手法まで]

実践Rust入門[言語仕様から開発手法まで]

Tensorflowを使ってUNetを試す Version 2

UNetを構築してみようVersion 2です。

Version2?

実は過去にTensorflowでUNetを書いた!という記事を書いています。。。
pyhaya.hatenablog.com

なぜ同じ内容をもう一回書くのかというと、Tensorflowのバージョンアップに伴って上のコードが割と根本的なところで動かなくなることが確定しているから、全部書き直そう!となったからです。(ツライ)

Tensorflowの変更点

Tensorflowは現在、2.0のベータ版がGithubで公開されています。メジャーバージョンが変わるのでそれなりに大きな変更であることは覚悟していたのですが、Googleは私の想像を超えてきました。

以下のサイトが変更点を詳しく紹介してくださっています。

qiita.com

ざっくりまとめると

  • tf.placeholderなくなる
  • tf.Session()なくなる
  • tf.global_variable_initializer()なくなる


、、、はい。2系が主流になると過去記事の私のコードはすべての行でエラーが出るんじゃないかというレベルで壊れます。

なのでこれから主流になるtf.kerasを使ってコードを書き直します。

モデル

UNetがどんなネットワークかは前の記事を見ていただくとして、モデルを書きます。

from typing import Optional
import argparse
import tensorflow as tf


class conv_set:
    def __init__(self, filters: int):
        self.filters = filters

    def __call__(self, inputs: tf.Tensor) -> tf.Tensor:
        y = tf.keras.layers.Conv2D(
            self.filters, kernel_size=3, padding="SAME", activation="relu"
        )(inputs)
        y = tf.keras.layers.Conv2D(
            self.filters, kernel_size=3, padding="SAME", activation="relu"
        )(y)
        y = tf.keras.layers.BatchNormalization()(y)
        return y


class upsampling:
    def __init__(self, filters: int, cut: Optional[int] = 0):
        self.filters = filters
        self.cut = cut

    def __call__(self, inputs: tf.Tensor) -> tf.Tensor:
        upconv = tf.keras.layers.Conv2DTranspose(
            self.filters, kernel_size=2, strides=2
        )(inputs[0])

        conv_crop = tf.keras.layers.Cropping2D(self.cut)(inputs[1])
        concat = tf.keras.layers.concatenate([conv_crop, upconv])

        return concat


def UNet(args: "argparse.Namespace") -> tf.keras.Model:
    n_classes: int = args.n_classes
    decay: float = args.l2

    x = tf.keras.Input(shape=(224, 224, 3))

    # down sampling
    conv1 = conv_set(64)(x)
    max_pool1 = tf.keras.layers.MaxPool2D()(conv1)
    conv2 = conv_set(128)(max_pool1)
    max_pool2 = tf.keras.layers.MaxPool2D()(conv2)
    conv3 = conv_set(256)(max_pool2)
    max_pool3 = tf.keras.layers.MaxPool2D()(conv3)
    conv4 = conv_set(512)(max_pool3)
    max_pool4 = tf.keras.layers.MaxPool2D()(conv4)
    conv5 = conv_set(1024)(max_pool4)

    # up sampling
    concat1 = upsampling(512)([conv5, conv4])
    conv6 = conv_set(512)(concat1)
    concat2 = upsampling(256)([conv6, conv3])
    conv7 = conv_set(256)(concat2)
    concat3 = upsampling(128)([conv7, conv2])
    conv8 = conv_set(128)(concat3)
    concat4 = upsampling(64)([conv8, conv1])
    conv9 = conv_set(64)(concat4)

    output = tf.keras.layers.Conv2D(filters=n_classes, kernel_size=1)(conv9)
    output = tf.keras.layers.Softmax()(output)

    model = tf.keras.Model(inputs=x, outputs=output)
    for layer in model.layers:
        if "kernel_regularizer" in layer.__dict__:
            layer.kernel_regularizer = tf.keras.regularizers.l2(decay)

    if args.weights != "":
        model.load_weights(args.weights)

    return model

生のTensorflowを使ったときと比べて、かなりスッキリとモデルを記述できていることがわかります。UNetには同じような繰り返しが存在している(畳み込みx2 + BNやupsampling)のでそれはクラスを定義してまとめてしまっています。これはカスタムレイヤーを定義しているわけではなく、既存のレイヤーを組み合わせているだけなので、tf.keras.layers.Layerを継承してcallをoverrideすることはせず、単にクラスを作ってcallableになるように__call__メソッドを定義していることに注意してください。

逆にLayerクラスを継承する形で定義してしまうと、訓練時にこのクラスが持つ重みパラメータが訓練可能パラメータとして認識されません。

訓練した結果

動作確認のために、試しにネコ画像で訓練してみました。

  • learning rate: 0.001
  • batch size: 4
  • epoch: 100

訓練後に適当にネコ画像をネットから取ってきて入れてみたら下のようになり、そこそこ行けている感じです。

f:id:pyhaya:20190818193242p:plain

どれくらいうまく行けているかはまだ定量的には測れていません(validation data 増やすのキツイ、、、)。

リンク(GitHub)

今回紹介したモデルはGitHubに上げています。

github.com

TensorFlowでUNetを構築する

この記事では、Tensorflowを使ってUNetを構築し、最終的には画像から猫を認識するように訓練するやり方を紹介します。(この記事で紹介しているコードはTensorflow2系では動作しません。2系でも動くコードは別記事にしたので良かったら読んでください
Tensorflowを使ってUNetを試す Version 2 - pyhaya’s diary

環境

  • Ubuntu 18.04
  • Python 3.6.8
  • Tensorflow 1.12.0
  • GeForce RTX2080

UNetとは何か

セマンティック・セグメンテーション(semantic segmentation)

UNetというのは機械学習モデルの名前で、セマンティク・セグメンテーションを行うために使われます。セマンティック・セグメンテーションというのは、画像をピクセル単位でいくつかのクラスに分類する画像処理の手法です。例えば下のような人と馬の画像を処理すると右のように人(薄ピンク)と馬(ピンク)、そして背景(黒)をピクセル単位で分類します。

f:id:pyhaya:20190502214459p:plain
https://nicolovaligi.com/deep-learning-models-semantic-segmentation.html から引用
似たようなタスクとして、画像から物体を認識する場合に物体があると考えられる領域を長方形で認識して表示するものがあります(下図)。これと比較するとセマンティック・セグメンテーションではより高度な処理を行っていることがわかります。
f:id:pyhaya:20190502215611j:plain
https://gigazine.net/news/20140920-revolutionary-machine-vision/ から引用

UNetの構造

UNetは2015年にドイツの大学の研究グループが発表したネットワークです。名前の由来はそのネットワークの形状で、U字型をしているためにこのように呼ばれています。

f:id:pyhaya:20190502225111p:plain
論文から引用

arxiv.org

UNetは基本的には畳み込みニューラルネットワーク(CNN)で、その特徴は大きく分けて次のような2種類の処理に分けて考えることができます。

1. ダウンサンプリング
畳み込みでfeature mapを倍にしながらmax poolingで画像サイズを小さくしていく
2. アップサンプリング
transpose convolution*1で画像サイズをもとに戻していく。このときダウンサンプリング中のデータを加えながら処理を進めていく(図の灰色矢印)

論文では入力画像は大きさが572x572でチャネル数が1なのでグレースケールの画像になっています。そして最終的には出力が388x388でチャネル数が2になっています。これは少し説明が必要で、論文ではゼロパディングをしていないので、出力が入力よりも小さくなります。まず、この論文では一つの画像を一度にセグメンテーションするのではなく、いくつかの388x388の領域に分割し、最後に出力結果をつなぎ合わせて最終的なセグメンテーション結果とします。そして388xx388の大きさの領域をセグメンテーションするためにその周りを含めて572x572の領域をネットワークに入れます。処理したい画像領域が元の画像の端で、572x572に拡大できないときには、足りない部分を元の画像の端を鏡面とした鏡映操作をして補います。

f:id:pyhaya:20190502231618p:plain
論文より引用

出力のチャネル数2は判別するクラスの数によります。この場合には判別するクラスが2つとなっているため出力のチャネル数が2になっています。この2つのクラスをクラス1, クラス2と書くことにすると、第一チャネルはクラス1に分類される部分だけ1でほかはゼロ、そして第二チャネルはクラス2に分類されるピクセルだけ1でほかはゼロというようなOne-Hot表現になっています。

UNetのPython(Tensorflow)での実装

では、Tensorflowを使ってUNetを実装してみます。実装では
github.com
を参考にさせていただきました。

UNet本体

UNetの本体はTensorflowで書くと下のようにかけます。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

import main

class UNet:
  def __init__(self, classes):
    self.IMAGE_DIR = './dataset/raw_images'
    self.SEGMENTED_DIR = './dataset/segmented_images'
    self.VALIDATION_DIR = './dataset/validation'
    self.classes = classes
    self.X = tf.placeholder(tf.float32, [None, 128, 128, 3]) 
    self.y = tf.placeholder(tf.int16, [None, 128, 128, self.classes])
    self.is_training = tf.placeholder(tf.bool)

  @staticmethod
  def conv2d(
    inputs, filters, kernel_size=3, activation=tf.nn.relu, l2_reg=None, 
    momentum=0.9, epsilon=0.001, is_training=False,
    ):
    """
    convolutional layer. If the l2_reg is a float number, L2 regularization is imposed.
    
    Parameters
    ----------
      inputs: tf.Tensor
      filters: Non-zero positive integer
        The number of the filter 
      activation: 
        The activation function. The default is tf.nn.relu
      l2_reg: None or float
        The strengthen of the L2 regularization
      is_training: tf.bool
        The default is False. If True, the batch normalization layer is added.
      momentum: float
        The hyper parameter of the batch normalization layer
      epsilon: float
        The hyper parameter of the batch normalization layer

    Returns
    -------
      layer: tf.Tensor
    """
    regularizer = tf.contrib.layers.l2_regularizer(scale=l2_reg) if l2_reg is not None else None
    layer = tf.layers.conv2d(
      inputs=inputs,
      filters=filters,
      kernel_size=kernel_size,
      padding='SAME',
      activation=activation,
      kernel_regularizer=regularizer
    )

    if is_training is not None:
      layer = tf.layers.batch_normalization(
        inputs=layer,
        axis=-1,
        momentum=momentum,
        epsilon=epsilon,
        center=True,
        scale=True,
        training=is_training
      )

    return layer

  @staticmethod
  def trans_conv(inputs, filters, activation=tf.nn.relu, kernel_size=2, strides=2, l2_reg=None):
    """
    transposed convolution layer.

    Parameters
    ---------- 
      inputs: tf.Tensor
      filters: int 
        the number of the filter
      activation: 
        the activation function. The default function is the ReLu.
      kernel_size: int
        the kernel size. Default = 2
      strides: int
        strides. Default = 2
      l2_reg: None or float 
        the strengthen of the L2 regularization.

    Returns
    -------
      layer: tf.Tensor
    """
    regularizer = tf.contrib.layers.l2_regularizer(scale=l2_reg) if l2_reg is not None else None

    layer = tf.layers.conv2d_transpose(
      inputs=inputs,
      filters=filters,
      kernel_size=kernel_size,
      strides=strides,
      kernel_regularizer=regularizer
    )

    return layer

  @staticmethod
  def pooling(inputs):
    return tf.layers.max_pooling2d(inputs=inputs, pool_size=2, strides=2)


  def UNet(self, is_training, l2_reg=None):
    """
    UNet structure.

    Parameters
    ----------
      l2_reg: None or float
        The strengthen of the L2 regularization.
      is_training: tf.bool
        Whether the session is for training or validation.

    Returns
    -------
      outputs: tf.Tensor
    """
    conv1_1 = self.conv2d(self.X, filters=64, l2_reg=l2_reg, is_training=is_training)
    conv1_2 = self.conv2d(conv1_1, filters=64, l2_reg=l2_reg, is_training=is_training)
    pool1 = self.pooling(conv1_2)

    conv2_1 = self.conv2d(pool1, filters=128, l2_reg=l2_reg, is_training=is_training)
    conv2_2 = self.conv2d(conv2_1, filters=128, l2_reg=l2_reg, is_training=is_training)
    pool2 = self.pooling(conv2_2)

    conv3_1 = self.conv2d(pool2, filters=256, l2_reg=l2_reg, is_training=is_training)
    conv3_2 = self.conv2d(conv3_1, filters=256, l2_reg=l2_reg, is_training=is_training)
    pool3 = self.pooling(conv3_2)

    conv4_1 = self.conv2d(pool3, filters=512, l2_reg=l2_reg, is_training=is_training)
    conv4_2 = self.conv2d(conv4_1, filters=512, l2_reg=l2_reg, is_training=is_training)
    pool4 = self.pooling(conv4_2)

    conv5_1 = self.conv2d(pool4, filters=1024, l2_reg=l2_reg)
    conv5_2 = self.conv2d(conv5_1, filters=1024, l2_reg=l2_reg)
    concat1 = tf.concat([conv4_2, self.trans_conv(conv5_2, filters=512, l2_reg=l2_reg)], axis=3)

    conv6_1 = self.conv2d(concat1, filters=512, l2_reg=l2_reg)
    conv6_2 = self.conv2d(conv6_1, filters=512, l2_reg=l2_reg)
    concat2 = tf.concat([conv3_2, self.trans_conv(conv6_2, filters=256, l2_reg=l2_reg)], axis=3)

    conv7_1 = self.conv2d(concat2, filters=256, l2_reg=l2_reg)
    conv7_2 = self.conv2d(conv7_1, filters=256, l2_reg=l2_reg)
    concat3 = tf.concat([conv2_2, self.trans_conv(conv7_2, filters=128, l2_reg=l2_reg)], axis=3)

    conv8_1 = self.conv2d(concat3, filters=128, l2_reg=l2_reg)
    conv8_2 = self.conv2d(conv8_1, filters=128, l2_reg=l2_reg)
    concat4 = tf.concat([conv1_2, self.trans_conv(conv8_2, filters=64, l2_reg=l2_reg)], axis=3)

    conv9_1 = self.conv2d(concat4, filters=64, l2_reg=l2_reg)
    conv9_2 = self.conv2d(conv9_1, filters=64, l2_reg=l2_reg)
    outputs = self.conv2d(conv9_2, filters=self.classes, kernel_size=1, activation=None)

    return outputs

  def train(self, parser):
    """
    training operation
    argument of this function are given by functions in main.py

    Parameters
    ----------
      parser: 
        the paser that has some options
    """
    epoch = parser.epoch
    l2 = parser.l2
    batch_size = parser.batch_size
    train_val_rate = parser.train_rate

    output = self.UNet(l2_reg=l2, is_training=self.is_training)
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=self.y, logits=output))
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
      train_ops = tf.train.AdamOptimizer(parser.learning_rate).minimize(loss)

    init = tf.global_variables_initializer()
    saver = tf.train.Saver(max_to_keep=100)
    all_train, all_val = main.load_data(self.IMAGE_DIR, self.SEGMENTED_DIR, n_class=2, train_val_rate=train_val_rate)
    with tf.Session() as sess:
      init.run()
      for e in range(epoch):
        data = main.generate_data(*all_train, batch_size)
        val_data = main.generate_data(*all_val, len(all_val[0]))
        for Input, Teacher in data:
          sess.run(train_ops, feed_dict={self.X: Input, self.y: Teacher, self.is_training: True})
          ls = loss.eval(feed_dict={self.X: Input, self.y: Teacher, self.is_training: None})
          for val_Input, val_Teacher in val_data:
            val_loss = loss.eval(feed_dict={self.X: val_Input, self.y: val_Teacher, self.is_training: None})

        print(f'epoch #{e + 1}, loss = {ls}, val loss = {val_loss}')
        if e % 100 == 0:
          saver.save(sess, f"./params/model_{e + 1}epochs.ckpt")

      self.validation(sess, output)

  def validation(self, sess, output):
    val_image = main.load_data(self.VALIDATION_DIR, '', n_class=2, train_val_rate=1)[0]
    data = main.generate_data(*val_image, batch_size=1)
    for Input, _ in data:
      result = sess.run(output, feed_dict={self.X: Input, self.is_training: None}) 
      break
    
    result = np.argmax(result[0], axis=2)
    ident = np.identity(3, dtype=np.int8)
    result = ident[result]*255

    plt.imshow((Input[0]*255).astype(np.int16))
    plt.imshow(result, alpha=0.2)
    plt.show()

長いけれども、やっていることは大したことなく、クラス内部に畳み込み層、transpose convolution層、プーリング層をメソッドとして定義しておいてUNetメソッドで本体を定義しています。

論文とこの実装は違っているところもあります。

trainメソッドで実際の学習を実行します。

学習データの作成

次に学習データをUNetに流し込む部分を書かなければいけませんが、その前に、学習データを作成する必要があります。画像のセグメンテーションにはlabelmeというフリーソフトを使いました。
github.com

GitHubに書いてあるインストール方法でインストールし、セグメンテーションしました。

f:id:pyhaya:20190504195754p:plainf:id:pyhaya:20190504195731p:plain
公開されているようなセマンティック・セグメンテーションのデータセットに比べると雑ですが、これで試してみます。

訓練してみる

上の要領で作ったデータを使って(76枚の画像データ)実際に訓練してみました。画像データが少ないのでそこまでうまくは行かないと思いますがこれでどの程度まで行くのか見てみます。

使ったパラメータは

  • 学習レート:0.0001
  • レーニングデータ:90%
  • バッチサイズ:20
  • L2正則化: 0.05
  • エポック数:100

のようになっています。結果を見るために適当なネコ画像を拾ってきて確認してみると下のようになっています。
f:id:pyhaya:20190512124258p:plain

猫の背中はよく認識できいますが、その他の部分はまだまだです。人間の視点から見ると猫といったら耳だろという感じですが、このモデルからしたら背中の方が認識しやすいようです。最もこれはこのモデルで判別するのが背景か猫の2択だけであるということも関係している可能性があります。つまり分類対象に犬などを入れたら状況は全然変わってくるでしょう(背中だけ見て犬猫を分類しろと言われたら難しい気がします)。

また、ここには載せていませんが、ロスを見ると完全に過学習しているような振る舞いをしておりやはりデータ数が足りないというのがネックになっています。今後はデータを増やすか水増しするかしていく予定です。

*1:畳み込みの逆の操作のようなもの(数学的な逆演算ではない)、日本語訳がわからない

Dockerを動かしてみる

Dockerを勉強し始めたので、学習記録としてまとめておきます。内容は基本的に
knowledge.sakura.ad.jp
のDocker入門で勉強したものを基礎としており、自分が引っかかったとことを付け足して書いています。

環境

Dockerのインストール

この記事を参考にしてインストールした。

qiita.com

Dockerイメージをダウンロードしてみる

まず最初に、軽量なWebサーバとして有名なNginxのDockerイメージを入れてみました。
hub.docker.com
まず、DockerHubのアカウントを作成しました。作成できたら、ここのIDとパスワードを使ってdockerでログインします。ログインをすることで、Docker Hubからイメージをpullすることができるようになります。(ログインしていない状態でpullしようとすると権限がありません、と怒られます。)

$ docker login
Authenticating with existing credentials...
Stored credentials invalid or expired
Login with your Docker ID to push and pull images from Docker Hub. If you don't have a Docker ID, head over to https://hub.docker.com to create one.
Username:      
Password: 

では、次にNginxのリポジトリをpullしてきます。

$ docker pull nginx
Using default tag: latest
latest: Pulling from library/nginx
27833a3ba0a5: Pull complete 
ea005e36e544: Pull complete 
d172c7f0578d: Pull complete 
Digest: sha256:e71b1bf4281f25533cf15e6e5f9be4dac74d2328152edf7ecde23abc54e16c1c
Status: Downloaded newer image for nginx:latest
$ docker images
REPOSITORY          TAG                 IMAGE ID            CREATED             SIZE
nginx               latest              27a188018e18        12 days ago         109MB

Nginxを動かしてみる

$ docker run -d --name nginx-container -p 8181:80 nginx

このようにすることで、Dockerイメージを動かすことができます。プロセスを確認すると、

$ docker ps -a
CONTAINER ID        IMAGE               COMMAND                  CREATED             STATUS              PORTS                  NAMES
21f11ea7bc28        nginx               "nginx -g 'daemon of…"   2 minutes ago       Up 2 minutes        0.0.0.0:8181->80/tcp   nginx-container

STATUSがUP、つまり現在動いていることがわかります。このとき、このプロセスが使っているのが、0.0.0.0:8181ポートであることがわかります。このポートにアクセスすれば「Welcome to nginx」のページが現れます。しかし、ファイアーウォールの設定でこのポートが閉じてしまっているときには接続がうまく行きませんので、開く必要があります。

$ sudo ufw status
状態:非アクティブ
$ sudo ufw enable    # ファイアウォールが有効になっていなかったので、有効にする
$ sudo ufw allow 8181
$ sudo ufw reload
$ sudo ufw status
状態: アクティブ

To                         Action      From
--                         ------      ----
8181                       ALLOW       Anywhere                           
8181 (v6)                  ALLOW       Anywhere (v6)  

これでアクセスすれば「Welcome to nginx」のページが現れます。

Dockerを停止する

$ docker stop nginx-container

Docker/Kubernetes 実践コンテナ開発入門

Docker/Kubernetes 実践コンテナ開発入門

ABC022-B Bumble Beeを解く

今回のエントリーはAtCoder Beginners Contestの過去問を扱います。今回扱うのは第22回のコンテストのB問題です。B問題にしては入力が大きく、計算量を意識するよい練習となります。

問題文

高橋君はマルハナバチ(Bumblebee)という種類のミツバチです。

今日も花の蜜を求めて異なるN個の花を訪れました。


高橋君がi番目に訪れた花の種類はA_iです。i 番目の花は、i>j かつi番目の花の種類とj番目の花の種類が同じになるようなjが存在すれば受粉します。

高橋君が訪れたN個の花の種類の情報が与えられるので、そのうちいくつの花が受粉したか求めてください。


なお、高橋君以外による受粉や自家受粉を考える必要はありません。


入力
入力は以下の形式で標準入力から与えられる

N
A1
A2
:
AN
  • 1行目には高橋君が訪れた花の個数を表す整数N(1\leq N\leq10^5)が与えられる。
  • 2行目からのN行のうちi行目にはi番目に高橋君が訪れた花の種類を表す整数A_i(1\leq A_i\leq10^5)が与えられる。


出力
受粉した花の個数を1行で出力せよ。出力の末尾にも改行を入れること。

この問題のリンクは
atcoder.jp

です。

考察

花を表す整数が「1, 2, 3, 2, 1」の場合には、4番目に現れる「2」と5番目に現れる「1」で受粉が起こります。なので、出力は2となります。

戦略1

この実験からすぐ思いつくのは、次のような戦略です。

入力を一つずつ受け取って、その番号がすでに一度でも出ていれば受粉する。これを1つずつ数えていけばよい

入力の個数はNなので、入力を一つ一つ受け取るのにNに比例した時間がかかります。加えて、入力一つ一つに対して既に出てきているかを調べる必要がありますが、これは最大でNのオーダーだとすると、全部で N^2のオーダーとなります。

N \leq 10^5なのでこれでは 10^{10}の時間がかかってしまうので、これでは制限時間に引っ掛かります。つまり、この戦略で行くには、数字がすでに出ているかを調べる操作を\log N以下の時間で済ませる方法を考える必要があります。

戦略2

もう1つ考えられる戦略が、計算をいくつかのパートに分割することです。具体的には、

  • 入力値を全部配列に入れる
  • 配列をソートする
  • それぞれの数字が何回出てきているか数える

という計算に分割します。配列をソートする操作はいろいろあり、ソートのための関数が言語に備え付けられている場合がほとんどです(計算量は速いものはN\log N)。ここでのネックはソートした配列にそれぞれの数字が何回出てきているかを調べる部分です。

その前に、「それぞれの数字が何回出てきているか数える」ことでなぜうまくいくか簡単に説明します。例えば、配列に「1」が2回出てきたとします。その時、2つの「1」のうち、片方は最初にもう片方はそのあとに出てきたはずです。後に出てきた「1」では受粉が起こるので、各数字が出てきている回数から1を引いた数をすべての数字に対して足せば、求める答えが出てくるはずです。

数えるのは、配列を最初から1回だけ見ていけばよいので(ソートされているため)、この部分の計算量は O(N)で済みそうです。

解答例(C++14)

戦略1

この戦略では、すでに出てきた数字をどのような形で保存するかが重要になります。vectorに入れてしまうと、新しい入力を受け取って、それがすでに出てきたか探索するためにO(N)の時間がかかってしまうので、上で述べたように間に合いません。

しかしmapを使えば、探索は O(\log N)なので間に合います。

#include <bits/stdc++.h>

using namespace std;

int main() {
	int n; cin >> n;
	map<int, int> m;
	int res = 0;    // 求める答え

	for (int i = 0; i < n; i++) {
		int a; cin >> a;
		if (m.find(a) != m.end()) {    // すでに出てきていればresを1増やす
			res++;
		}
		else {
			m[a] = 1;    // まだ出てきていなければmapに加える
		}
	}

	cout << res << endl;
}

戦略2

配列をソートしたあと、ループを回して各数字が何回出てきているか数えてもよいのですが、C++にはuniqueという、配列の重複要素を除いた要素を先頭に集めてくれる便利な関数があるので、これを使います。

#include <bits/stdc++.h>

using namespace std;

int main()
{
  int n;
  cin >> n;
  vector<int> v;

  for (int i = 0; i < n; i++)
  {
    int a;
    cin >> a;
    v.push_back(a);
  }

  sort(v.begin(), v.end());
  auto it = unique(v.begin(), v.end()); 
  v.erase(it, v.end());    // itから先は一度出た数字が並んでいるので、消去する。

  cout << n - v.size() << endl;
}

解答例(Python 3)

おまけとして、上のそれぞれの解法をPythonで実装したものも載せておきます。

戦略1

n = int(input())
m = dict()

res = 0
for i in range(n):
  a = int(input())
  if (a in m.keys()):
    res += 1
  else:
    m[a] = 1

print(res)

戦略2

n = int(input())
v = []
for i in range(n):
  a = int(input())
  v.append(a)

print(n - len(set(v)))

プログラミングコンテストチャレンジブック [第2版] ?問題解決のアルゴリズム活用力とコーディングテクニックを鍛える?

プログラミングコンテストチャレンジブック [第2版] ?問題解決のアルゴリズム活用力とコーディングテクニックを鍛える?

Tensorflowで犬猫画像分類する

最近Tensorflowを勉強していて、試しに定番の(?)犬猫の画像分類をしてみました。僕がやったことをまとめると

  • CNN
  • tf.kerasは使わない
  • TFRecordにデータを保存してそこからデータを引っ張り出してくる
  • もちろんBatch

こんな感じのことを書きます。なのでこの記事の位置づけは、画像解析手法を書くというよりかはTensorflowの使い方みたいな感じです。

環境

使ったデータ

データの出処

今回使ったデータはKaggleからダウンロードしました。
www.kaggle.com

データのディレクトリ関係

今回の分析でのディレクトリ構造は下のようになっています

.
├── log
└── training_set
|    ├── cats
|    └── dogs
└── tf_cnn.ipynb

catsディレクトリとdogsディレクトリにはそれぞれ約4000枚のJPEG画像が保存されています。中身を見てみると、必ずしも犬・猫のみが写ったものだけではなく、人も一緒に写っていたり、イラストだったりします。

TFRecordにデータを保存

今回はデータを一度TFRecordにバイナリ形式で保存します。まずは各画像ファイルへのパスとラベルをリストの中に収納します。ラベルは猫が1、犬が0となっています。

import numpy as np
import tensorflow as tf


cat_dir = './training_set/cats/'
dog_dir = './training_set/dogs/'

image_paths = []
labels = []

for fname in os.listdir(cat_dir):
    if '.jpg' in fname:
        image_paths.append(cat_dir + fname)
        labels.append(1)
        
for fname in os.listdir(dog_dir):
    if '.jpg' in fname:
        image_paths.append(dog_dir + fname)
        labels.append(0)

# シャッフルする
shuffle_ind = np.random.permutation(len(labels))
image_paths = np.array(image_paths)[shuffle_ind]
labels = np.array(labels)[shuffle_ind]

リストの後ろから1000個のファイルをテストデータとして切り分けてそれぞれ別ファイルに保存します。

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


from PIL import Image

# トレーニングデータの保存
with tf.python_io.TFRecordWriter('training_data.tfrecords') as writer:
    for fname, label in zip(image_paths[:-1000], labels[:-1000]):
        image = Image.open(fname)
        image_np = np.array(image)
        image_shape = image_np.shape
        image = open(fname, 'rb').read()

        feature = {
            'height' : _int64_feature(image_shape[0]),
            'width' : _int64_feature(image_shape[1]),
            'channel' : _int64_feature(image_shape[2]),
            'image_raw' : _bytes_feature(image),    # 画像はバイトとして保存する
            'label' : _int64_feature(label)
        }
        tf_example = tf.train.Example(features=tf.train.Features(feature=feature))
        writer.write(tf_example.SerializeToString())

# テストデータの保存
with tf.python_io.TFRecordWriter('test_data.tfrecords') as writer:
    for fname, label in zip(image_paths[-1000:], labels[-1000:]):
        image = Image.open(fname)
        image_np = np.array(image)
        image_shape = image_np.shape
        image = open(fname, 'rb').read()

        feature = {
            'height' : _int64_feature(image_shape[0]),
            'width' : _int64_feature(image_shape[1]),
            'channel' : _int64_feature(image_shape[2]),
            'image_raw' : _bytes_feature(image),
            'label' : _int64_feature(label)
        }
        tf_example = tf.train.Example(features=tf.train.Features(feature=feature))
        writer.write(tf_example.SerializeToString())

これでデータの保存はできました。

CNNによるモデルの構築

次に、画像の分類に用いるモデルをCNNで構築します。

tf.reset_default_graph()

X = tf.placeholder(tf.float32, shape=[None, 150, 150, 3])
y = tf.placeholder(tf.int32, shape=[None])

with tf.name_scope('layer1'):
    conv1 = tf.layers.conv2d(X, filters=32, kernel_size=4, strides=1, activation=tf.nn.relu, name='conv1')
    pool1 = tf.nn.max_pool(conv1, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID', name='pool1')
    
with tf.name_scope('layer2'):
    conv2 = tf.layers.conv2d(pool1, filters=64, kernel_size=3, strides=1, activation=tf.nn.relu, name='conv2')
    pool2 = tf.nn.max_pool(conv2, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID', name='pool2')
    
with tf.name_scope('layer3'):
    conv3 = tf.layers.conv2d(pool2, filters=128, kernel_size=3, strides=1, activation=tf.nn.relu, name='conv3')
    pool3 = tf.nn.max_pool(conv3, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID', name='pool3')
    
with tf.name_scope('dense'):
    flatten = tf.reshape(pool3, shape=[-1, 32768], name='flatten')
    dense1 = tf.layers.dense(flatten, 512, activation=tf.nn.relu, name='dense1')
    dense2 = tf.layers.dense(dense1, 2, activation=None, name='dense2')
    output = tf.nn.softmax(dense2, name='output')
    
with tf.name_scope('train'):
    xentropy = tf.losses.sparse_softmax_cross_entropy(logits=dense2, labels=y)
    loss = tf.reduce_mean(xentropy)
    optimizer = tf.train.AdamOptimizer()
    training_op = optimizer.minimize(loss)

with tf.name_scope('eval'):
    correct = tf.nn.in_top_k(dense2, y, 1)
    acc = tf.reduce_mean(tf.cast(correct, tf.float32))
    
with tf.name_scope('save'):
    train_acc = tf.summary.scalar('train_acc', acc)
    valid_acc = tf.summary.scalar('valid_acc', acc)
    file_writer = tf.summary.FileWriter('./log/190401/', tf.get_default_graph())
    saver = tf.train.Saver()

入力画像の形状は(150, 150, 3)で、カラー画像です。3つの畳み込み層と3つのプーリング層を重ね、最後に全結合層で長さ2のベクトルを出力しています。出力ベクトルのそれぞれの要素は、画像が犬である確率と猫である確率をそれぞれ表しています。

訓練はAdamを用いて行います。そして訓練の途中結果とパラメータを保存するためにFileWriterとSaverを用意します。

訓練する

では実際に訓練をします。まず、TFRecordからデータを取り出すための準備をしておきます。

image_feature_description = {
    'height' : tf.FixedLenFeature([], tf.int64),
    'width' : tf.FixedLenFeature([], tf.int64),
    'channel' : tf.FixedLenFeature([], tf.int64),
    'image_raw' : tf.FixedLenFeature([], tf.string),
    'label' : tf.FixedLenFeature([], tf.int64),
}

def _parse_fun(example_proto):
    feature = tf.parse_single_example(example_proto, image_feature_description)
    feature['image_raw'] = tf.image.decode_jpeg(feature['image_raw'])
    feature['image_raw'] = tf.cast(feature['image_raw'], tf.float32) / 255.0    #floatにキャストしてから255で割って正規化
    feature['image_raw'] = tf.image.resize_images(feature['image_raw'], (150, 150))    #150x150にリサイズ
    
    feature['label'] = tf.cast(feature['label'], tf.int32)
    
    return feature

では実際に訓練します。

epochs = 31
batch_size = 500

with tf.Session() as sess:
    raw_image_dataset = tf.data.TFRecordDataset('training_data.tfrecords')
    test_dataset = tf.data.TFRecordDataset('test_data.tfrecords')
    
    parsed_image_dataset = raw_image_dataset.map(_parse_fun)
    test_dataset = test_dataset.map(_parse_fun).batch(100)
    batched_dataset = parsed_image_dataset.batch(batch_size)
    
    init = tf.global_variables_initializer()
    init.run()

    for epoch in range(epochs):
        iterator = batched_dataset.make_one_shot_iterator()
        test_iter = test_dataset.make_one_shot_iterator()
        while True:
            try:
                batched = iterator.get_next()
                batched_eval = sess.run(batched)
                X_batch = batched_eval['image_raw']
                y_batch = batched_eval['label']
                sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
            except tf.errors.OutOfRangeError:
                break
        
        if epoch % 5 == 0:
            print(f"finished epoch #{epoch}")
            test_data = test_iter.get_next()
            test_data_eval = sess.run(test_data)
            X_test = test_data_eval['image_raw']
            y_test = test_data_eval['label']
            train_acc_str = train_acc.eval(feed_dict={X: X_batch, y: y_batch})
            valid_acc_str = valid_acc.eval(feed_dict={X: X_test, y: y_test})
            file_writer.add_summary(train_acc_str, epoch)
            file_writer.add_summary(valid_acc_str, epoch)
            save_path = saver.save(sess, './log/190401/model_ckpt_{}.ckpt'.format(epoch))

file_writer.close()

訓練ではバッチサイズを500として30エポック訓練します。訓練では、イテレータを作っておいてバッチサイズずつデータを取り出して訓練します。イテレータで次のデータを取り出せなくなったらtf.error.OutOfRangeErrorを送出するので、それを受け取ってwhileループを抜けます。エポック数が5の倍数のときに訓練データとテストデータでの精度とモデルの重みパラメータを保存します。

訓練結果をTensorboardで見てみる

訓練したら、その結果をTensorboardで見てみます。シェルから

tensorboard --logdir=./log/190401/

と入力すると、アドレスが出てくるので、そこにアクセスすると、下のような画面が出てきます。

f:id:pyhaya:20190410212610p:plain

左側が訓練データの精度で右側がテストデータの精度です。これを見ると訓練データでは精度がほとんど1になっているのに対してテストデータでは精度が0.68にしかなっておらず、過学習していることがわかります。これを解消するには、モデルのパラメータを減らす方法やドロップアウトなどの正則化をかける方法、そして画像の水増しなどの方法があります。これらの方法についてはまた別の場所で書きます。

scikit-learnとTensorFlowによる実践機械学習

scikit-learnとTensorFlowによる実践機械学習