initial commit
This commit is contained in:
213
apps/sketch_gan.py
Normal file
213
apps/sketch_gan.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""A simple training interface using ttools."""
|
||||
import argparse
|
||||
import os
|
||||
import logging
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision.datasets import MNIST
|
||||
import torchvision.transforms as xforms
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import ttools
|
||||
import ttools.interfaces
|
||||
|
||||
import pydiffvg
|
||||
|
||||
LOG = ttools.get_logger(__name__)
|
||||
|
||||
pydiffvg.render_pytorch.print_timing = False
|
||||
|
||||
torch.manual_seed(123)
|
||||
np.random.seed(123)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
latent_dim = 100
|
||||
img_size = 32
|
||||
num_paths = 8
|
||||
num_segments = 8
|
||||
|
||||
def weights_init_normal(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
elif classname.find("BatchNorm2d") != -1:
|
||||
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||
torch.nn.init.constant_(m.bias.data, 0.0)
|
||||
|
||||
class VisdomImageCallback(ttools.callbacks.ImageDisplayCallback):
|
||||
def visualized_image(self, batch, fwd_result):
|
||||
return torch.cat([batch[0], fwd_result.cpu()], dim = 2)
|
||||
|
||||
# From https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/dcgan/dcgan.py
|
||||
class Generator(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Generator, self).__init__()
|
||||
|
||||
self.fc = torch.nn.Sequential(
|
||||
torch.nn.Linear(latent_dim, 128),
|
||||
torch.nn.LeakyReLU(0.2, inplace=True),
|
||||
torch.nn.Linear(128, 256),
|
||||
torch.nn.LeakyReLU(0.2, inplace=True),
|
||||
torch.nn.Linear(256, 512),
|
||||
torch.nn.LeakyReLU(0.2, inplace=True),
|
||||
torch.nn.Linear(512, 1024),
|
||||
torch.nn.LeakyReLU(0.2, inplace=True),
|
||||
torch.nn.Linear(1024, 2 * num_paths * (num_segments + 1) + num_paths + num_paths),
|
||||
torch.nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, z):
|
||||
out = self.fc(z)
|
||||
# construct paths
|
||||
imgs = []
|
||||
for b in range(out.shape[0]):
|
||||
index = 0
|
||||
shapes = []
|
||||
shape_groups = []
|
||||
for i in range(num_paths):
|
||||
points = img_size * out[b, index: index + 2 * (num_segments + 1)].view(-1, 2).cpu()
|
||||
index += 2 * (num_segments + 1)
|
||||
stroke_width = img_size * out[b, index].view(1).cpu()
|
||||
index += 1
|
||||
|
||||
num_control_points = torch.zeros(num_segments, dtype = torch.int32) + 2
|
||||
path = pydiffvg.Path(num_control_points = num_control_points,
|
||||
points = points,
|
||||
stroke_width = stroke_width,
|
||||
is_closed = False)
|
||||
shapes.append(path)
|
||||
|
||||
stroke_color = out[b, index].view(1).cpu()
|
||||
index += 1
|
||||
stroke_color = torch.cat([stroke_color, torch.tensor([0.0, 0.0, 1.0])])
|
||||
path_group = pydiffvg.ShapeGroup(shape_ids = torch.tensor([len(shapes) - 1]),
|
||||
fill_color = None,
|
||||
stroke_color = stroke_color)
|
||||
shape_groups.append(path_group)
|
||||
scene_args = pydiffvg.RenderFunction.serialize_scene(img_size, img_size, shapes, shape_groups)
|
||||
render = pydiffvg.RenderFunction.apply
|
||||
img = render(img_size, # width
|
||||
img_size, # height
|
||||
2, # num_samples_x
|
||||
2, # num_samples_y
|
||||
random.randint(0, 1048576), # seed
|
||||
None,
|
||||
*scene_args)
|
||||
img = img[:, :, :1]
|
||||
# HWC -> NCHW
|
||||
img = img.unsqueeze(0)
|
||||
img = img.permute(0, 3, 1, 2) # NHWC -> NCHW
|
||||
imgs.append(img)
|
||||
img = torch.cat(imgs, dim = 0)
|
||||
return img
|
||||
|
||||
class Discriminator(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Discriminator, self).__init__()
|
||||
|
||||
def discriminator_block(in_filters, out_filters, bn=True):
|
||||
block = [torch.nn.Conv2d(in_filters, out_filters, 3, 2, 1),
|
||||
torch.nn.LeakyReLU(0.2, inplace=True),
|
||||
torch.nn.Dropout2d(0.25)]
|
||||
if bn:
|
||||
block.append(torch.nn.BatchNorm2d(out_filters, 0.8))
|
||||
return block
|
||||
|
||||
self.model = torch.nn.Sequential(
|
||||
*discriminator_block(1, 16, bn=False),
|
||||
*discriminator_block(16, 32),
|
||||
*discriminator_block(32, 64),
|
||||
*discriminator_block(64, 128),
|
||||
)
|
||||
|
||||
# The height and width of downsampled image
|
||||
ds_size = img_size // 2 ** 4
|
||||
self.adv_layer = torch.nn.Sequential(
|
||||
torch.nn.Linear(128 * ds_size ** 2, 1),
|
||||
torch.nn.Sigmoid())
|
||||
|
||||
def forward(self, img):
|
||||
out = self.model(img)
|
||||
out = out.view(out.shape[0], -1)
|
||||
validity = self.adv_layer(out)
|
||||
|
||||
return validity
|
||||
|
||||
class MNISTInterface(ttools.interfaces.SGANInterface):
|
||||
"""An adapter to run or train a model."""
|
||||
|
||||
def __init__(self, gen, discrim, lr=2e-4):
|
||||
super(MNISTInterface, self).__init__(gen, discrim, lr, opt = 'adam')
|
||||
|
||||
def forward(self, batch):
|
||||
return self.gen(torch.zeros([batch[0].shape[0], latent_dim], device = self.device).normal_())
|
||||
|
||||
def _discriminator_input(self, batch, fwd_data, fake=False):
|
||||
if fake:
|
||||
return fwd_data
|
||||
else:
|
||||
return batch[0].to(self.device)
|
||||
|
||||
def train(args):
|
||||
"""Train a MNIST classifier."""
|
||||
|
||||
# Setup train and val data
|
||||
_xform = xforms.Compose([xforms.Resize([32, 32]), xforms.ToTensor()])
|
||||
data = MNIST("data/mnist", train=True, download=True, transform=_xform)
|
||||
|
||||
# Initialize asynchronous dataloaders
|
||||
loader = DataLoader(data, batch_size=args.bs, num_workers=2)
|
||||
|
||||
# Instantiate the models
|
||||
gen = Generator()
|
||||
discrim = Discriminator()
|
||||
|
||||
gen.apply(weights_init_normal)
|
||||
discrim.apply(weights_init_normal)
|
||||
|
||||
# Checkpointer to save/recall model parameters
|
||||
checkpointer_gen = ttools.Checkpointer(os.path.join(args.out, "checkpoints"), model=gen, prefix="gen_")
|
||||
checkpointer_discrim = ttools.Checkpointer(os.path.join(args.out, "checkpoints"), model=discrim, prefix="discrim_")
|
||||
|
||||
# resume from a previous checkpoint, if any
|
||||
checkpointer_gen.load_latest()
|
||||
checkpointer_discrim.load_latest()
|
||||
|
||||
# Setup a training interface for the model
|
||||
interface = MNISTInterface(gen, discrim, lr=args.lr)
|
||||
|
||||
# Create a training looper with the interface we defined
|
||||
trainer = ttools.Trainer(interface)
|
||||
|
||||
# Adds several callbacks, that will be called by the trainer --------------
|
||||
# A periodic checkpointing operation
|
||||
trainer.add_callback(ttools.callbacks.CheckpointingCallback(checkpointer_gen))
|
||||
trainer.add_callback(ttools.callbacks.CheckpointingCallback(checkpointer_discrim))
|
||||
# A simple progress bar
|
||||
trainer.add_callback(ttools.callbacks.ProgressBarCallback(
|
||||
keys=["loss_g", "loss_d", "loss"]))
|
||||
# A volatile logging using visdom
|
||||
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
|
||||
keys=["loss_g", "loss_d", "loss"],
|
||||
port=8080, env="mnist_demo"))
|
||||
# Image
|
||||
trainer.add_callback(VisdomImageCallback(port=8080, env="mnist_demo"))
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# Start the training
|
||||
LOG.info("Training started, press Ctrl-C to interrupt.")
|
||||
trainer.train(loader, num_epochs=args.epochs)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# TODO: subparsers
|
||||
parser.add_argument("data", help="directory where we download and store the MNIST dataset.")
|
||||
parser.add_argument("out", help="directory where we write the checkpoints and visualizations.")
|
||||
parser.add_argument("--lr", type=float, default=1e-4, help="learning rate for the optimizer.")
|
||||
parser.add_argument("--epochs", type=int, default=500, help="number of epochs to train for.")
|
||||
parser.add_argument("--bs", type=int, default=64, help="number of elements per batch.")
|
||||
args = parser.parse_args()
|
||||
ttools.set_logger(True) # activate debug prints
|
||||
train(args)
|
Reference in New Issue
Block a user