initial commit
This commit is contained in:
489
apps/generative_models/train_gan.py
Executable file
489
apps/generative_models/train_gan.py
Executable file
@@ -0,0 +1,489 @@
|
||||
#!/bin/env python
|
||||
"""Train a GAN.
|
||||
|
||||
Usage:
|
||||
|
||||
* Train a MNIST model:
|
||||
|
||||
`python train_gan.py`
|
||||
|
||||
* Train a Quickdraw model:
|
||||
|
||||
`python train_gan.py --task quickdraw`
|
||||
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import ttools
|
||||
import ttools.interfaces
|
||||
|
||||
import losses
|
||||
import data
|
||||
import models
|
||||
|
||||
import pydiffvg
|
||||
|
||||
LOG = ttools.get_logger(__name__)
|
||||
|
||||
|
||||
BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir)
|
||||
OUTPUT = os.path.join(BASE_DIR, "results")
|
||||
|
||||
|
||||
class Callback(ttools.callbacks.ImageDisplayCallback):
|
||||
"""Simple callback that visualize images."""
|
||||
def visualized_image(self, batch, step_data, is_val=False):
|
||||
if is_val:
|
||||
return
|
||||
|
||||
gen = step_data["gen_image"][:16].detach()
|
||||
ref = step_data["gt_image"][:16].detach()
|
||||
|
||||
# tensor to visualize, concatenate images
|
||||
vizdata = th.cat([ref, gen], 2)
|
||||
|
||||
vector = step_data["vector_image"]
|
||||
if vector is not None:
|
||||
vector = vector[:16].detach()
|
||||
vizdata = th.cat([vizdata, vector], 2)
|
||||
|
||||
vizdata = (vizdata + 1.0 ) * 0.5
|
||||
viz = th.clamp(vizdata, 0, 1)
|
||||
return viz
|
||||
|
||||
def caption(self, batch, step_data, is_val=False):
|
||||
if step_data["vector_image"] is not None:
|
||||
s = "top: real, middle: raster, bottom: vector"
|
||||
else:
|
||||
s = "top: real, bottom: fake"
|
||||
return s
|
||||
|
||||
|
||||
class Interface(ttools.ModelInterface):
|
||||
def __init__(self, generator, vect_generator,
|
||||
discriminator, vect_discriminator,
|
||||
lr=1e-4, lr_decay=0.9999,
|
||||
gradient_penalty=10,
|
||||
wgan_gp=False,
|
||||
raster_resolution=32, device="cpu", grad_clip=1.0):
|
||||
super(Interface, self).__init__()
|
||||
|
||||
self.wgan_gp = wgan_gp
|
||||
self.w_gradient_penalty = gradient_penalty
|
||||
|
||||
self.n_critic = 1
|
||||
if self.wgan_gp:
|
||||
self.n_critic = 5
|
||||
|
||||
self.grad_clip = grad_clip
|
||||
self.raster_resolution = raster_resolution
|
||||
|
||||
self.gen = generator
|
||||
self.vect_gen = vect_generator
|
||||
self.discrim = discriminator
|
||||
self.vect_discrim = vect_discriminator
|
||||
|
||||
self.device = device
|
||||
self.gen.to(self.device)
|
||||
self.discrim.to(self.device)
|
||||
|
||||
beta1 = 0.5
|
||||
beta2 = 0.9
|
||||
|
||||
self.gen_opt = th.optim.Adam(
|
||||
self.gen.parameters(), lr=lr, betas=(beta1, beta2))
|
||||
self.discrim_opt = th.optim.Adam(
|
||||
self.discrim.parameters(), lr=lr, betas=(beta1, beta2))
|
||||
|
||||
self.schedulers = [
|
||||
th.optim.lr_scheduler.ExponentialLR(self.gen_opt, lr_decay),
|
||||
th.optim.lr_scheduler.ExponentialLR(self.discrim_opt, lr_decay),
|
||||
]
|
||||
|
||||
self.optimizers = [self.gen_opt, self.discrim_opt]
|
||||
|
||||
if self.vect_gen is not None:
|
||||
assert self.vect_discrim is not None
|
||||
|
||||
self.vect_gen.to(self.device)
|
||||
self.vect_discrim.to(self.device)
|
||||
|
||||
self.vect_gen_opt = th.optim.Adam(
|
||||
self.vect_gen.parameters(), lr=lr, betas=(beta1, beta2))
|
||||
self.vect_discrim_opt = th.optim.Adam(
|
||||
self.vect_discrim.parameters(), lr=lr, betas=(beta1, beta2))
|
||||
|
||||
self.schedulers += [
|
||||
th.optim.lr_scheduler.ExponentialLR(self.vect_gen_opt,
|
||||
lr_decay),
|
||||
th.optim.lr_scheduler.ExponentialLR(self.vect_discrim_opt,
|
||||
lr_decay),
|
||||
]
|
||||
|
||||
self.optimizers += [self.vect_gen_opt, self.vect_discrim_opt]
|
||||
|
||||
# include loss on alpha
|
||||
self.im_loss = losses.MultiscaleMSELoss(channels=4).to(self.device)
|
||||
|
||||
self.iter = 0
|
||||
|
||||
self.cross_entropy = th.nn.BCEWithLogitsLoss()
|
||||
self.mse = th.nn.MSELoss()
|
||||
|
||||
def _gradient_penalty(self, discrim, fake, real):
|
||||
bs = real.size(0)
|
||||
epsilon = th.rand(bs, 1, 1, 1, device=real.device)
|
||||
epsilon = epsilon.expand_as(real)
|
||||
|
||||
interpolation = epsilon * real.data + (1 - epsilon) * fake.data
|
||||
interpolation = th.autograd.Variable(interpolation, requires_grad=True)
|
||||
|
||||
interpolation_logits = discrim(interpolation)
|
||||
grad_outputs = th.ones(interpolation_logits.size(), device=real.device)
|
||||
|
||||
gradients = th.autograd.grad(outputs=interpolation_logits,
|
||||
inputs=interpolation,
|
||||
grad_outputs=grad_outputs,
|
||||
create_graph=True, retain_graph=True)[0]
|
||||
|
||||
gradients = gradients.view(bs, -1)
|
||||
gradients_norm = th.sqrt(th.sum(gradients ** 2, dim=1) + 1e-12)
|
||||
|
||||
# [Tanh-Tung 2019] https://openreview.net/pdf?id=ByxPYjC5KQ
|
||||
return self.w_gradient_penalty * ((gradients_norm - 0) ** 2).mean()
|
||||
|
||||
# return self.w_gradient_penalty * ((gradients_norm - 1) ** 2).mean()
|
||||
|
||||
def _discriminator_step(self, discrim, opt, fake, real):
|
||||
"""Try to classify fake as 0 and real as 1."""
|
||||
|
||||
opt.zero_grad()
|
||||
|
||||
# no backprop to gen
|
||||
fake = fake.detach()
|
||||
|
||||
fake_pred = discrim(fake)
|
||||
real_pred = discrim(real)
|
||||
|
||||
if self.wgan_gp:
|
||||
gradient_penalty = self._gradient_penalty(discrim, fake, real)
|
||||
loss_d = fake_pred.mean() - real_pred.mean() + gradient_penalty
|
||||
gradient_penalty = gradient_penalty.item()
|
||||
else:
|
||||
fake_loss = self.cross_entropy(fake_pred, th.zeros_like(fake_pred))
|
||||
real_loss = self.cross_entropy(real_pred, th.ones_like(real_pred))
|
||||
# fake_loss = self.mse(fake_pred, th.zeros_like(fake_pred))
|
||||
# real_loss = self.mse(real_pred, th.ones_like(real_pred))
|
||||
loss_d = 0.5*(fake_loss + real_loss)
|
||||
gradient_penalty = None
|
||||
|
||||
loss_d.backward()
|
||||
nrm = th.nn.utils.clip_grad_norm_(
|
||||
discrim.parameters(), self.grad_clip)
|
||||
if nrm > self.grad_clip:
|
||||
LOG.debug("Clipped discriminator gradient (%.5f) to %.2f",
|
||||
nrm, self.grad_clip)
|
||||
|
||||
opt.step()
|
||||
|
||||
return loss_d.item(), gradient_penalty
|
||||
|
||||
def _generator_step(self, gen, discrim, opt, fake):
|
||||
"""Try to classify fake as 1."""
|
||||
|
||||
opt.zero_grad()
|
||||
|
||||
fake_pred = discrim(fake)
|
||||
|
||||
if self.wgan_gp:
|
||||
loss_g = -fake_pred.mean()
|
||||
else:
|
||||
loss_g = self.cross_entropy(fake_pred, th.ones_like(fake_pred))
|
||||
# loss_g = self.mse(fake_pred, th.ones_like(fake_pred))
|
||||
|
||||
loss_g.backward()
|
||||
|
||||
# clip gradients
|
||||
nrm = th.nn.utils.clip_grad_norm_(
|
||||
gen.parameters(), self.grad_clip)
|
||||
if nrm > self.grad_clip:
|
||||
LOG.debug("Clipped generator gradient (%.5f) to %.2f",
|
||||
nrm, self.grad_clip)
|
||||
|
||||
opt.step()
|
||||
|
||||
return loss_g.item()
|
||||
|
||||
def training_step(self, batch):
|
||||
im = batch
|
||||
im = im.to(self.device)
|
||||
|
||||
z = self.gen.sample_z(im.shape[0], device=self.device)
|
||||
|
||||
generated = self.gen(z)
|
||||
|
||||
vect_generated = None
|
||||
if self.vect_gen is not None:
|
||||
vect_generated = self.vect_gen(z)
|
||||
|
||||
loss_g = None
|
||||
loss_d = None
|
||||
loss_g_vect = None
|
||||
loss_d_vect = None
|
||||
|
||||
gp = None
|
||||
gp_vect = None
|
||||
|
||||
if self.iter < self.n_critic: # Discriminator update
|
||||
self.iter += 1
|
||||
|
||||
loss_d, gp = self._discriminator_step(
|
||||
self.discrim, self.discrim_opt, generated, im)
|
||||
|
||||
if vect_generated is not None:
|
||||
loss_d_vect, gp_vect = self._discriminator_step(
|
||||
self.vect_discrim, self.vect_discrim_opt, vect_generated, im)
|
||||
|
||||
else: # Generator update
|
||||
self.iter = 0
|
||||
|
||||
loss_g = self._generator_step(
|
||||
self.gen, self.discrim, self.gen_opt, generated)
|
||||
|
||||
if vect_generated is not None:
|
||||
loss_g_vect = self._generator_step(
|
||||
self.vect_gen, self.vect_discrim, self.vect_gen_opt, vect_generated)
|
||||
|
||||
return {
|
||||
"loss_g": loss_g,
|
||||
"loss_d": loss_d,
|
||||
"loss_g_vect": loss_g_vect,
|
||||
"loss_d_vect": loss_d_vect,
|
||||
"gp": gp,
|
||||
"gp_vect": gp_vect,
|
||||
"gt_image": im,
|
||||
"gen_image": generated,
|
||||
"vector_image": vect_generated,
|
||||
"lr": self.gen_opt.param_groups[0]["lr"],
|
||||
}
|
||||
|
||||
def init_validation(self):
|
||||
return dict(sample=None)
|
||||
|
||||
def validation_step(self, batch, running_data):
|
||||
# Switch to eval mode for dropout, batchnorm, etc
|
||||
self.model.eval()
|
||||
return running_data
|
||||
|
||||
|
||||
def train(args):
|
||||
th.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
color_output = False
|
||||
if args.task == "mnist":
|
||||
dataset = data.MNISTDataset(args.raster_resolution, train=True)
|
||||
elif args.task == "quickdraw":
|
||||
dataset = data.QuickDrawImageDataset(
|
||||
args.raster_resolution, train=True)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=args.bs, num_workers=args.workers, shuffle=True)
|
||||
|
||||
val_dataloader = None
|
||||
|
||||
model_params = {
|
||||
"zdim": args.zdim,
|
||||
"num_strokes": args.num_strokes,
|
||||
"imsize": args.raster_resolution,
|
||||
"stroke_width": args.stroke_width,
|
||||
"color_output": color_output,
|
||||
}
|
||||
gen = models.Generator(**model_params)
|
||||
gen.train()
|
||||
|
||||
discrim = models.Discriminator(color_output=color_output)
|
||||
discrim.train()
|
||||
|
||||
if args.raster_only:
|
||||
vect_gen = None
|
||||
vect_discrim = None
|
||||
else:
|
||||
if args.generator == "fc":
|
||||
vect_gen = models.VectorGenerator(**model_params)
|
||||
elif args.generator == "bezier_fc":
|
||||
vect_gen = models.BezierVectorGenerator(**model_params)
|
||||
elif args.generator in ["rnn"]:
|
||||
vect_gen = models.RNNVectorGenerator(**model_params)
|
||||
elif args.generator in ["chain_rnn"]:
|
||||
vect_gen = models.ChainRNNVectorGenerator(**model_params)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
vect_gen.train()
|
||||
|
||||
vect_discrim = models.Discriminator(color_output=color_output)
|
||||
vect_discrim.train()
|
||||
|
||||
LOG.info("Model parameters:\n%s", model_params)
|
||||
|
||||
device = "cpu"
|
||||
if th.cuda.is_available():
|
||||
device = "cuda"
|
||||
LOG.info("Using CUDA")
|
||||
|
||||
interface = Interface(gen, vect_gen, discrim, vect_discrim,
|
||||
raster_resolution=args.raster_resolution, lr=args.lr,
|
||||
wgan_gp=args.wgan_gp,
|
||||
lr_decay=args.lr_decay, device=device)
|
||||
|
||||
env_name = args.task + "_gan"
|
||||
|
||||
if args.raster_only:
|
||||
env_name += "_raster"
|
||||
else:
|
||||
env_name += "_vector"
|
||||
|
||||
env_name += "_" + args.generator
|
||||
|
||||
if args.wgan_gp:
|
||||
env_name += "_wgan"
|
||||
|
||||
chkpt = os.path.join(OUTPUT, env_name)
|
||||
|
||||
meta = {
|
||||
"model_params": model_params,
|
||||
"task": args.task,
|
||||
"generator": args.generator,
|
||||
}
|
||||
checkpointer = ttools.Checkpointer(
|
||||
chkpt, gen, meta=meta,
|
||||
optimizers=interface.optimizers,
|
||||
schedulers=interface.schedulers,
|
||||
prefix="g_")
|
||||
checkpointer_d = ttools.Checkpointer(
|
||||
chkpt, discrim,
|
||||
prefix="d_")
|
||||
|
||||
# Resume from checkpoint, if any
|
||||
extras, _ = checkpointer.load_latest()
|
||||
checkpointer_d.load_latest()
|
||||
|
||||
if not args.raster_only:
|
||||
checkpointer_vect = ttools.Checkpointer(
|
||||
chkpt, vect_gen, meta=meta,
|
||||
optimizers=interface.optimizers,
|
||||
schedulers=interface.schedulers,
|
||||
prefix="vect_g_")
|
||||
checkpointer_d_vect = ttools.Checkpointer(
|
||||
chkpt, vect_discrim,
|
||||
prefix="vect_d_")
|
||||
extras, _ = checkpointer_vect.load_latest()
|
||||
checkpointer_d_vect.load_latest()
|
||||
|
||||
epoch = extras["epoch"] if extras and "epoch" in extras.keys() else 0
|
||||
|
||||
# if meta is not None and meta["model_parameters"] != model_params:
|
||||
# LOG.info("Checkpoint's metaparams differ "
|
||||
# "from CLI, aborting: %s and %s", meta, model_params)
|
||||
|
||||
trainer = ttools.Trainer(interface)
|
||||
|
||||
# Add callbacks
|
||||
losses = ["loss_g", "loss_d", "loss_g_vect", "loss_d_vect", "gp",
|
||||
"gp_vect"]
|
||||
training_debug = ["lr"]
|
||||
|
||||
trainer.add_callback(Callback(
|
||||
env=env_name, win="samples", port=args.port, frequency=args.freq))
|
||||
trainer.add_callback(ttools.callbacks.ProgressBarCallback(
|
||||
keys=losses, val_keys=None))
|
||||
trainer.add_callback(ttools.callbacks.MultiPlotCallback(
|
||||
keys=losses, val_keys=None, env=env_name, port=args.port,
|
||||
server=args.server, base_url=args.base_url,
|
||||
win="losses", frequency=args.freq))
|
||||
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
|
||||
keys=training_debug, smoothing=0, val_keys=None, env=env_name,
|
||||
server=args.server, base_url=args.base_url,
|
||||
port=args.port))
|
||||
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
|
||||
checkpointer, max_files=2, interval=600, max_epochs=10))
|
||||
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
|
||||
checkpointer_d, max_files=2, interval=600, max_epochs=10))
|
||||
|
||||
if not args.raster_only:
|
||||
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
|
||||
checkpointer_vect, max_files=2, interval=600, max_epochs=10))
|
||||
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
|
||||
checkpointer_d_vect, max_files=2, interval=600, max_epochs=10))
|
||||
|
||||
trainer.add_callback(
|
||||
ttools.callbacks.LRSchedulerCallback(interface.schedulers))
|
||||
|
||||
# Start training
|
||||
trainer.train(dataloader, starting_epoch=epoch,
|
||||
val_dataloader=val_dataloader,
|
||||
num_epochs=args.num_epochs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task",
|
||||
default="mnist",
|
||||
choices=["mnist", "quickdraw"])
|
||||
parser.add_argument("--generator",
|
||||
default="bezier_fc",
|
||||
choices=["bezier_fc", "fc", "rnn", "chain_rnn"],
|
||||
help="model to use as generator")
|
||||
|
||||
parser.add_argument("--raster_only", action="store_true", default=False,
|
||||
help="if true only train the raster baseline")
|
||||
|
||||
parser.add_argument("--standard_gan", dest="wgan_gp", action="store_false",
|
||||
default=True,
|
||||
help="if true, use regular GAN instead of WGAN")
|
||||
|
||||
# Training params
|
||||
parser.add_argument("--bs", type=int, default=4, help="batch size")
|
||||
parser.add_argument("--workers", type=int, default=4,
|
||||
help="number of dataloader threads")
|
||||
parser.add_argument("--num_epochs", type=int, default=200,
|
||||
help="number of epochs to train for")
|
||||
parser.add_argument("--lr", type=float, default=1e-4,
|
||||
help="learning rate")
|
||||
parser.add_argument("--lr_decay", type=float, default=0.9999,
|
||||
help="exponential learning rate decay rate")
|
||||
|
||||
# Model configuration
|
||||
parser.add_argument("--zdim", type=int, default=32,
|
||||
help="latent space dimension")
|
||||
parser.add_argument("--stroke_width", type=float, nargs=2,
|
||||
default=(0.5, 1.5),
|
||||
help="min and max stroke width")
|
||||
parser.add_argument("--num_strokes", type=int, default=16,
|
||||
help="number of strokes to generate")
|
||||
parser.add_argument("--raster_resolution", type=int, default=32,
|
||||
help="raster canvas resolution on each side")
|
||||
|
||||
# Viz params
|
||||
parser.add_argument("--freq", type=int, default=10,
|
||||
help="visualization frequency")
|
||||
parser.add_argument("--port", type=int, default=8097,
|
||||
help="visdom port")
|
||||
parser.add_argument("--server", default=None,
|
||||
help="visdom server if not local.")
|
||||
parser.add_argument("--base_url", default="", help="visdom entrypoint URL")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
pydiffvg.set_use_gpu(False)
|
||||
|
||||
ttools.set_logger(False)
|
||||
|
||||
train(args)
|
Reference in New Issue
Block a user