462 lines
17 KiB
Python
Executable File
462 lines
17 KiB
Python
Executable File
#!/bin/env python
|
|
"""Train a Sketch-RNN."""
|
|
import argparse
|
|
from enum import Enum
|
|
import os
|
|
import wget
|
|
|
|
import numpy as np
|
|
import torch as th
|
|
from torch.utils.data import DataLoader
|
|
import torchvision.datasets as dset
|
|
import torchvision.transforms as transforms
|
|
|
|
import ttools
|
|
import ttools.interfaces
|
|
from ttools.modules import networks
|
|
|
|
import pydiffvg
|
|
|
|
import rendering
|
|
import losses
|
|
import data
|
|
|
|
LOG = ttools.get_logger(__name__)
|
|
|
|
|
|
BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir)
|
|
OUTPUT = os.path.join(BASE_DIR, "results", "sketch_rnn_diffvg")
|
|
OUTPUT_BASELINE = os.path.join(BASE_DIR, "results", "sketch_rnn")
|
|
|
|
|
|
class SketchRNN(th.nn.Module):
|
|
class Encoder(th.nn.Module):
|
|
def __init__(self, hidden_size=512, dropout=0.9, zdim=128,
|
|
num_layers=1):
|
|
super(SketchRNN.Encoder, self).__init__()
|
|
self.hidden_size = hidden_size
|
|
self.num_layers = num_layers
|
|
self.zdim = zdim
|
|
|
|
self.lstm = th.nn.LSTM(5, hidden_size, num_layers=self.num_layers,
|
|
dropout=dropout, bidirectional=True,
|
|
batch_first=True)
|
|
|
|
# bidirectional model -> *2
|
|
self.mu_predictor = th.nn.Linear(2*hidden_size, zdim)
|
|
self.sigma_predictor = th.nn.Linear(2*hidden_size, zdim)
|
|
|
|
def forward(self, sequences, hidden_and_cell=None):
|
|
bs = sequences.shape[0]
|
|
if hidden_and_cell is None:
|
|
hidden = th.zeros(self.num_layers*2, bs, self.hidden_size).to(
|
|
sequences.device)
|
|
cell = th.zeros(self.num_layers*2, bs, self.hidden_size).to(
|
|
sequences.device)
|
|
hidden_and_cell = (hidden, cell)
|
|
|
|
out, hidden_and_cell = self.lstm(sequences, hidden_and_cell)
|
|
hidden = hidden_and_cell[0]
|
|
|
|
# Concat the forward/backward states
|
|
fc_input = th.cat([hidden[0], hidden[1]], 1)
|
|
|
|
# VAE params
|
|
mu = self.mu_predictor(fc_input)
|
|
log_sigma = self.sigma_predictor(fc_input)
|
|
|
|
# Sample a latent vector
|
|
sigma = th.exp(log_sigma/2.0)
|
|
z0 = th.randn(self.zdim, device=mu.device)
|
|
z = mu + sigma*z0
|
|
|
|
# KL divergence needs mu/sigma
|
|
return z, mu, log_sigma
|
|
|
|
class Decoder(th.nn.Module):
|
|
"""
|
|
The decoder outputs a sequence where each time step models (dx, dy) as
|
|
a mixture of `num_gaussians` 2D Gaussians and the state triplet is a
|
|
categorical distribution.
|
|
|
|
The model outputs at each time step:
|
|
- 5 parameters for each Gaussian: mu_x, mu_y, sigma_x, sigma_y,
|
|
rho_xy
|
|
- 1 logit for each Gaussian (the mixture weight)
|
|
- 3 logits for the state triplet probabilities
|
|
"""
|
|
def __init__(self, hidden_size=512, dropout=0.9, zdim=128,
|
|
num_layers=1, num_gaussians=20):
|
|
super(SketchRNN.Decoder, self).__init__()
|
|
self.hidden_size = hidden_size
|
|
self.num_layers = num_layers
|
|
self.zdim = zdim
|
|
self.num_gaussians = num_gaussians
|
|
|
|
# Maps the latent vector to an initial cell/hidden vector
|
|
self.hidden_cell_predictor = th.nn.Linear(zdim, 2*hidden_size)
|
|
|
|
self.lstm = th.nn.LSTM(
|
|
5 + zdim, hidden_size,
|
|
num_layers=self.num_layers, dropout=dropout,
|
|
batch_first=True)
|
|
|
|
self.parameters_predictor = th.nn.Linear(
|
|
hidden_size, num_gaussians + 5*num_gaussians + 3)
|
|
|
|
def forward(self, inputs, z, hidden_and_cell=None):
|
|
# Every step in the sequence takes the latent vector as input so we
|
|
# replicate it here
|
|
expanded_z = z.unsqueeze(1).repeat(1, inputs.shape[1], 1)
|
|
inputs = th.cat([inputs, expanded_z], 2)
|
|
|
|
bs, steps = inputs.shape[:2]
|
|
if hidden_and_cell is None:
|
|
# Initialize from latent vector
|
|
hidden_and_cell = self.hidden_cell_predictor(th.tanh(z))
|
|
hidden = hidden_and_cell[:, :self.hidden_size]
|
|
hidden = hidden.unsqueeze(0).contiguous()
|
|
cell = hidden_and_cell[:, self.hidden_size:]
|
|
cell = cell.unsqueeze(0).contiguous()
|
|
hidden_and_cell = (hidden, cell)
|
|
|
|
outputs, hidden_and_cell = self.lstm(inputs, hidden_and_cell)
|
|
hidden, cell = hidden_and_cell
|
|
|
|
# if self.training:
|
|
# At train time we want parameters for each time step
|
|
outputs = outputs.reshape(bs*steps, self.hidden_size)
|
|
params = self.parameters_predictor(outputs).view(bs, steps, -1)
|
|
|
|
pen_logits = params[..., -3:]
|
|
gaussian_params = params[..., :-3]
|
|
mixture_logits = gaussian_params[..., :self.num_gaussians]
|
|
gaussian_params = gaussian_params[..., self.num_gaussians:].view(
|
|
bs, steps, self.num_gaussians, -1)
|
|
|
|
return pen_logits, mixture_logits, gaussian_params, hidden_and_cell
|
|
|
|
def __init__(self, zdim=128, num_gaussians=20, encoder_dim=256,
|
|
decoder_dim=512):
|
|
super(SketchRNN, self).__init__()
|
|
self.encoder = SketchRNN.Encoder(zdim=zdim, hidden_size=encoder_dim)
|
|
self.decoder = SketchRNN.Decoder(zdim=zdim, hidden_size=decoder_dim,
|
|
num_gaussians=num_gaussians)
|
|
|
|
def forward(self, sequences):
|
|
# Encode the sequences as latent vectors
|
|
# We skip the first time step since it is the same for all sequences:
|
|
# (0, 0, 1, 0, 0)
|
|
z, mu, log_sigma = self.encoder(sequences[:, 1:])
|
|
|
|
# Decode the latent vector into a model sequence
|
|
# Do not process the last time step (it is an end-of-sequence token)
|
|
pen_logits, mixture_logits, gaussian_params, hidden_and_cell = \
|
|
self.decoder(sequences[:, :-1], z)
|
|
|
|
return {
|
|
"pen_logits": pen_logits,
|
|
"mixture_logits": mixture_logits,
|
|
"gaussian_params": gaussian_params,
|
|
"z": z,
|
|
"mu": mu,
|
|
"log_sigma": log_sigma,
|
|
"hidden_and_cell": hidden_and_cell,
|
|
}
|
|
|
|
def sample(self, sequences, temperature=1.0):
|
|
# Compute a latent vector conditionned based on a real sequence
|
|
z, _, _ = self.encoder(sequences[:, 1:])
|
|
|
|
start_of_seq = sequences[:, :1]
|
|
|
|
max_steps = sequences.shape[1] - 1 # last step is an end-of-seq token
|
|
|
|
output_sequences = th.zeros_like(sequences)
|
|
output_sequences[:, 0] = start_of_seq.squeeze(1)
|
|
|
|
current_input = start_of_seq
|
|
hidden_and_cell = None
|
|
for step in range(max_steps):
|
|
pen_logits, mixture_logits, gaussian_params, hidden_and_cell = \
|
|
self.decoder(current_input, z, hidden_and_cell=hidden_and_cell)
|
|
|
|
# Pen and displacement state for the next step
|
|
next_state = th.zeros_like(current_input)
|
|
|
|
# Adjust temperature to control randomness
|
|
mixture_logits = mixture_logits*temperature
|
|
pen_logits = pen_logits*temperature
|
|
|
|
# Select one of 3 pen states
|
|
pen_distrib = \
|
|
th.distributions.categorical.Categorical(logits=pen_logits)
|
|
pen_state = pen_distrib.sample()
|
|
|
|
# One-hot encoding of the state
|
|
next_state[:, :, 2:].scatter_(2, pen_state.unsqueeze(-1),
|
|
th.ones_like(next_state[:, :, 2:]))
|
|
|
|
# Select one of the Gaussians from the mixture
|
|
mixture_distrib = \
|
|
th.distributions.categorical.Categorical(logits=mixture_logits)
|
|
mixture_idx = mixture_distrib.sample()
|
|
|
|
# select the Gaussian parameter
|
|
mixture_idx = mixture_idx.unsqueeze(-1).unsqueeze(-1)
|
|
mixture_idx = mixture_idx.repeat(1, 1, 1, 5)
|
|
params = th.gather(gaussian_params, 2, mixture_idx).squeeze(2)
|
|
|
|
# Sample a Gaussian from the corresponding Gaussian
|
|
mu = params[..., :2]
|
|
sigma_x = params[..., 2].exp()
|
|
sigma_y = params[..., 3].exp()
|
|
rho_xy = th.tanh(params[..., 4])
|
|
cov = th.zeros(params.shape[0], params.shape[1], 2, 2,
|
|
device=params.device)
|
|
cov[..., 0, 0] = sigma_x.pow(2)*temperature
|
|
cov[..., 1, 1] = sigma_x.pow(2)*temperature
|
|
cov[..., 1, 0] = sigma_x*sigma_y*rho_xy*temperature
|
|
point_distrib = \
|
|
th.distributions.multivariate_normal.MultivariateNormal(
|
|
mu, scale_tril=cov)
|
|
point = point_distrib.sample()
|
|
next_state[:, :, :2] = point
|
|
|
|
# Commit step to output
|
|
output_sequences[:, step + 1] = next_state.squeeze(1)
|
|
|
|
# Prepare next recurrent step
|
|
current_input = next_state
|
|
|
|
return output_sequences
|
|
|
|
|
|
class SketchRNNCallback(ttools.callbacks.ImageDisplayCallback):
|
|
"""Simple callback that visualize images."""
|
|
def visualized_image(self, batch, step_data, is_val=False):
|
|
if not is_val:
|
|
# No need to render training data
|
|
return None
|
|
|
|
with th.no_grad():
|
|
# only display the first n drawings
|
|
n = 8
|
|
batch = batch[:n]
|
|
|
|
out_im = rendering.stroke2diffvg(step_data["sample"][:n])
|
|
im = rendering.stroke2diffvg(batch)
|
|
im = th.cat([im, out_im], 2)
|
|
|
|
return im
|
|
|
|
def caption(self, batch, step_data, is_val=False):
|
|
if is_val:
|
|
return "top: truth, bottom: sample"
|
|
else:
|
|
return "top: truth, bottom: sample"
|
|
|
|
|
|
class Interface(ttools.ModelInterface):
|
|
def __init__(self, model, lr=1e-3, lr_decay=0.9999,
|
|
kl_weight=0.5, kl_min_weight=0.01, kl_decay=0.99995,
|
|
device="cpu", grad_clip=1.0, sampling_temperature=0.4):
|
|
super(Interface, self).__init__()
|
|
self.grad_clip = grad_clip
|
|
self.sampling_temperature = sampling_temperature
|
|
|
|
self.model = model
|
|
self.device = device
|
|
self.model.to(self.device)
|
|
self.enc_opt = th.optim.Adam(self.model.encoder.parameters(), lr=lr)
|
|
self.dec_opt = th.optim.Adam(self.model.decoder.parameters(), lr=lr)
|
|
|
|
self.kl_weight = kl_weight
|
|
self.kl_min_weight = kl_min_weight
|
|
self.kl_decay = kl_decay
|
|
self.kl_loss = losses.KLDivergence()
|
|
|
|
self.schedulers = [
|
|
th.optim.lr_scheduler.ExponentialLR(self.enc_opt, lr_decay),
|
|
th.optim.lr_scheduler.ExponentialLR(self.dec_opt, lr_decay),
|
|
]
|
|
|
|
self.reconstruction_loss = losses.GaussianMixtureReconstructionLoss()
|
|
|
|
def optimizers(self):
|
|
return [self.enc_opt, self.dec_opt]
|
|
|
|
def training_step(self, batch):
|
|
batch = batch.to(self.device)
|
|
out = self.model(batch)
|
|
|
|
kl_loss = self.kl_loss(
|
|
out["mu"], out["log_sigma"])
|
|
|
|
# The target to predict is the next sequence step
|
|
targets = batch[:, 1:].to(self.device)
|
|
|
|
# Scale the KL divergence weight
|
|
try:
|
|
state = self.enc_opt.state_dict()["param_groups"][0]["params"][0]
|
|
optim_step = self.enc_opt.state_dict()["state"][state]["step"]
|
|
except KeyError:
|
|
optim_step = 0 # no step taken yet
|
|
kl_scaling = 1.0 - (1.0 -
|
|
self.kl_min_weight)*(self.kl_decay**optim_step)
|
|
kl_weight = self.kl_weight * kl_scaling
|
|
|
|
reconstruction_loss = self.reconstruction_loss(
|
|
out["pen_logits"], out["mixture_logits"],
|
|
out["gaussian_params"], targets)
|
|
loss = kl_loss*self.kl_weight + reconstruction_loss
|
|
|
|
self.enc_opt.zero_grad()
|
|
self.dec_opt.zero_grad()
|
|
loss.backward()
|
|
|
|
# clip gradients
|
|
enc_nrm = th.nn.utils.clip_grad_norm_(
|
|
self.model.encoder.parameters(), self.grad_clip)
|
|
dec_nrm = th.nn.utils.clip_grad_norm_(
|
|
self.model.decoder.parameters(), self.grad_clip)
|
|
|
|
if enc_nrm > self.grad_clip:
|
|
LOG.debug("Clipped encoder gradient (%.5f) to %.2f",
|
|
enc_nrm, self.grad_clip)
|
|
|
|
if dec_nrm > self.grad_clip:
|
|
LOG.debug("Clipped decoder gradient (%.5f) to %.2f",
|
|
dec_nrm, self.grad_clip)
|
|
|
|
self.enc_opt.step()
|
|
self.dec_opt.step()
|
|
|
|
return {
|
|
"loss": loss.item(),
|
|
"kl_loss": kl_loss.item(),
|
|
"kl_weight": kl_weight,
|
|
"recons_loss": reconstruction_loss.item(),
|
|
"lr": self.enc_opt.param_groups[0]["lr"],
|
|
}
|
|
|
|
def init_validation(self):
|
|
return dict(sample=None)
|
|
|
|
def validation_step(self, batch, running_data):
|
|
# Switch to eval mode for dropout, batchnorm, etc
|
|
self.model.eval()
|
|
with th.no_grad():
|
|
sample = self.model.sample(
|
|
batch.to(self.device), temperature=self.sampling_temperature)
|
|
running_data["sample"] = sample
|
|
self.model.train()
|
|
return running_data
|
|
|
|
|
|
def train(args):
|
|
th.manual_seed(0)
|
|
np.random.seed(0)
|
|
|
|
dataset = data.QuickDrawDataset(args.dataset)
|
|
dataloader = DataLoader(
|
|
dataset, batch_size=args.bs, num_workers=4, shuffle=True,
|
|
pin_memory=False)
|
|
|
|
val_dataset = [s for idx, s in enumerate(dataset) if idx < 8]
|
|
val_dataloader = DataLoader(
|
|
val_dataset, batch_size=8, num_workers=4, shuffle=False,
|
|
pin_memory=False)
|
|
|
|
model_params = {
|
|
"zdim": args.zdim,
|
|
"num_gaussians": args.num_gaussians,
|
|
"encoder_dim": args.encoder_dim,
|
|
"decoder_dim": args.decoder_dim,
|
|
}
|
|
model = SketchRNN(**model_params)
|
|
model.train()
|
|
|
|
device = "cpu"
|
|
if th.cuda.is_available():
|
|
device = "cuda"
|
|
LOG.info("Using CUDA")
|
|
|
|
interface = Interface(model, lr=args.lr, lr_decay=args.lr_decay,
|
|
kl_decay=args.kl_decay, kl_weight=args.kl_weight,
|
|
sampling_temperature=args.sampling_temperature,
|
|
device=device)
|
|
|
|
chkpt = OUTPUT_BASELINE
|
|
env_name = "sketch_rnn"
|
|
|
|
# Resume from checkpoint, if any
|
|
checkpointer = ttools.Checkpointer(
|
|
chkpt, model, meta=model_params,
|
|
optimizers=interface.optimizers(),
|
|
schedulers=interface.schedulers)
|
|
extras, meta = checkpointer.load_latest()
|
|
epoch = extras["epoch"] if extras and "epoch" in extras.keys() else 0
|
|
|
|
if meta is not None and meta != model_params:
|
|
LOG.info("Checkpoint's metaparams differ "
|
|
"from CLI, aborting: %s and %s", meta, model_params)
|
|
|
|
trainer = ttools.Trainer(interface)
|
|
|
|
# Add callbacks
|
|
losses = ["loss", "kl_loss", "recons_loss"]
|
|
training_debug = ["lr", "kl_weight"]
|
|
trainer.add_callback(ttools.callbacks.ProgressBarCallback(
|
|
keys=losses, val_keys=None))
|
|
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
|
|
keys=losses, val_keys=None, env=env_name, port=args.port))
|
|
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
|
|
keys=training_debug, smoothing=0, val_keys=None, env=env_name,
|
|
port=args.port))
|
|
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
|
|
checkpointer, max_files=2, interval=600, max_epochs=10))
|
|
trainer.add_callback(
|
|
ttools.callbacks.LRSchedulerCallback(interface.schedulers))
|
|
|
|
trainer.add_callback(SketchRNNCallback(
|
|
env=env_name, win="samples", port=args.port, frequency=args.freq))
|
|
|
|
# Start training
|
|
trainer.train(dataloader, starting_epoch=epoch,
|
|
val_dataloader=val_dataloader,
|
|
num_epochs=args.num_epochs)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--dataset", default="cat.npz")
|
|
|
|
# Training params
|
|
parser.add_argument("--bs", type=int, default=100)
|
|
parser.add_argument("--num_epochs", type=int, default=10000)
|
|
parser.add_argument("--lr", type=float, default=1e-4)
|
|
parser.add_argument("--lr_decay", type=float, default=0.9999)
|
|
parser.add_argument("--kl_weight", type=float, default=0.5)
|
|
parser.add_argument("--kl_decay", type=float, default=0.99995)
|
|
|
|
# Model configuration
|
|
parser.add_argument("--zdim", type=int, default=128)
|
|
parser.add_argument("--num_gaussians", type=int, default=20)
|
|
parser.add_argument("--encoder_dim", type=int, default=256)
|
|
parser.add_argument("--decoder_dim", type=int, default=512)
|
|
|
|
parser.add_argument("--sampling_temperature", type=float, default=0.4,
|
|
help="controls sampling randomness. "
|
|
"0.0: deterministic, 1.0: unchanged")
|
|
|
|
# Viz params
|
|
parser.add_argument("--freq", type=int, default=100)
|
|
parser.add_argument("--port", type=int, default=5000)
|
|
|
|
args = parser.parse_args()
|
|
|
|
pydiffvg.set_use_gpu(th.cuda.is_available())
|
|
|
|
train(args)
|