214 lines
8.0 KiB
Python
214 lines
8.0 KiB
Python
"""A simple training interface using ttools."""
|
|
import argparse
|
|
import os
|
|
import logging
|
|
import random
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torchvision.datasets import MNIST
|
|
import torchvision.transforms as xforms
|
|
from torch.utils.data import DataLoader
|
|
|
|
import ttools
|
|
import ttools.interfaces
|
|
|
|
import pydiffvg
|
|
|
|
LOG = ttools.get_logger(__name__)
|
|
|
|
pydiffvg.render_pytorch.print_timing = False
|
|
|
|
torch.manual_seed(123)
|
|
np.random.seed(123)
|
|
torch.backends.cudnn.deterministic = True
|
|
|
|
latent_dim = 100
|
|
img_size = 32
|
|
num_paths = 8
|
|
num_segments = 8
|
|
|
|
def weights_init_normal(m):
|
|
classname = m.__class__.__name__
|
|
if classname.find("Conv") != -1:
|
|
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
|
|
elif classname.find("BatchNorm2d") != -1:
|
|
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
|
|
torch.nn.init.constant_(m.bias.data, 0.0)
|
|
|
|
class VisdomImageCallback(ttools.callbacks.ImageDisplayCallback):
|
|
def visualized_image(self, batch, fwd_result):
|
|
return torch.cat([batch[0], fwd_result.cpu()], dim = 2)
|
|
|
|
# From https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/dcgan/dcgan.py
|
|
class Generator(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Generator, self).__init__()
|
|
|
|
self.fc = torch.nn.Sequential(
|
|
torch.nn.Linear(latent_dim, 128),
|
|
torch.nn.LeakyReLU(0.2, inplace=True),
|
|
torch.nn.Linear(128, 256),
|
|
torch.nn.LeakyReLU(0.2, inplace=True),
|
|
torch.nn.Linear(256, 512),
|
|
torch.nn.LeakyReLU(0.2, inplace=True),
|
|
torch.nn.Linear(512, 1024),
|
|
torch.nn.LeakyReLU(0.2, inplace=True),
|
|
torch.nn.Linear(1024, 2 * num_paths * (num_segments + 1) + num_paths + num_paths),
|
|
torch.nn.Sigmoid()
|
|
)
|
|
|
|
def forward(self, z):
|
|
out = self.fc(z)
|
|
# construct paths
|
|
imgs = []
|
|
for b in range(out.shape[0]):
|
|
index = 0
|
|
shapes = []
|
|
shape_groups = []
|
|
for i in range(num_paths):
|
|
points = img_size * out[b, index: index + 2 * (num_segments + 1)].view(-1, 2).cpu()
|
|
index += 2 * (num_segments + 1)
|
|
stroke_width = img_size * out[b, index].view(1).cpu()
|
|
index += 1
|
|
|
|
num_control_points = torch.zeros(num_segments, dtype = torch.int32) + 2
|
|
path = pydiffvg.Path(num_control_points = num_control_points,
|
|
points = points,
|
|
stroke_width = stroke_width,
|
|
is_closed = False)
|
|
shapes.append(path)
|
|
|
|
stroke_color = out[b, index].view(1).cpu()
|
|
index += 1
|
|
stroke_color = torch.cat([stroke_color, torch.tensor([0.0, 0.0, 1.0])])
|
|
path_group = pydiffvg.ShapeGroup(shape_ids = torch.tensor([len(shapes) - 1]),
|
|
fill_color = None,
|
|
stroke_color = stroke_color)
|
|
shape_groups.append(path_group)
|
|
scene_args = pydiffvg.RenderFunction.serialize_scene(img_size, img_size, shapes, shape_groups)
|
|
render = pydiffvg.RenderFunction.apply
|
|
img = render(img_size, # width
|
|
img_size, # height
|
|
2, # num_samples_x
|
|
2, # num_samples_y
|
|
random.randint(0, 1048576), # seed
|
|
None,
|
|
*scene_args)
|
|
img = img[:, :, :1]
|
|
# HWC -> NCHW
|
|
img = img.unsqueeze(0)
|
|
img = img.permute(0, 3, 1, 2) # NHWC -> NCHW
|
|
imgs.append(img)
|
|
img = torch.cat(imgs, dim = 0)
|
|
return img
|
|
|
|
class Discriminator(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Discriminator, self).__init__()
|
|
|
|
def discriminator_block(in_filters, out_filters, bn=True):
|
|
block = [torch.nn.Conv2d(in_filters, out_filters, 3, 2, 1),
|
|
torch.nn.LeakyReLU(0.2, inplace=True),
|
|
torch.nn.Dropout2d(0.25)]
|
|
if bn:
|
|
block.append(torch.nn.BatchNorm2d(out_filters, 0.8))
|
|
return block
|
|
|
|
self.model = torch.nn.Sequential(
|
|
*discriminator_block(1, 16, bn=False),
|
|
*discriminator_block(16, 32),
|
|
*discriminator_block(32, 64),
|
|
*discriminator_block(64, 128),
|
|
)
|
|
|
|
# The height and width of downsampled image
|
|
ds_size = img_size // 2 ** 4
|
|
self.adv_layer = torch.nn.Sequential(
|
|
torch.nn.Linear(128 * ds_size ** 2, 1),
|
|
torch.nn.Sigmoid())
|
|
|
|
def forward(self, img):
|
|
out = self.model(img)
|
|
out = out.view(out.shape[0], -1)
|
|
validity = self.adv_layer(out)
|
|
|
|
return validity
|
|
|
|
class MNISTInterface(ttools.interfaces.SGANInterface):
|
|
"""An adapter to run or train a model."""
|
|
|
|
def __init__(self, gen, discrim, lr=2e-4):
|
|
super(MNISTInterface, self).__init__(gen, discrim, lr, opt = 'adam')
|
|
|
|
def forward(self, batch):
|
|
return self.gen(torch.zeros([batch[0].shape[0], latent_dim], device = self.device).normal_())
|
|
|
|
def _discriminator_input(self, batch, fwd_data, fake=False):
|
|
if fake:
|
|
return fwd_data
|
|
else:
|
|
return batch[0].to(self.device)
|
|
|
|
def train(args):
|
|
"""Train a MNIST classifier."""
|
|
|
|
# Setup train and val data
|
|
_xform = xforms.Compose([xforms.Resize([32, 32]), xforms.ToTensor()])
|
|
data = MNIST("data/mnist", train=True, download=True, transform=_xform)
|
|
|
|
# Initialize asynchronous dataloaders
|
|
loader = DataLoader(data, batch_size=args.bs, num_workers=2)
|
|
|
|
# Instantiate the models
|
|
gen = Generator()
|
|
discrim = Discriminator()
|
|
|
|
gen.apply(weights_init_normal)
|
|
discrim.apply(weights_init_normal)
|
|
|
|
# Checkpointer to save/recall model parameters
|
|
checkpointer_gen = ttools.Checkpointer(os.path.join(args.out, "checkpoints"), model=gen, prefix="gen_")
|
|
checkpointer_discrim = ttools.Checkpointer(os.path.join(args.out, "checkpoints"), model=discrim, prefix="discrim_")
|
|
|
|
# resume from a previous checkpoint, if any
|
|
checkpointer_gen.load_latest()
|
|
checkpointer_discrim.load_latest()
|
|
|
|
# Setup a training interface for the model
|
|
interface = MNISTInterface(gen, discrim, lr=args.lr)
|
|
|
|
# Create a training looper with the interface we defined
|
|
trainer = ttools.Trainer(interface)
|
|
|
|
# Adds several callbacks, that will be called by the trainer --------------
|
|
# A periodic checkpointing operation
|
|
trainer.add_callback(ttools.callbacks.CheckpointingCallback(checkpointer_gen))
|
|
trainer.add_callback(ttools.callbacks.CheckpointingCallback(checkpointer_discrim))
|
|
# A simple progress bar
|
|
trainer.add_callback(ttools.callbacks.ProgressBarCallback(
|
|
keys=["loss_g", "loss_d", "loss"]))
|
|
# A volatile logging using visdom
|
|
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
|
|
keys=["loss_g", "loss_d", "loss"],
|
|
port=8080, env="mnist_demo"))
|
|
# Image
|
|
trainer.add_callback(VisdomImageCallback(port=8080, env="mnist_demo"))
|
|
# -------------------------------------------------------------------------
|
|
|
|
# Start the training
|
|
LOG.info("Training started, press Ctrl-C to interrupt.")
|
|
trainer.train(loader, num_epochs=args.epochs)
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
# TODO: subparsers
|
|
parser.add_argument("data", help="directory where we download and store the MNIST dataset.")
|
|
parser.add_argument("out", help="directory where we write the checkpoints and visualizations.")
|
|
parser.add_argument("--lr", type=float, default=1e-4, help="learning rate for the optimizer.")
|
|
parser.add_argument("--epochs", type=int, default=500, help="number of epochs to train for.")
|
|
parser.add_argument("--bs", type=int, default=64, help="number of elements per batch.")
|
|
args = parser.parse_args()
|
|
ttools.set_logger(True) # activate debug prints
|
|
train(args)
|