A.I
Explolation20 cGAN을 이용한 스케치 채색 본문
GAN 과 cGAN¶
- Generator
노이즈 z(파란색)이 입력되고 특정 representation(검정색)으로 변환된 후 가짜 데이터 G(z)(빨간색)을 생성해 냅니다. - Discriminator
실제 데이터 x와 Generator가 생성한 가짜 데이터 G(z)를 각각 입력받아 D(x) 및 D(G(z))(보라색)을 계산하여 진짜와 가짜를 식별해 냅니다.
- Generator
노이즈 z(파란색)와 추가 정보 y(녹색)을 함께 입력받아 Generator 내부에서 결합되어 representation(검정색)으로 변환되며 가짜 데이터 G(zㅣy)를 생성합니다. MNIST나 CIFAR-10 등의 데이터셋에 대해 학습시키는 경우 y는 레이블 정보이며, 일반적으로 one-hot 벡터를 입력으로 넣습니다. - Discriminator
실제 데이터 x와 Generator가 생성한 가짜 데이터 G(zㅣy)를 각각 입력받으며, 마찬가지로 y 정보가 각각 함께 입력되어 진짜와 가짜를 식별합니다. MNIST나 CIFAR-10 등의 데이터셋에 대해 학습시키는 경우 실제 데이터 x와 y는 알맞은 한 쌍("7"이라 쓰인 이미지의 경우 레이블도 7)을 이뤄야 하며, 마찬가지로 Generator에 입력된 x와 Discriminator에 입력되는 y는 동일한 레이블을 나타내야 합니다.
In [1]:
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
# 텐서플로가 첫 번째 GPU에 6GB 메모리만 할당하도록 제한
try:
tf.config.experimental.set_virtual_device_configuration(
gpus[0],
[tf.config.experimental.VirtualDeviceConfiguration(memory_limit=6144)])
except RuntimeError as e:
print(e)
Generator 구성¶
- pip install tensorflow-datasets
In [2]:
import tensorflow_datasets as tfds
mnist, info = tfds.load(
"mnist", split="train", with_info=True
)
fig = tfds.show_examples(mnist, info)
In [3]:
import tensorflow as tf
BATCH_SIZE = 128
def gan_preprocessing(data):
image = data["image"]
image = tf.cast(image, tf.float32)
image = (image / 127.5) - 1
return image
def cgan_preprocessing(data):
image = data["image"]
image = tf.cast(image, tf.float32)
image = (image / 127.5) - 1
label = tf.one_hot(data["label"], 10)
return image, label
gan_datasets = mnist.map(gan_preprocessing).shuffle(1000).batch(BATCH_SIZE)
cgan_datasets = mnist.map(cgan_preprocessing).shuffle(100).batch(BATCH_SIZE)
print("✅")
✅
In [4]:
import matplotlib.pyplot as plt
for i,j in cgan_datasets : break
# 이미지 i와 라벨 j가 일치하는지 확인해 봅니다.
# 이미지에 쓰인 숫자와 레이블이 일치해야 하고, 이미지 값의 범위가 -1~1 사이에 있어야 합니다.
print("Label :", j[0])
print("Image Min/Max :", i.numpy().min(), i.numpy().max())
plt.imshow(i.numpy()[0,...,0], plt.cm.gray)
Label : tf.Tensor([0. 0. 0. 0. 0. 0. 0. 1. 0. 0.], shape=(10,), dtype=float32) Image Min/Max : -1.0 1.0
Out[4]:
<matplotlib.image.AxesImage at 0x7f26700ebd50>
GAN Generator¶
In [5]:
from tensorflow.keras import layers, Input, Model
class GeneratorGAN(Model):
def __init__(self):
super(GeneratorGAN, self).__init__()
# 4개의 fully-connected 레이어 중 한 개를 제외하고 모두 ReLU 활성화를 사용하는 것으로 확인
self.dense_1 = layers.Dense(128, activation='relu')
self.dense_2 = layers.Dense(256, activation='relu')
self.dense_3 = layers.Dense(512, activation='relu')
self.dense_4 = layers.Dense(28*28*1, activation='tanh')
self.reshape = layers.Reshape((28, 28, 1))
def call(self, noise):
out = self.dense_1(noise)
out = self.dense_2(out)
out = self.dense_3(out)
out = self.dense_4(out)
return self.reshape(out)
print("✅")
✅
cGAN Generator 구성¶
In [6]:
class GeneratorCGAN(Model):
def __init__(self):
super(GeneratorCGAN, self).__init__()
# 노이즈 입력 및 레이블 입력은 각각 1개의 fully-connected 레이어와 ReLU 활성화를 통과(dense_z, dense_y)
self.dense_z = layers.Dense(256, activation='relu')
self.dense_y = layers.Dense(256, activation='relu')
self.combined_dense = layers.Dense(512, activation='relu')
self.final_dense = layers.Dense(28 * 28 * 1, activation='tanh')
self.reshape = layers.Reshape((28, 28, 1))
def call(self, noise, label):
noise = self.dense_z(noise)
label = self.dense_y(label)
# 각 결과가 서로 연결되어 다시 한번 1개의 fully-connected 레이어와 ReLU 활성화를 통과 (tf.concat, conbined_dense)
out = self.combined_dense(tf.concat([noise, label], axis=-1))
# 1개의 fully-connected 레이어 및 Hyperbolic tangent 활성화를 거쳐 28x28 차원의 결과가 생성
# (28,28,1) 크기의 이미지 형태로 변환되어 출력됩니다 (final_dense, reshape)
out = self.final_dense(out)
return self.reshape(out)
print("✅")
✅
Discriminator 구성¶
GAN Discriminator 구성¶
In [7]:
class DiscriminatorGAN(Model):
def __init__(self):
super(DiscriminatorGAN, self).__init__()
self.flatten = layers.Flatten()
self.blocks = []
for f in [512, 256, 128, 1]:
self.blocks.append(
layers.Dense(f, activation=None if f==1 else "relu")
)
def call(self, x):
x = self.flatten(x)
for block in self.blocks:
x = block(x)
return x
print("✅")
✅
cGAN Discriminator 구성¶
- Maxout
두 레이어 사이를 연결할 때, 여러 개의 fully-connected 레이어를 통과시켜 그 중 가장 큰 값을 가져오는 것.
In [8]:
class Maxout(layers.Layer):
# units 차원 수를 가진 fully-connected 레이어를 pieces개 만큼 만들고 그 중 최대 값을 출력
def __init__(self, units, pieces):
super(Maxout, self).__init__()
self.dense = layers.Dense(units*pieces, activation="relu")
self.dropout = layers.Dropout(.5)
self.reshape = layers.Reshape((-1, pieces, units))
def call(self, x):
x = self.dense(x)
x = self.dropout(x)
x = self.reshape(x)
return tf.math.reduce_max(x, axis=1)
print("✅")
✅
In [9]:
class DiscriminatorCGAN(Model):
def __init__(self):
super(DiscriminatorCGAN, self).__init__()
self.flatten = layers.Flatten()
self.image_block = Maxout(240, 5)
self.label_block = Maxout(50, 5)
self.combine_block = Maxout(240, 4)
self.dense = layers.Dense(1, activation=None)
def call(self, image, label):
image = self.flatten(image)
image = self.image_block(image)
label = self.label_block(label)
x = layers.Concatenate()([image, label])
x = self.combine_block(x)
return self.dense(x)
print("✅")
✅
학습 및 테스트¶
In [10]:
from tensorflow.keras import optimizers, losses
bce = losses.BinaryCrossentropy(from_logits=True)
def generator_loss(fake_output):
return bce(tf.ones_like(fake_output), fake_output)
def discriminator_loss(real_output, fake_output):
return bce(tf.ones_like(real_output), real_output) + bce(tf.zeros_like(fake_output), fake_output)
gene_opt = optimizers.Adam(1e-4)
disc_opt = optimizers.Adam(1e-4)
print("✅")
✅
GAN으로 MNIST 학습¶
In [11]:
gan_generator = GeneratorGAN()
gan_discriminator = DiscriminatorGAN()
@tf.function()
def gan_step(real_images):
noise = tf.random.normal([real_images.shape[0], 100])
with tf.GradientTape(persistent=True) as tape:
# Generator를 이용해 가짜 이미지 생성
fake_images = gan_generator(noise)
# Discriminator를 이용해 진짜 및 가짜이미지를 각각 판별
real_out = gan_discriminator(real_images)
fake_out = gan_discriminator(fake_images)
# 각 손실(loss)을 계산
gene_loss = generator_loss(fake_out)
disc_loss = discriminator_loss(real_out, fake_out)
# gradient 계산
gene_grad = tape.gradient(gene_loss, gan_generator.trainable_variables)
disc_grad = tape.gradient(disc_loss, gan_discriminator.trainable_variables)
# 모델 학습
gene_opt.apply_gradients(zip(gene_grad, gan_generator.trainable_variables))
disc_opt.apply_gradients(zip(disc_grad, gan_discriminator.trainable_variables))
return gene_loss, disc_loss
print("✅")
✅
In [12]:
EPOCHS = 10
for epoch in range(1, EPOCHS+1):
for i, images in enumerate(gan_datasets):
gene_loss, disc_loss = gan_step(images)
if (i+1) % 100 == 0:
print(f"[{epoch}/{EPOCHS} EPOCHS, {i+1} ITER] G:{gene_loss}, D:{disc_loss}")
[1/10 EPOCHS, 100 ITER] G:2.240199089050293, D:0.1363147348165512 [1/10 EPOCHS, 200 ITER] G:2.643437147140503, D:0.10650718212127686 [1/10 EPOCHS, 300 ITER] G:2.5449469089508057, D:0.1548902988433838 [1/10 EPOCHS, 400 ITER] G:2.8896124362945557, D:0.08236328512430191 [2/10 EPOCHS, 100 ITER] G:3.6674177646636963, D:0.058024849742650986 [2/10 EPOCHS, 200 ITER] G:2.766561985015869, D:0.14555320143699646 [2/10 EPOCHS, 300 ITER] G:2.529735565185547, D:0.1973763406276703 [2/10 EPOCHS, 400 ITER] G:2.6738710403442383, D:0.15558688342571259 [3/10 EPOCHS, 100 ITER] G:3.2016184329986572, D:0.2673863172531128 [3/10 EPOCHS, 200 ITER] G:4.122986316680908, D:0.1851266324520111 [3/10 EPOCHS, 300 ITER] G:3.4892544746398926, D:0.22411954402923584 [3/10 EPOCHS, 400 ITER] G:2.525853157043457, D:0.20478767156600952 [4/10 EPOCHS, 100 ITER] G:2.2505226135253906, D:0.19485771656036377 [4/10 EPOCHS, 200 ITER] G:2.706967353820801, D:0.1406809389591217 [4/10 EPOCHS, 300 ITER] G:3.407163143157959, D:0.09906414896249771 [4/10 EPOCHS, 400 ITER] G:4.038682460784912, D:0.14888018369674683 [5/10 EPOCHS, 100 ITER] G:4.0816802978515625, D:0.0911555141210556 [5/10 EPOCHS, 200 ITER] G:2.704348087310791, D:0.1466299295425415 [5/10 EPOCHS, 300 ITER] G:4.224481105804443, D:0.12020998448133469 [5/10 EPOCHS, 400 ITER] G:3.869384288787842, D:0.09422999620437622 [6/10 EPOCHS, 100 ITER] G:3.5552425384521484, D:0.059337738901376724 [6/10 EPOCHS, 200 ITER] G:4.10662317276001, D:0.09973932802677155 [6/10 EPOCHS, 300 ITER] G:3.0945067405700684, D:0.08878493309020996 [6/10 EPOCHS, 400 ITER] G:5.6324143409729, D:0.18964596092700958 [7/10 EPOCHS, 100 ITER] G:4.257791042327881, D:0.08427975326776505 [7/10 EPOCHS, 200 ITER] G:3.550205707550049, D:0.041241783648729324 [7/10 EPOCHS, 300 ITER] G:3.884168863296509, D:0.030784696340560913 [7/10 EPOCHS, 400 ITER] G:4.196020126342773, D:0.026394512504339218 [8/10 EPOCHS, 100 ITER] G:6.192587852478027, D:0.004940925166010857 [8/10 EPOCHS, 200 ITER] G:5.3129496574401855, D:0.010079941712319851 [8/10 EPOCHS, 300 ITER] G:4.186598300933838, D:0.06414561718702316 [8/10 EPOCHS, 400 ITER] G:5.328157901763916, D:0.08880393952131271 [9/10 EPOCHS, 100 ITER] G:4.083330154418945, D:0.0561496764421463 [9/10 EPOCHS, 200 ITER] G:4.414109230041504, D:0.05515436828136444 [9/10 EPOCHS, 300 ITER] G:4.784689426422119, D:0.019647054374217987 [9/10 EPOCHS, 400 ITER] G:3.9072790145874023, D:0.028399493545293808 [10/10 EPOCHS, 100 ITER] G:3.4736249446868896, D:0.11244657635688782 [10/10 EPOCHS, 200 ITER] G:5.187185287475586, D:0.010640854947268963 [10/10 EPOCHS, 300 ITER] G:5.757209777832031, D:0.13463236391544342 [10/10 EPOCHS, 400 ITER] G:4.209407806396484, D:0.04772578924894333
In [13]:
import numpy as np
noise = tf.random.normal([10, 100])
output = gan_generator(noise)
output = np.squeeze(output.numpy())
plt.figure(figsize=(15,6))
for i in range(1, 11):
plt.subplot(2,5,i)
plt.imshow(output[i-1])
500개의 학습된 GAN 데이터 다운로드¶
- mkdir -p ~/aiffel/conditional_generation/gan
- wget https://aiffelstaticprd.blob.core.windows.net/media/documents/GAN_500.zip
- mv GAN_500.zip ~/aiffel/conditional_generation/gan
- cd ~/aiffel/conditional_generation/gan && unzip GAN_500.zip
In [14]:
import os
weight_path = os.getenv('HOME')+'/aiffel/conditional_generation/gan/GAN_500'
noise = tf.random.normal([10, 100])
gan_generator = GeneratorGAN()
gan_generator.load_weights(weight_path)
output = gan_generator(noise)
output = np.squeeze(output.numpy())
plt.figure(figsize=(15,6))
for i in range(1, 11):
plt.subplot(2,5,i)
plt.imshow(output[i-1])
cGAN으로 MNIST 학습¶
In [15]:
cgan_generator = GeneratorCGAN()
cgan_discriminator = DiscriminatorCGAN()
@tf.function()
def cgan_step(real_images, labels):
noise = tf.random.normal([real_images.shape[0], 100])
with tf.GradientTape(persistent=True) as tape:
fake_images = cgan_generator(noise, labels)
real_out = cgan_discriminator(real_images, labels)
fake_out = cgan_discriminator(fake_images, labels)
gene_loss = generator_loss(fake_out)
disc_loss = discriminator_loss(real_out, fake_out)
gene_grad = tape.gradient(gene_loss, cgan_generator.trainable_variables)
disc_grad = tape.gradient(disc_loss, cgan_discriminator.trainable_variables)
gene_opt.apply_gradients(zip(gene_grad, cgan_generator.trainable_variables))
disc_opt.apply_gradients(zip(disc_grad, cgan_discriminator.trainable_variables))
return gene_loss, disc_loss
EPOCHS = 1
for epoch in range(1, EPOCHS+1):
for i, (images, labels) in enumerate(cgan_datasets):
gene_loss, disc_loss = cgan_step(images, labels)
if (i+1) % 100 == 0:
print(f"[{epoch}/{EPOCHS} EPOCHS, {i} ITER] G:{gene_loss}, D:{disc_loss}")
[1/1 EPOCHS, 99 ITER] G:5.0402069091796875, D:0.013017227873206139 [1/1 EPOCHS, 199 ITER] G:4.192083835601807, D:0.036856137216091156 [1/1 EPOCHS, 299 ITER] G:3.177192449569702, D:0.12768149375915527 [1/1 EPOCHS, 399 ITER] G:4.604147434234619, D:0.05097862333059311
500개 학습 cGAN데이터 다운로드¶
- mkdir -p ~/aiffel/conditional_generation/cgan
- wget https://aiffelstaticprd.blob.core.windows.net/media/documents/CGAN_500.zip
- mv CGAN_500.zip ~/aiffel/conditional_generation/cgan
- cd ~/aiffel/conditional_generation/cgan && unzip CGAN_500.zip
In [16]:
number = 7
weight_path = os.getenv('HOME')+'/aiffel/conditional_generation/cgan/CGAN_500'
noise = tf.random.normal([10, 100])
label = tf.one_hot(number, 10)
label = tf.expand_dims(label, axis=0)
label = tf.repeat(label, 10, axis=0)
generator = GeneratorCGAN()
generator.load_weights(weight_path)
output = generator(noise, label)
output = np.squeeze(output.numpy())
plt.figure(figsize=(15,6))
for i in range(1, 11):
plt.subplot(2,5,i)
plt.imshow(output[i-1])
Pix2Pix - GAN 모델에 이미지 넣기¶
Pix2Pix Generator¶
- 각 레이어마다 Encoder와 Decoder가 연결(skip connection)되어 있어 Decoder가 변환된 이미지를 더 잘 생성하도록 Encoder로부터 더 많은 추가 정보를 이용합니다.
Pix2Pix Discriminator¶
- 하나의 이미지가 Discriminator의 입력으로 들어오면, convolution 레이어를 거쳐 확률값을 나타내는 최종 결과를 생성하는데, 그 결과는 하나의 값이 아닌 여러 개의 값을 갖습니다 (위 그림의 Prediction은 16개의 값을 가지고 있습니다). 위 그림에서 입력이미지의 파란색 점선은 여러 개의 출력 중 하나의 출력을 계산하기 위한 입력이미지의 receptive field 영역을 나타내고 있으며, 전체 영역을 다 보는 것이 아닌 일부 영역(파란색 점선)에 대해서만 진짜/가짜를 판별하는 하나의 확률값을 도출한다는 것입니다.
스케치에서 채색입히기¶
- wget https://aiffelstaticprd.blob.core.windows.net/media/documents/sketch2pokemon.zip
- mv sketch2pokemon.zip ~/aiffel/conditional_generation
- cd ~/aiffel/conditional_generation && unzip sketch2pokemon.zip
1. 데이터 준비¶
In [17]:
import os
data_path = os.getenv('HOME')+'/aiffel/conditional_generation/pokemon_pix2pix_dataset/train/'
print("number of train examples :", len(os.listdir(data_path)))
number of train examples : 830
In [18]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
plt.figure(figsize=(20,15))
for i in range(1, 7):
f = data_path + os.listdir(data_path)[np.random.randint(800)]
img = cv2.imread(f, cv2.IMREAD_COLOR)
plt.subplot(3,2,i)
plt.imshow(img)
In [19]:
f = data_path + os.listdir(data_path)[0]
img = cv2.imread(f, cv2.IMREAD_COLOR)
print(img.shape)
(256, 512, 3)
In [20]:
import tensorflow as tf
def normalize(x):
x = tf.cast(x, tf.float32)
return (x/127.5) - 1
def denormalize(x):
x = (x+1)*127.5
x = x.numpy()
return x.astype(np.uint8)
def load_img(img_path):
img = tf.io.read_file(img_path)
img = tf.image.decode_image(img, 3)
w = tf.shape(img)[1] // 2
sketch = img[:, :w, :]
sketch = tf.cast(sketch, tf.float32)
colored = img[:, w:, :]
colored = tf.cast(colored, tf.float32)
return normalize(sketch), normalize(colored)
f = data_path + os.listdir(data_path)[1]
sketch, colored = load_img(f)
plt.figure(figsize=(10,7))
plt.subplot(1,2,1); plt.imshow(denormalize(sketch))
plt.subplot(1,2,2); plt.imshow(denormalize(colored))
Out[20]:
<matplotlib.image.AxesImage at 0x7f2648184710>
In [21]:
from tensorflow import image
from tensorflow.keras.preprocessing.image import random_rotation
@tf.function() # 빠른 텐서플로 연산을 위해 @tf.function()을 사용합니다.
def apply_augmentation(sketch, colored):
stacked = tf.concat([sketch, colored], axis=-1)
# 두 이미지가 채널 축으로 연결됩니다. (tf.concat). 두 이미지가 각각 3채널인 경우 6채널이 됩니다.
_pad = tf.constant([[30,30],[30,30],[0,0]])
# 결과에 각 50% 확률로 Refection padding 또는 constant padding이 30픽셀의 pad width 만큼적용됩니다. (tf.pad)
if tf.random.uniform(()) < .5:
padded = tf.pad(stacked, _pad, "REFLECT")
else:
padded = tf.pad(stacked, _pad, "CONSTANT", constant_values=1.)
# 결과에서 (256,256,6) 크기를 가진 이미지를 임의로 잘라냅니다. (tf.image.random_crop)
out = image.random_crop(padded, size=[256, 256, 6])
# 결과를 50% 확률로 가로로 뒤집습니다. (tf.image.random_flip_left_right)
out = image.random_flip_left_right(out)
# 결과를 50% 확률로 세로로 뒤집습니다. (tf.image.random_flip_up_down)
out = image.random_flip_up_down(out)
# 결과를 50% 확률로 회전시킵니다. (tf.image.rot90)
if tf.random.uniform(()) < .5:
degree = tf.random.uniform([], minval=1, maxval=4, dtype=tf.int32)
out = image.rot90(out, k=degree)
return out[...,:3], out[...,3:]
print("✅")
✅
In [22]:
plt.figure(figsize=(15,13))
img_n = 1
for i in range(1, 13, 2):
augmented_sketch, augmented_colored = apply_augmentation(sketch, colored)
plt.subplot(3,4,i)
plt.imshow(denormalize(augmented_sketch)); plt.title(f"Image {img_n}")
plt.subplot(3,4,i+1);
plt.imshow(denormalize(augmented_colored)); plt.title(f"Image {img_n}")
img_n += 1
In [23]:
from tensorflow import data
def get_train(img_path):
sketch, colored = load_img(img_path)
sketch, colored = apply_augmentation(sketch, colored)
return sketch, colored
train_images = data.Dataset.list_files(data_path + "*.jpg")
train_images = train_images.map(get_train).shuffle(100).batch(4)
sample = train_images.take(1)
sample = list(sample.as_numpy_iterator())
sketch, colored = (sample[0][0]+1)*127.5, (sample[0][1]+1)*127.5
plt.figure(figsize=(10,5))
plt.subplot(1,2,1); plt.imshow(sketch[0].astype(np.uint8))
plt.subplot(1,2,2); plt.imshow(colored[0].astype(np.uint8))
Out[23]:
<matplotlib.image.AxesImage at 0x7f25900f2c90>
2. Generator 구성¶
In [24]:
from tensorflow.keras import layers, Input, Model
class EncodeBlock(layers.Layer):
def __init__(self, n_filters, use_bn=True):
super(EncodeBlock, self).__init__()
self.use_bn = use_bn
self.conv = layers.Conv2D(n_filters, 4, 2, "same", use_bias=False)
self.batchnorm = layers.BatchNormalization()
self.lrelu= layers.LeakyReLU(0.2)
def call(self, x):
x = self.conv(x)
if self.use_bn:
x = self.batchnorm(x)
return self.lrelu(x)
print("✅")
✅
In [25]:
class Encoder(layers.Layer):
def __init__(self):
super(Encoder, self).__init__()
filters = [64,128,256,512,512,512,512,512]
self.blocks = []
for i, f in enumerate(filters):
if i == 0:
self.blocks.append(EncodeBlock(f, use_bn=False))
else:
self.blocks.append(EncodeBlock(f))
def call(self, x):
for block in self.blocks:
x = block(x)
return x
def get_summary(self, input_shape=(256,256,3)):
inputs = Input(input_shape)
return Model(inputs, self.call(inputs)).summary()
print("✅")
✅
In [26]:
Encoder().get_summary()
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 256, 256, 3)] 0 _________________________________________________________________ encode_block (EncodeBlock) (None, 128, 128, 64) 3072 _________________________________________________________________ encode_block_1 (EncodeBlock) (None, 64, 64, 128) 131584 _________________________________________________________________ encode_block_2 (EncodeBlock) (None, 32, 32, 256) 525312 _________________________________________________________________ encode_block_3 (EncodeBlock) (None, 16, 16, 512) 2099200 _________________________________________________________________ encode_block_4 (EncodeBlock) (None, 8, 8, 512) 4196352 _________________________________________________________________ encode_block_5 (EncodeBlock) (None, 4, 4, 512) 4196352 _________________________________________________________________ encode_block_6 (EncodeBlock) (None, 2, 2, 512) 4196352 _________________________________________________________________ encode_block_7 (EncodeBlock) (None, 1, 1, 512) 4196352 ================================================================= Total params: 19,544,576 Trainable params: 19,538,688 Non-trainable params: 5,888 _________________________________________________________________
In [27]:
class DecodeBlock(layers.Layer):
def __init__(self, f, dropout=True):
super(DecodeBlock, self).__init__()
self.dropout = dropout
self.Transconv = layers.Conv2DTranspose(f, 4, 2, "same", use_bias=False)
self.batchnorm = layers.BatchNormalization()
self.relu = layers.ReLU()
def call(self, x):
x = self.Transconv(x)
x = self.batchnorm(x)
if self.dropout:
x = layers.Dropout(.5)(x)
return self.relu(x)
class Decoder(layers.Layer):
def __init__(self):
super(Decoder, self).__init__()
filters = [512,512,512,512,256,128,64]
self.blocks = []
for i, f in enumerate(filters):
if i < 3:
self.blocks.append(DecodeBlock(f))
else:
self.blocks.append(DecodeBlock(f, dropout=False))
self.blocks.append(layers.Conv2DTranspose(3, 4, 2, "same", use_bias=False))
def call(self, x):
for block in self.blocks:
x = block(x)
return x
def get_summary(self, input_shape=(1,1,256)):
inputs = Input(input_shape)
return Model(inputs, self.call(inputs)).summary()
print("✅")
✅
In [28]:
Decoder().get_summary()
Model: "model_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_2 (InputLayer) [(None, 1, 1, 256)] 0 _________________________________________________________________ decode_block (DecodeBlock) (None, 2, 2, 512) 2099200 _________________________________________________________________ decode_block_1 (DecodeBlock) (None, 4, 4, 512) 4196352 _________________________________________________________________ decode_block_2 (DecodeBlock) (None, 8, 8, 512) 4196352 _________________________________________________________________ decode_block_3 (DecodeBlock) (None, 16, 16, 512) 4196352 _________________________________________________________________ decode_block_4 (DecodeBlock) (None, 32, 32, 256) 2098176 _________________________________________________________________ decode_block_5 (DecodeBlock) (None, 64, 64, 128) 524800 _________________________________________________________________ decode_block_6 (DecodeBlock) (None, 128, 128, 64) 131328 _________________________________________________________________ conv2d_transpose_7 (Conv2DTr (None, 256, 256, 3) 3072 ================================================================= Total params: 17,445,632 Trainable params: 17,440,640 Non-trainable params: 4,992 _________________________________________________________________
In [29]:
class EncoderDecoderGenerator(Model):
def __init__(self):
super(EncoderDecoderGenerator, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def call(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
def get_summary(self, input_shape=(256,256,3)):
inputs = Input(input_shape)
return Model(inputs, self.call(inputs)).summary()
EncoderDecoderGenerator().get_summary()
Model: "model_2" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_3 (InputLayer) [(None, 256, 256, 3)] 0 _________________________________________________________________ encoder_1 (Encoder) (None, 1, 1, 512) 19544576 _________________________________________________________________ decoder_1 (Decoder) (None, 256, 256, 3) 19542784 ================================================================= Total params: 39,087,360 Trainable params: 39,076,480 Non-trainable params: 10,880 _________________________________________________________________
3. Generator 재구성¶
In [30]:
class EncodeBlock(layers.Layer):
def __init__(self, n_filters, use_bn=True):
super(EncodeBlock, self).__init__()
self.use_bn = use_bn
self.conv = layers.Conv2D(n_filters, 4, 2, "same", use_bias=False)
self.batchnorm = layers.BatchNormalization()
self.lrelu = layers.LeakyReLU(0.2)
def call(self, x):
x = self.conv(x)
if self.use_bn:
x = self.batchnorm(x)
return self.lrelu(x)
class DecodeBlock(layers.Layer):
def __init__(self, f, dropout=True):
super(DecodeBlock, self).__init__()
self.dropout = dropout
self.Transconv = layers.Conv2DTranspose(f, 4, 2, "same", use_bias=False)
self.batchnorm = layers.BatchNormalization()
self.relu = layers.ReLU()
def call(self, x):
x = self.Transconv(x)
x = self.batchnorm(x)
if self.dropout:
x = layers.Dropout(.5)(x)
return self.relu(x)
print("✅")
✅
In [31]:
class UNetGenerator(Model):
def __init__(self):
super(UNetGenerator, self).__init__()
encode_filters = [64,128,256,512,512,512,512,512]
decode_filters = [512,512,512,512,256,128,64]
self.encode_blocks = []
for i, f in enumerate(encode_filters):
if i == 0:
self.encode_blocks.append(EncodeBlock(f, use_bn=False))
else:
self.encode_blocks.append(EncodeBlock(f))
self.decode_blocks = []
for i, f in enumerate(decode_filters):
if i < 3:
self.decode_blocks.append(DecodeBlock(f))
else:
self.decode_blocks.append(DecodeBlock(f, dropout=False))
self.last_conv = layers.Conv2DTranspose(3, 4, 2, "same", use_bias=False)
def call(self, x):
features = []
for block in self.encode_blocks:
x = block(x)
features.append(x)
features = features[:-1]
for block, feat in zip(self.decode_blocks, features[::-1]):
x = block(x)
x = layers.Concatenate()([x, feat])
x = self.last_conv(x)
return x
def get_summary(self, input_shape=(256,256,3)):
inputs = Input(input_shape)
return Model(inputs, self.call(inputs)).summary()
print("✅")
✅
In [32]:
UNetGenerator().get_summary()
Model: "model_3" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_4 (InputLayer) [(None, 256, 256, 3) 0 __________________________________________________________________________________________________ encode_block_16 (EncodeBlock) (None, 128, 128, 64) 3072 input_4[0][0] __________________________________________________________________________________________________ encode_block_17 (EncodeBlock) (None, 64, 64, 128) 131584 encode_block_16[0][0] __________________________________________________________________________________________________ encode_block_18 (EncodeBlock) (None, 32, 32, 256) 525312 encode_block_17[0][0] __________________________________________________________________________________________________ encode_block_19 (EncodeBlock) (None, 16, 16, 512) 2099200 encode_block_18[0][0] __________________________________________________________________________________________________ encode_block_20 (EncodeBlock) (None, 8, 8, 512) 4196352 encode_block_19[0][0] __________________________________________________________________________________________________ encode_block_21 (EncodeBlock) (None, 4, 4, 512) 4196352 encode_block_20[0][0] __________________________________________________________________________________________________ encode_block_22 (EncodeBlock) (None, 2, 2, 512) 4196352 encode_block_21[0][0] __________________________________________________________________________________________________ encode_block_23 (EncodeBlock) (None, 1, 1, 512) 4196352 encode_block_22[0][0] __________________________________________________________________________________________________ decode_block_14 (DecodeBlock) (None, 2, 2, 512) 4196352 encode_block_23[0][0] __________________________________________________________________________________________________ concatenate (Concatenate) (None, 2, 2, 1024) 0 decode_block_14[0][0] encode_block_22[0][0] __________________________________________________________________________________________________ decode_block_15 (DecodeBlock) (None, 4, 4, 512) 8390656 concatenate[0][0] __________________________________________________________________________________________________ concatenate_1 (Concatenate) (None, 4, 4, 1024) 0 decode_block_15[0][0] encode_block_21[0][0] __________________________________________________________________________________________________ decode_block_16 (DecodeBlock) (None, 8, 8, 512) 8390656 concatenate_1[0][0] __________________________________________________________________________________________________ concatenate_2 (Concatenate) (None, 8, 8, 1024) 0 decode_block_16[0][0] encode_block_20[0][0] __________________________________________________________________________________________________ decode_block_17 (DecodeBlock) (None, 16, 16, 512) 8390656 concatenate_2[0][0] __________________________________________________________________________________________________ concatenate_3 (Concatenate) (None, 16, 16, 1024) 0 decode_block_17[0][0] encode_block_19[0][0] __________________________________________________________________________________________________ decode_block_18 (DecodeBlock) (None, 32, 32, 256) 4195328 concatenate_3[0][0] __________________________________________________________________________________________________ concatenate_4 (Concatenate) (None, 32, 32, 512) 0 decode_block_18[0][0] encode_block_18[0][0] __________________________________________________________________________________________________ decode_block_19 (DecodeBlock) (None, 64, 64, 128) 1049088 concatenate_4[0][0] __________________________________________________________________________________________________ concatenate_5 (Concatenate) (None, 64, 64, 256) 0 decode_block_19[0][0] encode_block_17[0][0] __________________________________________________________________________________________________ decode_block_20 (DecodeBlock) (None, 128, 128, 64) 262400 concatenate_5[0][0] __________________________________________________________________________________________________ concatenate_6 (Concatenate) (None, 128, 128, 128 0 decode_block_20[0][0] encode_block_16[0][0] __________________________________________________________________________________________________ conv2d_transpose_23 (Conv2DTran (None, 256, 256, 3) 6144 concatenate_6[0][0] ================================================================================================== Total params: 54,425,856 Trainable params: 54,414,976 Non-trainable params: 10,880 __________________________________________________________________________________________________
4. Discriminator 구성¶
In [33]:
class DiscBlock(layers.Layer):
# 필터의 수(n_filters), 필터가 순회하는 간격(stride), 출력 feature map의 크기를 조절할 수 있도록 하는 패딩 설정(custom_pad),
# BatchNorm의 사용 여부(use_bn), 활성화 함수 사용 여부(act)
def __init__(self, n_filters, stride=2, custom_pad=False, use_bn=True, act=True):
super(DiscBlock, self).__init__()
self.custom_pad = custom_pad
self.use_bn = use_bn
self.act = act
if custom_pad:
self.padding = layers.ZeroPadding2D()
self.conv = layers.Conv2D(n_filters, 4, stride, "valid", use_bias=False)
else:
self.conv = layers.Conv2D(n_filters, 4, stride, "same", use_bias=False)
self.batchnorm = layers.BatchNormalization() if use_bn else None
self.lrelu = layers.LeakyReLU(0.2) if act else None
def call(self, x):
if self.custom_pad:
x = self.padding(x)
x = self.conv(x)
else:
x = self.conv(x)
if self.use_bn:
x = self.batchnorm(x)
if self.act:
x = self.lrelu(x)
return x
print("✅")
✅
문제. DiscBlock(n_filters=64, stride=1, custom_pad=True, use_bn=True, act=True) 으로 생성된 블록에 (width, height, channel) = (128, 128, 32) 크기가 입력되면 출력크기가 어떻게 변하는가?¶
- (128,128,32) 크기의 입력이 layers.ZeroPadding2D()를 통과하면, width 및 height의 양쪽 면에 각각 1씩 패딩되어 총 2만큼 크기가 늘어납니다. 출력 : (130,130,32)
- 패딩하지 않고 필터 크기 4 및 간격(stride) 1의 convolution 레이어를 통과하면 width 및 height 가 3씩 줄어듭니다. 이는 OutSize=(InSize+2∗PadSize−FilterSize)/Stride+1 의 식으로 계산할 수 있습니다. 채널 수는 사용한 필터 개수와 같습니다. 출력 : (127,127,64)
- 이 외 다른 레이어(BatchNorm, LeakyReLU)는 출력의 크기에 영향을 주지 않습니다.
In [34]:
inputs = Input((128,128,32))
out = layers.ZeroPadding2D()(inputs)
out = layers.Conv2D(64, 4, 1, "valid", use_bias=False)(out)
out = layers.BatchNormalization()(out)
out = layers.LeakyReLU(0.2)(out)
Model(inputs, out).summary()
Model: "model_4" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_5 (InputLayer) [(None, 128, 128, 32)] 0 _________________________________________________________________ zero_padding2d (ZeroPadding2 (None, 130, 130, 32) 0 _________________________________________________________________ conv2d_24 (Conv2D) (None, 127, 127, 64) 32768 _________________________________________________________________ batch_normalization_45 (Batc (None, 127, 127, 64) 256 _________________________________________________________________ leaky_re_lu_24 (LeakyReLU) (None, 127, 127, 64) 0 ================================================================= Total params: 33,024 Trainable params: 32,896 Non-trainable params: 128 _________________________________________________________________
In [35]:
class Discriminator(Model):
def __init__(self):
super(Discriminator, self).__init__()
self.block1 = layers.Concatenate()
self.block2 = DiscBlock(n_filters=64, stride=2, custom_pad=False, use_bn=False, act=True)
self.block3 = DiscBlock(n_filters=128, stride=2, custom_pad=False, use_bn=True, act=True)
self.block4 = DiscBlock(n_filters=256, stride=2, custom_pad=False, use_bn=True, act=True)
self.block5 = DiscBlock(n_filters=512, stride=1, custom_pad=True, use_bn=True, act=True)
self.block6 = DiscBlock(n_filters=1, stride=1, custom_pad=True, use_bn=False, act=False)
self.sigmoid = layers.Activation("sigmoid")
# filters = [64,128,256,512,1]
# self.blocks = [layers.Concatenate()]
# for i, f in enumerate(filters):
# self.blocks.append(DiscBlock(
# n_filters=f,
# strides=2 if i<3 else 1,
# custom_pad=False if i<3 else True,
# use_bn=False if i==0 and i==4 else True,
# act=True if i<4 else False
# ))
def call(self, x, y):
out = self.block1([x, y])
out = self.block2(out)
out = self.block3(out)
out = self.block4(out)
out = self.block5(out)
out = self.block6(out)
return self.sigmoid(out)
def get_summary(self, x_shape=(256,256,3), y_shape=(256,256,3)):
x, y = Input(x_shape), Input(y_shape)
return Model((x, y), self.call(x, y)).summary()
print("✅")
✅
In [36]:
Discriminator().get_summary()
Model: "model_5" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_6 (InputLayer) [(None, 256, 256, 3) 0 __________________________________________________________________________________________________ input_7 (InputLayer) [(None, 256, 256, 3) 0 __________________________________________________________________________________________________ concatenate_7 (Concatenate) (None, 256, 256, 6) 0 input_6[0][0] input_7[0][0] __________________________________________________________________________________________________ disc_block (DiscBlock) (None, 128, 128, 64) 6144 concatenate_7[0][0] __________________________________________________________________________________________________ disc_block_1 (DiscBlock) (None, 64, 64, 128) 131584 disc_block[0][0] __________________________________________________________________________________________________ disc_block_2 (DiscBlock) (None, 32, 32, 256) 525312 disc_block_1[0][0] __________________________________________________________________________________________________ disc_block_3 (DiscBlock) (None, 31, 31, 512) 2099200 disc_block_2[0][0] __________________________________________________________________________________________________ disc_block_4 (DiscBlock) (None, 30, 30, 1) 8192 disc_block_3[0][0] __________________________________________________________________________________________________ activation (Activation) (None, 30, 30, 1) 0 disc_block_4[0][0] ================================================================================================== Total params: 2,770,432 Trainable params: 2,768,640 Non-trainable params: 1,792 __________________________________________________________________________________________________
In [37]:
x = tf.random.normal([1,256,256,3])
y = tf.random.uniform([1,256,256,3])
disc_out = Discriminator()(x, y)
plt.imshow(disc_out[0, ... ,0])
plt.colorbar()
Out[37]:
<matplotlib.colorbar.Colorbar at 0x7f2589b0a550>
5. 학습 및 테스트¶
In [38]:
from tensorflow.keras import losses
bce = losses.BinaryCrossentropy(from_logits=False)
mae = losses.MeanAbsoluteError()
def get_gene_loss(fake_output, real_output, fake_disc):
l1_loss = mae(real_output, fake_output)
gene_loss = bce(tf.ones_like(fake_disc), fake_disc)
return gene_loss, l1_loss
def get_disc_loss(fake_disc, real_disc):
return bce(tf.zeros_like(fake_disc), fake_disc) + bce(tf.ones_like(real_disc), real_disc)
print("✅")
✅
Generator의 손실함수 (위 코드의 get_gene_loss)는 총 3개의 입력이 있습니다. 이 중 fake_disc는 Generator가 생성한 가짜 이미지를 Discriminator에 입력하여 얻어진 값이며, 실제 이미지를 뜻하는 "1"과 비교하기 위해 tf.ones_like()를 사용합니다. 또한 L1 손실을 계산하기 위해 생성한 가짜 이미지(fake_output)와 실제 이미지(real_output) 사이의 MAE(Mean Absolute Error)를 계산합니다.
Discriminator의 손실함수 (위 코드의 get_disc_loss)는 2개의 입력이 있으며, 이들은 가짜 및 진짜 이미지가 Discriminator에 각각 입력되어 얻어진 값입니다. Discriminator는 실제 이미지를 잘 구분해 내야 하므로 real_disc는 "1"로 채워진 벡터와 비교하고, fake_disc는 "0"으로 채워진 벡터와 비교합니다.
In [39]:
from tensorflow.keras import optimizers
gene_opt = optimizers.Adam(2e-4, beta_1=.5, beta_2=.999)
disc_opt = optimizers.Adam(2e-4, beta_1=.5, beta_2=.999)
print("✅")
✅
In [40]:
@tf.function
def train_step(sketch, real_colored):
with tf.GradientTape() as gene_tape, tf.GradientTape() as disc_tape:
# Generator 예측
fake_colored = generator(sketch, training=True)
# Discriminator 예측
fake_disc = discriminator(sketch, fake_colored, training=True)
real_disc = discriminator(sketch, real_colored, training=True)
# Generator 손실 계산
gene_loss, l1_loss = get_gene_loss(fake_colored, real_colored, fake_disc)
gene_total_loss = gene_loss + (100 * l1_loss) ## <===== L1 손실 반영 λ=100
# Discrminator 손실 계산
disc_loss = get_disc_loss(fake_disc, real_disc)
gene_gradient = gene_tape.gradient(gene_total_loss, generator.trainable_variables)
disc_gradient = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
gene_opt.apply_gradients(zip(gene_gradient, generator.trainable_variables))
disc_opt.apply_gradients(zip(disc_gradient, discriminator.trainable_variables))
return gene_loss, l1_loss, disc_loss
print("✅")
✅
In [41]:
EPOCHS = 10
generator = UNetGenerator()
discriminator = Discriminator()
for epoch in range(1, EPOCHS+1):
for i, (sketch, colored) in enumerate(train_images):
g_loss, l1_loss, d_loss = train_step(sketch, colored)
# 10회 반복마다 손실을 출력합니다.
if (i+1) % 10 == 0:
print(f"EPOCH[{epoch}] - STEP[{i+1}] \
\nGenerator_loss:{g_loss.numpy():.4f} \
\nL1_loss:{l1_loss.numpy():.4f} \
\nDiscriminator_loss:{d_loss.numpy():.4f}", end="\n\n")
WARNING:tensorflow:Unresolved object in checkpoint: (root).dropout_z
WARNING:tensorflow:Unresolved object in checkpoint: (root).dropout_z
WARNING:tensorflow:Unresolved object in checkpoint: (root).dropout_y
WARNING:tensorflow:Unresolved object in checkpoint: (root).dropout_y
WARNING:tensorflow:Unresolved object in checkpoint: (root).dropout_x
WARNING:tensorflow:Unresolved object in checkpoint: (root).dropout_x
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
EPOCH[1] - STEP[10] Generator_loss:1.0010 L1_loss:0.4576 Discriminator_loss:1.1311 EPOCH[1] - STEP[20] Generator_loss:0.5860 L1_loss:0.3019 Discriminator_loss:1.5537 EPOCH[1] - STEP[30] Generator_loss:1.0363 L1_loss:0.2897 Discriminator_loss:1.1280 EPOCH[1] - STEP[40] Generator_loss:0.8726 L1_loss:0.2857 Discriminator_loss:1.0730 EPOCH[1] - STEP[50] Generator_loss:1.3895 L1_loss:0.2602 Discriminator_loss:0.6744 EPOCH[1] - STEP[60] Generator_loss:2.1092 L1_loss:0.2362 Discriminator_loss:0.5553 EPOCH[1] - STEP[70] Generator_loss:1.8627 L1_loss:0.2583 Discriminator_loss:0.5313 EPOCH[1] - STEP[80] Generator_loss:2.1157 L1_loss:0.2191 Discriminator_loss:0.4157 EPOCH[1] - STEP[90] Generator_loss:2.4823 L1_loss:0.2151 Discriminator_loss:0.3816 EPOCH[1] - STEP[100] Generator_loss:2.5487 L1_loss:0.2522 Discriminator_loss:0.8413 EPOCH[1] - STEP[110] Generator_loss:2.6772 L1_loss:0.2553 Discriminator_loss:0.9682 EPOCH[1] - STEP[120] Generator_loss:2.6156 L1_loss:0.2277 Discriminator_loss:0.6573 EPOCH[1] - STEP[130] Generator_loss:0.5419 L1_loss:0.2407 Discriminator_loss:1.1320 EPOCH[1] - STEP[140] Generator_loss:0.5518 L1_loss:0.2454 Discriminator_loss:1.0725 EPOCH[1] - STEP[150] Generator_loss:2.3182 L1_loss:0.2183 Discriminator_loss:1.1135 EPOCH[1] - STEP[160] Generator_loss:1.7766 L1_loss:0.1966 Discriminator_loss:1.1099 EPOCH[1] - STEP[170] Generator_loss:1.8285 L1_loss:0.2368 Discriminator_loss:0.3756 EPOCH[1] - STEP[180] Generator_loss:1.2804 L1_loss:0.2616 Discriminator_loss:0.6728 EPOCH[1] - STEP[190] Generator_loss:1.9089 L1_loss:0.3064 Discriminator_loss:0.3559 EPOCH[1] - STEP[200] Generator_loss:2.6578 L1_loss:0.2396 Discriminator_loss:0.3953 EPOCH[2] - STEP[10] Generator_loss:1.6381 L1_loss:0.2756 Discriminator_loss:0.3331 EPOCH[2] - STEP[20] Generator_loss:2.4975 L1_loss:0.3243 Discriminator_loss:0.1724 EPOCH[2] - STEP[30] Generator_loss:1.5875 L1_loss:0.3369 Discriminator_loss:0.3291 EPOCH[2] - STEP[40] Generator_loss:2.9748 L1_loss:0.2089 Discriminator_loss:0.3915 EPOCH[2] - STEP[50] Generator_loss:2.4194 L1_loss:0.2241 Discriminator_loss:0.4415 EPOCH[2] - STEP[60] Generator_loss:1.9780 L1_loss:0.2090 Discriminator_loss:1.1041 EPOCH[2] - STEP[70] Generator_loss:1.7335 L1_loss:0.2343 Discriminator_loss:0.3062 EPOCH[2] - STEP[80] Generator_loss:3.3015 L1_loss:0.2102 Discriminator_loss:1.0488 EPOCH[2] - STEP[90] Generator_loss:2.3993 L1_loss:0.2646 Discriminator_loss:0.5095 EPOCH[2] - STEP[100] Generator_loss:1.2300 L1_loss:0.2390 Discriminator_loss:0.4717 EPOCH[2] - STEP[110] Generator_loss:2.8951 L1_loss:0.2487 Discriminator_loss:1.8513 EPOCH[2] - STEP[120] Generator_loss:0.6708 L1_loss:0.2976 Discriminator_loss:1.0479 EPOCH[2] - STEP[130] Generator_loss:1.6145 L1_loss:0.2373 Discriminator_loss:0.9767 EPOCH[2] - STEP[140] Generator_loss:1.6705 L1_loss:0.2223 Discriminator_loss:0.4310 EPOCH[2] - STEP[150] Generator_loss:2.4708 L1_loss:0.2724 Discriminator_loss:0.2210 EPOCH[2] - STEP[160] Generator_loss:0.7588 L1_loss:0.3230 Discriminator_loss:0.8766 EPOCH[2] - STEP[170] Generator_loss:0.4012 L1_loss:0.2401 Discriminator_loss:1.2912 EPOCH[2] - STEP[180] Generator_loss:3.0479 L1_loss:0.2807 Discriminator_loss:0.1811 EPOCH[2] - STEP[190] Generator_loss:0.2137 L1_loss:0.2302 Discriminator_loss:2.0025 EPOCH[2] - STEP[200] Generator_loss:1.9751 L1_loss:0.2297 Discriminator_loss:0.3146 EPOCH[3] - STEP[10] Generator_loss:0.6792 L1_loss:0.2171 Discriminator_loss:0.8624 EPOCH[3] - STEP[20] Generator_loss:2.3557 L1_loss:0.2137 Discriminator_loss:0.7816 EPOCH[3] - STEP[30] Generator_loss:2.5031 L1_loss:0.2536 Discriminator_loss:0.2880 EPOCH[3] - STEP[40] Generator_loss:3.0610 L1_loss:0.2307 Discriminator_loss:0.5051 EPOCH[3] - STEP[50] Generator_loss:1.1017 L1_loss:0.2031 Discriminator_loss:0.6411 EPOCH[3] - STEP[60] Generator_loss:2.6103 L1_loss:0.2628 Discriminator_loss:0.1697 EPOCH[3] - STEP[70] Generator_loss:2.6398 L1_loss:0.2530 Discriminator_loss:0.2809 EPOCH[3] - STEP[80] Generator_loss:0.8708 L1_loss:0.2631 Discriminator_loss:0.8718 EPOCH[3] - STEP[90] Generator_loss:0.9146 L1_loss:0.2529 Discriminator_loss:0.8837 EPOCH[3] - STEP[100] Generator_loss:1.1095 L1_loss:0.2459 Discriminator_loss:0.6568 EPOCH[3] - STEP[110] Generator_loss:2.4003 L1_loss:0.2172 Discriminator_loss:0.7640 EPOCH[3] - STEP[120] Generator_loss:3.9347 L1_loss:0.2997 Discriminator_loss:1.0989 EPOCH[3] - STEP[130] Generator_loss:4.3341 L1_loss:0.2526 Discriminator_loss:1.9718 EPOCH[3] - STEP[140] Generator_loss:3.1539 L1_loss:0.2170 Discriminator_loss:0.4247 EPOCH[3] - STEP[150] Generator_loss:1.5405 L1_loss:0.2283 Discriminator_loss:0.6444 EPOCH[3] - STEP[160] Generator_loss:3.7489 L1_loss:0.2042 Discriminator_loss:1.1447 EPOCH[3] - STEP[170] Generator_loss:0.5597 L1_loss:0.2419 Discriminator_loss:2.7373 EPOCH[3] - STEP[180] Generator_loss:1.9020 L1_loss:0.3081 Discriminator_loss:0.7936 EPOCH[3] - STEP[190] Generator_loss:3.4013 L1_loss:0.2896 Discriminator_loss:2.1105 EPOCH[3] - STEP[200] Generator_loss:1.0373 L1_loss:0.3002 Discriminator_loss:0.5114 EPOCH[4] - STEP[10] Generator_loss:1.0762 L1_loss:0.2708 Discriminator_loss:0.5891 EPOCH[4] - STEP[20] Generator_loss:1.9751 L1_loss:0.2262 Discriminator_loss:0.5233 EPOCH[4] - STEP[30] Generator_loss:2.1259 L1_loss:0.2289 Discriminator_loss:1.0054 EPOCH[4] - STEP[40] Generator_loss:2.5949 L1_loss:0.3342 Discriminator_loss:0.1857 EPOCH[4] - STEP[50] Generator_loss:0.3508 L1_loss:0.2522 Discriminator_loss:1.5320 EPOCH[4] - STEP[60] Generator_loss:1.7856 L1_loss:0.2462 Discriminator_loss:0.2435 EPOCH[4] - STEP[70] Generator_loss:0.8020 L1_loss:0.2371 Discriminator_loss:1.1231 EPOCH[4] - STEP[80] Generator_loss:0.3059 L1_loss:0.2372 Discriminator_loss:1.6619 EPOCH[4] - STEP[90] Generator_loss:3.7595 L1_loss:0.2914 Discriminator_loss:0.2115 EPOCH[4] - STEP[100] Generator_loss:2.7792 L1_loss:0.2264 Discriminator_loss:1.8937 EPOCH[4] - STEP[110] Generator_loss:2.0566 L1_loss:0.2160 Discriminator_loss:1.0014 EPOCH[4] - STEP[120] Generator_loss:0.2768 L1_loss:0.2584 Discriminator_loss:1.7309 EPOCH[4] - STEP[130] Generator_loss:3.4893 L1_loss:0.2903 Discriminator_loss:0.5809 EPOCH[4] - STEP[140] Generator_loss:0.0466 L1_loss:0.2457 Discriminator_loss:3.4347 EPOCH[4] - STEP[150] Generator_loss:1.7382 L1_loss:0.2292 Discriminator_loss:0.5870 EPOCH[4] - STEP[160] Generator_loss:0.6523 L1_loss:0.2256 Discriminator_loss:1.4057 EPOCH[4] - STEP[170] Generator_loss:0.6437 L1_loss:0.2692 Discriminator_loss:1.7672 EPOCH[4] - STEP[180] Generator_loss:0.9811 L1_loss:0.2238 Discriminator_loss:2.2538 EPOCH[4] - STEP[190] Generator_loss:1.3778 L1_loss:0.2474 Discriminator_loss:0.5708 EPOCH[4] - STEP[200] Generator_loss:1.3852 L1_loss:0.2366 Discriminator_loss:0.4235 EPOCH[5] - STEP[10] Generator_loss:2.0459 L1_loss:0.2008 Discriminator_loss:0.2922 EPOCH[5] - STEP[20] Generator_loss:1.2076 L1_loss:0.2370 Discriminator_loss:0.5079 EPOCH[5] - STEP[30] Generator_loss:1.5700 L1_loss:0.2709 Discriminator_loss:0.4503 EPOCH[5] - STEP[40] Generator_loss:2.2638 L1_loss:0.2891 Discriminator_loss:0.2829 EPOCH[5] - STEP[50] Generator_loss:1.3253 L1_loss:0.1984 Discriminator_loss:0.5384 EPOCH[5] - STEP[60] Generator_loss:2.9162 L1_loss:0.2350 Discriminator_loss:0.2032 EPOCH[5] - STEP[70] Generator_loss:2.8247 L1_loss:0.2722 Discriminator_loss:0.1778 EPOCH[5] - STEP[80] Generator_loss:1.3745 L1_loss:0.2689 Discriminator_loss:0.4644 EPOCH[5] - STEP[90] Generator_loss:1.2800 L1_loss:0.2785 Discriminator_loss:0.4416 EPOCH[5] - STEP[100] Generator_loss:0.2251 L1_loss:0.2669 Discriminator_loss:1.7830 EPOCH[5] - STEP[110] Generator_loss:1.1723 L1_loss:0.2268 Discriminator_loss:0.4798 EPOCH[5] - STEP[120] Generator_loss:0.0958 L1_loss:0.2115 Discriminator_loss:2.6382 EPOCH[5] - STEP[130] Generator_loss:1.6692 L1_loss:0.2540 Discriminator_loss:1.3435 EPOCH[5] - STEP[140] Generator_loss:1.8382 L1_loss:0.2541 Discriminator_loss:0.2337 EPOCH[5] - STEP[150] Generator_loss:2.0201 L1_loss:0.2771 Discriminator_loss:0.2272 EPOCH[5] - STEP[160] Generator_loss:1.9623 L1_loss:0.2450 Discriminator_loss:1.2168 EPOCH[5] - STEP[170] Generator_loss:2.2365 L1_loss:0.2410 Discriminator_loss:0.2749 EPOCH[5] - STEP[180] Generator_loss:2.3378 L1_loss:0.1989 Discriminator_loss:2.1019 EPOCH[5] - STEP[190] Generator_loss:1.2331 L1_loss:0.2542 Discriminator_loss:0.5702 EPOCH[5] - STEP[200] Generator_loss:0.9026 L1_loss:0.2183 Discriminator_loss:0.6382 EPOCH[6] - STEP[10] Generator_loss:2.0028 L1_loss:0.3390 Discriminator_loss:0.2052 EPOCH[6] - STEP[20] Generator_loss:1.5195 L1_loss:0.2093 Discriminator_loss:1.2448 EPOCH[6] - STEP[30] Generator_loss:1.5588 L1_loss:0.2535 Discriminator_loss:0.4039 EPOCH[6] - STEP[40] Generator_loss:1.6981 L1_loss:0.2142 Discriminator_loss:0.4145 EPOCH[6] - STEP[50] Generator_loss:1.1488 L1_loss:0.2702 Discriminator_loss:0.4974 EPOCH[6] - STEP[60] Generator_loss:0.5422 L1_loss:0.2604 Discriminator_loss:1.0784 EPOCH[6] - STEP[70] Generator_loss:1.6072 L1_loss:0.2213 Discriminator_loss:0.4323 EPOCH[6] - STEP[80] Generator_loss:0.3009 L1_loss:0.2074 Discriminator_loss:1.7673 EPOCH[6] - STEP[90] Generator_loss:0.7464 L1_loss:0.2677 Discriminator_loss:0.9281 EPOCH[6] - STEP[100] Generator_loss:1.4725 L1_loss:0.2580 Discriminator_loss:0.4680 EPOCH[6] - STEP[110] Generator_loss:2.9919 L1_loss:0.2957 Discriminator_loss:0.2527 EPOCH[6] - STEP[120] Generator_loss:3.2498 L1_loss:0.2617 Discriminator_loss:0.6698 EPOCH[6] - STEP[130] Generator_loss:2.2918 L1_loss:0.2280 Discriminator_loss:0.2545 EPOCH[6] - STEP[140] Generator_loss:0.6840 L1_loss:0.2670 Discriminator_loss:0.9725 EPOCH[6] - STEP[150] Generator_loss:1.0994 L1_loss:0.2576 Discriminator_loss:0.4782 EPOCH[6] - STEP[160] Generator_loss:2.8691 L1_loss:0.2266 Discriminator_loss:0.6030 EPOCH[6] - STEP[170] Generator_loss:1.0839 L1_loss:0.2526 Discriminator_loss:0.6908 EPOCH[6] - STEP[180] Generator_loss:1.0759 L1_loss:0.2429 Discriminator_loss:0.5160 EPOCH[6] - STEP[190] Generator_loss:2.0028 L1_loss:0.1939 Discriminator_loss:1.0346 EPOCH[6] - STEP[200] Generator_loss:1.5050 L1_loss:0.2163 Discriminator_loss:1.5739 EPOCH[7] - STEP[10] Generator_loss:1.8876 L1_loss:0.2825 Discriminator_loss:0.2179 EPOCH[7] - STEP[20] Generator_loss:1.3367 L1_loss:0.2064 Discriminator_loss:0.8425 EPOCH[7] - STEP[30] Generator_loss:3.5250 L1_loss:0.2551 Discriminator_loss:0.7233 EPOCH[7] - STEP[40] Generator_loss:0.1946 L1_loss:0.2836 Discriminator_loss:1.9609 EPOCH[7] - STEP[50] Generator_loss:1.5293 L1_loss:0.2353 Discriminator_loss:0.5123 EPOCH[7] - STEP[60] Generator_loss:1.4980 L1_loss:0.2420 Discriminator_loss:0.3616 EPOCH[7] - STEP[70] Generator_loss:2.7583 L1_loss:0.2843 Discriminator_loss:0.1780 EPOCH[7] - STEP[80] Generator_loss:3.4033 L1_loss:0.2519 Discriminator_loss:0.6679 EPOCH[7] - STEP[90] Generator_loss:0.9387 L1_loss:0.2728 Discriminator_loss:0.5983 EPOCH[7] - STEP[100] Generator_loss:1.5025 L1_loss:0.2827 Discriminator_loss:0.3407 EPOCH[7] - STEP[110] Generator_loss:1.9227 L1_loss:0.3165 Discriminator_loss:0.2894 EPOCH[7] - STEP[120] Generator_loss:0.9507 L1_loss:0.2009 Discriminator_loss:1.0880 EPOCH[7] - STEP[130] Generator_loss:1.7721 L1_loss:0.2382 Discriminator_loss:0.3245 EPOCH[7] - STEP[140] Generator_loss:1.1416 L1_loss:0.2439 Discriminator_loss:0.5357 EPOCH[7] - STEP[150] Generator_loss:2.4210 L1_loss:0.2369 Discriminator_loss:2.1473 EPOCH[7] - STEP[160] Generator_loss:0.2619 L1_loss:0.2466 Discriminator_loss:1.7886 EPOCH[7] - STEP[170] Generator_loss:1.1228 L1_loss:0.2367 Discriminator_loss:0.7020 EPOCH[7] - STEP[180] Generator_loss:2.2024 L1_loss:0.2365 Discriminator_loss:2.1149 EPOCH[7] - STEP[190] Generator_loss:2.6664 L1_loss:0.2624 Discriminator_loss:1.1143 EPOCH[7] - STEP[200] Generator_loss:0.7669 L1_loss:0.2113 Discriminator_loss:0.8657 EPOCH[8] - STEP[10] Generator_loss:4.3373 L1_loss:0.2649 Discriminator_loss:0.6480 EPOCH[8] - STEP[20] Generator_loss:3.1080 L1_loss:0.2202 Discriminator_loss:0.1472 EPOCH[8] - STEP[30] Generator_loss:2.9131 L1_loss:0.2353 Discriminator_loss:0.4729 EPOCH[8] - STEP[40] Generator_loss:2.3592 L1_loss:0.2343 Discriminator_loss:1.9335 EPOCH[8] - STEP[50] Generator_loss:1.1812 L1_loss:0.2453 Discriminator_loss:0.4318 EPOCH[8] - STEP[60] Generator_loss:3.7869 L1_loss:0.2125 Discriminator_loss:1.5767 EPOCH[8] - STEP[70] Generator_loss:1.3765 L1_loss:0.2293 Discriminator_loss:0.6949 EPOCH[8] - STEP[80] Generator_loss:0.8298 L1_loss:0.2804 Discriminator_loss:0.7971 EPOCH[8] - STEP[90] Generator_loss:2.8699 L1_loss:0.2279 Discriminator_loss:1.8825 EPOCH[8] - STEP[100] Generator_loss:2.5272 L1_loss:0.2582 Discriminator_loss:0.3203 EPOCH[8] - STEP[110] Generator_loss:0.7175 L1_loss:0.3123 Discriminator_loss:0.7772 EPOCH[8] - STEP[120] Generator_loss:0.5865 L1_loss:0.1829 Discriminator_loss:1.0341 EPOCH[8] - STEP[130] Generator_loss:1.1208 L1_loss:0.2334 Discriminator_loss:0.5789 EPOCH[8] - STEP[140] Generator_loss:3.4102 L1_loss:0.2345 Discriminator_loss:0.3220 EPOCH[8] - STEP[150] Generator_loss:1.8477 L1_loss:0.2354 Discriminator_loss:0.4835 EPOCH[8] - STEP[160] Generator_loss:1.4046 L1_loss:0.2915 Discriminator_loss:0.3685 EPOCH[8] - STEP[170] Generator_loss:1.9019 L1_loss:0.2346 Discriminator_loss:0.9833 EPOCH[8] - STEP[180] Generator_loss:2.5868 L1_loss:0.2581 Discriminator_loss:0.2518 EPOCH[8] - STEP[190] Generator_loss:2.2489 L1_loss:0.2544 Discriminator_loss:0.4895 EPOCH[8] - STEP[200] Generator_loss:1.7930 L1_loss:0.2774 Discriminator_loss:0.3417 EPOCH[9] - STEP[10] Generator_loss:0.7260 L1_loss:0.2471 Discriminator_loss:0.7713 EPOCH[9] - STEP[20] Generator_loss:2.6876 L1_loss:0.2573 Discriminator_loss:1.4210 EPOCH[9] - STEP[30] Generator_loss:1.9275 L1_loss:0.2337 Discriminator_loss:1.0809 EPOCH[9] - STEP[40] Generator_loss:2.6430 L1_loss:0.2633 Discriminator_loss:1.1877 EPOCH[9] - STEP[50] Generator_loss:1.6845 L1_loss:0.2987 Discriminator_loss:0.3077 EPOCH[9] - STEP[60] Generator_loss:1.9599 L1_loss:0.2100 Discriminator_loss:0.4789 EPOCH[9] - STEP[70] Generator_loss:1.6082 L1_loss:0.2662 Discriminator_loss:0.4247 EPOCH[9] - STEP[80] Generator_loss:3.0757 L1_loss:0.2424 Discriminator_loss:0.4668 EPOCH[9] - STEP[90] Generator_loss:2.9688 L1_loss:0.2502 Discriminator_loss:0.3580 EPOCH[9] - STEP[100] Generator_loss:2.3038 L1_loss:0.2138 Discriminator_loss:0.4733 EPOCH[9] - STEP[110] Generator_loss:1.3259 L1_loss:0.2068 Discriminator_loss:0.5537 EPOCH[9] - STEP[120] Generator_loss:2.6865 L1_loss:0.2595 Discriminator_loss:0.8708 EPOCH[9] - STEP[130] Generator_loss:2.2758 L1_loss:0.2320 Discriminator_loss:1.9853 EPOCH[9] - STEP[140] Generator_loss:2.2018 L1_loss:0.2315 Discriminator_loss:0.5504 EPOCH[9] - STEP[150] Generator_loss:1.3302 L1_loss:0.2479 Discriminator_loss:0.8400 EPOCH[9] - STEP[160] Generator_loss:2.0148 L1_loss:0.3700 Discriminator_loss:0.3058 EPOCH[9] - STEP[170] Generator_loss:1.4845 L1_loss:0.2067 Discriminator_loss:1.2632 EPOCH[9] - STEP[180] Generator_loss:0.3145 L1_loss:0.2350 Discriminator_loss:1.4872 EPOCH[9] - STEP[190] Generator_loss:2.0284 L1_loss:0.3577 Discriminator_loss:0.2145 EPOCH[9] - STEP[200] Generator_loss:1.0712 L1_loss:0.2115 Discriminator_loss:0.6170 EPOCH[10] - STEP[10] Generator_loss:3.7951 L1_loss:0.2425 Discriminator_loss:0.6173 EPOCH[10] - STEP[20] Generator_loss:1.9505 L1_loss:0.2107 Discriminator_loss:1.6559 EPOCH[10] - STEP[30] Generator_loss:2.5263 L1_loss:0.2486 Discriminator_loss:1.5598 EPOCH[10] - STEP[40] Generator_loss:1.8633 L1_loss:0.2578 Discriminator_loss:1.2147 EPOCH[10] - STEP[50] Generator_loss:0.7814 L1_loss:0.2352 Discriminator_loss:0.7886 EPOCH[10] - STEP[60] Generator_loss:1.8119 L1_loss:0.2139 Discriminator_loss:0.5041 EPOCH[10] - STEP[70] Generator_loss:1.6571 L1_loss:0.2438 Discriminator_loss:0.3205 EPOCH[10] - STEP[80] Generator_loss:2.9587 L1_loss:0.2740 Discriminator_loss:0.0883 EPOCH[10] - STEP[90] Generator_loss:1.7818 L1_loss:0.2628 Discriminator_loss:0.2695 EPOCH[10] - STEP[100] Generator_loss:1.3280 L1_loss:0.2163 Discriminator_loss:0.8481 EPOCH[10] - STEP[110] Generator_loss:1.0679 L1_loss:0.2534 Discriminator_loss:0.5530 EPOCH[10] - STEP[120] Generator_loss:1.3815 L1_loss:0.2723 Discriminator_loss:0.4162 EPOCH[10] - STEP[130] Generator_loss:1.6723 L1_loss:0.2788 Discriminator_loss:0.4820 EPOCH[10] - STEP[140] Generator_loss:1.2653 L1_loss:0.2392 Discriminator_loss:0.5801 EPOCH[10] - STEP[150] Generator_loss:0.5502 L1_loss:0.2402 Discriminator_loss:1.1338 EPOCH[10] - STEP[160] Generator_loss:1.2634 L1_loss:0.2923 Discriminator_loss:0.4193 EPOCH[10] - STEP[170] Generator_loss:1.3364 L1_loss:0.2558 Discriminator_loss:0.8977 EPOCH[10] - STEP[180] Generator_loss:2.8509 L1_loss:0.2563 Discriminator_loss:1.3814 EPOCH[10] - STEP[190] Generator_loss:0.3459 L1_loss:0.2372 Discriminator_loss:1.4756 EPOCH[10] - STEP[200] Generator_loss:0.5680 L1_loss:0.2222 Discriminator_loss:0.9808
In [42]:
test_ind = 1
f = data_path + os.listdir(data_path)[test_ind]
sketch, colored = load_img(f)
pred = generator(tf.expand_dims(sketch, 0))
pred = denormalize(pred)
plt.figure(figsize=(20,10))
plt.subplot(1,3,1); plt.imshow(denormalize(sketch))
plt.subplot(1,3,2); plt.imshow(pred[0])
plt.subplot(1,3,3); plt.imshow(denormalize(colored))
Out[42]:
<matplotlib.image.AxesImage at 0x7f2588cd0890>
'AIFFEL' 카테고리의 다른 글
Explolation19 Bert를 이용한 퀴즈 정답 예측 (0) | 2021.03.18 |
---|---|
Explolation17 예측 추천 시스템 (0) | 2021.03.15 |
Explolation18 OCR로 글씨 인식 (0) | 2021.03.11 |
Explolation15 챗봇 만들기 (0) | 2021.03.03 |
Explolation14 의료영상 진단 (0) | 2021.02.25 |