pyhaya’s diary

プログラミング、特にPythonについての記事を書きます。Djangoや機械学習などホットな話題をわかりやすく説明していきたいと思います。

TensorFlowでUNetを構築する

この記事では、Tensorflowを使ってUNetを構築し、最終的には画像から猫を認識するように訓練するやり方を紹介します。(この記事で紹介しているコードはTensorflow2系では動作しません。2系でも動くコードは別記事にしたので良かったら読んでください
Tensorflowを使ってUNetを試す Version 2 - pyhaya’s diary

環境

  • Ubuntu 18.04
  • Python 3.6.8
  • Tensorflow 1.12.0
  • GeForce RTX2080

UNetとは何か

セマンティック・セグメンテーション(semantic segmentation)

UNetというのは機械学習モデルの名前で、セマンティク・セグメンテーションを行うために使われます。セマンティック・セグメンテーションというのは、画像をピクセル単位でいくつかのクラスに分類する画像処理の手法です。例えば下のような人と馬の画像を処理すると右のように人(薄ピンク)と馬(ピンク)、そして背景(黒)をピクセル単位で分類します。

f:id:pyhaya:20190502214459p:plain
https://nicolovaligi.com/deep-learning-models-semantic-segmentation.html から引用
似たようなタスクとして、画像から物体を認識する場合に物体があると考えられる領域を長方形で認識して表示するものがあります(下図)。これと比較するとセマンティック・セグメンテーションではより高度な処理を行っていることがわかります。
f:id:pyhaya:20190502215611j:plain
https://gigazine.net/news/20140920-revolutionary-machine-vision/ から引用

UNetの構造

UNetは2015年にドイツの大学の研究グループが発表したネットワークです。名前の由来はそのネットワークの形状で、U字型をしているためにこのように呼ばれています。

f:id:pyhaya:20190502225111p:plain
論文から引用

arxiv.org

UNetは基本的には畳み込みニューラルネットワーク(CNN)で、その特徴は大きく分けて次のような2種類の処理に分けて考えることができます。

1. ダウンサンプリング
畳み込みでfeature mapを倍にしながらmax poolingで画像サイズを小さくしていく
2. アップサンプリング
transpose convolution*1で画像サイズをもとに戻していく。このときダウンサンプリング中のデータを加えながら処理を進めていく(図の灰色矢印)

論文では入力画像は大きさが572x572でチャネル数が1なのでグレースケールの画像になっています。そして最終的には出力が388x388でチャネル数が2になっています。これは少し説明が必要で、論文ではゼロパディングをしていないので、出力が入力よりも小さくなります。まず、この論文では一つの画像を一度にセグメンテーションするのではなく、いくつかの388x388の領域に分割し、最後に出力結果をつなぎ合わせて最終的なセグメンテーション結果とします。そして388xx388の大きさの領域をセグメンテーションするためにその周りを含めて572x572の領域をネットワークに入れます。処理したい画像領域が元の画像の端で、572x572に拡大できないときには、足りない部分を元の画像の端を鏡面とした鏡映操作をして補います。

f:id:pyhaya:20190502231618p:plain
論文より引用

出力のチャネル数2は判別するクラスの数によります。この場合には判別するクラスが2つとなっているため出力のチャネル数が2になっています。この2つのクラスをクラス1, クラス2と書くことにすると、第一チャネルはクラス1に分類される部分だけ1でほかはゼロ、そして第二チャネルはクラス2に分類されるピクセルだけ1でほかはゼロというようなOne-Hot表現になっています。

UNetのPython(Tensorflow)での実装

では、Tensorflowを使ってUNetを実装してみます。実装では
github.com
を参考にさせていただきました。

UNet本体

UNetの本体はTensorflowで書くと下のようにかけます。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

import main

class UNet:
  def __init__(self, classes):
    self.IMAGE_DIR = './dataset/raw_images'
    self.SEGMENTED_DIR = './dataset/segmented_images'
    self.VALIDATION_DIR = './dataset/validation'
    self.classes = classes
    self.X = tf.placeholder(tf.float32, [None, 128, 128, 3]) 
    self.y = tf.placeholder(tf.int16, [None, 128, 128, self.classes])
    self.is_training = tf.placeholder(tf.bool)

  @staticmethod
  def conv2d(
    inputs, filters, kernel_size=3, activation=tf.nn.relu, l2_reg=None, 
    momentum=0.9, epsilon=0.001, is_training=False,
    ):
    """
    convolutional layer. If the l2_reg is a float number, L2 regularization is imposed.
    
    Parameters
    ----------
      inputs: tf.Tensor
      filters: Non-zero positive integer
        The number of the filter 
      activation: 
        The activation function. The default is tf.nn.relu
      l2_reg: None or float
        The strengthen of the L2 regularization
      is_training: tf.bool
        The default is False. If True, the batch normalization layer is added.
      momentum: float
        The hyper parameter of the batch normalization layer
      epsilon: float
        The hyper parameter of the batch normalization layer

    Returns
    -------
      layer: tf.Tensor
    """
    regularizer = tf.contrib.layers.l2_regularizer(scale=l2_reg) if l2_reg is not None else None
    layer = tf.layers.conv2d(
      inputs=inputs,
      filters=filters,
      kernel_size=kernel_size,
      padding='SAME',
      activation=activation,
      kernel_regularizer=regularizer
    )

    if is_training is not None:
      layer = tf.layers.batch_normalization(
        inputs=layer,
        axis=-1,
        momentum=momentum,
        epsilon=epsilon,
        center=True,
        scale=True,
        training=is_training
      )

    return layer

  @staticmethod
  def trans_conv(inputs, filters, activation=tf.nn.relu, kernel_size=2, strides=2, l2_reg=None):
    """
    transposed convolution layer.

    Parameters
    ---------- 
      inputs: tf.Tensor
      filters: int 
        the number of the filter
      activation: 
        the activation function. The default function is the ReLu.
      kernel_size: int
        the kernel size. Default = 2
      strides: int
        strides. Default = 2
      l2_reg: None or float 
        the strengthen of the L2 regularization.

    Returns
    -------
      layer: tf.Tensor
    """
    regularizer = tf.contrib.layers.l2_regularizer(scale=l2_reg) if l2_reg is not None else None

    layer = tf.layers.conv2d_transpose(
      inputs=inputs,
      filters=filters,
      kernel_size=kernel_size,
      strides=strides,
      kernel_regularizer=regularizer
    )

    return layer

  @staticmethod
  def pooling(inputs):
    return tf.layers.max_pooling2d(inputs=inputs, pool_size=2, strides=2)


  def UNet(self, is_training, l2_reg=None):
    """
    UNet structure.

    Parameters
    ----------
      l2_reg: None or float
        The strengthen of the L2 regularization.
      is_training: tf.bool
        Whether the session is for training or validation.

    Returns
    -------
      outputs: tf.Tensor
    """
    conv1_1 = self.conv2d(self.X, filters=64, l2_reg=l2_reg, is_training=is_training)
    conv1_2 = self.conv2d(conv1_1, filters=64, l2_reg=l2_reg, is_training=is_training)
    pool1 = self.pooling(conv1_2)

    conv2_1 = self.conv2d(pool1, filters=128, l2_reg=l2_reg, is_training=is_training)
    conv2_2 = self.conv2d(conv2_1, filters=128, l2_reg=l2_reg, is_training=is_training)
    pool2 = self.pooling(conv2_2)

    conv3_1 = self.conv2d(pool2, filters=256, l2_reg=l2_reg, is_training=is_training)
    conv3_2 = self.conv2d(conv3_1, filters=256, l2_reg=l2_reg, is_training=is_training)
    pool3 = self.pooling(conv3_2)

    conv4_1 = self.conv2d(pool3, filters=512, l2_reg=l2_reg, is_training=is_training)
    conv4_2 = self.conv2d(conv4_1, filters=512, l2_reg=l2_reg, is_training=is_training)
    pool4 = self.pooling(conv4_2)

    conv5_1 = self.conv2d(pool4, filters=1024, l2_reg=l2_reg)
    conv5_2 = self.conv2d(conv5_1, filters=1024, l2_reg=l2_reg)
    concat1 = tf.concat([conv4_2, self.trans_conv(conv5_2, filters=512, l2_reg=l2_reg)], axis=3)

    conv6_1 = self.conv2d(concat1, filters=512, l2_reg=l2_reg)
    conv6_2 = self.conv2d(conv6_1, filters=512, l2_reg=l2_reg)
    concat2 = tf.concat([conv3_2, self.trans_conv(conv6_2, filters=256, l2_reg=l2_reg)], axis=3)

    conv7_1 = self.conv2d(concat2, filters=256, l2_reg=l2_reg)
    conv7_2 = self.conv2d(conv7_1, filters=256, l2_reg=l2_reg)
    concat3 = tf.concat([conv2_2, self.trans_conv(conv7_2, filters=128, l2_reg=l2_reg)], axis=3)

    conv8_1 = self.conv2d(concat3, filters=128, l2_reg=l2_reg)
    conv8_2 = self.conv2d(conv8_1, filters=128, l2_reg=l2_reg)
    concat4 = tf.concat([conv1_2, self.trans_conv(conv8_2, filters=64, l2_reg=l2_reg)], axis=3)

    conv9_1 = self.conv2d(concat4, filters=64, l2_reg=l2_reg)
    conv9_2 = self.conv2d(conv9_1, filters=64, l2_reg=l2_reg)
    outputs = self.conv2d(conv9_2, filters=self.classes, kernel_size=1, activation=None)

    return outputs

  def train(self, parser):
    """
    training operation
    argument of this function are given by functions in main.py

    Parameters
    ----------
      parser: 
        the paser that has some options
    """
    epoch = parser.epoch
    l2 = parser.l2
    batch_size = parser.batch_size
    train_val_rate = parser.train_rate

    output = self.UNet(l2_reg=l2, is_training=self.is_training)
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=self.y, logits=output))
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
      train_ops = tf.train.AdamOptimizer(parser.learning_rate).minimize(loss)

    init = tf.global_variables_initializer()
    saver = tf.train.Saver(max_to_keep=100)
    all_train, all_val = main.load_data(self.IMAGE_DIR, self.SEGMENTED_DIR, n_class=2, train_val_rate=train_val_rate)
    with tf.Session() as sess:
      init.run()
      for e in range(epoch):
        data = main.generate_data(*all_train, batch_size)
        val_data = main.generate_data(*all_val, len(all_val[0]))
        for Input, Teacher in data:
          sess.run(train_ops, feed_dict={self.X: Input, self.y: Teacher, self.is_training: True})
          ls = loss.eval(feed_dict={self.X: Input, self.y: Teacher, self.is_training: None})
          for val_Input, val_Teacher in val_data:
            val_loss = loss.eval(feed_dict={self.X: val_Input, self.y: val_Teacher, self.is_training: None})

        print(f'epoch #{e + 1}, loss = {ls}, val loss = {val_loss}')
        if e % 100 == 0:
          saver.save(sess, f"./params/model_{e + 1}epochs.ckpt")

      self.validation(sess, output)

  def validation(self, sess, output):
    val_image = main.load_data(self.VALIDATION_DIR, '', n_class=2, train_val_rate=1)[0]
    data = main.generate_data(*val_image, batch_size=1)
    for Input, _ in data:
      result = sess.run(output, feed_dict={self.X: Input, self.is_training: None}) 
      break
    
    result = np.argmax(result[0], axis=2)
    ident = np.identity(3, dtype=np.int8)
    result = ident[result]*255

    plt.imshow((Input[0]*255).astype(np.int16))
    plt.imshow(result, alpha=0.2)
    plt.show()

長いけれども、やっていることは大したことなく、クラス内部に畳み込み層、transpose convolution層、プーリング層をメソッドとして定義しておいてUNetメソッドで本体を定義しています。

論文とこの実装は違っているところもあります。

trainメソッドで実際の学習を実行します。

学習データの作成

次に学習データをUNetに流し込む部分を書かなければいけませんが、その前に、学習データを作成する必要があります。画像のセグメンテーションにはlabelmeというフリーソフトを使いました。
github.com

GitHubに書いてあるインストール方法でインストールし、セグメンテーションしました。

f:id:pyhaya:20190504195754p:plainf:id:pyhaya:20190504195731p:plain
公開されているようなセマンティック・セグメンテーションのデータセットに比べると雑ですが、これで試してみます。

訓練してみる

上の要領で作ったデータを使って(76枚の画像データ)実際に訓練してみました。画像データが少ないのでそこまでうまくは行かないと思いますがこれでどの程度まで行くのか見てみます。

使ったパラメータは

  • 学習レート:0.0001
  • レーニングデータ:90%
  • バッチサイズ:20
  • L2正則化: 0.05
  • エポック数:100

のようになっています。結果を見るために適当なネコ画像を拾ってきて確認してみると下のようになっています。
f:id:pyhaya:20190512124258p:plain

猫の背中はよく認識できいますが、その他の部分はまだまだです。人間の視点から見ると猫といったら耳だろという感じですが、このモデルからしたら背中の方が認識しやすいようです。最もこれはこのモデルで判別するのが背景か猫の2択だけであるということも関係している可能性があります。つまり分類対象に犬などを入れたら状況は全然変わってくるでしょう(背中だけ見て犬猫を分類しろと言われたら難しい気がします)。

また、ここには載せていませんが、ロスを見ると完全に過学習しているような振る舞いをしておりやはりデータ数が足りないというのがネックになっています。今後はデータを増やすか水増しするかしていく予定です。

*1:畳み込みの逆の操作のようなもの(数学的な逆演算ではない)、日本語訳がわからない