525 lines
18 KiB
Python
Executable File
525 lines
18 KiB
Python
Executable File
#!/bin/env python
|
|
"""Train a Sketch-VAE."""
|
|
import argparse
|
|
from enum import Enum
|
|
import os
|
|
import wget
|
|
import time
|
|
|
|
import numpy as np
|
|
import torch as th
|
|
from torch.utils.data import DataLoader
|
|
import torchvision.datasets as dset
|
|
import torchvision.transforms as transforms
|
|
|
|
import ttools
|
|
import ttools.interfaces
|
|
from ttools.modules import networks
|
|
|
|
import rendering
|
|
import losses
|
|
import modules
|
|
import data
|
|
|
|
import pydiffvg
|
|
|
|
LOG = ttools.get_logger(__name__)
|
|
|
|
|
|
BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir)
|
|
OUTPUT = os.path.join(BASE_DIR, "results")
|
|
|
|
|
|
class SketchVAE(th.nn.Module):
|
|
class ImageEncoder(th.nn.Module):
|
|
def __init__(self, image_size=64, width=64, zdim=128):
|
|
super(SketchVAE.ImageEncoder, self).__init__()
|
|
self.zdim = zdim
|
|
|
|
self.net = th.nn.Sequential(
|
|
th.nn.Conv2d(4, width, 5, padding=2),
|
|
th.nn.InstanceNorm2d(width),
|
|
th.nn.ReLU(inplace=True),
|
|
# 64x64
|
|
|
|
th.nn.Conv2d(width, width, 5, padding=2),
|
|
th.nn.InstanceNorm2d(width),
|
|
th.nn.ReLU( inplace=True),
|
|
# 64x64
|
|
|
|
th.nn.Conv2d(width, 2*width, 5, stride=1, padding=2),
|
|
th.nn.InstanceNorm2d(2*width),
|
|
th.nn.ReLU( inplace=True),
|
|
# 32x32
|
|
|
|
th.nn.Conv2d(2*width, 2*width, 5, stride=2, padding=2),
|
|
th.nn.InstanceNorm2d(2*width),
|
|
th.nn.ReLU( inplace=True),
|
|
# 16x16
|
|
|
|
th.nn.Conv2d(2*width, 2*width, 5, stride=2, padding=2),
|
|
th.nn.InstanceNorm2d(2*width),
|
|
th.nn.ReLU( inplace=True),
|
|
# 16x16
|
|
|
|
th.nn.Conv2d(2*width, 2*width, 5, stride=2, padding=2),
|
|
th.nn.InstanceNorm2d(2*width),
|
|
th.nn.ReLU( inplace=True),
|
|
# 8x8
|
|
|
|
th.nn.Conv2d(2*width, 2*width, 5, stride=2, padding=2),
|
|
th.nn.InstanceNorm2d(2*width),
|
|
th.nn.ReLU( inplace=True),
|
|
# 4x4
|
|
|
|
modules.Flatten(),
|
|
th.nn.Linear(4*4*2*width, 2*zdim)
|
|
)
|
|
|
|
def forward(self, images):
|
|
features = self.net(images)
|
|
|
|
# VAE params
|
|
mu = features[:, :self.zdim]
|
|
log_sigma = features[:, self.zdim:]
|
|
|
|
# Sample a latent vector
|
|
sigma = th.exp(log_sigma/2.0)
|
|
z0 = th.randn(self.zdim, device=mu.device)
|
|
z = mu + sigma*z0
|
|
|
|
# KL divergence needs mu/sigma
|
|
return z, mu, log_sigma
|
|
|
|
class ImageDecoder(th.nn.Module):
|
|
""""""
|
|
def __init__(self, zdim=128, image_size=64, width=64):
|
|
super(SketchVAE.ImageDecoder, self).__init__()
|
|
self.zdim = zdim
|
|
self.width = width
|
|
|
|
self.embedding = th.nn.Linear(zdim, 4*4*2*width)
|
|
|
|
self.net = th.nn.Sequential(
|
|
th.nn.ConvTranspose2d(2*width, 2*width, 4, padding=1, stride=2),
|
|
th.nn.InstanceNorm2d(2*width),
|
|
th.nn.ReLU( inplace=True),
|
|
# 8x8
|
|
|
|
th.nn.ConvTranspose2d(2*width, 2*width, 4, padding=1, stride=2),
|
|
th.nn.InstanceNorm2d(2*width),
|
|
th.nn.ReLU( inplace=True),
|
|
# 16x16
|
|
|
|
th.nn.ConvTranspose2d(2*width, 2*width, 4, padding=1, stride=2),
|
|
th.nn.InstanceNorm2d(2*width),
|
|
th.nn.ReLU( inplace=True),
|
|
# 16x16
|
|
|
|
th.nn.Conv2d(2*width, 2*width, 5, padding=2, stride=1),
|
|
th.nn.InstanceNorm2d(2*width),
|
|
th.nn.ReLU( inplace=True),
|
|
# 16x16
|
|
|
|
th.nn.ConvTranspose2d(2*width, 2*width, 4, padding=1, stride=2),
|
|
th.nn.InstanceNorm2d(2*width),
|
|
th.nn.ReLU( inplace=True),
|
|
# 32x32
|
|
|
|
th.nn.Conv2d(2*width, width, 5, padding=2, stride=1),
|
|
th.nn.InstanceNorm2d(width),
|
|
th.nn.ReLU( inplace=True),
|
|
# 32x32
|
|
|
|
th.nn.ConvTranspose2d(width, width, 5, padding=2, stride=1),
|
|
th.nn.InstanceNorm2d(width),
|
|
th.nn.ReLU( inplace=True),
|
|
# 64x64
|
|
|
|
th.nn.Conv2d(width, width, 5, padding=2, stride=1),
|
|
th.nn.InstanceNorm2d(width),
|
|
th.nn.ReLU( inplace=True),
|
|
# 64x64
|
|
|
|
th.nn.Conv2d(width, 4, 5, padding=2, stride=1),
|
|
)
|
|
|
|
def forward(self, z):
|
|
bs = z.shape[0]
|
|
im = self.embedding(z).view(bs, 2*self.width, 4, 4)
|
|
out = self.net(im)
|
|
return out
|
|
|
|
class SketchDecoder(th.nn.Module):
|
|
"""
|
|
The decoder outputs a sequence where each time step models (dx, dy,
|
|
opacity).
|
|
"""
|
|
def __init__(self, sequence_length, hidden_size=512, dropout=0.9,
|
|
zdim=128, num_layers=3):
|
|
super(SketchVAE.SketchDecoder, self).__init__()
|
|
self.sequence_length = sequence_length
|
|
self.hidden_size = hidden_size
|
|
self.num_layers = num_layers
|
|
self.zdim = zdim
|
|
|
|
# Maps the latent vector to an initial cell/hidden vector
|
|
self.hidden_cell_predictor = th.nn.Linear(zdim, 2*hidden_size*num_layers)
|
|
|
|
self.lstm = th.nn.LSTM(
|
|
zdim, hidden_size,
|
|
num_layers=self.num_layers, dropout=dropout,
|
|
batch_first=True)
|
|
|
|
self.dxdy_predictor = th.nn.Sequential(
|
|
th.nn.Linear(hidden_size, 2),
|
|
th.nn.Tanh(),
|
|
)
|
|
self.opacity_predictor = th.nn.Sequential(
|
|
th.nn.Linear(hidden_size, 1),
|
|
th.nn.Sigmoid(),
|
|
)
|
|
|
|
def forward(self, z, hidden_and_cell=None):
|
|
# Every step in the sequence takes the latent vector as input so we
|
|
# replicate it here
|
|
bs = z.shape[0]
|
|
steps = self.sequence_length - 1 # no need to predict the start of sequence
|
|
expanded_z = z.unsqueeze(1).repeat(1, steps, 1)
|
|
|
|
if hidden_and_cell is None:
|
|
# Initialize from latent vector
|
|
hidden_and_cell = self.hidden_cell_predictor(
|
|
th.tanh(z))
|
|
hidden = hidden_and_cell[:, :self.hidden_size*self.num_layers]
|
|
hidden = hidden.view(-1, self.num_layers, self.hidden_size)
|
|
hidden = hidden.permute(1, 0, 2).contiguous()
|
|
# hidden = hidden.unsqueeze(1).contiguous()
|
|
cell = hidden_and_cell[:, self.hidden_size*self.num_layers:]
|
|
cell = cell.view(-1, self.num_layers, self.hidden_size)
|
|
cell = cell.permute(1, 0, 2).contiguous()
|
|
# cell = cell.unsqueeze(1).contiguous()
|
|
hidden_and_cell = (hidden, cell)
|
|
|
|
outputs, hidden_and_cell = self.lstm(expanded_z, hidden_and_cell)
|
|
hidden, cell = hidden_and_cell
|
|
|
|
dxdy = self.dxdy_predictor(
|
|
outputs.reshape(bs*steps, self.hidden_size)).view(bs, steps, -1)
|
|
|
|
opacity = self.opacity_predictor(
|
|
outputs.reshape(bs*steps, self.hidden_size)).view(bs, steps, -1)
|
|
|
|
strokes = th.cat([dxdy, opacity], -1)
|
|
|
|
return strokes
|
|
|
|
def __init__(self, sequence_length, zdim=128, image_size=64):
|
|
super(SketchVAE, self).__init__()
|
|
self.im_encoder = SketchVAE.ImageEncoder(
|
|
zdim=zdim, image_size=image_size)
|
|
self.im_decoder = SketchVAE.ImageDecoder(
|
|
zdim=zdim, image_size=image_size)
|
|
self.sketch_decoder = SketchVAE.SketchDecoder(
|
|
sequence_length, zdim=zdim)
|
|
|
|
def forward(self, images):
|
|
# Encode the images as latent vectors
|
|
z, mu, log_sigma = self.im_encoder(images)
|
|
decoded_im = self.im_decoder(z)
|
|
decoded_sketch = self.sketch_decoder(z)
|
|
|
|
return {
|
|
"decoded_im": decoded_im,
|
|
"decoded_sketch": decoded_sketch,
|
|
"z": z,
|
|
"mu": mu,
|
|
"log_sigma": log_sigma,
|
|
}
|
|
|
|
|
|
class SketchVAECallback(ttools.callbacks.ImageDisplayCallback):
|
|
"""Simple callback that visualize images."""
|
|
def visualized_image(self, batch, step_data, is_val=False):
|
|
if is_val:
|
|
return None
|
|
|
|
# only display the first n drawings
|
|
n = 8
|
|
gt = step_data["gt_image"][:n].detach()
|
|
vae_im = step_data["vae_image"][:n].detach()
|
|
sketch_im = step_data["sketch_image"][:n].detach()
|
|
|
|
rendering = th.cat([gt, vae_im, sketch_im], 2)
|
|
rendering = th.clamp(rendering, 0, 1)
|
|
alpha = rendering[:, 3:4]
|
|
rendering = rendering[:, :3] * alpha
|
|
|
|
return rendering
|
|
|
|
def caption(self, batch, step_data, is_val=False):
|
|
if is_val:
|
|
return ""
|
|
else:
|
|
return "top: truth, middle: vae sample, output: rnn-output"
|
|
|
|
|
|
|
|
|
|
class Interface(ttools.ModelInterface):
|
|
def __init__(self, model, lr=1e-4, lr_decay=0.9999,
|
|
kl_weight=0.5, kl_min_weight=0.01, kl_decay=0.99995,
|
|
raster_resolution=64, absolute_coords=False,
|
|
device="cpu", grad_clip=1.0):
|
|
super(Interface, self).__init__()
|
|
|
|
self.grad_clip = grad_clip
|
|
self.raster_resolution = raster_resolution
|
|
self.absolute_coords = absolute_coords
|
|
|
|
self.model = model
|
|
self.device = device
|
|
self.model.to(self.device)
|
|
self.im_enc_opt = th.optim.Adam(
|
|
self.model.im_encoder.parameters(), lr=lr)
|
|
self.im_dec_opt = th.optim.Adam(
|
|
self.model.im_decoder.parameters(), lr=lr)
|
|
self.sketch_dec_opt = th.optim.Adam(
|
|
self.model.sketch_decoder.parameters(), lr=lr)
|
|
|
|
self.kl_weight = kl_weight
|
|
self.kl_min_weight = kl_min_weight
|
|
self.kl_decay = kl_decay
|
|
self.kl_loss = losses.KLDivergence()
|
|
|
|
self.schedulers = [
|
|
th.optim.lr_scheduler.ExponentialLR(self.im_enc_opt, lr_decay),
|
|
th.optim.lr_scheduler.ExponentialLR(self.im_dec_opt, lr_decay),
|
|
th.optim.lr_scheduler.ExponentialLR(self.sketch_dec_opt, lr_decay),
|
|
]
|
|
|
|
# include loss on alpha
|
|
self.im_loss = losses.MultiscaleMSELoss(channels=4).to(self.device)
|
|
|
|
def optimizers(self):
|
|
return [self.im_enc_opt, self.im_dec_opt, self.sketch_dec_opt]
|
|
|
|
def kl_scaling(self):
|
|
# Scale the KL divergence weight
|
|
try:
|
|
state = self.im_enc_opt.state_dict()["param_groups"][0]["params"][0]
|
|
optim_step = self.im_enc_opt.state_dict()["state"][state]["step"]
|
|
except KeyError:
|
|
optim_step = 0 # no step taken yet
|
|
kl_scaling = 1.0 - (1.0 -
|
|
self.kl_min_weight)*(self.kl_decay**optim_step)
|
|
return kl_scaling
|
|
|
|
def training_step(self, batch):
|
|
gt_strokes, gt_im = batch
|
|
gt_strokes = gt_strokes.to(self.device)
|
|
gt_im = gt_im.to(self.device)
|
|
|
|
out = self.model(gt_im)
|
|
|
|
kl_loss = self.kl_loss(
|
|
out["mu"], out["log_sigma"])
|
|
kl_weight = self.kl_weight * self.kl_scaling()
|
|
|
|
# add start of sequence
|
|
sos = gt_strokes[:, :1]
|
|
sketch = th.cat([sos, out["decoded_sketch"]], 1)
|
|
|
|
vae_im = out["decoded_im"]
|
|
|
|
# start = time.time()
|
|
sketch_im = rendering.opacityStroke2diffvg(
|
|
sketch, canvas_size=self.raster_resolution, debug=False,
|
|
force_cpu=True, relative=not self.absolute_coords)
|
|
# elapsed = (time.time() - start)*1000
|
|
# print("out rendering took %.2fms" % elapsed)
|
|
|
|
vae_im_loss = self.im_loss(vae_im, gt_im)
|
|
sketch_im_loss = self.im_loss(sketch_im, gt_im)
|
|
|
|
# vae_im_loss = th.nn.functional.mse_loss(vae_im, gt_im)
|
|
# sketch_im_loss = th.nn.functional.mse_loss(sketch_im, gt_im)
|
|
|
|
loss = vae_im_loss + kl_loss*kl_weight + sketch_im_loss
|
|
|
|
self.im_enc_opt.zero_grad()
|
|
self.im_dec_opt.zero_grad()
|
|
self.sketch_dec_opt.zero_grad()
|
|
loss.backward()
|
|
|
|
# clip gradients
|
|
enc_nrm = th.nn.utils.clip_grad_norm_(
|
|
self.model.im_encoder.parameters(), self.grad_clip)
|
|
dec_nrm = th.nn.utils.clip_grad_norm_(
|
|
self.model.im_decoder.parameters(), self.grad_clip)
|
|
sketch_dec_nrm = th.nn.utils.clip_grad_norm_(
|
|
self.model.sketch_decoder.parameters(), self.grad_clip)
|
|
|
|
if enc_nrm > self.grad_clip:
|
|
LOG.debug("Clipped encoder gradient (%.5f) to %.2f",
|
|
enc_nrm, self.grad_clip)
|
|
|
|
if dec_nrm > self.grad_clip:
|
|
LOG.debug("Clipped decoder gradient (%.5f) to %.2f",
|
|
dec_nrm, self.grad_clip)
|
|
|
|
if sketch_dec_nrm > self.grad_clip:
|
|
LOG.debug("Clipped sketch decoder gradient (%.5f) to %.2f",
|
|
sketch_dec_nrm, self.grad_clip)
|
|
|
|
self.im_enc_opt.step()
|
|
self.im_dec_opt.step()
|
|
self.sketch_dec_opt.step()
|
|
|
|
return {
|
|
"vae_image": vae_im,
|
|
"sketch_image": sketch_im,
|
|
"gt_image": gt_im,
|
|
"loss": loss.item(),
|
|
"vae_im_loss": vae_im_loss.item(),
|
|
"sketch_im_loss": sketch_im_loss.item(),
|
|
"kl_loss": kl_loss.item(),
|
|
"kl_weight": kl_weight,
|
|
"lr": self.im_enc_opt.param_groups[0]["lr"],
|
|
}
|
|
|
|
def init_validation(self):
|
|
return dict(sample=None)
|
|
|
|
def validation_step(self, batch, running_data):
|
|
# Switch to eval mode for dropout, batchnorm, etc
|
|
# self.model.eval()
|
|
# with th.no_grad():
|
|
# # sample = self.model.sample(
|
|
# # batch.to(self.device), temperature=self.sampling_temperature)
|
|
# # running_data["sample"] = sample
|
|
# self.model.train()
|
|
return running_data
|
|
|
|
|
|
def train(args):
|
|
th.manual_seed(0)
|
|
np.random.seed(0)
|
|
|
|
dataset = data.FixedLengthQuickDrawDataset(
|
|
args.dataset, max_seq_length=args.sequence_length,
|
|
canvas_size=args.raster_resolution)
|
|
dataloader = DataLoader(
|
|
dataset, batch_size=args.bs, num_workers=args.workers, shuffle=True)
|
|
|
|
# val_dataset = [s for idx, s in enumerate(dataset) if idx < 8]
|
|
# val_dataloader = DataLoader(
|
|
# val_dataset, batch_size=8, num_workers=4, shuffle=False)
|
|
|
|
val_dataloader = None
|
|
|
|
model_params = {
|
|
"zdim": args.zdim,
|
|
"sequence_length": args.sequence_length,
|
|
"image_size": args.raster_resolution,
|
|
# "encoder_dim": args.encoder_dim,
|
|
# "decoder_dim": args.decoder_dim,
|
|
}
|
|
model = SketchVAE(**model_params)
|
|
model.train()
|
|
|
|
LOG.info("Model parameters:\n%s", model_params)
|
|
|
|
device = "cpu"
|
|
if th.cuda.is_available():
|
|
device = "cuda"
|
|
LOG.info("Using CUDA")
|
|
|
|
interface = Interface(model, raster_resolution=args.raster_resolution,
|
|
lr=args.lr, lr_decay=args.lr_decay,
|
|
kl_decay=args.kl_decay, kl_weight=args.kl_weight,
|
|
absolute_coords=args.absolute_coordinates,
|
|
device=device)
|
|
|
|
env_name = "sketch_vae"
|
|
if args.custom_name is not None:
|
|
env_name += "_" + args.custom_name
|
|
|
|
if args.absolute_coordinates:
|
|
env_name += "_abs_coords"
|
|
|
|
chkpt = os.path.join(OUTPUT, env_name)
|
|
|
|
# Resume from checkpoint, if any
|
|
checkpointer = ttools.Checkpointer(
|
|
chkpt, model, meta=model_params,
|
|
optimizers=interface.optimizers(),
|
|
schedulers=interface.schedulers)
|
|
extras, meta = checkpointer.load_latest()
|
|
epoch = extras["epoch"] if extras and "epoch" in extras.keys() else 0
|
|
|
|
if meta is not None and meta != model_params:
|
|
LOG.info("Checkpoint's metaparams differ "
|
|
"from CLI, aborting: %s and %s", meta, model_params)
|
|
|
|
trainer = ttools.Trainer(interface)
|
|
|
|
# Add callbacks
|
|
losses = ["loss", "kl_loss", "vae_im_loss", "sketch_im_loss"]
|
|
training_debug = ["lr", "kl_weight"]
|
|
trainer.add_callback(ttools.callbacks.ProgressBarCallback(
|
|
keys=losses, val_keys=None))
|
|
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
|
|
keys=losses, val_keys=None, env=env_name, port=args.port))
|
|
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
|
|
keys=training_debug, smoothing=0, val_keys=None, env=env_name,
|
|
port=args.port))
|
|
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
|
|
checkpointer, max_files=2, interval=600, max_epochs=10))
|
|
trainer.add_callback(
|
|
ttools.callbacks.LRSchedulerCallback(interface.schedulers))
|
|
|
|
trainer.add_callback(SketchVAECallback(
|
|
env=env_name, win="samples", port=args.port, frequency=args.freq))
|
|
|
|
# Start training
|
|
trainer.train(dataloader, starting_epoch=epoch,
|
|
val_dataloader=val_dataloader,
|
|
num_epochs=args.num_epochs)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--dataset", default="cat.npz")
|
|
|
|
parser.add_argument("--absolute_coordinates", action="store_true",
|
|
default=False)
|
|
|
|
parser.add_argument("--custom_name")
|
|
|
|
# Training params
|
|
parser.add_argument("--bs", type=int, default=1)
|
|
parser.add_argument("--workers", type=int, default=0)
|
|
parser.add_argument("--num_epochs", type=int, default=10000)
|
|
parser.add_argument("--lr", type=float, default=1e-4)
|
|
parser.add_argument("--lr_decay", type=float, default=0.9999)
|
|
parser.add_argument("--kl_weight", type=float, default=0.5)
|
|
parser.add_argument("--kl_decay", type=float, default=0.99995)
|
|
|
|
# Model configuration
|
|
parser.add_argument("--zdim", type=int, default=128)
|
|
parser.add_argument("--sequence_length", type=int, default=50)
|
|
parser.add_argument("--raster_resolution", type=int, default=64)
|
|
# parser.add_argument("--encoder_dim", type=int, default=256)
|
|
# parser.add_argument("--decoder_dim", type=int, default=512)
|
|
|
|
# Viz params
|
|
parser.add_argument("--freq", type=int, default=10)
|
|
parser.add_argument("--port", type=int, default=5000)
|
|
|
|
args = parser.parse_args()
|
|
|
|
pydiffvg.set_use_gpu(False)
|
|
|
|
train(args)
|