490 lines
16 KiB
Python
Executable File
490 lines
16 KiB
Python
Executable File
#!/bin/env python
|
|
"""Train a GAN.
|
|
|
|
Usage:
|
|
|
|
* Train a MNIST model:
|
|
|
|
`python train_gan.py`
|
|
|
|
* Train a Quickdraw model:
|
|
|
|
`python train_gan.py --task quickdraw`
|
|
|
|
"""
|
|
import argparse
|
|
import os
|
|
|
|
import numpy as np
|
|
import torch as th
|
|
from torch.utils.data import DataLoader
|
|
|
|
import ttools
|
|
import ttools.interfaces
|
|
|
|
import losses
|
|
import data
|
|
import models
|
|
|
|
import pydiffvg
|
|
|
|
LOG = ttools.get_logger(__name__)
|
|
|
|
|
|
BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir)
|
|
OUTPUT = os.path.join(BASE_DIR, "results")
|
|
|
|
|
|
class Callback(ttools.callbacks.ImageDisplayCallback):
|
|
"""Simple callback that visualize images."""
|
|
def visualized_image(self, batch, step_data, is_val=False):
|
|
if is_val:
|
|
return
|
|
|
|
gen = step_data["gen_image"][:16].detach()
|
|
ref = step_data["gt_image"][:16].detach()
|
|
|
|
# tensor to visualize, concatenate images
|
|
vizdata = th.cat([ref, gen], 2)
|
|
|
|
vector = step_data["vector_image"]
|
|
if vector is not None:
|
|
vector = vector[:16].detach()
|
|
vizdata = th.cat([vizdata, vector], 2)
|
|
|
|
vizdata = (vizdata + 1.0 ) * 0.5
|
|
viz = th.clamp(vizdata, 0, 1)
|
|
return viz
|
|
|
|
def caption(self, batch, step_data, is_val=False):
|
|
if step_data["vector_image"] is not None:
|
|
s = "top: real, middle: raster, bottom: vector"
|
|
else:
|
|
s = "top: real, bottom: fake"
|
|
return s
|
|
|
|
|
|
class Interface(ttools.ModelInterface):
|
|
def __init__(self, generator, vect_generator,
|
|
discriminator, vect_discriminator,
|
|
lr=1e-4, lr_decay=0.9999,
|
|
gradient_penalty=10,
|
|
wgan_gp=False,
|
|
raster_resolution=32, device="cpu", grad_clip=1.0):
|
|
super(Interface, self).__init__()
|
|
|
|
self.wgan_gp = wgan_gp
|
|
self.w_gradient_penalty = gradient_penalty
|
|
|
|
self.n_critic = 1
|
|
if self.wgan_gp:
|
|
self.n_critic = 5
|
|
|
|
self.grad_clip = grad_clip
|
|
self.raster_resolution = raster_resolution
|
|
|
|
self.gen = generator
|
|
self.vect_gen = vect_generator
|
|
self.discrim = discriminator
|
|
self.vect_discrim = vect_discriminator
|
|
|
|
self.device = device
|
|
self.gen.to(self.device)
|
|
self.discrim.to(self.device)
|
|
|
|
beta1 = 0.5
|
|
beta2 = 0.9
|
|
|
|
self.gen_opt = th.optim.Adam(
|
|
self.gen.parameters(), lr=lr, betas=(beta1, beta2))
|
|
self.discrim_opt = th.optim.Adam(
|
|
self.discrim.parameters(), lr=lr, betas=(beta1, beta2))
|
|
|
|
self.schedulers = [
|
|
th.optim.lr_scheduler.ExponentialLR(self.gen_opt, lr_decay),
|
|
th.optim.lr_scheduler.ExponentialLR(self.discrim_opt, lr_decay),
|
|
]
|
|
|
|
self.optimizers = [self.gen_opt, self.discrim_opt]
|
|
|
|
if self.vect_gen is not None:
|
|
assert self.vect_discrim is not None
|
|
|
|
self.vect_gen.to(self.device)
|
|
self.vect_discrim.to(self.device)
|
|
|
|
self.vect_gen_opt = th.optim.Adam(
|
|
self.vect_gen.parameters(), lr=lr, betas=(beta1, beta2))
|
|
self.vect_discrim_opt = th.optim.Adam(
|
|
self.vect_discrim.parameters(), lr=lr, betas=(beta1, beta2))
|
|
|
|
self.schedulers += [
|
|
th.optim.lr_scheduler.ExponentialLR(self.vect_gen_opt,
|
|
lr_decay),
|
|
th.optim.lr_scheduler.ExponentialLR(self.vect_discrim_opt,
|
|
lr_decay),
|
|
]
|
|
|
|
self.optimizers += [self.vect_gen_opt, self.vect_discrim_opt]
|
|
|
|
# include loss on alpha
|
|
self.im_loss = losses.MultiscaleMSELoss(channels=4).to(self.device)
|
|
|
|
self.iter = 0
|
|
|
|
self.cross_entropy = th.nn.BCEWithLogitsLoss()
|
|
self.mse = th.nn.MSELoss()
|
|
|
|
def _gradient_penalty(self, discrim, fake, real):
|
|
bs = real.size(0)
|
|
epsilon = th.rand(bs, 1, 1, 1, device=real.device)
|
|
epsilon = epsilon.expand_as(real)
|
|
|
|
interpolation = epsilon * real.data + (1 - epsilon) * fake.data
|
|
interpolation = th.autograd.Variable(interpolation, requires_grad=True)
|
|
|
|
interpolation_logits = discrim(interpolation)
|
|
grad_outputs = th.ones(interpolation_logits.size(), device=real.device)
|
|
|
|
gradients = th.autograd.grad(outputs=interpolation_logits,
|
|
inputs=interpolation,
|
|
grad_outputs=grad_outputs,
|
|
create_graph=True, retain_graph=True)[0]
|
|
|
|
gradients = gradients.view(bs, -1)
|
|
gradients_norm = th.sqrt(th.sum(gradients ** 2, dim=1) + 1e-12)
|
|
|
|
# [Tanh-Tung 2019] https://openreview.net/pdf?id=ByxPYjC5KQ
|
|
return self.w_gradient_penalty * ((gradients_norm - 0) ** 2).mean()
|
|
|
|
# return self.w_gradient_penalty * ((gradients_norm - 1) ** 2).mean()
|
|
|
|
def _discriminator_step(self, discrim, opt, fake, real):
|
|
"""Try to classify fake as 0 and real as 1."""
|
|
|
|
opt.zero_grad()
|
|
|
|
# no backprop to gen
|
|
fake = fake.detach()
|
|
|
|
fake_pred = discrim(fake)
|
|
real_pred = discrim(real)
|
|
|
|
if self.wgan_gp:
|
|
gradient_penalty = self._gradient_penalty(discrim, fake, real)
|
|
loss_d = fake_pred.mean() - real_pred.mean() + gradient_penalty
|
|
gradient_penalty = gradient_penalty.item()
|
|
else:
|
|
fake_loss = self.cross_entropy(fake_pred, th.zeros_like(fake_pred))
|
|
real_loss = self.cross_entropy(real_pred, th.ones_like(real_pred))
|
|
# fake_loss = self.mse(fake_pred, th.zeros_like(fake_pred))
|
|
# real_loss = self.mse(real_pred, th.ones_like(real_pred))
|
|
loss_d = 0.5*(fake_loss + real_loss)
|
|
gradient_penalty = None
|
|
|
|
loss_d.backward()
|
|
nrm = th.nn.utils.clip_grad_norm_(
|
|
discrim.parameters(), self.grad_clip)
|
|
if nrm > self.grad_clip:
|
|
LOG.debug("Clipped discriminator gradient (%.5f) to %.2f",
|
|
nrm, self.grad_clip)
|
|
|
|
opt.step()
|
|
|
|
return loss_d.item(), gradient_penalty
|
|
|
|
def _generator_step(self, gen, discrim, opt, fake):
|
|
"""Try to classify fake as 1."""
|
|
|
|
opt.zero_grad()
|
|
|
|
fake_pred = discrim(fake)
|
|
|
|
if self.wgan_gp:
|
|
loss_g = -fake_pred.mean()
|
|
else:
|
|
loss_g = self.cross_entropy(fake_pred, th.ones_like(fake_pred))
|
|
# loss_g = self.mse(fake_pred, th.ones_like(fake_pred))
|
|
|
|
loss_g.backward()
|
|
|
|
# clip gradients
|
|
nrm = th.nn.utils.clip_grad_norm_(
|
|
gen.parameters(), self.grad_clip)
|
|
if nrm > self.grad_clip:
|
|
LOG.debug("Clipped generator gradient (%.5f) to %.2f",
|
|
nrm, self.grad_clip)
|
|
|
|
opt.step()
|
|
|
|
return loss_g.item()
|
|
|
|
def training_step(self, batch):
|
|
im = batch
|
|
im = im.to(self.device)
|
|
|
|
z = self.gen.sample_z(im.shape[0], device=self.device)
|
|
|
|
generated = self.gen(z)
|
|
|
|
vect_generated = None
|
|
if self.vect_gen is not None:
|
|
vect_generated = self.vect_gen(z)
|
|
|
|
loss_g = None
|
|
loss_d = None
|
|
loss_g_vect = None
|
|
loss_d_vect = None
|
|
|
|
gp = None
|
|
gp_vect = None
|
|
|
|
if self.iter < self.n_critic: # Discriminator update
|
|
self.iter += 1
|
|
|
|
loss_d, gp = self._discriminator_step(
|
|
self.discrim, self.discrim_opt, generated, im)
|
|
|
|
if vect_generated is not None:
|
|
loss_d_vect, gp_vect = self._discriminator_step(
|
|
self.vect_discrim, self.vect_discrim_opt, vect_generated, im)
|
|
|
|
else: # Generator update
|
|
self.iter = 0
|
|
|
|
loss_g = self._generator_step(
|
|
self.gen, self.discrim, self.gen_opt, generated)
|
|
|
|
if vect_generated is not None:
|
|
loss_g_vect = self._generator_step(
|
|
self.vect_gen, self.vect_discrim, self.vect_gen_opt, vect_generated)
|
|
|
|
return {
|
|
"loss_g": loss_g,
|
|
"loss_d": loss_d,
|
|
"loss_g_vect": loss_g_vect,
|
|
"loss_d_vect": loss_d_vect,
|
|
"gp": gp,
|
|
"gp_vect": gp_vect,
|
|
"gt_image": im,
|
|
"gen_image": generated,
|
|
"vector_image": vect_generated,
|
|
"lr": self.gen_opt.param_groups[0]["lr"],
|
|
}
|
|
|
|
def init_validation(self):
|
|
return dict(sample=None)
|
|
|
|
def validation_step(self, batch, running_data):
|
|
# Switch to eval mode for dropout, batchnorm, etc
|
|
self.model.eval()
|
|
return running_data
|
|
|
|
|
|
def train(args):
|
|
th.manual_seed(0)
|
|
np.random.seed(0)
|
|
|
|
color_output = False
|
|
if args.task == "mnist":
|
|
dataset = data.MNISTDataset(args.raster_resolution, train=True)
|
|
elif args.task == "quickdraw":
|
|
dataset = data.QuickDrawImageDataset(
|
|
args.raster_resolution, train=True)
|
|
else:
|
|
raise NotImplementedError()
|
|
|
|
dataloader = DataLoader(
|
|
dataset, batch_size=args.bs, num_workers=args.workers, shuffle=True)
|
|
|
|
val_dataloader = None
|
|
|
|
model_params = {
|
|
"zdim": args.zdim,
|
|
"num_strokes": args.num_strokes,
|
|
"imsize": args.raster_resolution,
|
|
"stroke_width": args.stroke_width,
|
|
"color_output": color_output,
|
|
}
|
|
gen = models.Generator(**model_params)
|
|
gen.train()
|
|
|
|
discrim = models.Discriminator(color_output=color_output)
|
|
discrim.train()
|
|
|
|
if args.raster_only:
|
|
vect_gen = None
|
|
vect_discrim = None
|
|
else:
|
|
if args.generator == "fc":
|
|
vect_gen = models.VectorGenerator(**model_params)
|
|
elif args.generator == "bezier_fc":
|
|
vect_gen = models.BezierVectorGenerator(**model_params)
|
|
elif args.generator in ["rnn"]:
|
|
vect_gen = models.RNNVectorGenerator(**model_params)
|
|
elif args.generator in ["chain_rnn"]:
|
|
vect_gen = models.ChainRNNVectorGenerator(**model_params)
|
|
else:
|
|
raise NotImplementedError()
|
|
vect_gen.train()
|
|
|
|
vect_discrim = models.Discriminator(color_output=color_output)
|
|
vect_discrim.train()
|
|
|
|
LOG.info("Model parameters:\n%s", model_params)
|
|
|
|
device = "cpu"
|
|
if th.cuda.is_available():
|
|
device = "cuda"
|
|
LOG.info("Using CUDA")
|
|
|
|
interface = Interface(gen, vect_gen, discrim, vect_discrim,
|
|
raster_resolution=args.raster_resolution, lr=args.lr,
|
|
wgan_gp=args.wgan_gp,
|
|
lr_decay=args.lr_decay, device=device)
|
|
|
|
env_name = args.task + "_gan"
|
|
|
|
if args.raster_only:
|
|
env_name += "_raster"
|
|
else:
|
|
env_name += "_vector"
|
|
|
|
env_name += "_" + args.generator
|
|
|
|
if args.wgan_gp:
|
|
env_name += "_wgan"
|
|
|
|
chkpt = os.path.join(OUTPUT, env_name)
|
|
|
|
meta = {
|
|
"model_params": model_params,
|
|
"task": args.task,
|
|
"generator": args.generator,
|
|
}
|
|
checkpointer = ttools.Checkpointer(
|
|
chkpt, gen, meta=meta,
|
|
optimizers=interface.optimizers,
|
|
schedulers=interface.schedulers,
|
|
prefix="g_")
|
|
checkpointer_d = ttools.Checkpointer(
|
|
chkpt, discrim,
|
|
prefix="d_")
|
|
|
|
# Resume from checkpoint, if any
|
|
extras, _ = checkpointer.load_latest()
|
|
checkpointer_d.load_latest()
|
|
|
|
if not args.raster_only:
|
|
checkpointer_vect = ttools.Checkpointer(
|
|
chkpt, vect_gen, meta=meta,
|
|
optimizers=interface.optimizers,
|
|
schedulers=interface.schedulers,
|
|
prefix="vect_g_")
|
|
checkpointer_d_vect = ttools.Checkpointer(
|
|
chkpt, vect_discrim,
|
|
prefix="vect_d_")
|
|
extras, _ = checkpointer_vect.load_latest()
|
|
checkpointer_d_vect.load_latest()
|
|
|
|
epoch = extras["epoch"] if extras and "epoch" in extras.keys() else 0
|
|
|
|
# if meta is not None and meta["model_parameters"] != model_params:
|
|
# LOG.info("Checkpoint's metaparams differ "
|
|
# "from CLI, aborting: %s and %s", meta, model_params)
|
|
|
|
trainer = ttools.Trainer(interface)
|
|
|
|
# Add callbacks
|
|
losses = ["loss_g", "loss_d", "loss_g_vect", "loss_d_vect", "gp",
|
|
"gp_vect"]
|
|
training_debug = ["lr"]
|
|
|
|
trainer.add_callback(Callback(
|
|
env=env_name, win="samples", port=args.port, frequency=args.freq))
|
|
trainer.add_callback(ttools.callbacks.ProgressBarCallback(
|
|
keys=losses, val_keys=None))
|
|
trainer.add_callback(ttools.callbacks.MultiPlotCallback(
|
|
keys=losses, val_keys=None, env=env_name, port=args.port,
|
|
server=args.server, base_url=args.base_url,
|
|
win="losses", frequency=args.freq))
|
|
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
|
|
keys=training_debug, smoothing=0, val_keys=None, env=env_name,
|
|
server=args.server, base_url=args.base_url,
|
|
port=args.port))
|
|
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
|
|
checkpointer, max_files=2, interval=600, max_epochs=10))
|
|
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
|
|
checkpointer_d, max_files=2, interval=600, max_epochs=10))
|
|
|
|
if not args.raster_only:
|
|
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
|
|
checkpointer_vect, max_files=2, interval=600, max_epochs=10))
|
|
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
|
|
checkpointer_d_vect, max_files=2, interval=600, max_epochs=10))
|
|
|
|
trainer.add_callback(
|
|
ttools.callbacks.LRSchedulerCallback(interface.schedulers))
|
|
|
|
# Start training
|
|
trainer.train(dataloader, starting_epoch=epoch,
|
|
val_dataloader=val_dataloader,
|
|
num_epochs=args.num_epochs)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--task",
|
|
default="mnist",
|
|
choices=["mnist", "quickdraw"])
|
|
parser.add_argument("--generator",
|
|
default="bezier_fc",
|
|
choices=["bezier_fc", "fc", "rnn", "chain_rnn"],
|
|
help="model to use as generator")
|
|
|
|
parser.add_argument("--raster_only", action="store_true", default=False,
|
|
help="if true only train the raster baseline")
|
|
|
|
parser.add_argument("--standard_gan", dest="wgan_gp", action="store_false",
|
|
default=True,
|
|
help="if true, use regular GAN instead of WGAN")
|
|
|
|
# Training params
|
|
parser.add_argument("--bs", type=int, default=4, help="batch size")
|
|
parser.add_argument("--workers", type=int, default=4,
|
|
help="number of dataloader threads")
|
|
parser.add_argument("--num_epochs", type=int, default=200,
|
|
help="number of epochs to train for")
|
|
parser.add_argument("--lr", type=float, default=1e-4,
|
|
help="learning rate")
|
|
parser.add_argument("--lr_decay", type=float, default=0.9999,
|
|
help="exponential learning rate decay rate")
|
|
|
|
# Model configuration
|
|
parser.add_argument("--zdim", type=int, default=32,
|
|
help="latent space dimension")
|
|
parser.add_argument("--stroke_width", type=float, nargs=2,
|
|
default=(0.5, 1.5),
|
|
help="min and max stroke width")
|
|
parser.add_argument("--num_strokes", type=int, default=16,
|
|
help="number of strokes to generate")
|
|
parser.add_argument("--raster_resolution", type=int, default=32,
|
|
help="raster canvas resolution on each side")
|
|
|
|
# Viz params
|
|
parser.add_argument("--freq", type=int, default=10,
|
|
help="visualization frequency")
|
|
parser.add_argument("--port", type=int, default=8097,
|
|
help="visdom port")
|
|
parser.add_argument("--server", default=None,
|
|
help="visdom server if not local.")
|
|
parser.add_argument("--base_url", default="", help="visdom entrypoint URL")
|
|
|
|
args = parser.parse_args()
|
|
|
|
pydiffvg.set_use_gpu(False)
|
|
|
|
ttools.set_logger(False)
|
|
|
|
train(args)
|