pyhaya’s diary

機械学習系の記事をメインで書きます

SRGANで画像の高解像度化

今回はGANを使った初期の画像の高解像度化モデルであるSRGANを実装してみたので紹介したいと思います。

実行環境

今回のモデルは以下のような環境で実装しています。

ソースコードは以下に公開しています。
github.com


検証に用いたデータセットはDIV2Kで以下からダウンロードしました。
data.vision.ee.ethz.ch

訓練に用いるために、各画像から20枚ずつ32x32のランダムクロップを行いました。

SRGANとは

SRGANとは、その名前にもあるようにGAN(敵対的生成ネットワーク)を使って、入力画像に対して解像度が上がった画像を生成しようという目的で作られたモデルです。

f:id:pyhaya:20210822222355p:plain
https://arxiv.org/pdf/1609.04802.pdf から引用

モデルの実装

モデルはTensorflowのKeras APIを使って実装します。

import tensorflow as tf
from tensorflow.keras.layers import (
    Conv2D,
    Dense,
    PReLU,
    BatchNormalization,
    LeakyReLU,
    Flatten,
)
from tensorflow.keras import Sequential, Model


class BResidualBlock(Model):
    def __init__(self):
        super(BResidualBlock, self).__init__()

        self.conv = Conv2D(filters=64, kernel_size=3, strides=1, padding="same")
        self.bn = BatchNormalization(momentum=0.8)
        self.prelu = PReLU(shared_axes=[1, 2])

        self.conv2 = Conv2D(filters=64, kernel_size=3, strides=1, padding="same")
        self.bn2 = BatchNormalization(momentum=0.8)

    def call(self, input_tensor, training=True):
        x = self.conv(input_tensor)
        x = self.bn(x, training=training)
        x = self.prelu(x)

        x = self.conv2(x)
        x = self.bn2(x, training=training)
        x += input_tensor

        return x


class ResidualBlock(Model):
    def __init__(self):
        super(ResidualBlock, self).__init__()

        self.residual1 = BResidualBlock()
        self.residual2 = BResidualBlock()
        self.residual3 = BResidualBlock()
        self.residual4 = BResidualBlock()
        self.residual5 = BResidualBlock()

        self.conv = Conv2D(filters=64, kernel_size=3, padding="same")
        self.bn = BatchNormalization(momentum=0.8)

    def call(self, input_tensor, training=True):
        x = self.residual1(input_tensor)
        x = self.residual2(x, training=training)
        x = self.residual3(x, training=training)
        x = self.residual4(x, training=training)
        x = self.residual5(x, training=training)

        x = self.conv(x)
        x = self.bn(x)

        x += input_tensor

        return x


class DiscriminatorBlock(Model):
    def __init__(self, filters=128):
        super(DiscriminatorBlock, self).__init__()
        self.filters = filters

        self.conv1 = Conv2D(filters=filters, kernel_size=3, strides=1, padding="same")
        self.bn1 = BatchNormalization(momentum=0.8)
        self.lrelu1 = LeakyReLU(alpha=0.2)
        self.conv2 = Conv2D(filters=filters, kernel_size=3, strides=2, padding="same")
        self.bn2 = BatchNormalization(momentum=0.8)
        self.lrelu2 = LeakyReLU(alpha=0.2)

    def call(self, input_tensor, training=True):
        x = self.conv1(input_tensor)
        x = self.bn1(x, training=training)
        x = self.lrelu1(x)
        x = self.conv2(x)
        x = self.bn2(x, training=training)
        x = self.lrelu2(x)

        return x


class PixelShuffler(tf.keras.layers.Layer):
    def __init__(self):
        super(PixelShuffler, self).__init__()

    def call(self, input_tensor):
        x = tf.nn.depth_to_space(input_tensor, 2)

        return x


def make_generator():
    model = Sequential(
        [
            Conv2D(filters=64, kernel_size=9, padding="same"),
            PReLU(shared_axes=[1, 2]),
            ResidualBlock(),
            Conv2D(filters=256, kernel_size=3, padding="same"),
            PixelShuffler(),
            PReLU(shared_axes=[1, 2]),
            Conv2D(filters=256, kernel_size=3, padding="same"),
            PixelShuffler(),
            PReLU(shared_axes=[1, 2]),
            Conv2D(filters=3, kernel_size=9, padding="same"),
        ]
    )

    return model


def make_discriminator():
    model = Sequential(
        [
            Conv2D(filters=64, kernel_size=3, padding="same"),
            LeakyReLU(alpha=0.2),
            Conv2D(filters=64, kernel_size=3, strides=2, padding="same"),
            BatchNormalization(momentum=0.8),
            LeakyReLU(alpha=0.2),
            DiscriminatorBlock(128),
            DiscriminatorBlock(256),
            DiscriminatorBlock(512),
            Flatten(),
            Dense(1024),
            LeakyReLU(alpha=0.2),
            Dense(1, activation="sigmoid"),
        ]
    )

    return model

def make_vgg(height, width):
    vgg = tf.keras.applications.vgg19.VGG19(include_top=False, weights="imagenet")
    partial_vgg = tf.keras.Model(
        inputs=vgg.input, outputs=vgg.get_layer("block5_conv4").output
    )
    partial_vgg.trainable = False
    partial_vgg.build(input_shape=(None, height * 4, width * 4, 3))

    return partial_vgg

レーニングの実行

Generatorのトレーニン

このモデルでは、GeneratorとDiscriminatorを同時に訓練するのではなく、先にGeneratorを訓練しておきます。これによりGeneratorが "local optima" にハマるのを防ぎます。Generatorの訓練では損失関数にMeanSquaredErrorを使います。

class SRResNetTrainer:
    def __init__(
        self,
        epochs: int = 10000,
        batch_size: int = 32,
        learning_rate: float = 1e-4,
        training_data_path: str = "./datasets/train.tfrecords",
        validate_data_path: str = "./datasets/valid.tfrecords",
        height: int = 32,
        width: int = 32,
        g_weight: str = None,
        checkpoint_path: str = "./checkpoint",
        best_generator_loss: float = 1e9,
    ):
        self.epochs = epochs
        self.batch_size = batch_size

        self.generator = make_generator()
        if g_weight is not None and g_weight != "":
            print("Loading weights on generator...")
            self.generator.load_weights(g_weight)

        self.train_data, self.validate_data = prepare_from_tfrecords(
            train_data=training_data_path,
            validate_data=validate_data_path,
            height=height,
            width=width,
            batch_size=batch_size,
        )
        self.mse_loss = tf.keras.losses.MeanSquaredError()
        self.best_generator_loss = best_generator_loss
        self.generator_optimizer = tf.keras.optimizers.Adam(learning_rate)

        self.checkpoint_path = checkpoint_path
        self.make_checkpoint = len(checkpoint_path) > 0

        current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        train_log_dir = "./logs/" + current_time + "/train_generator"
        valid_log_dir = "./logs/" + current_time + "/valid_generator"
        self.train_summary_writer = tf.summary.create_file_writer(train_log_dir)
        self.valid_summary_writer = tf.summary.create_file_writer(valid_log_dir)

    @tf.function
    def train_step(self, lr: tf.Tensor, hr: tf.Tensor):
        with tf.GradientTape() as tape:
            generated_fake = self.generator(lr)
            g_loss = self.mse_loss(generated_fake, hr)

        generator_grad = tape.gradient(g_loss, self.generator.trainable_variables)
        self.generator_optimizer.apply_gradients(
            grads_and_vars=zip(generator_grad, self.generator.trainable_variables)
        )

        return g_loss

    @tf.function
    def validation_step(self, lr: tf.Tensor, hr: tf.Tensor):
        generated_fake = self.generator(lr)
        g_loss = self.mse_loss(generated_fake, hr)

        return g_loss

    def train(self, start_epoch=0):
        for step in range(start_epoch, self.epochs):
            g_loss_train = []
            for images in tqdm(self.train_data):
                g_loss = self.train_step(images["low"], images["high"])
                g_loss_train.append(g_loss.numpy())

            g_loss_train_mean = np.mean(g_loss_train)

            with self.train_summary_writer.as_default():
                tf.summary.scalar("g_loss", g_loss_train_mean, step=step)

            print(
                f"Epoch {step+ 1}| Generator-Loss: {g_loss_train_mean:.3e},",
            )

            g_loss_valid = []
            for images in tqdm(self.validate_data):
                g_loss = self.validation_step(images["low"], images["high"])
                g_loss_valid.append(g_loss)

            g_loss_valid_mean = np.mean(g_loss_valid)

            with self.valid_summary_writer.as_default():
                tf.summary.scalar("g_loss", g_loss_valid_mean, step=step)

            print(
                f"Validation| Generator-Loss: {g_loss_valid_mean:.3e},",
            )

            if self.make_checkpoint:
                self.generator.save_weights(f"{self.checkpoint_path}/generator_last")

                if g_loss_valid_mean < self.best_generator_loss:
                    self.best_generator_loss = g_loss_valid_mean
                    self.generator.save_weights(
                        f"{self.checkpoint_path}/generator_best"
                    )

                    print("Model Saved")
GANの訓練

Generatorを訓練したら、その重みを使ってGANの訓練を行います。先ほどとは違い、Generatorの損失関数にはBinaryCrossentropyとcontent lossと呼ばれるものを使っています。このロスでは、正解の高解像度画像とGeneratorの生成した画像をそれぞれVGG19に入れ、その中間層の出力を比べたときの二乗誤差を計算します。

class SRGANTrainer:
    def __init__(
        self,
        epochs: int = 100,
        batch_size: int = 16,
        learning_rate: float = 1e-4,
        height: int = 32,
        width: int = 32,
        g_weight: str = None,
        d_weight: str = None,
        training_data_path: str = "./datasets/train.tfrecords",
        validate_data_path: str = "./datasets/valid.tfrecords",
        checkpoint_path: str = "./checkpoints",
        best_generator_loss: float = 1e9,
    ):
        # -----------------------------
        # Hyper-parameters
        # -----------------------------
        self.epochs = epochs
        self.batch_size = batch_size

        # -----------------------------
        # Model
        # -----------------------------
        self.generator = make_generator()
        self.discriminator = make_discriminator()
        self.vgg = make_vgg(height=height, width=width)

        if g_weight is not None and g_weight != "":
            print("Loading weights on generator...")
            self.generator.load_weights(g_weight)
        if d_weight is not None and d_weight != "":
            print("Loading weights on discriminator...")
            self.discriminator.load_weights(d_weight)

        # -----------------------------
        # Data
        # -----------------------------
        self.train_data, self.validate_data = prepare_from_tfrecords(
            train_data=training_data_path,
            validate_data=validate_data_path,
            height=height,
            width=width,
            batch_size=batch_size,
        )

        # -----------------------------
        # Loss
        # -----------------------------
        self.discriminator_loss_fn = tf.keras.losses.BinaryCrossentropy(
            from_logits=False
        )
        self.mse_loss = tf.keras.losses.MeanSquaredError()
        self.bce_loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)

        self.best_generator_loss = best_generator_loss

        # -----------------------------
        # Optimizer
        # -----------------------------
        self.discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate)
        self.generator_optimizer = tf.keras.optimizers.Adam(learning_rate)

        # -----------------------------
        # Summary Writer
        # -----------------------------
        current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        train_log_dir = "./logs/" + current_time + "/train"
        valid_log_dir = "./logs/" + current_time + "/valid"
        self.train_summary_writer = tf.summary.create_file_writer(train_log_dir)
        self.valid_summary_writer = tf.summary.create_file_writer(valid_log_dir)

        self.checkpoint_path = checkpoint_path
        self.make_checkpoint = len(checkpoint_path) > 0

    @tf.function
    def _content_loss(self, lr: tf.Tensor, hr: tf.Tensor):
        lr = (lr + 1) * 127.5
        hr = (hr + 1) * 127.5

        lr = tf.keras.applications.vgg19.preprocess_input(lr)
        hr = tf.keras.applications.vgg19.preprocess_input(hr)
        lr_vgg = self.vgg(lr) / 12.75
        hr_vgg = self.vgg(hr) / 12.75

        return self.mse_loss(lr_vgg, hr_vgg)

    def _adversarial_loss(self, output):
        return self.bce_loss(tf.ones_like(output), output)

    @tf.function
    def train_step(self, lr: tf.Tensor, hr: tf.Tensor) -> tuple[tf.Tensor]:
        with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
            generated_fake = self.generator(lr)

            real = self.discriminator(hr)
            fake = self.discriminator(generated_fake)

            d_loss = self.discriminator_loss_fn(real, tf.ones_like(real))
            d_loss += self.discriminator_loss_fn(fake, tf.zeros_like(fake))

            g_loss = self._content_loss(generated_fake, hr)
            g_loss += self._adversarial_loss(generated_fake) * 1e-3

        discriminator_grad = d_tape.gradient(
            d_loss, self.discriminator.trainable_variables
        )
        generator_grad = g_tape.gradient(g_loss, self.generator.trainable_variables)
        self.discriminator_optimizer.apply_gradients(
            grads_and_vars=zip(
                discriminator_grad, self.discriminator.trainable_variables
            )
        )
        self.generator_optimizer.apply_gradients(
            grads_and_vars=zip(generator_grad, self.generator.trainable_variables)
        )

        return g_loss, d_loss

    @tf.function
    def validation_step(self, lr: tf.Tensor, hr: tf.Tensor):
        generated_fake = self.generator(lr)
        real = self.discriminator(hr)
        fake = self.discriminator(generated_fake)

        d_loss = self.discriminator_loss_fn(real, tf.ones_like(real))
        d_loss += self.discriminator_loss_fn(fake, tf.zeros_like(fake))

        g_loss = self._content_loss(generated_fake, hr)
        g_loss += self._adversarial_loss(generated_fake) * 1e-3

        return g_loss, d_loss

    def train(self, start_epoch):
        for step in range(start_epoch, self.epochs):
            d_loss_train = []
            g_loss_train = []
            for images in tqdm(self.train_data):
                g_loss, d_loss = self.train_step(images["low"], images["high"])
                g_loss_train.append(g_loss.numpy())
                d_loss_train.append(d_loss.numpy())

            g_loss_train_mean = np.mean(g_loss_train)
            d_loss_train_mean = np.mean(d_loss_train)

            with self.train_summary_writer.as_default():
                tf.summary.scalar("g_loss", g_loss_train_mean, step=step)
                tf.summary.scalar("d_loss", d_loss_train_mean, step=step)

            print(
                f"Epoch {step+ 1}| Generator-Loss: {g_loss_train_mean:.3e},",
                f"Discriminator-Loss: {d_loss_train_mean:.3e}",
            )

            d_loss_valid = []
            g_loss_valid = []
            for images in tqdm(self.validate_data):
                g_loss, d_loss = self.validation_step(images["low"], images["high"])
                d_loss_valid.append(d_loss)
                g_loss_valid.append(g_loss)

            g_loss_valid_mean = np.mean(g_loss_valid)
            d_loss_valid_mean = np.mean(d_loss_valid)

            with self.valid_summary_writer.as_default():
                tf.summary.scalar("g_loss", g_loss_valid_mean, step=step)
                tf.summary.scalar("d_loss", d_loss_valid_mean, step=step)

            print(
                f"Validation| Generator-Loss: {g_loss_valid_mean:.3e},",
                f"Discriminator-Loss: {d_loss_valid_mean:.3e}",
            )

            if self.make_checkpoint:
                self.generator.save_weights(f"{self.checkpoint_path}/generator_last")
                self.discriminator.save_weights(
                    f"{self.checkpoint_path}/discriminator_last"
                )

                if g_loss_valid_mean < self.best_generator_loss:
                    self.best_generator_loss = g_loss_valid_mean
                    self.generator.save_weights(
                        f"{self.checkpoint_path}/generator_best"
                    )
                    self.discriminator.save_weights(
                        f"{self.checkpoint_path}/discriminator_best"
                    )

                    print("Model Saved")

訓練結果

訓練したモデルを使ってテストデータの高解像度化を行ってみました。ところどころノイズは載っていますが概ねよく高解像度化できているように見えます。

f:id:pyhaya:20210905214807p:plain