pyhaya’s diary

機械学習系の記事をメインで書きます

Tensorflowを使ってUNetを試す Version 2

UNetを構築してみようVersion 2です。

Version2?

実は過去にTensorflowでUNetを書いた!という記事を書いています。。。
pyhaya.hatenablog.com

なぜ同じ内容をもう一回書くのかというと、Tensorflowのバージョンアップに伴って上のコードが割と根本的なところで動かなくなることが確定しているから、全部書き直そう!となったからです。(ツライ)

Tensorflowの変更点

Tensorflowは現在、2.0のベータ版がGithubで公開されています。メジャーバージョンが変わるのでそれなりに大きな変更であることは覚悟していたのですが、Googleは私の想像を超えてきました。

以下のサイトが変更点を詳しく紹介してくださっています。

qiita.com

ざっくりまとめると

  • tf.placeholderなくなる
  • tf.Session()なくなる
  • tf.global_variable_initializer()なくなる


、、、はい。2系が主流になると過去記事の私のコードはすべての行でエラーが出るんじゃないかというレベルで壊れます。

なのでこれから主流になるtf.kerasを使ってコードを書き直します。

モデル

UNetがどんなネットワークかは前の記事を見ていただくとして、モデルを書きます。

from typing import Optional
import argparse
import tensorflow as tf


class conv_set:
    def __init__(self, filters: int):
        self.filters = filters

    def __call__(self, inputs: tf.Tensor) -> tf.Tensor:
        y = tf.keras.layers.Conv2D(
            self.filters, kernel_size=3, padding="SAME", activation="relu"
        )(inputs)
        y = tf.keras.layers.Conv2D(
            self.filters, kernel_size=3, padding="SAME", activation="relu"
        )(y)
        y = tf.keras.layers.BatchNormalization()(y)
        return y


class upsampling:
    def __init__(self, filters: int, cut: Optional[int] = 0):
        self.filters = filters
        self.cut = cut

    def __call__(self, inputs: tf.Tensor) -> tf.Tensor:
        upconv = tf.keras.layers.Conv2DTranspose(
            self.filters, kernel_size=2, strides=2
        )(inputs[0])

        conv_crop = tf.keras.layers.Cropping2D(self.cut)(inputs[1])
        concat = tf.keras.layers.concatenate([conv_crop, upconv])

        return concat


def UNet(args: "argparse.Namespace") -> tf.keras.Model:
    n_classes: int = args.n_classes
    decay: float = args.l2

    x = tf.keras.Input(shape=(224, 224, 3))

    # down sampling
    conv1 = conv_set(64)(x)
    max_pool1 = tf.keras.layers.MaxPool2D()(conv1)
    conv2 = conv_set(128)(max_pool1)
    max_pool2 = tf.keras.layers.MaxPool2D()(conv2)
    conv3 = conv_set(256)(max_pool2)
    max_pool3 = tf.keras.layers.MaxPool2D()(conv3)
    conv4 = conv_set(512)(max_pool3)
    max_pool4 = tf.keras.layers.MaxPool2D()(conv4)
    conv5 = conv_set(1024)(max_pool4)

    # up sampling
    concat1 = upsampling(512)([conv5, conv4])
    conv6 = conv_set(512)(concat1)
    concat2 = upsampling(256)([conv6, conv3])
    conv7 = conv_set(256)(concat2)
    concat3 = upsampling(128)([conv7, conv2])
    conv8 = conv_set(128)(concat3)
    concat4 = upsampling(64)([conv8, conv1])
    conv9 = conv_set(64)(concat4)

    output = tf.keras.layers.Conv2D(filters=n_classes, kernel_size=1)(conv9)
    output = tf.keras.layers.Softmax()(output)

    model = tf.keras.Model(inputs=x, outputs=output)
    for layer in model.layers:
        if "kernel_regularizer" in layer.__dict__:
            layer.kernel_regularizer = tf.keras.regularizers.l2(decay)

    if args.weights != "":
        model.load_weights(args.weights)

    return model

生のTensorflowを使ったときと比べて、かなりスッキリとモデルを記述できていることがわかります。UNetには同じような繰り返しが存在している(畳み込みx2 + BNやupsampling)のでそれはクラスを定義してまとめてしまっています。これはカスタムレイヤーを定義しているわけではなく、既存のレイヤーを組み合わせているだけなので、tf.keras.layers.Layerを継承してcallをoverrideすることはせず、単にクラスを作ってcallableになるように__call__メソッドを定義していることに注意してください。

逆にLayerクラスを継承する形で定義してしまうと、訓練時にこのクラスが持つ重みパラメータが訓練可能パラメータとして認識されません。

訓練した結果

動作確認のために、試しにネコ画像で訓練してみました。

  • learning rate: 0.001
  • batch size: 4
  • epoch: 100

訓練後に適当にネコ画像をネットから取ってきて入れてみたら下のようになり、そこそこ行けている感じです。

f:id:pyhaya:20190818193242p:plain

どれくらいうまく行けているかはまだ定量的には測れていません(validation data 増やすのキツイ、、、)。

リンク(GitHub)

今回紹介したモデルはGitHubに上げています。

github.com