今回はGANを使った初期の画像の高解像度化モデルであるSRGANを実装してみたので紹介したいと思います。
実行環境
今回のモデルは以下のような環境で実装しています。
ソースコードは以下に公開しています。
github.com
検証に用いたデータセットはDIV2Kで以下からダウンロードしました。
data.vision.ee.ethz.ch
訓練に用いるために、各画像から20枚ずつ32x32のランダムクロップを行いました。
SRGANとは
SRGANとは、その名前にもあるようにGAN(敵対的生成ネットワーク)を使って、入力画像に対して解像度が上がった画像を生成しようという目的で作られたモデルです。

モデルの実装
モデルはTensorflowのKeras APIを使って実装します。
import tensorflow as tf from tensorflow.keras.layers import ( Conv2D, Dense, PReLU, BatchNormalization, LeakyReLU, Flatten, ) from tensorflow.keras import Sequential, Model class BResidualBlock(Model): def __init__(self): super(BResidualBlock, self).__init__() self.conv = Conv2D(filters=64, kernel_size=3, strides=1, padding="same") self.bn = BatchNormalization(momentum=0.8) self.prelu = PReLU(shared_axes=[1, 2]) self.conv2 = Conv2D(filters=64, kernel_size=3, strides=1, padding="same") self.bn2 = BatchNormalization(momentum=0.8) def call(self, input_tensor, training=True): x = self.conv(input_tensor) x = self.bn(x, training=training) x = self.prelu(x) x = self.conv2(x) x = self.bn2(x, training=training) x += input_tensor return x class ResidualBlock(Model): def __init__(self): super(ResidualBlock, self).__init__() self.residual1 = BResidualBlock() self.residual2 = BResidualBlock() self.residual3 = BResidualBlock() self.residual4 = BResidualBlock() self.residual5 = BResidualBlock() self.conv = Conv2D(filters=64, kernel_size=3, padding="same") self.bn = BatchNormalization(momentum=0.8) def call(self, input_tensor, training=True): x = self.residual1(input_tensor) x = self.residual2(x, training=training) x = self.residual3(x, training=training) x = self.residual4(x, training=training) x = self.residual5(x, training=training) x = self.conv(x) x = self.bn(x) x += input_tensor return x class DiscriminatorBlock(Model): def __init__(self, filters=128): super(DiscriminatorBlock, self).__init__() self.filters = filters self.conv1 = Conv2D(filters=filters, kernel_size=3, strides=1, padding="same") self.bn1 = BatchNormalization(momentum=0.8) self.lrelu1 = LeakyReLU(alpha=0.2) self.conv2 = Conv2D(filters=filters, kernel_size=3, strides=2, padding="same") self.bn2 = BatchNormalization(momentum=0.8) self.lrelu2 = LeakyReLU(alpha=0.2) def call(self, input_tensor, training=True): x = self.conv1(input_tensor) x = self.bn1(x, training=training) x = self.lrelu1(x) x = self.conv2(x) x = self.bn2(x, training=training) x = self.lrelu2(x) return x class PixelShuffler(tf.keras.layers.Layer): def __init__(self): super(PixelShuffler, self).__init__() def call(self, input_tensor): x = tf.nn.depth_to_space(input_tensor, 2) return x def make_generator(): model = Sequential( [ Conv2D(filters=64, kernel_size=9, padding="same"), PReLU(shared_axes=[1, 2]), ResidualBlock(), Conv2D(filters=256, kernel_size=3, padding="same"), PixelShuffler(), PReLU(shared_axes=[1, 2]), Conv2D(filters=256, kernel_size=3, padding="same"), PixelShuffler(), PReLU(shared_axes=[1, 2]), Conv2D(filters=3, kernel_size=9, padding="same"), ] ) return model def make_discriminator(): model = Sequential( [ Conv2D(filters=64, kernel_size=3, padding="same"), LeakyReLU(alpha=0.2), Conv2D(filters=64, kernel_size=3, strides=2, padding="same"), BatchNormalization(momentum=0.8), LeakyReLU(alpha=0.2), DiscriminatorBlock(128), DiscriminatorBlock(256), DiscriminatorBlock(512), Flatten(), Dense(1024), LeakyReLU(alpha=0.2), Dense(1, activation="sigmoid"), ] ) return model def make_vgg(height, width): vgg = tf.keras.applications.vgg19.VGG19(include_top=False, weights="imagenet") partial_vgg = tf.keras.Model( inputs=vgg.input, outputs=vgg.get_layer("block5_conv4").output ) partial_vgg.trainable = False partial_vgg.build(input_shape=(None, height * 4, width * 4, 3)) return partial_vgg
トレーニングの実行
Generatorのトレーニング
このモデルでは、GeneratorとDiscriminatorを同時に訓練するのではなく、先にGeneratorを訓練しておきます。これによりGeneratorが "local optima" にハマるのを防ぎます。Generatorの訓練では損失関数にMeanSquaredErrorを使います。
class SRResNetTrainer: def __init__( self, epochs: int = 10000, batch_size: int = 32, learning_rate: float = 1e-4, training_data_path: str = "./datasets/train.tfrecords", validate_data_path: str = "./datasets/valid.tfrecords", height: int = 32, width: int = 32, g_weight: str = None, checkpoint_path: str = "./checkpoint", best_generator_loss: float = 1e9, ): self.epochs = epochs self.batch_size = batch_size self.generator = make_generator() if g_weight is not None and g_weight != "": print("Loading weights on generator...") self.generator.load_weights(g_weight) self.train_data, self.validate_data = prepare_from_tfrecords( train_data=training_data_path, validate_data=validate_data_path, height=height, width=width, batch_size=batch_size, ) self.mse_loss = tf.keras.losses.MeanSquaredError() self.best_generator_loss = best_generator_loss self.generator_optimizer = tf.keras.optimizers.Adam(learning_rate) self.checkpoint_path = checkpoint_path self.make_checkpoint = len(checkpoint_path) > 0 current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") train_log_dir = "./logs/" + current_time + "/train_generator" valid_log_dir = "./logs/" + current_time + "/valid_generator" self.train_summary_writer = tf.summary.create_file_writer(train_log_dir) self.valid_summary_writer = tf.summary.create_file_writer(valid_log_dir) @tf.function def train_step(self, lr: tf.Tensor, hr: tf.Tensor): with tf.GradientTape() as tape: generated_fake = self.generator(lr) g_loss = self.mse_loss(generated_fake, hr) generator_grad = tape.gradient(g_loss, self.generator.trainable_variables) self.generator_optimizer.apply_gradients( grads_and_vars=zip(generator_grad, self.generator.trainable_variables) ) return g_loss @tf.function def validation_step(self, lr: tf.Tensor, hr: tf.Tensor): generated_fake = self.generator(lr) g_loss = self.mse_loss(generated_fake, hr) return g_loss def train(self, start_epoch=0): for step in range(start_epoch, self.epochs): g_loss_train = [] for images in tqdm(self.train_data): g_loss = self.train_step(images["low"], images["high"]) g_loss_train.append(g_loss.numpy()) g_loss_train_mean = np.mean(g_loss_train) with self.train_summary_writer.as_default(): tf.summary.scalar("g_loss", g_loss_train_mean, step=step) print( f"Epoch {step+ 1}| Generator-Loss: {g_loss_train_mean:.3e},", ) g_loss_valid = [] for images in tqdm(self.validate_data): g_loss = self.validation_step(images["low"], images["high"]) g_loss_valid.append(g_loss) g_loss_valid_mean = np.mean(g_loss_valid) with self.valid_summary_writer.as_default(): tf.summary.scalar("g_loss", g_loss_valid_mean, step=step) print( f"Validation| Generator-Loss: {g_loss_valid_mean:.3e},", ) if self.make_checkpoint: self.generator.save_weights(f"{self.checkpoint_path}/generator_last") if g_loss_valid_mean < self.best_generator_loss: self.best_generator_loss = g_loss_valid_mean self.generator.save_weights( f"{self.checkpoint_path}/generator_best" ) print("Model Saved")
GANの訓練
Generatorを訓練したら、その重みを使ってGANの訓練を行います。先ほどとは違い、Generatorの損失関数にはBinaryCrossentropyとcontent lossと呼ばれるものを使っています。このロスでは、正解の高解像度画像とGeneratorの生成した画像をそれぞれVGG19に入れ、その中間層の出力を比べたときの二乗誤差を計算します。
class SRGANTrainer: def __init__( self, epochs: int = 100, batch_size: int = 16, learning_rate: float = 1e-4, height: int = 32, width: int = 32, g_weight: str = None, d_weight: str = None, training_data_path: str = "./datasets/train.tfrecords", validate_data_path: str = "./datasets/valid.tfrecords", checkpoint_path: str = "./checkpoints", best_generator_loss: float = 1e9, ): # ----------------------------- # Hyper-parameters # ----------------------------- self.epochs = epochs self.batch_size = batch_size # ----------------------------- # Model # ----------------------------- self.generator = make_generator() self.discriminator = make_discriminator() self.vgg = make_vgg(height=height, width=width) if g_weight is not None and g_weight != "": print("Loading weights on generator...") self.generator.load_weights(g_weight) if d_weight is not None and d_weight != "": print("Loading weights on discriminator...") self.discriminator.load_weights(d_weight) # ----------------------------- # Data # ----------------------------- self.train_data, self.validate_data = prepare_from_tfrecords( train_data=training_data_path, validate_data=validate_data_path, height=height, width=width, batch_size=batch_size, ) # ----------------------------- # Loss # ----------------------------- self.discriminator_loss_fn = tf.keras.losses.BinaryCrossentropy( from_logits=False ) self.mse_loss = tf.keras.losses.MeanSquaredError() self.bce_loss = tf.keras.losses.BinaryCrossentropy(from_logits=False) self.best_generator_loss = best_generator_loss # ----------------------------- # Optimizer # ----------------------------- self.discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate) self.generator_optimizer = tf.keras.optimizers.Adam(learning_rate) # ----------------------------- # Summary Writer # ----------------------------- current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") train_log_dir = "./logs/" + current_time + "/train" valid_log_dir = "./logs/" + current_time + "/valid" self.train_summary_writer = tf.summary.create_file_writer(train_log_dir) self.valid_summary_writer = tf.summary.create_file_writer(valid_log_dir) self.checkpoint_path = checkpoint_path self.make_checkpoint = len(checkpoint_path) > 0 @tf.function def _content_loss(self, lr: tf.Tensor, hr: tf.Tensor): lr = (lr + 1) * 127.5 hr = (hr + 1) * 127.5 lr = tf.keras.applications.vgg19.preprocess_input(lr) hr = tf.keras.applications.vgg19.preprocess_input(hr) lr_vgg = self.vgg(lr) / 12.75 hr_vgg = self.vgg(hr) / 12.75 return self.mse_loss(lr_vgg, hr_vgg) def _adversarial_loss(self, output): return self.bce_loss(tf.ones_like(output), output) @tf.function def train_step(self, lr: tf.Tensor, hr: tf.Tensor) -> tuple[tf.Tensor]: with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape: generated_fake = self.generator(lr) real = self.discriminator(hr) fake = self.discriminator(generated_fake) d_loss = self.discriminator_loss_fn(real, tf.ones_like(real)) d_loss += self.discriminator_loss_fn(fake, tf.zeros_like(fake)) g_loss = self._content_loss(generated_fake, hr) g_loss += self._adversarial_loss(generated_fake) * 1e-3 discriminator_grad = d_tape.gradient( d_loss, self.discriminator.trainable_variables ) generator_grad = g_tape.gradient(g_loss, self.generator.trainable_variables) self.discriminator_optimizer.apply_gradients( grads_and_vars=zip( discriminator_grad, self.discriminator.trainable_variables ) ) self.generator_optimizer.apply_gradients( grads_and_vars=zip(generator_grad, self.generator.trainable_variables) ) return g_loss, d_loss @tf.function def validation_step(self, lr: tf.Tensor, hr: tf.Tensor): generated_fake = self.generator(lr) real = self.discriminator(hr) fake = self.discriminator(generated_fake) d_loss = self.discriminator_loss_fn(real, tf.ones_like(real)) d_loss += self.discriminator_loss_fn(fake, tf.zeros_like(fake)) g_loss = self._content_loss(generated_fake, hr) g_loss += self._adversarial_loss(generated_fake) * 1e-3 return g_loss, d_loss def train(self, start_epoch): for step in range(start_epoch, self.epochs): d_loss_train = [] g_loss_train = [] for images in tqdm(self.train_data): g_loss, d_loss = self.train_step(images["low"], images["high"]) g_loss_train.append(g_loss.numpy()) d_loss_train.append(d_loss.numpy()) g_loss_train_mean = np.mean(g_loss_train) d_loss_train_mean = np.mean(d_loss_train) with self.train_summary_writer.as_default(): tf.summary.scalar("g_loss", g_loss_train_mean, step=step) tf.summary.scalar("d_loss", d_loss_train_mean, step=step) print( f"Epoch {step+ 1}| Generator-Loss: {g_loss_train_mean:.3e},", f"Discriminator-Loss: {d_loss_train_mean:.3e}", ) d_loss_valid = [] g_loss_valid = [] for images in tqdm(self.validate_data): g_loss, d_loss = self.validation_step(images["low"], images["high"]) d_loss_valid.append(d_loss) g_loss_valid.append(g_loss) g_loss_valid_mean = np.mean(g_loss_valid) d_loss_valid_mean = np.mean(d_loss_valid) with self.valid_summary_writer.as_default(): tf.summary.scalar("g_loss", g_loss_valid_mean, step=step) tf.summary.scalar("d_loss", d_loss_valid_mean, step=step) print( f"Validation| Generator-Loss: {g_loss_valid_mean:.3e},", f"Discriminator-Loss: {d_loss_valid_mean:.3e}", ) if self.make_checkpoint: self.generator.save_weights(f"{self.checkpoint_path}/generator_last") self.discriminator.save_weights( f"{self.checkpoint_path}/discriminator_last" ) if g_loss_valid_mean < self.best_generator_loss: self.best_generator_loss = g_loss_valid_mean self.generator.save_weights( f"{self.checkpoint_path}/generator_best" ) self.discriminator.save_weights( f"{self.checkpoint_path}/discriminator_best" ) print("Model Saved")
訓練結果
訓練したモデルを使ってテストデータの高解像度化を行ってみました。ところどころノイズは載っていますが概ねよく高解像度化できているように見えます。