モデルの実装
モデルはTensorflowのKeras APIを使って実装します。
import tensorflow as tf
from tensorflow.keras.layers import (
Conv2D,
Dense,
PReLU,
BatchNormalization,
LeakyReLU,
Flatten,
)
from tensorflow.keras import Sequential, Model
class BResidualBlock(Model):
def __init__(self):
super(BResidualBlock, self).__init__()
self.conv = Conv2D(filters=64, kernel_size=3, strides=1, padding="same")
self.bn = BatchNormalization(momentum=0.8)
self.prelu = PReLU(shared_axes=[1, 2])
self.conv2 = Conv2D(filters=64, kernel_size=3, strides=1, padding="same")
self.bn2 = BatchNormalization(momentum=0.8)
def call(self, input_tensor, training=True):
x = self.conv(input_tensor)
x = self.bn(x, training=training)
x = self.prelu(x)
x = self.conv2(x)
x = self.bn2(x, training=training)
x += input_tensor
return x
class ResidualBlock(Model):
def __init__(self):
super(ResidualBlock, self).__init__()
self.residual1 = BResidualBlock()
self.residual2 = BResidualBlock()
self.residual3 = BResidualBlock()
self.residual4 = BResidualBlock()
self.residual5 = BResidualBlock()
self.conv = Conv2D(filters=64, kernel_size=3, padding="same")
self.bn = BatchNormalization(momentum=0.8)
def call(self, input_tensor, training=True):
x = self.residual1(input_tensor)
x = self.residual2(x, training=training)
x = self.residual3(x, training=training)
x = self.residual4(x, training=training)
x = self.residual5(x, training=training)
x = self.conv(x)
x = self.bn(x)
x += input_tensor
return x
class DiscriminatorBlock(Model):
def __init__(self, filters=128):
super(DiscriminatorBlock, self).__init__()
self.filters = filters
self.conv1 = Conv2D(filters=filters, kernel_size=3, strides=1, padding="same")
self.bn1 = BatchNormalization(momentum=0.8)
self.lrelu1 = LeakyReLU(alpha=0.2)
self.conv2 = Conv2D(filters=filters, kernel_size=3, strides=2, padding="same")
self.bn2 = BatchNormalization(momentum=0.8)
self.lrelu2 = LeakyReLU(alpha=0.2)
def call(self, input_tensor, training=True):
x = self.conv1(input_tensor)
x = self.bn1(x, training=training)
x = self.lrelu1(x)
x = self.conv2(x)
x = self.bn2(x, training=training)
x = self.lrelu2(x)
return x
class PixelShuffler(tf.keras.layers.Layer):
def __init__(self):
super(PixelShuffler, self).__init__()
def call(self, input_tensor):
x = tf.nn.depth_to_space(input_tensor, 2)
return x
def make_generator():
model = Sequential(
[
Conv2D(filters=64, kernel_size=9, padding="same"),
PReLU(shared_axes=[1, 2]),
ResidualBlock(),
Conv2D(filters=256, kernel_size=3, padding="same"),
PixelShuffler(),
PReLU(shared_axes=[1, 2]),
Conv2D(filters=256, kernel_size=3, padding="same"),
PixelShuffler(),
PReLU(shared_axes=[1, 2]),
Conv2D(filters=3, kernel_size=9, padding="same"),
]
)
return model
def make_discriminator():
model = Sequential(
[
Conv2D(filters=64, kernel_size=3, padding="same"),
LeakyReLU(alpha=0.2),
Conv2D(filters=64, kernel_size=3, strides=2, padding="same"),
BatchNormalization(momentum=0.8),
LeakyReLU(alpha=0.2),
DiscriminatorBlock(128),
DiscriminatorBlock(256),
DiscriminatorBlock(512),
Flatten(),
Dense(1024),
LeakyReLU(alpha=0.2),
Dense(1, activation="sigmoid"),
]
)
return model
def make_vgg(height, width):
vgg = tf.keras.applications.vgg19.VGG19(include_top=False, weights="imagenet")
partial_vgg = tf.keras.Model(
inputs=vgg.input, outputs=vgg.get_layer("block5_conv4").output
)
partial_vgg.trainable = False
partial_vgg.build(input_shape=(None, height * 4, width * 4, 3))
return partial_vgg
Generatorのトレーニング
このモデルでは、GeneratorとDiscriminatorを同時に訓練するのではなく、先にGeneratorを訓練しておきます。これによりGeneratorが "local optima" にハマるのを防ぎます。Generatorの訓練では損失関数にMeanSquaredErrorを使います。
class SRResNetTrainer:
def __init__(
self,
epochs: int = 10000,
batch_size: int = 32,
learning_rate: float = 1e-4,
training_data_path: str = "./datasets/train.tfrecords",
validate_data_path: str = "./datasets/valid.tfrecords",
height: int = 32,
width: int = 32,
g_weight: str = None,
checkpoint_path: str = "./checkpoint",
best_generator_loss: float = 1e9,
):
self.epochs = epochs
self.batch_size = batch_size
self.generator = make_generator()
if g_weight is not None and g_weight != "":
print("Loading weights on generator...")
self.generator.load_weights(g_weight)
self.train_data, self.validate_data = prepare_from_tfrecords(
train_data=training_data_path,
validate_data=validate_data_path,
height=height,
width=width,
batch_size=batch_size,
)
self.mse_loss = tf.keras.losses.MeanSquaredError()
self.best_generator_loss = best_generator_loss
self.generator_optimizer = tf.keras.optimizers.Adam(learning_rate)
self.checkpoint_path = checkpoint_path
self.make_checkpoint = len(checkpoint_path) > 0
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = "./logs/" + current_time + "/train_generator"
valid_log_dir = "./logs/" + current_time + "/valid_generator"
self.train_summary_writer = tf.summary.create_file_writer(train_log_dir)
self.valid_summary_writer = tf.summary.create_file_writer(valid_log_dir)
@tf.function
def train_step(self, lr: tf.Tensor, hr: tf.Tensor):
with tf.GradientTape() as tape:
generated_fake = self.generator(lr)
g_loss = self.mse_loss(generated_fake, hr)
generator_grad = tape.gradient(g_loss, self.generator.trainable_variables)
self.generator_optimizer.apply_gradients(
grads_and_vars=zip(generator_grad, self.generator.trainable_variables)
)
return g_loss
@tf.function
def validation_step(self, lr: tf.Tensor, hr: tf.Tensor):
generated_fake = self.generator(lr)
g_loss = self.mse_loss(generated_fake, hr)
return g_loss
def train(self, start_epoch=0):
for step in range(start_epoch, self.epochs):
g_loss_train = []
for images in tqdm(self.train_data):
g_loss = self.train_step(images["low"], images["high"])
g_loss_train.append(g_loss.numpy())
g_loss_train_mean = np.mean(g_loss_train)
with self.train_summary_writer.as_default():
tf.summary.scalar("g_loss", g_loss_train_mean, step=step)
print(
f"Epoch {step+ 1}| Generator-Loss: {g_loss_train_mean:.3e},",
)
g_loss_valid = []
for images in tqdm(self.validate_data):
g_loss = self.validation_step(images["low"], images["high"])
g_loss_valid.append(g_loss)
g_loss_valid_mean = np.mean(g_loss_valid)
with self.valid_summary_writer.as_default():
tf.summary.scalar("g_loss", g_loss_valid_mean, step=step)
print(
f"Validation| Generator-Loss: {g_loss_valid_mean:.3e},",
)
if self.make_checkpoint:
self.generator.save_weights(f"{self.checkpoint_path}/generator_last")
if g_loss_valid_mean < self.best_generator_loss:
self.best_generator_loss = g_loss_valid_mean
self.generator.save_weights(
f"{self.checkpoint_path}/generator_best"
)
print("Model Saved")
GANの訓練
Generatorを訓練したら、その重みを使ってGANの訓練を行います。先ほどとは違い、Generatorの損失関数にはBinaryCrossentropyとcontent lossと呼ばれるものを使っています。このロスでは、正解の高解像度画像とGeneratorの生成した画像をそれぞれVGG19に入れ、その中間層の出力を比べたときの二乗誤差を計算します。
class SRGANTrainer:
def __init__(
self,
epochs: int = 100,
batch_size: int = 16,
learning_rate: float = 1e-4,
height: int = 32,
width: int = 32,
g_weight: str = None,
d_weight: str = None,
training_data_path: str = "./datasets/train.tfrecords",
validate_data_path: str = "./datasets/valid.tfrecords",
checkpoint_path: str = "./checkpoints",
best_generator_loss: float = 1e9,
):
self.epochs = epochs
self.batch_size = batch_size
self.generator = make_generator()
self.discriminator = make_discriminator()
self.vgg = make_vgg(height=height, width=width)
if g_weight is not None and g_weight != "":
print("Loading weights on generator...")
self.generator.load_weights(g_weight)
if d_weight is not None and d_weight != "":
print("Loading weights on discriminator...")
self.discriminator.load_weights(d_weight)
self.train_data, self.validate_data = prepare_from_tfrecords(
train_data=training_data_path,
validate_data=validate_data_path,
height=height,
width=width,
batch_size=batch_size,
)
self.discriminator_loss_fn = tf.keras.losses.BinaryCrossentropy(
from_logits=False
)
self.mse_loss = tf.keras.losses.MeanSquaredError()
self.bce_loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)
self.best_generator_loss = best_generator_loss
self.discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate)
self.generator_optimizer = tf.keras.optimizers.Adam(learning_rate)
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = "./logs/" + current_time + "/train"
valid_log_dir = "./logs/" + current_time + "/valid"
self.train_summary_writer = tf.summary.create_file_writer(train_log_dir)
self.valid_summary_writer = tf.summary.create_file_writer(valid_log_dir)
self.checkpoint_path = checkpoint_path
self.make_checkpoint = len(checkpoint_path) > 0
@tf.function
def _content_loss(self, lr: tf.Tensor, hr: tf.Tensor):
lr = (lr + 1) * 127.5
hr = (hr + 1) * 127.5
lr = tf.keras.applications.vgg19.preprocess_input(lr)
hr = tf.keras.applications.vgg19.preprocess_input(hr)
lr_vgg = self.vgg(lr) / 12.75
hr_vgg = self.vgg(hr) / 12.75
return self.mse_loss(lr_vgg, hr_vgg)
def _adversarial_loss(self, output):
return self.bce_loss(tf.ones_like(output), output)
@tf.function
def train_step(self, lr: tf.Tensor, hr: tf.Tensor) -> tuple[tf.Tensor]:
with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
generated_fake = self.generator(lr)
real = self.discriminator(hr)
fake = self.discriminator(generated_fake)
d_loss = self.discriminator_loss_fn(real, tf.ones_like(real))
d_loss += self.discriminator_loss_fn(fake, tf.zeros_like(fake))
g_loss = self._content_loss(generated_fake, hr)
g_loss += self._adversarial_loss(generated_fake) * 1e-3
discriminator_grad = d_tape.gradient(
d_loss, self.discriminator.trainable_variables
)
generator_grad = g_tape.gradient(g_loss, self.generator.trainable_variables)
self.discriminator_optimizer.apply_gradients(
grads_and_vars=zip(
discriminator_grad, self.discriminator.trainable_variables
)
)
self.generator_optimizer.apply_gradients(
grads_and_vars=zip(generator_grad, self.generator.trainable_variables)
)
return g_loss, d_loss
@tf.function
def validation_step(self, lr: tf.Tensor, hr: tf.Tensor):
generated_fake = self.generator(lr)
real = self.discriminator(hr)
fake = self.discriminator(generated_fake)
d_loss = self.discriminator_loss_fn(real, tf.ones_like(real))
d_loss += self.discriminator_loss_fn(fake, tf.zeros_like(fake))
g_loss = self._content_loss(generated_fake, hr)
g_loss += self._adversarial_loss(generated_fake) * 1e-3
return g_loss, d_loss
def train(self, start_epoch):
for step in range(start_epoch, self.epochs):
d_loss_train = []
g_loss_train = []
for images in tqdm(self.train_data):
g_loss, d_loss = self.train_step(images["low"], images["high"])
g_loss_train.append(g_loss.numpy())
d_loss_train.append(d_loss.numpy())
g_loss_train_mean = np.mean(g_loss_train)
d_loss_train_mean = np.mean(d_loss_train)
with self.train_summary_writer.as_default():
tf.summary.scalar("g_loss", g_loss_train_mean, step=step)
tf.summary.scalar("d_loss", d_loss_train_mean, step=step)
print(
f"Epoch {step+ 1}| Generator-Loss: {g_loss_train_mean:.3e},",
f"Discriminator-Loss: {d_loss_train_mean:.3e}",
)
d_loss_valid = []
g_loss_valid = []
for images in tqdm(self.validate_data):
g_loss, d_loss = self.validation_step(images["low"], images["high"])
d_loss_valid.append(d_loss)
g_loss_valid.append(g_loss)
g_loss_valid_mean = np.mean(g_loss_valid)
d_loss_valid_mean = np.mean(d_loss_valid)
with self.valid_summary_writer.as_default():
tf.summary.scalar("g_loss", g_loss_valid_mean, step=step)
tf.summary.scalar("d_loss", d_loss_valid_mean, step=step)
print(
f"Validation| Generator-Loss: {g_loss_valid_mean:.3e},",
f"Discriminator-Loss: {d_loss_valid_mean:.3e}",
)
if self.make_checkpoint:
self.generator.save_weights(f"{self.checkpoint_path}/generator_last")
self.discriminator.save_weights(
f"{self.checkpoint_path}/discriminator_last"
)
if g_loss_valid_mean < self.best_generator_loss:
self.best_generator_loss = g_loss_valid_mean
self.generator.save_weights(
f"{self.checkpoint_path}/generator_best"
)
self.discriminator.save_weights(
f"{self.checkpoint_path}/discriminator_best"
)
print("Model Saved")