initial commit

This commit is contained in:
Tzu-Mao Li
2020-09-03 22:30:30 -04:00
commit 413a3e5cee
148 changed files with 138536 additions and 0 deletions

1
apps/generative_models/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
.gdb_history

View File

@@ -0,0 +1,5 @@
# Usage
For the GAN models, see `train_gan.py`. Generate samples from a pretrained using `eval_gan.py`
For the VAE models, see `mnist_vae.py`.

View File

View File

@@ -0,0 +1,229 @@
import os
import time
import torch as th
import numpy as np
import torchvision.datasets as dset
import torchvision.transforms as transforms
import imageio
import ttools
import rendering
BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir)
DATA = os.path.join(BASE_DIR, "data")
LOG = ttools.get_logger(__name__)
class QuickDrawImageDataset(th.utils.data.Dataset):
BASE_DATA_URL = \
"https://console.cloud.google.com/storage/browser/_details/quickdraw_dataset/full/numpy_bitmap/cat.npy"
"""
Args:
spatial_limit(int): maximum spatial extent in pixels.
"""
def __init__(self, imsize, train=True):
super(QuickDrawImageDataset, self).__init__()
file = os.path.join(DATA, "cat.npy")
self.imsize = imsize
if not os.path.exists(file):
msg = "Dataset file %s does not exist, please download"
" it from %s" % (file, QuickDrawImageDataset.BASE_DATA_URL)
LOG.error(msg)
raise RuntimeError(msg)
self.data = np.load(file, allow_pickle=True, encoding="latin1")
def __len__(self):
return self.data.shape[0]
def __getitem__(self, idx):
im = np.reshape(self.data[idx], (1, 1, 28, 28))
im = th.from_numpy(im).float() / 255.0
im = th.nn.functional.interpolate(im, size=(self.imsize, self.imsize))
# Bring it to [-1, 1]
im = th.clamp(im, 0, 1)
im -= 0.5
im /= 0.5
return im.squeeze(0)
class QuickDrawDataset(th.utils.data.Dataset):
BASE_DATA_URL = \
"https://storage.cloud.google.com/quickdraw_dataset/sketchrnn"
"""
Args:
spatial_limit(int): maximum spatial extent in pixels.
"""
def __init__(self, dataset, mode="train",
max_seq_length=250,
spatial_limit=1000):
super(QuickDrawDataset, self).__init__()
file = os.path.join(DATA, "sketchrnn_"+dataset)
remote = os.path.join(QuickDrawDataset.BASE_DATA_URL, dataset)
self.max_seq_length = max_seq_length
self.spatial_limit = spatial_limit
if mode not in ["train", "test", "valid"]:
return ValueError("Only allowed data mode are 'train' and 'test',"
" 'valid'.")
if not os.path.exists(file):
msg = "Dataset file %s does not exist, please download"
" it from %s" % (file, remote)
LOG.error(msg)
raise RuntimeError(msg)
data = np.load(file, allow_pickle=True, encoding="latin1")[mode]
data = self.purify(data)
data = self.normalize(data)
# Length of longest sequence in the dataset
self.nmax = max([len(seq) for seq in data])
self.sketches = data
def __repr__(self):
return "Dataset with %d sequences of max length %d" % \
(len(self.sketches), self.nmax)
def __len__(self):
return len(self.sketches)
def __getitem__(self, idx):
"""Return the idx-th stroke in 5-D format, padded to length (Nmax+2).
The first and last element of the sequence are fixed to "start-" and
"end-of-sequence" token.
dx, dy, + 3 numbers for one-hot encoding of state:
1 0 0: pen touching paper till next point
0 1 0: pen lifted from paper after current point
0 0 1: drawing has ended, next points (including current will not be
drawn)
"""
sample_data = self.sketches[idx]
# Allow two extra slots for start/end of sequence tokens
sample = np.zeros((self.nmax+2, 5), dtype=np.float32)
n = sample_data.shape[0]
# normalize dx, dy
deltas = sample_data[:, :2]
# Absolute coordinates
positions = deltas[..., :2].cumsum(0)
maxi = np.abs(positions).max() + 1e-8
deltas = deltas / (1.1 * maxi) # leave some margin on edges
# fill in dx, dy coordinates
sample[1:n+1, :2] = deltas
# on paper indicator: 0 means touching paper in the 3d format, flip it
sample[1:n+1, 2] = 1 - sample_data[:, 2]
# off-paper indicator, complement of previous flag
sample[1:n+1, 3] = 1 - sample[1:n+1, 2]
# fill with end of sequence tokens for the remainder
sample[n+1:, 4] = 1
# Start of sequence token
sample[0] = [0, 0, 1, 0, 0]
return sample
def purify(self, strokes):
"""removes to small or too long sequences + removes large gaps"""
data = []
for seq in strokes:
if seq.shape[0] <= self.max_seq_length:
# and seq.shape[0] > 10:
# Limit large spatial gaps
seq = np.minimum(seq, self.spatial_limit)
seq = np.maximum(seq, -self.spatial_limit)
seq = np.array(seq, dtype=np.float32)
data.append(seq)
return data
def calculate_normalizing_scale_factor(self, strokes):
"""Calculate the normalizing factor explained in appendix of
sketch-rnn."""
data = []
for i, stroke_i in enumerate(strokes):
for j, pt in enumerate(strokes[i]):
data.append(pt[0])
data.append(pt[1])
data = np.array(data)
return np.std(data)
def normalize(self, strokes):
"""Normalize entire dataset (delta_x, delta_y) by the scaling
factor."""
data = []
scale_factor = self.calculate_normalizing_scale_factor(strokes)
for seq in strokes:
seq[:, 0:2] /= scale_factor
data.append(seq)
return data
class FixedLengthQuickDrawDataset(QuickDrawDataset):
"""A variant of the QuickDraw dataset where the strokes are represented as
a fixed-length sequence of triplets (dx, dy, opacity), where opacity = 0, 1.
"""
def __init__(self, *args, canvas_size=64, **kwargs):
super(FixedLengthQuickDrawDataset, self).__init__(*args, **kwargs)
self.canvas_size = canvas_size
def __getitem__(self, idx):
sample = super(FixedLengthQuickDrawDataset, self).__getitem__(idx)
# We construct a stroke opacity variable from the pen down state, dx, dy remain unchanged
strokes = sample[:, :3]
im = np.zeros((1, 1))
# render image
# start = time.time()
im = rendering.opacityStroke2diffvg(
th.from_numpy(strokes).unsqueeze(0), canvas_size=self.canvas_size,
relative=True, debug=False)
im = im.squeeze(0).numpy()
# elapsed = (time.time() - start)*1000
# print("item %d pipeline gt rendering took %.2fms" % (idx, elapsed))
return strokes, im
class MNISTDataset(th.utils.data.Dataset):
def __init__(self, imsize, train=True):
super(MNISTDataset, self).__init__()
self.mnist = dset.MNIST(root=os.path.join(DATA, "mnist"),
train=train,
download=True,
transform=transforms.Compose([
transforms.Resize((imsize, imsize)),
transforms.ToTensor(),
]))
def __len__(self):
return len(self.mnist)
def __getitem__(self, idx):
im, label = self.mnist[idx]
# make sure data uses [0, 1] range
im -= im.min()
im /= im.max() + 1e-8
# Bring it to [-1, 1]
im -= 0.5
im /= 0.5
return im

View File

@@ -0,0 +1,182 @@
"""Evaluate a pretrained GAN model.
Usage:
`python eval_gan.py <path/to/model/folder>`, e.g.
`../results/quickdraw_gan_vector_bezier_fc_wgan`.
"""
import os
import argparse
import torch as th
import numpy as np
import ttools
import imageio
from subprocess import call
import pydiffvg
import models
LOG = ttools.get_logger(__name__)
def postprocess(im, invert=False):
im = th.clamp((im + 1.0) / 2.0, 0, 1)
if invert:
im = (1.0 - im)
im = ttools.tensor2image(im)
return im
def imsave(im, path):
os.makedirs(os.path.dirname(path), exist_ok=True)
imageio.imwrite(path, im)
def save_scene(scn, path):
os.makedirs(os.path.dirname(path), exist_ok=True)
pydiffvg.save_svg(path, *scn, use_gamma=False)
def run(args):
th.manual_seed(0)
np.random.seed(0)
meta = ttools.Checkpointer.load_meta(args.model, "vect_g_")
if meta is None:
LOG.warning("Could not load metadata at %s, aborting.", args.model)
return
LOG.info("Loaded model %s with metadata:\n %s", args.model, meta)
if args.output_dir is None:
outdir = os.path.join(args.model, "eval")
else:
outdir = args.output_dir
os.makedirs(outdir, exist_ok=True)
model_params = meta["model_params"]
if args.imsize is not None:
LOG.info("Overriding output image size to: %dx%d", args.imsize,
args.imsize)
old_size = model_params["imsize"]
scale = args.imsize * 1.0 / old_size
model_params["imsize"] = args.imsize
model_params["stroke_width"] = [w*scale for w in
model_params["stroke_width"]]
LOG.info("Overriding width to: %s", model_params["stroke_width"])
# task = meta["task"]
generator = meta["generator"]
if generator == "fc":
model = models.VectorGenerator(**model_params)
elif generator == "bezier_fc":
model = models.BezierVectorGenerator(**model_params)
elif generator in ["rnn"]:
model = models.RNNVectorGenerator(**model_params)
elif generator in ["chain_rnn"]:
model = models.ChainRNNVectorGenerator(**model_params)
else:
raise NotImplementedError()
model.eval()
device = "cpu"
if th.cuda.is_available():
device = "cuda"
model.to(device)
checkpointer = ttools.Checkpointer(
args.model, model, meta=meta, prefix="vect_g_")
checkpointer.load_latest()
LOG.info("Computing latent space interpolation")
for i in range(args.nsamples):
z0 = model.sample_z(1)
z1 = model.sample_z(1)
# interpolation
alpha = th.linspace(0, 1, args.nsteps).view(args.nsteps, 1).to(device)
alpha_video = th.linspace(0, 1, args.nframes).view(args.nframes, 1)
alpha_video = alpha_video.to(device)
length = [args.nsteps, args.nframes]
for idx, a in enumerate([alpha, alpha_video]):
_z0 = z0.repeat(length[idx], 1).to(device)
_z1 = z1.repeat(length[idx], 1).to(device)
batch = _z0*(1-a) + _z1*a
out = model(batch)
if idx == 0: # image viz
n, c, h, w = out.shape
out = out.permute(1, 2, 0, 3)
out = out.contiguous().view(1, c, h, w*n)
out = postprocess(out, invert=args.invert)
imsave(out, os.path.join(outdir,
"latent_interp", "%03d.png" % i))
scenes = model.get_vector(batch)
for scn_idx, scn in enumerate(scenes):
save_scene(scn, os.path.join(outdir, "latent_interp_svg",
"%03d" % i, "%03d.svg" %
scn_idx))
else: # video viz
anim_root = os.path.join(outdir,
"latent_interp_video", "%03d" % i)
LOG.info("Rendering animation %d", i)
for frame_idx, frame in enumerate(out):
LOG.info("frame %d", frame_idx)
frame = frame.unsqueeze(0)
frame = postprocess(frame, invert=args.invert)
imsave(frame, os.path.join(anim_root,
"frame%04d.png" % frame_idx))
call(["ffmpeg", "-framerate", "30", "-i",
os.path.join(anim_root, "frame%04d.png"), "-vb", "20M",
os.path.join(outdir,
"latent_interp_video", "%03d.mp4" % i)])
LOG.info(" saved %d", i)
LOG.info("Sampling latent space")
for i in range(args.nsamples):
n = 8
bs = n*n
z = model.sample_z(bs).to(device)
out = model(z)
_, c, h, w = out.shape
out = out.view(n, n, c, h, w).permute(2, 0, 3, 1, 4)
out = out.contiguous().view(1, c, h*n, w*n)
out = postprocess(out)
imsave(out, os.path.join(outdir, "samples_%03d.png" % i))
LOG.info(" saved %d", i)
LOG.info("output images saved to %s", outdir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("model")
parser.add_argument("--output_dir", help="output directory for "
" the samples. Defaults to the model's path")
parser.add_argument("--nsamples", default=16, type=int,
help="number of output to compute")
parser.add_argument("--imsize", type=int,
help="if provided, override the raster output "
"resolution")
parser.add_argument("--nsteps", default=9, type=int, help="number of "
"interpolation steps for the interpolation")
parser.add_argument("--nframes", default=120, type=int, help="number of "
"frames for the interpolation video")
parser.add_argument("--invert", default=False, action="store_true",
help="if True, render black on white rather than the"
" opposite")
args = parser.parse_args()
pydiffvg.set_use_gpu(False)
ttools.set_logger(False)
run(args)

View File

@@ -0,0 +1,99 @@
"""Losses for the generative models and baselines."""
import torch as th
import numpy as np
import ttools.modules.image_operators as imops
class KLDivergence(th.nn.Module):
"""
Args:
min_value(float): the loss is clipped so that value below this
number don't affect the optimization.
"""
def __init__(self, min_value=0.2):
super(KLDivergence, self).__init__()
self.min_value = min_value
def forward(self, mu, log_sigma):
loss = -0.5 * (1.0 + log_sigma - mu.pow(2) - log_sigma.exp())
loss = loss.mean()
loss = th.max(loss, self.min_value*th.ones_like(loss))
return loss
class MultiscaleMSELoss(th.nn.Module):
def __init__(self, channels=3):
super(MultiscaleMSELoss, self).__init__()
self.blur = imops.GaussianBlur(1, channels=channels)
def forward(self, im, target):
bs, c, h, w = im.shape
num_levels = max(int(np.ceil(np.log2(h))) - 2, 1)
losses = []
for lvl in range(num_levels):
loss = th.nn.functional.mse_loss(im, target)
losses.append(loss)
im = th.nn.functional.interpolate(self.blur(im),
scale_factor=0.5,
mode="nearest")
target = th.nn.functional.interpolate(self.blur(target),
scale_factor=0.5,
mode="nearest")
losses = th.stack(losses)
return losses.sum()
def gaussian_pdfs(dx, dy, params):
"""Returns the pdf at (dx, dy) for each Gaussian in the mixture.
"""
dx = dx.unsqueeze(-1) # replicate dx, dy to evaluate all pdfs at once
dy = dy.unsqueeze(-1)
mu_x = params[..., 0]
mu_y = params[..., 1]
sigma_x = params[..., 2].exp()
sigma_y = params[..., 3].exp()
rho_xy = th.tanh(params[..., 4])
x = ((dx-mu_x) / sigma_x).pow(2)
y = ((dy-mu_y) / sigma_y).pow(2)
xy = (dx-mu_x)*(dy-mu_y) / (sigma_x * sigma_y)
arg = x + y - 2.0*rho_xy*xy
pdf = th.exp(-arg / (2*(1.0 - rho_xy.pow(2))))
norm = 2.0 * np.pi * sigma_x * sigma_y * (1.0 - rho_xy.pow(2)).sqrt()
return pdf / norm
class GaussianMixtureReconstructionLoss(th.nn.Module):
"""
Args:
"""
def __init__(self, eps=1e-5):
super(GaussianMixtureReconstructionLoss, self).__init__()
self.eps = eps
def forward(self, pen_logits, mixture_logits, gaussian_params, targets):
dx = targets[..., 0]
dy = targets[..., 1]
pen_state = targets[..., 2:].argmax(-1) # target index
# Likelihood loss on the stroke position
# No need to predict accurate pen position for end-of-sequence tokens
valid_stroke = (targets[..., -1] != 1.0).float()
mixture_weights = th.nn.functional.softmax(mixture_logits, -1)
pdfs = gaussian_pdfs(dx, dy, gaussian_params)
position_loss = - th.log(self.eps + (pdfs * mixture_weights).sum(-1))
# by actual non-empty count
position_loss = (position_loss*valid_stroke).sum() / valid_stroke.sum()
# Classification loss for the stroke mode
pen_loss = th.nn.functional.cross_entropy(pen_logits.view(-1, 3),
pen_state.view(-1))
return position_loss + pen_loss

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,484 @@
"""Collection of generative models."""
import torch as th
import ttools
import rendering
import modules
LOG = ttools.get_logger(__name__)
class BaseModel(th.nn.Module):
def sample_z(self, bs, device="cpu"):
return th.randn(bs, self.zdim).to(device)
class BaseVectorModel(BaseModel):
def get_vector(self, z):
_, scenes = self._forward(z)
return scenes
def _forward(self, x):
raise NotImplementedError()
def forward(self, z):
# Only return the raster
return self._forward(z)[0]
class BezierVectorGenerator(BaseVectorModel):
NUM_SEGMENTS = 2
def __init__(self, num_strokes=4,
zdim=128, width=32, imsize=32,
color_output=False,
stroke_width=None):
super(BezierVectorGenerator, self).__init__()
if stroke_width is None:
self.stroke_width = (0.5, 3.0)
LOG.warning("Setting default stroke with %s", self.stroke_width)
else:
self.stroke_width = stroke_width
self.imsize = imsize
self.num_strokes = num_strokes
self.zdim = zdim
self.trunk = th.nn.Sequential(
th.nn.Linear(zdim, width),
th.nn.SELU(inplace=True),
th.nn.Linear(width, 2*width),
th.nn.SELU(inplace=True),
th.nn.Linear(2*width, 4*width),
th.nn.SELU(inplace=True),
th.nn.Linear(4*width, 8*width),
th.nn.SELU(inplace=True),
)
# 4 points bezier with n_segments -> 3*n_segments + 1 points
self.point_predictor = th.nn.Sequential(
th.nn.Linear(8*width,
2*self.num_strokes*(
BezierVectorGenerator.NUM_SEGMENTS*3 + 1)),
th.nn.Tanh() # bound spatial extent
)
self.width_predictor = th.nn.Sequential(
th.nn.Linear(8*width, self.num_strokes),
th.nn.Sigmoid()
)
self.alpha_predictor = th.nn.Sequential(
th.nn.Linear(8*width, self.num_strokes),
th.nn.Sigmoid()
)
self.color_predictor = None
if color_output:
self.color_predictor = th.nn.Sequential(
th.nn.Linear(8*width, 3*self.num_strokes),
th.nn.Sigmoid()
)
def _forward(self, z):
bs = z.shape[0]
feats = self.trunk(z)
all_points = self.point_predictor(feats)
all_alphas = self.alpha_predictor(feats)
if self.color_predictor:
all_colors = self.color_predictor(feats)
all_colors = all_colors.view(bs, self.num_strokes, 3)
else:
all_colors = None
all_widths = self.width_predictor(feats)
min_width = self.stroke_width[0]
max_width = self.stroke_width[1]
all_widths = (max_width - min_width) * all_widths + min_width
all_points = all_points.view(
bs, self.num_strokes, BezierVectorGenerator.NUM_SEGMENTS*3+1, 2)
output, scenes = rendering.bezier_render(all_points, all_widths, all_alphas,
colors=all_colors,
canvas_size=self.imsize)
# map to [-1, 1]
output = output*2.0 - 1.0
return output, scenes
class VectorGenerator(BaseVectorModel):
def __init__(self, num_strokes=4,
zdim=128, width=32, imsize=32,
color_output=False,
stroke_width=None):
super(VectorGenerator, self).__init__()
if stroke_width is None:
self.stroke_width = (0.5, 3.0)
LOG.warning("Setting default stroke with %s", self.stroke_width)
else:
self.stroke_width = stroke_width
self.imsize = imsize
self.num_strokes = num_strokes
self.zdim = zdim
self.trunk = th.nn.Sequential(
th.nn.Linear(zdim, width),
th.nn.SELU(inplace=True),
th.nn.Linear(width, 2*width),
th.nn.SELU(inplace=True),
th.nn.Linear(2*width, 4*width),
th.nn.SELU(inplace=True),
th.nn.Linear(4*width, 8*width),
th.nn.SELU(inplace=True),
)
# straight lines so n_segments -> n_segments - 1 points
self.point_predictor = th.nn.Sequential(
th.nn.Linear(8*width, 2*(self.num_strokes*2)),
th.nn.Tanh() # bound spatial extent
)
self.width_predictor = th.nn.Sequential(
th.nn.Linear(8*width, self.num_strokes),
th.nn.Sigmoid()
)
self.alpha_predictor = th.nn.Sequential(
th.nn.Linear(8*width, self.num_strokes),
th.nn.Sigmoid()
)
self.color_predictor = None
if color_output:
self.color_predictor = th.nn.Sequential(
th.nn.Linear(8*width, 3*self.num_strokes),
th.nn.Sigmoid()
)
def _forward(self, z):
bs = z.shape[0]
feats = self.trunk(z)
all_points = self.point_predictor(feats)
all_alphas = self.alpha_predictor(feats)
if self.color_predictor:
all_colors = self.color_predictor(feats)
all_colors = all_colors.view(bs, self.num_strokes, 3)
else:
all_colors = None
all_widths = self.width_predictor(feats)
min_width = self.stroke_width[0]
max_width = self.stroke_width[1]
all_widths = (max_width - min_width) * all_widths + min_width
all_points = all_points.view(bs, self.num_strokes, 2, 2)
output, scenes = rendering.line_render(all_points, all_widths, all_alphas,
colors=all_colors,
canvas_size=self.imsize)
# map to [-1, 1]
output = output*2.0 - 1.0
return output, scenes
class RNNVectorGenerator(BaseVectorModel):
def __init__(self, num_strokes=64,
zdim=128, width=32, imsize=32,
hidden_size=512, dropout=0.9,
color_output=False,
num_layers=3, stroke_width=None):
super(RNNVectorGenerator, self).__init__()
if stroke_width is None:
self.stroke_width = (0.5, 3.0)
LOG.warning("Setting default stroke with %s", self.stroke_width)
else:
self.stroke_width = stroke_width
self.num_layers = num_layers
self.imsize = imsize
self.num_strokes = num_strokes
self.hidden_size = hidden_size
self.zdim = zdim
self.hidden_cell_predictor = th.nn.Linear(
zdim, 2*hidden_size*num_layers)
self.lstm = th.nn.LSTM(
zdim, hidden_size,
num_layers=self.num_layers, dropout=dropout,
batch_first=True)
# straight lines so n_segments -> n_segments - 1 points
self.point_predictor = th.nn.Sequential(
th.nn.Linear(hidden_size, 2*2), # 2 points, (x,y)
th.nn.Tanh() # bound spatial extent
)
self.width_predictor = th.nn.Sequential(
th.nn.Linear(hidden_size, 1),
th.nn.Sigmoid()
)
self.alpha_predictor = th.nn.Sequential(
th.nn.Linear(hidden_size, 1),
th.nn.Sigmoid()
)
def _forward(self, z, hidden_and_cell=None):
steps = self.num_strokes
# z is passed at each step, duplicate it
bs = z.shape[0]
expanded_z = z.unsqueeze(1).repeat(1, steps, 1)
# First step in the RNN
if hidden_and_cell is None:
# Initialize from latent vector
hidden_and_cell = self.hidden_cell_predictor(th.tanh(z))
hidden = hidden_and_cell[:, :self.hidden_size*self.num_layers]
hidden = hidden.view(-1, self.num_layers, self.hidden_size)
hidden = hidden.permute(1, 0, 2).contiguous()
cell = hidden_and_cell[:, self.hidden_size*self.num_layers:]
cell = cell.view(-1, self.num_layers, self.hidden_size)
cell = cell.permute(1, 0, 2).contiguous()
hidden_and_cell = (hidden, cell)
feats, hidden_and_cell = self.lstm(expanded_z, hidden_and_cell)
hidden, cell = hidden_and_cell
feats = feats.reshape(bs*steps, self.hidden_size)
all_points = self.point_predictor(feats).view(bs, steps, 2, 2)
all_alphas = self.alpha_predictor(feats).view(bs, steps)
all_widths = self.width_predictor(feats).view(bs, steps)
min_width = self.stroke_width[0]
max_width = self.stroke_width[1]
all_widths = (max_width - min_width) * all_widths + min_width
output, scenes = rendering.line_render(all_points, all_widths, all_alphas,
canvas_size=self.imsize)
# map to [-1, 1]
output = output*2.0 - 1.0
return output, scenes
class ChainRNNVectorGenerator(BaseVectorModel):
"""Strokes form a single long chain."""
def __init__(self, num_strokes=64,
zdim=128, width=32, imsize=32,
hidden_size=512, dropout=0.9,
color_output=False,
num_layers=3, stroke_width=None):
super(ChainRNNVectorGenerator, self).__init__()
if stroke_width is None:
self.stroke_width = (0.5, 3.0)
LOG.warning("Setting default stroke with %s", self.stroke_width)
else:
self.stroke_width = stroke_width
self.num_layers = num_layers
self.imsize = imsize
self.num_strokes = num_strokes
self.hidden_size = hidden_size
self.zdim = zdim
self.hidden_cell_predictor = th.nn.Linear(
zdim, 2*hidden_size*num_layers)
self.lstm = th.nn.LSTM(
zdim, hidden_size,
num_layers=self.num_layers, dropout=dropout,
batch_first=True)
# straight lines so n_segments -> n_segments - 1 points
self.point_predictor = th.nn.Sequential(
th.nn.Linear(hidden_size, 2), # 1 point, (x,y)
th.nn.Tanh() # bound spatial extent
)
self.width_predictor = th.nn.Sequential(
th.nn.Linear(hidden_size, 1),
th.nn.Sigmoid()
)
self.alpha_predictor = th.nn.Sequential(
th.nn.Linear(hidden_size, 1),
th.nn.Sigmoid()
)
def _forward(self, z, hidden_and_cell=None):
steps = self.num_strokes
# z is passed at each step, duplicate it
bs = z.shape[0]
expanded_z = z.unsqueeze(1).repeat(1, steps, 1)
# First step in the RNN
if hidden_and_cell is None:
# Initialize from latent vector
hidden_and_cell = self.hidden_cell_predictor(th.tanh(z))
hidden = hidden_and_cell[:, :self.hidden_size*self.num_layers]
hidden = hidden.view(-1, self.num_layers, self.hidden_size)
hidden = hidden.permute(1, 0, 2).contiguous()
cell = hidden_and_cell[:, self.hidden_size*self.num_layers:]
cell = cell.view(-1, self.num_layers, self.hidden_size)
cell = cell.permute(1, 0, 2).contiguous()
hidden_and_cell = (hidden, cell)
feats, hidden_and_cell = self.lstm(expanded_z, hidden_and_cell)
hidden, cell = hidden_and_cell
feats = feats.reshape(bs*steps, self.hidden_size)
# Construct the chain
end_points = self.point_predictor(feats).view(bs, steps, 1, 2)
start_points = th.cat([
# first point is canvas center
th.zeros(bs, 1, 1, 2, device=feats.device),
end_points[:, 1:, :, :]], 1)
all_points = th.cat([start_points, end_points], 2)
all_alphas = self.alpha_predictor(feats).view(bs, steps)
all_widths = self.width_predictor(feats).view(bs, steps)
min_width = self.stroke_width[0]
max_width = self.stroke_width[1]
all_widths = (max_width - min_width) * all_widths + min_width
output, scenes = rendering.line_render(all_points, all_widths, all_alphas,
canvas_size=self.imsize)
# map to [-1, 1]
output = output*2.0 - 1.0
return output, scenes
class Generator(BaseModel):
def __init__(self, width=64, imsize=32, zdim=128,
stroke_width=None,
color_output=False,
num_strokes=4):
super(Generator, self).__init__()
assert imsize == 32
self.imsize = imsize
self.zdim = zdim
num_in_chans = self.zdim // (2*2)
num_out_chans = 3 if color_output else 1
self.net = th.nn.Sequential(
th.nn.ConvTranspose2d(num_in_chans, width*8, 4, padding=1,
stride=2),
th.nn.LeakyReLU(0.2, inplace=True),
th.nn.Conv2d(width*8, width*8, 3, padding=1),
th.nn.BatchNorm2d(width*8),
th.nn.LeakyReLU(0.2, inplace=True),
# 4x4
th.nn.ConvTranspose2d(8*width, 4*width, 4, padding=1, stride=2),
th.nn.LeakyReLU(0.2, inplace=True),
th.nn.Conv2d(4*width, 4*width, 3, padding=1),
th.nn.BatchNorm2d(width*4),
th.nn.LeakyReLU(0.2, inplace=True),
# 8x8
th.nn.ConvTranspose2d(4*width, 2*width, 4, padding=1, stride=2),
th.nn.LeakyReLU(0.2, inplace=True),
th.nn.Conv2d(2*width, 2*width, 3, padding=1),
th.nn.BatchNorm2d(width*2),
th.nn.LeakyReLU(0.2, inplace=True),
# 16x16
th.nn.ConvTranspose2d(2*width, width, 4, padding=1, stride=2),
th.nn.LeakyReLU(0.2, inplace=True),
th.nn.Conv2d(width, width, 3, padding=1),
th.nn.BatchNorm2d(width),
th.nn.LeakyReLU(0.2, inplace=True),
# 32x32
th.nn.Conv2d(width, width, 3, padding=1),
th.nn.BatchNorm2d(width),
th.nn.LeakyReLU(0.2, inplace=True),
th.nn.Conv2d(width, width, 3, padding=1),
th.nn.LeakyReLU(0.2, inplace=True),
th.nn.Conv2d(width, num_out_chans, 1),
th.nn.Tanh(),
)
def forward(self, z):
bs = z.shape[0]
num_in_chans = self.zdim // (2*2)
raster = self.net(z.view(bs, num_in_chans, 2, 2))
return raster
class Discriminator(th.nn.Module):
def __init__(self, conditional=False, width=64, color_output=False):
super(Discriminator, self).__init__()
self.conditional = conditional
sn = th.nn.utils.spectral_norm
num_chan_in = 3 if color_output else 1
self.net = th.nn.Sequential(
th.nn.Conv2d(num_chan_in, width, 3, padding=1),
th.nn.LeakyReLU(0.2, inplace=True),
th.nn.Conv2d(width, 2*width, 4, padding=1, stride=2),
th.nn.LeakyReLU(0.2, inplace=True),
# 16x16
sn(th.nn.Conv2d(2*width, 2*width, 3, padding=1)),
th.nn.LeakyReLU(0.2, inplace=True),
sn(th.nn.Conv2d(2*width, 4*width, 4, padding=1, stride=2)),
th.nn.LeakyReLU(0.2, inplace=True),
# 8x8
sn(th.nn.Conv2d(4*width, 4*width, 3, padding=1)),
th.nn.LeakyReLU(0.2, inplace=True),
sn(th.nn.Conv2d(4*width, width*4, 4, padding=1, stride=2)),
th.nn.LeakyReLU(0.2, inplace=True),
# 4x4
sn(th.nn.Conv2d(4*width, 4*width, 3, padding=1)),
th.nn.LeakyReLU(0.2, inplace=True),
sn(th.nn.Conv2d(4*width, width*4, 4, padding=1, stride=2)),
th.nn.LeakyReLU(0.2, inplace=True),
# 2x2
modules.Flatten(),
th.nn.Linear(width*4*2*2, 1),
)
def forward(self, x):
out = self.net(x)
return out

View File

@@ -0,0 +1,11 @@
"""Helper modules to build our networks."""
import torch as th
class Flatten(th.nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
bs = x.shape[0]
return x.view(bs, -1)

View File

@@ -0,0 +1,307 @@
import os
import torch as th
import torch.multiprocessing as mp
import threading as mt
import numpy as np
import random
import ttools
import pydiffvg
import time
def render(canvas_width, canvas_height, shapes, shape_groups, samples=2,
seed=None):
if seed is None:
seed = random.randint(0, 1000000)
_render = pydiffvg.RenderFunction.apply
scene_args = pydiffvg.RenderFunction.serialize_scene(
canvas_width, canvas_height, shapes, shape_groups)
img = _render(canvas_width, canvas_height, samples, samples,
seed, # seed
None, # background image
*scene_args)
return img
def opacityStroke2diffvg(strokes, canvas_size=128, debug=False, relative=True,
force_cpu=True):
dev = strokes.device
if force_cpu:
strokes = strokes.to("cpu")
# pydiffvg.set_use_gpu(False)
# if strokes.is_cuda:
# pydiffvg.set_use_gpu(True)
"""Rasterize strokes given in (dx, dy, opacity) sequence format."""
bs, nsegs, dims = strokes.shape
out = []
start = time.time()
for batch_idx, stroke in enumerate(strokes):
if relative: # Absolute coordinates
all_points = stroke[..., :2].cumsum(0)
else:
all_points = stroke[..., :2]
all_opacities = stroke[..., 2]
# Transform from [-1, 1] to canvas coordinates
# Make sure points are in canvas
all_points = 0.5*(all_points + 1.0) * canvas_size
# all_points = th.clamp(0.5*(all_points + 1.0), 0, 1) * canvas_size
# Avoid overlapping points
eps = 1e-4
all_points = all_points + eps*th.randn_like(all_points)
shapes = []
shape_groups = []
for start_idx in range(0, nsegs-1):
points = all_points[start_idx:start_idx+2].contiguous().float()
opacity = all_opacities[start_idx]
num_ctrl_pts = th.zeros(points.shape[0] - 1, dtype=th.int32)
width = th.ones(1)
path = pydiffvg.Path(
num_control_points=num_ctrl_pts, points=points,
stroke_width=width, is_closed=False)
shapes.append(path)
color = th.cat([th.ones(3, device=opacity.device),
opacity.unsqueeze(0)], 0)
path_group = pydiffvg.ShapeGroup(
shape_ids=th.tensor([len(shapes) - 1]),
fill_color=None,
stroke_color=color)
shape_groups.append(path_group)
# Rasterize only if there are shapes
if shapes:
inner_start = time.time()
out.append(render(canvas_size, canvas_size, shapes, shape_groups,
samples=4))
if debug:
inner_elapsed = time.time() - inner_start
print("diffvg call took %.2fms" % inner_elapsed)
else:
out.append(th.zeros(canvas_size, canvas_size, 4,
device=strokes.device))
if debug:
elapsed = (time.time() - start)*1000
print("rendering took %.2fms" % elapsed)
images = th.stack(out, 0).permute(0, 3, 1, 2).contiguous()
# Return data on the same device as input
return images.to(dev)
def stroke2diffvg(strokes, canvas_size=128):
"""Rasterize strokes given some sequential data."""
bs, nsegs, dims = strokes.shape
out = []
for stroke_idx, stroke in enumerate(strokes):
end_of_stroke = stroke[:, 4] == 1
last = end_of_stroke.cpu().numpy().argmax()
stroke = stroke[:last+1, :]
# stroke = stroke[~end_of_stroke]
# TODO: stop at the first end of stroke
# import ipdb; ipdb.set_trace()
split_idx = stroke[:, 3].nonzero().squeeze(1)
# Absolute coordinates
all_points = stroke[..., :2].cumsum(0)
# Transform to canvas coordinates
all_points[..., 0] += 0.5
all_points[..., 0] *= canvas_size
all_points[..., 1] += 0.5
all_points[..., 1] *= canvas_size
# Make sure points are in canvas
all_points[..., :2] = th.clamp(all_points[..., :2], 0, canvas_size)
shape_groups = []
shapes = []
start_idx = 0
for count, end_idx in enumerate(split_idx):
points = all_points[start_idx:end_idx+1].contiguous().float()
if points.shape[0] <= 2: # we need at least 2 points for a line
continue
num_ctrl_pts = th.zeros(points.shape[0] - 1, dtype=th.int32)
width = th.ones(1)
path = pydiffvg.Path(
num_control_points=num_ctrl_pts, points=points,
stroke_width=width, is_closed=False)
start_idx = end_idx+1
shapes.append(path)
color = th.ones(4, 1)
path_group = pydiffvg.ShapeGroup(
shape_ids=th.tensor([len(shapes) - 1]),
fill_color=None,
stroke_color=color)
shape_groups.append(path_group)
# Rasterize
if shapes:
# draw only if there are shapes
out.append(render(canvas_size, canvas_size, shapes, shape_groups, samples=2))
else:
out.append(th.zeros(canvas_size, canvas_size, 4,
device=strokes.device))
return th.stack(out, 0).permute(0, 3, 1, 2)[:, :3].contiguous()
def line_render(all_points, all_widths, all_alphas, force_cpu=True,
canvas_size=32, colors=None):
dev = all_points.device
if force_cpu:
all_points = all_points.to("cpu")
all_widths = all_widths.to("cpu")
all_alphas = all_alphas.to("cpu")
if colors is not None:
colors = colors.to("cpu")
all_points = 0.5*(all_points + 1.0) * canvas_size
eps = 1e-4
all_points = all_points + eps*th.randn_like(all_points)
bs, num_segments, _, _ = all_points.shape
n_out = 3 if colors is not None else 1
output = th.zeros(bs, n_out, canvas_size, canvas_size,
device=all_points.device)
scenes = []
for k in range(bs):
shapes = []
shape_groups = []
for p in range(num_segments):
points = all_points[k, p].contiguous().cpu()
num_ctrl_pts = th.zeros(1, dtype=th.int32)
width = all_widths[k, p].cpu()
alpha = all_alphas[k, p].cpu()
if colors is not None:
color = colors[k, p]
else:
color = th.ones(3, device=alpha.device)
color = th.cat([color, alpha.view(1,)])
path = pydiffvg.Path(
num_control_points=num_ctrl_pts, points=points,
stroke_width=width, is_closed=False)
shapes.append(path)
path_group = pydiffvg.ShapeGroup(
shape_ids=th.tensor([len(shapes) - 1]),
fill_color=None,
stroke_color=color)
shape_groups.append(path_group)
# Rasterize
scenes.append((canvas_size, canvas_size, shapes, shape_groups))
raster = render(canvas_size, canvas_size, shapes, shape_groups,
samples=2)
raster = raster.permute(2, 0, 1).view(4, canvas_size, canvas_size)
alpha = raster[3:4]
if colors is not None: # color output
image = raster[:3]
alpha = alpha.repeat(3, 1, 1)
else:
image = raster[:1]
# alpha compositing
image = image*alpha
output[k] = image
output = output.to(dev)
return output, scenes
def bezier_render(all_points, all_widths, all_alphas, force_cpu=True,
canvas_size=32, colors=None):
dev = all_points.device
if force_cpu:
all_points = all_points.to("cpu")
all_widths = all_widths.to("cpu")
all_alphas = all_alphas.to("cpu")
if colors is not None:
colors = colors.to("cpu")
all_points = 0.5*(all_points + 1.0) * canvas_size
eps = 1e-4
all_points = all_points + eps*th.randn_like(all_points)
bs, num_strokes, num_pts, _ = all_points.shape
num_segments = (num_pts - 1) // 3
n_out = 3 if colors is not None else 1
output = th.zeros(bs, n_out, canvas_size, canvas_size,
device=all_points.device)
scenes = []
for k in range(bs):
shapes = []
shape_groups = []
for p in range(num_strokes):
points = all_points[k, p].contiguous().cpu()
# bezier
num_ctrl_pts = th.zeros(num_segments, dtype=th.int32) + 2
width = all_widths[k, p].cpu()
alpha = all_alphas[k, p].cpu()
if colors is not None:
color = colors[k, p]
else:
color = th.ones(3, device=alpha.device)
color = th.cat([color, alpha.view(1,)])
path = pydiffvg.Path(
num_control_points=num_ctrl_pts, points=points,
stroke_width=width, is_closed=False)
shapes.append(path)
path_group = pydiffvg.ShapeGroup(
shape_ids=th.tensor([len(shapes) - 1]),
fill_color=None,
stroke_color=color)
shape_groups.append(path_group)
# Rasterize
scenes.append((canvas_size, canvas_size, shapes, shape_groups))
raster = render(canvas_size, canvas_size, shapes, shape_groups,
samples=2)
raster = raster.permute(2, 0, 1).view(4, canvas_size, canvas_size)
alpha = raster[3:4]
if colors is not None: # color output
image = raster[:3]
alpha = alpha.repeat(3, 1, 1)
else:
image = raster[:1]
# alpha compositing
image = image*alpha
output[k] = image
output = output.to(dev)
return output, scenes

View File

@@ -0,0 +1,461 @@
#!/bin/env python
"""Train a Sketch-RNN."""
import argparse
from enum import Enum
import os
import wget
import numpy as np
import torch as th
from torch.utils.data import DataLoader
import torchvision.datasets as dset
import torchvision.transforms as transforms
import ttools
import ttools.interfaces
from ttools.modules import networks
import pydiffvg
import rendering
import losses
import data
LOG = ttools.get_logger(__name__)
BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir)
OUTPUT = os.path.join(BASE_DIR, "results", "sketch_rnn_diffvg")
OUTPUT_BASELINE = os.path.join(BASE_DIR, "results", "sketch_rnn")
class SketchRNN(th.nn.Module):
class Encoder(th.nn.Module):
def __init__(self, hidden_size=512, dropout=0.9, zdim=128,
num_layers=1):
super(SketchRNN.Encoder, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.zdim = zdim
self.lstm = th.nn.LSTM(5, hidden_size, num_layers=self.num_layers,
dropout=dropout, bidirectional=True,
batch_first=True)
# bidirectional model -> *2
self.mu_predictor = th.nn.Linear(2*hidden_size, zdim)
self.sigma_predictor = th.nn.Linear(2*hidden_size, zdim)
def forward(self, sequences, hidden_and_cell=None):
bs = sequences.shape[0]
if hidden_and_cell is None:
hidden = th.zeros(self.num_layers*2, bs, self.hidden_size).to(
sequences.device)
cell = th.zeros(self.num_layers*2, bs, self.hidden_size).to(
sequences.device)
hidden_and_cell = (hidden, cell)
out, hidden_and_cell = self.lstm(sequences, hidden_and_cell)
hidden = hidden_and_cell[0]
# Concat the forward/backward states
fc_input = th.cat([hidden[0], hidden[1]], 1)
# VAE params
mu = self.mu_predictor(fc_input)
log_sigma = self.sigma_predictor(fc_input)
# Sample a latent vector
sigma = th.exp(log_sigma/2.0)
z0 = th.randn(self.zdim, device=mu.device)
z = mu + sigma*z0
# KL divergence needs mu/sigma
return z, mu, log_sigma
class Decoder(th.nn.Module):
"""
The decoder outputs a sequence where each time step models (dx, dy) as
a mixture of `num_gaussians` 2D Gaussians and the state triplet is a
categorical distribution.
The model outputs at each time step:
- 5 parameters for each Gaussian: mu_x, mu_y, sigma_x, sigma_y,
rho_xy
- 1 logit for each Gaussian (the mixture weight)
- 3 logits for the state triplet probabilities
"""
def __init__(self, hidden_size=512, dropout=0.9, zdim=128,
num_layers=1, num_gaussians=20):
super(SketchRNN.Decoder, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.zdim = zdim
self.num_gaussians = num_gaussians
# Maps the latent vector to an initial cell/hidden vector
self.hidden_cell_predictor = th.nn.Linear(zdim, 2*hidden_size)
self.lstm = th.nn.LSTM(
5 + zdim, hidden_size,
num_layers=self.num_layers, dropout=dropout,
batch_first=True)
self.parameters_predictor = th.nn.Linear(
hidden_size, num_gaussians + 5*num_gaussians + 3)
def forward(self, inputs, z, hidden_and_cell=None):
# Every step in the sequence takes the latent vector as input so we
# replicate it here
expanded_z = z.unsqueeze(1).repeat(1, inputs.shape[1], 1)
inputs = th.cat([inputs, expanded_z], 2)
bs, steps = inputs.shape[:2]
if hidden_and_cell is None:
# Initialize from latent vector
hidden_and_cell = self.hidden_cell_predictor(th.tanh(z))
hidden = hidden_and_cell[:, :self.hidden_size]
hidden = hidden.unsqueeze(0).contiguous()
cell = hidden_and_cell[:, self.hidden_size:]
cell = cell.unsqueeze(0).contiguous()
hidden_and_cell = (hidden, cell)
outputs, hidden_and_cell = self.lstm(inputs, hidden_and_cell)
hidden, cell = hidden_and_cell
# if self.training:
# At train time we want parameters for each time step
outputs = outputs.reshape(bs*steps, self.hidden_size)
params = self.parameters_predictor(outputs).view(bs, steps, -1)
pen_logits = params[..., -3:]
gaussian_params = params[..., :-3]
mixture_logits = gaussian_params[..., :self.num_gaussians]
gaussian_params = gaussian_params[..., self.num_gaussians:].view(
bs, steps, self.num_gaussians, -1)
return pen_logits, mixture_logits, gaussian_params, hidden_and_cell
def __init__(self, zdim=128, num_gaussians=20, encoder_dim=256,
decoder_dim=512):
super(SketchRNN, self).__init__()
self.encoder = SketchRNN.Encoder(zdim=zdim, hidden_size=encoder_dim)
self.decoder = SketchRNN.Decoder(zdim=zdim, hidden_size=decoder_dim,
num_gaussians=num_gaussians)
def forward(self, sequences):
# Encode the sequences as latent vectors
# We skip the first time step since it is the same for all sequences:
# (0, 0, 1, 0, 0)
z, mu, log_sigma = self.encoder(sequences[:, 1:])
# Decode the latent vector into a model sequence
# Do not process the last time step (it is an end-of-sequence token)
pen_logits, mixture_logits, gaussian_params, hidden_and_cell = \
self.decoder(sequences[:, :-1], z)
return {
"pen_logits": pen_logits,
"mixture_logits": mixture_logits,
"gaussian_params": gaussian_params,
"z": z,
"mu": mu,
"log_sigma": log_sigma,
"hidden_and_cell": hidden_and_cell,
}
def sample(self, sequences, temperature=1.0):
# Compute a latent vector conditionned based on a real sequence
z, _, _ = self.encoder(sequences[:, 1:])
start_of_seq = sequences[:, :1]
max_steps = sequences.shape[1] - 1 # last step is an end-of-seq token
output_sequences = th.zeros_like(sequences)
output_sequences[:, 0] = start_of_seq.squeeze(1)
current_input = start_of_seq
hidden_and_cell = None
for step in range(max_steps):
pen_logits, mixture_logits, gaussian_params, hidden_and_cell = \
self.decoder(current_input, z, hidden_and_cell=hidden_and_cell)
# Pen and displacement state for the next step
next_state = th.zeros_like(current_input)
# Adjust temperature to control randomness
mixture_logits = mixture_logits*temperature
pen_logits = pen_logits*temperature
# Select one of 3 pen states
pen_distrib = \
th.distributions.categorical.Categorical(logits=pen_logits)
pen_state = pen_distrib.sample()
# One-hot encoding of the state
next_state[:, :, 2:].scatter_(2, pen_state.unsqueeze(-1),
th.ones_like(next_state[:, :, 2:]))
# Select one of the Gaussians from the mixture
mixture_distrib = \
th.distributions.categorical.Categorical(logits=mixture_logits)
mixture_idx = mixture_distrib.sample()
# select the Gaussian parameter
mixture_idx = mixture_idx.unsqueeze(-1).unsqueeze(-1)
mixture_idx = mixture_idx.repeat(1, 1, 1, 5)
params = th.gather(gaussian_params, 2, mixture_idx).squeeze(2)
# Sample a Gaussian from the corresponding Gaussian
mu = params[..., :2]
sigma_x = params[..., 2].exp()
sigma_y = params[..., 3].exp()
rho_xy = th.tanh(params[..., 4])
cov = th.zeros(params.shape[0], params.shape[1], 2, 2,
device=params.device)
cov[..., 0, 0] = sigma_x.pow(2)*temperature
cov[..., 1, 1] = sigma_x.pow(2)*temperature
cov[..., 1, 0] = sigma_x*sigma_y*rho_xy*temperature
point_distrib = \
th.distributions.multivariate_normal.MultivariateNormal(
mu, scale_tril=cov)
point = point_distrib.sample()
next_state[:, :, :2] = point
# Commit step to output
output_sequences[:, step + 1] = next_state.squeeze(1)
# Prepare next recurrent step
current_input = next_state
return output_sequences
class SketchRNNCallback(ttools.callbacks.ImageDisplayCallback):
"""Simple callback that visualize images."""
def visualized_image(self, batch, step_data, is_val=False):
if not is_val:
# No need to render training data
return None
with th.no_grad():
# only display the first n drawings
n = 8
batch = batch[:n]
out_im = rendering.stroke2diffvg(step_data["sample"][:n])
im = rendering.stroke2diffvg(batch)
im = th.cat([im, out_im], 2)
return im
def caption(self, batch, step_data, is_val=False):
if is_val:
return "top: truth, bottom: sample"
else:
return "top: truth, bottom: sample"
class Interface(ttools.ModelInterface):
def __init__(self, model, lr=1e-3, lr_decay=0.9999,
kl_weight=0.5, kl_min_weight=0.01, kl_decay=0.99995,
device="cpu", grad_clip=1.0, sampling_temperature=0.4):
super(Interface, self).__init__()
self.grad_clip = grad_clip
self.sampling_temperature = sampling_temperature
self.model = model
self.device = device
self.model.to(self.device)
self.enc_opt = th.optim.Adam(self.model.encoder.parameters(), lr=lr)
self.dec_opt = th.optim.Adam(self.model.decoder.parameters(), lr=lr)
self.kl_weight = kl_weight
self.kl_min_weight = kl_min_weight
self.kl_decay = kl_decay
self.kl_loss = losses.KLDivergence()
self.schedulers = [
th.optim.lr_scheduler.ExponentialLR(self.enc_opt, lr_decay),
th.optim.lr_scheduler.ExponentialLR(self.dec_opt, lr_decay),
]
self.reconstruction_loss = losses.GaussianMixtureReconstructionLoss()
def optimizers(self):
return [self.enc_opt, self.dec_opt]
def training_step(self, batch):
batch = batch.to(self.device)
out = self.model(batch)
kl_loss = self.kl_loss(
out["mu"], out["log_sigma"])
# The target to predict is the next sequence step
targets = batch[:, 1:].to(self.device)
# Scale the KL divergence weight
try:
state = self.enc_opt.state_dict()["param_groups"][0]["params"][0]
optim_step = self.enc_opt.state_dict()["state"][state]["step"]
except KeyError:
optim_step = 0 # no step taken yet
kl_scaling = 1.0 - (1.0 -
self.kl_min_weight)*(self.kl_decay**optim_step)
kl_weight = self.kl_weight * kl_scaling
reconstruction_loss = self.reconstruction_loss(
out["pen_logits"], out["mixture_logits"],
out["gaussian_params"], targets)
loss = kl_loss*self.kl_weight + reconstruction_loss
self.enc_opt.zero_grad()
self.dec_opt.zero_grad()
loss.backward()
# clip gradients
enc_nrm = th.nn.utils.clip_grad_norm_(
self.model.encoder.parameters(), self.grad_clip)
dec_nrm = th.nn.utils.clip_grad_norm_(
self.model.decoder.parameters(), self.grad_clip)
if enc_nrm > self.grad_clip:
LOG.debug("Clipped encoder gradient (%.5f) to %.2f",
enc_nrm, self.grad_clip)
if dec_nrm > self.grad_clip:
LOG.debug("Clipped decoder gradient (%.5f) to %.2f",
dec_nrm, self.grad_clip)
self.enc_opt.step()
self.dec_opt.step()
return {
"loss": loss.item(),
"kl_loss": kl_loss.item(),
"kl_weight": kl_weight,
"recons_loss": reconstruction_loss.item(),
"lr": self.enc_opt.param_groups[0]["lr"],
}
def init_validation(self):
return dict(sample=None)
def validation_step(self, batch, running_data):
# Switch to eval mode for dropout, batchnorm, etc
self.model.eval()
with th.no_grad():
sample = self.model.sample(
batch.to(self.device), temperature=self.sampling_temperature)
running_data["sample"] = sample
self.model.train()
return running_data
def train(args):
th.manual_seed(0)
np.random.seed(0)
dataset = data.QuickDrawDataset(args.dataset)
dataloader = DataLoader(
dataset, batch_size=args.bs, num_workers=4, shuffle=True,
pin_memory=False)
val_dataset = [s for idx, s in enumerate(dataset) if idx < 8]
val_dataloader = DataLoader(
val_dataset, batch_size=8, num_workers=4, shuffle=False,
pin_memory=False)
model_params = {
"zdim": args.zdim,
"num_gaussians": args.num_gaussians,
"encoder_dim": args.encoder_dim,
"decoder_dim": args.decoder_dim,
}
model = SketchRNN(**model_params)
model.train()
device = "cpu"
if th.cuda.is_available():
device = "cuda"
LOG.info("Using CUDA")
interface = Interface(model, lr=args.lr, lr_decay=args.lr_decay,
kl_decay=args.kl_decay, kl_weight=args.kl_weight,
sampling_temperature=args.sampling_temperature,
device=device)
chkpt = OUTPUT_BASELINE
env_name = "sketch_rnn"
# Resume from checkpoint, if any
checkpointer = ttools.Checkpointer(
chkpt, model, meta=model_params,
optimizers=interface.optimizers(),
schedulers=interface.schedulers)
extras, meta = checkpointer.load_latest()
epoch = extras["epoch"] if extras and "epoch" in extras.keys() else 0
if meta is not None and meta != model_params:
LOG.info("Checkpoint's metaparams differ "
"from CLI, aborting: %s and %s", meta, model_params)
trainer = ttools.Trainer(interface)
# Add callbacks
losses = ["loss", "kl_loss", "recons_loss"]
training_debug = ["lr", "kl_weight"]
trainer.add_callback(ttools.callbacks.ProgressBarCallback(
keys=losses, val_keys=None))
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
keys=losses, val_keys=None, env=env_name, port=args.port))
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
keys=training_debug, smoothing=0, val_keys=None, env=env_name,
port=args.port))
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
checkpointer, max_files=2, interval=600, max_epochs=10))
trainer.add_callback(
ttools.callbacks.LRSchedulerCallback(interface.schedulers))
trainer.add_callback(SketchRNNCallback(
env=env_name, win="samples", port=args.port, frequency=args.freq))
# Start training
trainer.train(dataloader, starting_epoch=epoch,
val_dataloader=val_dataloader,
num_epochs=args.num_epochs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", default="cat.npz")
# Training params
parser.add_argument("--bs", type=int, default=100)
parser.add_argument("--num_epochs", type=int, default=10000)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--lr_decay", type=float, default=0.9999)
parser.add_argument("--kl_weight", type=float, default=0.5)
parser.add_argument("--kl_decay", type=float, default=0.99995)
# Model configuration
parser.add_argument("--zdim", type=int, default=128)
parser.add_argument("--num_gaussians", type=int, default=20)
parser.add_argument("--encoder_dim", type=int, default=256)
parser.add_argument("--decoder_dim", type=int, default=512)
parser.add_argument("--sampling_temperature", type=float, default=0.4,
help="controls sampling randomness. "
"0.0: deterministic, 1.0: unchanged")
# Viz params
parser.add_argument("--freq", type=int, default=100)
parser.add_argument("--port", type=int, default=5000)
args = parser.parse_args()
pydiffvg.set_use_gpu(th.cuda.is_available())
train(args)

View File

@@ -0,0 +1,524 @@
#!/bin/env python
"""Train a Sketch-VAE."""
import argparse
from enum import Enum
import os
import wget
import time
import numpy as np
import torch as th
from torch.utils.data import DataLoader
import torchvision.datasets as dset
import torchvision.transforms as transforms
import ttools
import ttools.interfaces
from ttools.modules import networks
import rendering
import losses
import modules
import data
import pydiffvg
LOG = ttools.get_logger(__name__)
BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir)
OUTPUT = os.path.join(BASE_DIR, "results")
class SketchVAE(th.nn.Module):
class ImageEncoder(th.nn.Module):
def __init__(self, image_size=64, width=64, zdim=128):
super(SketchVAE.ImageEncoder, self).__init__()
self.zdim = zdim
self.net = th.nn.Sequential(
th.nn.Conv2d(4, width, 5, padding=2),
th.nn.InstanceNorm2d(width),
th.nn.ReLU(inplace=True),
# 64x64
th.nn.Conv2d(width, width, 5, padding=2),
th.nn.InstanceNorm2d(width),
th.nn.ReLU( inplace=True),
# 64x64
th.nn.Conv2d(width, 2*width, 5, stride=1, padding=2),
th.nn.InstanceNorm2d(2*width),
th.nn.ReLU( inplace=True),
# 32x32
th.nn.Conv2d(2*width, 2*width, 5, stride=2, padding=2),
th.nn.InstanceNorm2d(2*width),
th.nn.ReLU( inplace=True),
# 16x16
th.nn.Conv2d(2*width, 2*width, 5, stride=2, padding=2),
th.nn.InstanceNorm2d(2*width),
th.nn.ReLU( inplace=True),
# 16x16
th.nn.Conv2d(2*width, 2*width, 5, stride=2, padding=2),
th.nn.InstanceNorm2d(2*width),
th.nn.ReLU( inplace=True),
# 8x8
th.nn.Conv2d(2*width, 2*width, 5, stride=2, padding=2),
th.nn.InstanceNorm2d(2*width),
th.nn.ReLU( inplace=True),
# 4x4
modules.Flatten(),
th.nn.Linear(4*4*2*width, 2*zdim)
)
def forward(self, images):
features = self.net(images)
# VAE params
mu = features[:, :self.zdim]
log_sigma = features[:, self.zdim:]
# Sample a latent vector
sigma = th.exp(log_sigma/2.0)
z0 = th.randn(self.zdim, device=mu.device)
z = mu + sigma*z0
# KL divergence needs mu/sigma
return z, mu, log_sigma
class ImageDecoder(th.nn.Module):
""""""
def __init__(self, zdim=128, image_size=64, width=64):
super(SketchVAE.ImageDecoder, self).__init__()
self.zdim = zdim
self.width = width
self.embedding = th.nn.Linear(zdim, 4*4*2*width)
self.net = th.nn.Sequential(
th.nn.ConvTranspose2d(2*width, 2*width, 4, padding=1, stride=2),
th.nn.InstanceNorm2d(2*width),
th.nn.ReLU( inplace=True),
# 8x8
th.nn.ConvTranspose2d(2*width, 2*width, 4, padding=1, stride=2),
th.nn.InstanceNorm2d(2*width),
th.nn.ReLU( inplace=True),
# 16x16
th.nn.ConvTranspose2d(2*width, 2*width, 4, padding=1, stride=2),
th.nn.InstanceNorm2d(2*width),
th.nn.ReLU( inplace=True),
# 16x16
th.nn.Conv2d(2*width, 2*width, 5, padding=2, stride=1),
th.nn.InstanceNorm2d(2*width),
th.nn.ReLU( inplace=True),
# 16x16
th.nn.ConvTranspose2d(2*width, 2*width, 4, padding=1, stride=2),
th.nn.InstanceNorm2d(2*width),
th.nn.ReLU( inplace=True),
# 32x32
th.nn.Conv2d(2*width, width, 5, padding=2, stride=1),
th.nn.InstanceNorm2d(width),
th.nn.ReLU( inplace=True),
# 32x32
th.nn.ConvTranspose2d(width, width, 5, padding=2, stride=1),
th.nn.InstanceNorm2d(width),
th.nn.ReLU( inplace=True),
# 64x64
th.nn.Conv2d(width, width, 5, padding=2, stride=1),
th.nn.InstanceNorm2d(width),
th.nn.ReLU( inplace=True),
# 64x64
th.nn.Conv2d(width, 4, 5, padding=2, stride=1),
)
def forward(self, z):
bs = z.shape[0]
im = self.embedding(z).view(bs, 2*self.width, 4, 4)
out = self.net(im)
return out
class SketchDecoder(th.nn.Module):
"""
The decoder outputs a sequence where each time step models (dx, dy,
opacity).
"""
def __init__(self, sequence_length, hidden_size=512, dropout=0.9,
zdim=128, num_layers=3):
super(SketchVAE.SketchDecoder, self).__init__()
self.sequence_length = sequence_length
self.hidden_size = hidden_size
self.num_layers = num_layers
self.zdim = zdim
# Maps the latent vector to an initial cell/hidden vector
self.hidden_cell_predictor = th.nn.Linear(zdim, 2*hidden_size*num_layers)
self.lstm = th.nn.LSTM(
zdim, hidden_size,
num_layers=self.num_layers, dropout=dropout,
batch_first=True)
self.dxdy_predictor = th.nn.Sequential(
th.nn.Linear(hidden_size, 2),
th.nn.Tanh(),
)
self.opacity_predictor = th.nn.Sequential(
th.nn.Linear(hidden_size, 1),
th.nn.Sigmoid(),
)
def forward(self, z, hidden_and_cell=None):
# Every step in the sequence takes the latent vector as input so we
# replicate it here
bs = z.shape[0]
steps = self.sequence_length - 1 # no need to predict the start of sequence
expanded_z = z.unsqueeze(1).repeat(1, steps, 1)
if hidden_and_cell is None:
# Initialize from latent vector
hidden_and_cell = self.hidden_cell_predictor(
th.tanh(z))
hidden = hidden_and_cell[:, :self.hidden_size*self.num_layers]
hidden = hidden.view(-1, self.num_layers, self.hidden_size)
hidden = hidden.permute(1, 0, 2).contiguous()
# hidden = hidden.unsqueeze(1).contiguous()
cell = hidden_and_cell[:, self.hidden_size*self.num_layers:]
cell = cell.view(-1, self.num_layers, self.hidden_size)
cell = cell.permute(1, 0, 2).contiguous()
# cell = cell.unsqueeze(1).contiguous()
hidden_and_cell = (hidden, cell)
outputs, hidden_and_cell = self.lstm(expanded_z, hidden_and_cell)
hidden, cell = hidden_and_cell
dxdy = self.dxdy_predictor(
outputs.reshape(bs*steps, self.hidden_size)).view(bs, steps, -1)
opacity = self.opacity_predictor(
outputs.reshape(bs*steps, self.hidden_size)).view(bs, steps, -1)
strokes = th.cat([dxdy, opacity], -1)
return strokes
def __init__(self, sequence_length, zdim=128, image_size=64):
super(SketchVAE, self).__init__()
self.im_encoder = SketchVAE.ImageEncoder(
zdim=zdim, image_size=image_size)
self.im_decoder = SketchVAE.ImageDecoder(
zdim=zdim, image_size=image_size)
self.sketch_decoder = SketchVAE.SketchDecoder(
sequence_length, zdim=zdim)
def forward(self, images):
# Encode the images as latent vectors
z, mu, log_sigma = self.im_encoder(images)
decoded_im = self.im_decoder(z)
decoded_sketch = self.sketch_decoder(z)
return {
"decoded_im": decoded_im,
"decoded_sketch": decoded_sketch,
"z": z,
"mu": mu,
"log_sigma": log_sigma,
}
class SketchVAECallback(ttools.callbacks.ImageDisplayCallback):
"""Simple callback that visualize images."""
def visualized_image(self, batch, step_data, is_val=False):
if is_val:
return None
# only display the first n drawings
n = 8
gt = step_data["gt_image"][:n].detach()
vae_im = step_data["vae_image"][:n].detach()
sketch_im = step_data["sketch_image"][:n].detach()
rendering = th.cat([gt, vae_im, sketch_im], 2)
rendering = th.clamp(rendering, 0, 1)
alpha = rendering[:, 3:4]
rendering = rendering[:, :3] * alpha
return rendering
def caption(self, batch, step_data, is_val=False):
if is_val:
return ""
else:
return "top: truth, middle: vae sample, output: rnn-output"
class Interface(ttools.ModelInterface):
def __init__(self, model, lr=1e-4, lr_decay=0.9999,
kl_weight=0.5, kl_min_weight=0.01, kl_decay=0.99995,
raster_resolution=64, absolute_coords=False,
device="cpu", grad_clip=1.0):
super(Interface, self).__init__()
self.grad_clip = grad_clip
self.raster_resolution = raster_resolution
self.absolute_coords = absolute_coords
self.model = model
self.device = device
self.model.to(self.device)
self.im_enc_opt = th.optim.Adam(
self.model.im_encoder.parameters(), lr=lr)
self.im_dec_opt = th.optim.Adam(
self.model.im_decoder.parameters(), lr=lr)
self.sketch_dec_opt = th.optim.Adam(
self.model.sketch_decoder.parameters(), lr=lr)
self.kl_weight = kl_weight
self.kl_min_weight = kl_min_weight
self.kl_decay = kl_decay
self.kl_loss = losses.KLDivergence()
self.schedulers = [
th.optim.lr_scheduler.ExponentialLR(self.im_enc_opt, lr_decay),
th.optim.lr_scheduler.ExponentialLR(self.im_dec_opt, lr_decay),
th.optim.lr_scheduler.ExponentialLR(self.sketch_dec_opt, lr_decay),
]
# include loss on alpha
self.im_loss = losses.MultiscaleMSELoss(channels=4).to(self.device)
def optimizers(self):
return [self.im_enc_opt, self.im_dec_opt, self.sketch_dec_opt]
def kl_scaling(self):
# Scale the KL divergence weight
try:
state = self.im_enc_opt.state_dict()["param_groups"][0]["params"][0]
optim_step = self.im_enc_opt.state_dict()["state"][state]["step"]
except KeyError:
optim_step = 0 # no step taken yet
kl_scaling = 1.0 - (1.0 -
self.kl_min_weight)*(self.kl_decay**optim_step)
return kl_scaling
def training_step(self, batch):
gt_strokes, gt_im = batch
gt_strokes = gt_strokes.to(self.device)
gt_im = gt_im.to(self.device)
out = self.model(gt_im)
kl_loss = self.kl_loss(
out["mu"], out["log_sigma"])
kl_weight = self.kl_weight * self.kl_scaling()
# add start of sequence
sos = gt_strokes[:, :1]
sketch = th.cat([sos, out["decoded_sketch"]], 1)
vae_im = out["decoded_im"]
# start = time.time()
sketch_im = rendering.opacityStroke2diffvg(
sketch, canvas_size=self.raster_resolution, debug=False,
force_cpu=True, relative=not self.absolute_coords)
# elapsed = (time.time() - start)*1000
# print("out rendering took %.2fms" % elapsed)
vae_im_loss = self.im_loss(vae_im, gt_im)
sketch_im_loss = self.im_loss(sketch_im, gt_im)
# vae_im_loss = th.nn.functional.mse_loss(vae_im, gt_im)
# sketch_im_loss = th.nn.functional.mse_loss(sketch_im, gt_im)
loss = vae_im_loss + kl_loss*kl_weight + sketch_im_loss
self.im_enc_opt.zero_grad()
self.im_dec_opt.zero_grad()
self.sketch_dec_opt.zero_grad()
loss.backward()
# clip gradients
enc_nrm = th.nn.utils.clip_grad_norm_(
self.model.im_encoder.parameters(), self.grad_clip)
dec_nrm = th.nn.utils.clip_grad_norm_(
self.model.im_decoder.parameters(), self.grad_clip)
sketch_dec_nrm = th.nn.utils.clip_grad_norm_(
self.model.sketch_decoder.parameters(), self.grad_clip)
if enc_nrm > self.grad_clip:
LOG.debug("Clipped encoder gradient (%.5f) to %.2f",
enc_nrm, self.grad_clip)
if dec_nrm > self.grad_clip:
LOG.debug("Clipped decoder gradient (%.5f) to %.2f",
dec_nrm, self.grad_clip)
if sketch_dec_nrm > self.grad_clip:
LOG.debug("Clipped sketch decoder gradient (%.5f) to %.2f",
sketch_dec_nrm, self.grad_clip)
self.im_enc_opt.step()
self.im_dec_opt.step()
self.sketch_dec_opt.step()
return {
"vae_image": vae_im,
"sketch_image": sketch_im,
"gt_image": gt_im,
"loss": loss.item(),
"vae_im_loss": vae_im_loss.item(),
"sketch_im_loss": sketch_im_loss.item(),
"kl_loss": kl_loss.item(),
"kl_weight": kl_weight,
"lr": self.im_enc_opt.param_groups[0]["lr"],
}
def init_validation(self):
return dict(sample=None)
def validation_step(self, batch, running_data):
# Switch to eval mode for dropout, batchnorm, etc
# self.model.eval()
# with th.no_grad():
# # sample = self.model.sample(
# # batch.to(self.device), temperature=self.sampling_temperature)
# # running_data["sample"] = sample
# self.model.train()
return running_data
def train(args):
th.manual_seed(0)
np.random.seed(0)
dataset = data.FixedLengthQuickDrawDataset(
args.dataset, max_seq_length=args.sequence_length,
canvas_size=args.raster_resolution)
dataloader = DataLoader(
dataset, batch_size=args.bs, num_workers=args.workers, shuffle=True)
# val_dataset = [s for idx, s in enumerate(dataset) if idx < 8]
# val_dataloader = DataLoader(
# val_dataset, batch_size=8, num_workers=4, shuffle=False)
val_dataloader = None
model_params = {
"zdim": args.zdim,
"sequence_length": args.sequence_length,
"image_size": args.raster_resolution,
# "encoder_dim": args.encoder_dim,
# "decoder_dim": args.decoder_dim,
}
model = SketchVAE(**model_params)
model.train()
LOG.info("Model parameters:\n%s", model_params)
device = "cpu"
if th.cuda.is_available():
device = "cuda"
LOG.info("Using CUDA")
interface = Interface(model, raster_resolution=args.raster_resolution,
lr=args.lr, lr_decay=args.lr_decay,
kl_decay=args.kl_decay, kl_weight=args.kl_weight,
absolute_coords=args.absolute_coordinates,
device=device)
env_name = "sketch_vae"
if args.custom_name is not None:
env_name += "_" + args.custom_name
if args.absolute_coordinates:
env_name += "_abs_coords"
chkpt = os.path.join(OUTPUT, env_name)
# Resume from checkpoint, if any
checkpointer = ttools.Checkpointer(
chkpt, model, meta=model_params,
optimizers=interface.optimizers(),
schedulers=interface.schedulers)
extras, meta = checkpointer.load_latest()
epoch = extras["epoch"] if extras and "epoch" in extras.keys() else 0
if meta is not None and meta != model_params:
LOG.info("Checkpoint's metaparams differ "
"from CLI, aborting: %s and %s", meta, model_params)
trainer = ttools.Trainer(interface)
# Add callbacks
losses = ["loss", "kl_loss", "vae_im_loss", "sketch_im_loss"]
training_debug = ["lr", "kl_weight"]
trainer.add_callback(ttools.callbacks.ProgressBarCallback(
keys=losses, val_keys=None))
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
keys=losses, val_keys=None, env=env_name, port=args.port))
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
keys=training_debug, smoothing=0, val_keys=None, env=env_name,
port=args.port))
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
checkpointer, max_files=2, interval=600, max_epochs=10))
trainer.add_callback(
ttools.callbacks.LRSchedulerCallback(interface.schedulers))
trainer.add_callback(SketchVAECallback(
env=env_name, win="samples", port=args.port, frequency=args.freq))
# Start training
trainer.train(dataloader, starting_epoch=epoch,
val_dataloader=val_dataloader,
num_epochs=args.num_epochs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", default="cat.npz")
parser.add_argument("--absolute_coordinates", action="store_true",
default=False)
parser.add_argument("--custom_name")
# Training params
parser.add_argument("--bs", type=int, default=1)
parser.add_argument("--workers", type=int, default=0)
parser.add_argument("--num_epochs", type=int, default=10000)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--lr_decay", type=float, default=0.9999)
parser.add_argument("--kl_weight", type=float, default=0.5)
parser.add_argument("--kl_decay", type=float, default=0.99995)
# Model configuration
parser.add_argument("--zdim", type=int, default=128)
parser.add_argument("--sequence_length", type=int, default=50)
parser.add_argument("--raster_resolution", type=int, default=64)
# parser.add_argument("--encoder_dim", type=int, default=256)
# parser.add_argument("--decoder_dim", type=int, default=512)
# Viz params
parser.add_argument("--freq", type=int, default=10)
parser.add_argument("--port", type=int, default=5000)
args = parser.parse_args()
pydiffvg.set_use_gpu(False)
train(args)

View File

@@ -0,0 +1,489 @@
#!/bin/env python
"""Train a GAN.
Usage:
* Train a MNIST model:
`python train_gan.py`
* Train a Quickdraw model:
`python train_gan.py --task quickdraw`
"""
import argparse
import os
import numpy as np
import torch as th
from torch.utils.data import DataLoader
import ttools
import ttools.interfaces
import losses
import data
import models
import pydiffvg
LOG = ttools.get_logger(__name__)
BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir)
OUTPUT = os.path.join(BASE_DIR, "results")
class Callback(ttools.callbacks.ImageDisplayCallback):
"""Simple callback that visualize images."""
def visualized_image(self, batch, step_data, is_val=False):
if is_val:
return
gen = step_data["gen_image"][:16].detach()
ref = step_data["gt_image"][:16].detach()
# tensor to visualize, concatenate images
vizdata = th.cat([ref, gen], 2)
vector = step_data["vector_image"]
if vector is not None:
vector = vector[:16].detach()
vizdata = th.cat([vizdata, vector], 2)
vizdata = (vizdata + 1.0 ) * 0.5
viz = th.clamp(vizdata, 0, 1)
return viz
def caption(self, batch, step_data, is_val=False):
if step_data["vector_image"] is not None:
s = "top: real, middle: raster, bottom: vector"
else:
s = "top: real, bottom: fake"
return s
class Interface(ttools.ModelInterface):
def __init__(self, generator, vect_generator,
discriminator, vect_discriminator,
lr=1e-4, lr_decay=0.9999,
gradient_penalty=10,
wgan_gp=False,
raster_resolution=32, device="cpu", grad_clip=1.0):
super(Interface, self).__init__()
self.wgan_gp = wgan_gp
self.w_gradient_penalty = gradient_penalty
self.n_critic = 1
if self.wgan_gp:
self.n_critic = 5
self.grad_clip = grad_clip
self.raster_resolution = raster_resolution
self.gen = generator
self.vect_gen = vect_generator
self.discrim = discriminator
self.vect_discrim = vect_discriminator
self.device = device
self.gen.to(self.device)
self.discrim.to(self.device)
beta1 = 0.5
beta2 = 0.9
self.gen_opt = th.optim.Adam(
self.gen.parameters(), lr=lr, betas=(beta1, beta2))
self.discrim_opt = th.optim.Adam(
self.discrim.parameters(), lr=lr, betas=(beta1, beta2))
self.schedulers = [
th.optim.lr_scheduler.ExponentialLR(self.gen_opt, lr_decay),
th.optim.lr_scheduler.ExponentialLR(self.discrim_opt, lr_decay),
]
self.optimizers = [self.gen_opt, self.discrim_opt]
if self.vect_gen is not None:
assert self.vect_discrim is not None
self.vect_gen.to(self.device)
self.vect_discrim.to(self.device)
self.vect_gen_opt = th.optim.Adam(
self.vect_gen.parameters(), lr=lr, betas=(beta1, beta2))
self.vect_discrim_opt = th.optim.Adam(
self.vect_discrim.parameters(), lr=lr, betas=(beta1, beta2))
self.schedulers += [
th.optim.lr_scheduler.ExponentialLR(self.vect_gen_opt,
lr_decay),
th.optim.lr_scheduler.ExponentialLR(self.vect_discrim_opt,
lr_decay),
]
self.optimizers += [self.vect_gen_opt, self.vect_discrim_opt]
# include loss on alpha
self.im_loss = losses.MultiscaleMSELoss(channels=4).to(self.device)
self.iter = 0
self.cross_entropy = th.nn.BCEWithLogitsLoss()
self.mse = th.nn.MSELoss()
def _gradient_penalty(self, discrim, fake, real):
bs = real.size(0)
epsilon = th.rand(bs, 1, 1, 1, device=real.device)
epsilon = epsilon.expand_as(real)
interpolation = epsilon * real.data + (1 - epsilon) * fake.data
interpolation = th.autograd.Variable(interpolation, requires_grad=True)
interpolation_logits = discrim(interpolation)
grad_outputs = th.ones(interpolation_logits.size(), device=real.device)
gradients = th.autograd.grad(outputs=interpolation_logits,
inputs=interpolation,
grad_outputs=grad_outputs,
create_graph=True, retain_graph=True)[0]
gradients = gradients.view(bs, -1)
gradients_norm = th.sqrt(th.sum(gradients ** 2, dim=1) + 1e-12)
# [Tanh-Tung 2019] https://openreview.net/pdf?id=ByxPYjC5KQ
return self.w_gradient_penalty * ((gradients_norm - 0) ** 2).mean()
# return self.w_gradient_penalty * ((gradients_norm - 1) ** 2).mean()
def _discriminator_step(self, discrim, opt, fake, real):
"""Try to classify fake as 0 and real as 1."""
opt.zero_grad()
# no backprop to gen
fake = fake.detach()
fake_pred = discrim(fake)
real_pred = discrim(real)
if self.wgan_gp:
gradient_penalty = self._gradient_penalty(discrim, fake, real)
loss_d = fake_pred.mean() - real_pred.mean() + gradient_penalty
gradient_penalty = gradient_penalty.item()
else:
fake_loss = self.cross_entropy(fake_pred, th.zeros_like(fake_pred))
real_loss = self.cross_entropy(real_pred, th.ones_like(real_pred))
# fake_loss = self.mse(fake_pred, th.zeros_like(fake_pred))
# real_loss = self.mse(real_pred, th.ones_like(real_pred))
loss_d = 0.5*(fake_loss + real_loss)
gradient_penalty = None
loss_d.backward()
nrm = th.nn.utils.clip_grad_norm_(
discrim.parameters(), self.grad_clip)
if nrm > self.grad_clip:
LOG.debug("Clipped discriminator gradient (%.5f) to %.2f",
nrm, self.grad_clip)
opt.step()
return loss_d.item(), gradient_penalty
def _generator_step(self, gen, discrim, opt, fake):
"""Try to classify fake as 1."""
opt.zero_grad()
fake_pred = discrim(fake)
if self.wgan_gp:
loss_g = -fake_pred.mean()
else:
loss_g = self.cross_entropy(fake_pred, th.ones_like(fake_pred))
# loss_g = self.mse(fake_pred, th.ones_like(fake_pred))
loss_g.backward()
# clip gradients
nrm = th.nn.utils.clip_grad_norm_(
gen.parameters(), self.grad_clip)
if nrm > self.grad_clip:
LOG.debug("Clipped generator gradient (%.5f) to %.2f",
nrm, self.grad_clip)
opt.step()
return loss_g.item()
def training_step(self, batch):
im = batch
im = im.to(self.device)
z = self.gen.sample_z(im.shape[0], device=self.device)
generated = self.gen(z)
vect_generated = None
if self.vect_gen is not None:
vect_generated = self.vect_gen(z)
loss_g = None
loss_d = None
loss_g_vect = None
loss_d_vect = None
gp = None
gp_vect = None
if self.iter < self.n_critic: # Discriminator update
self.iter += 1
loss_d, gp = self._discriminator_step(
self.discrim, self.discrim_opt, generated, im)
if vect_generated is not None:
loss_d_vect, gp_vect = self._discriminator_step(
self.vect_discrim, self.vect_discrim_opt, vect_generated, im)
else: # Generator update
self.iter = 0
loss_g = self._generator_step(
self.gen, self.discrim, self.gen_opt, generated)
if vect_generated is not None:
loss_g_vect = self._generator_step(
self.vect_gen, self.vect_discrim, self.vect_gen_opt, vect_generated)
return {
"loss_g": loss_g,
"loss_d": loss_d,
"loss_g_vect": loss_g_vect,
"loss_d_vect": loss_d_vect,
"gp": gp,
"gp_vect": gp_vect,
"gt_image": im,
"gen_image": generated,
"vector_image": vect_generated,
"lr": self.gen_opt.param_groups[0]["lr"],
}
def init_validation(self):
return dict(sample=None)
def validation_step(self, batch, running_data):
# Switch to eval mode for dropout, batchnorm, etc
self.model.eval()
return running_data
def train(args):
th.manual_seed(0)
np.random.seed(0)
color_output = False
if args.task == "mnist":
dataset = data.MNISTDataset(args.raster_resolution, train=True)
elif args.task == "quickdraw":
dataset = data.QuickDrawImageDataset(
args.raster_resolution, train=True)
else:
raise NotImplementedError()
dataloader = DataLoader(
dataset, batch_size=args.bs, num_workers=args.workers, shuffle=True)
val_dataloader = None
model_params = {
"zdim": args.zdim,
"num_strokes": args.num_strokes,
"imsize": args.raster_resolution,
"stroke_width": args.stroke_width,
"color_output": color_output,
}
gen = models.Generator(**model_params)
gen.train()
discrim = models.Discriminator(color_output=color_output)
discrim.train()
if args.raster_only:
vect_gen = None
vect_discrim = None
else:
if args.generator == "fc":
vect_gen = models.VectorGenerator(**model_params)
elif args.generator == "bezier_fc":
vect_gen = models.BezierVectorGenerator(**model_params)
elif args.generator in ["rnn"]:
vect_gen = models.RNNVectorGenerator(**model_params)
elif args.generator in ["chain_rnn"]:
vect_gen = models.ChainRNNVectorGenerator(**model_params)
else:
raise NotImplementedError()
vect_gen.train()
vect_discrim = models.Discriminator(color_output=color_output)
vect_discrim.train()
LOG.info("Model parameters:\n%s", model_params)
device = "cpu"
if th.cuda.is_available():
device = "cuda"
LOG.info("Using CUDA")
interface = Interface(gen, vect_gen, discrim, vect_discrim,
raster_resolution=args.raster_resolution, lr=args.lr,
wgan_gp=args.wgan_gp,
lr_decay=args.lr_decay, device=device)
env_name = args.task + "_gan"
if args.raster_only:
env_name += "_raster"
else:
env_name += "_vector"
env_name += "_" + args.generator
if args.wgan_gp:
env_name += "_wgan"
chkpt = os.path.join(OUTPUT, env_name)
meta = {
"model_params": model_params,
"task": args.task,
"generator": args.generator,
}
checkpointer = ttools.Checkpointer(
chkpt, gen, meta=meta,
optimizers=interface.optimizers,
schedulers=interface.schedulers,
prefix="g_")
checkpointer_d = ttools.Checkpointer(
chkpt, discrim,
prefix="d_")
# Resume from checkpoint, if any
extras, _ = checkpointer.load_latest()
checkpointer_d.load_latest()
if not args.raster_only:
checkpointer_vect = ttools.Checkpointer(
chkpt, vect_gen, meta=meta,
optimizers=interface.optimizers,
schedulers=interface.schedulers,
prefix="vect_g_")
checkpointer_d_vect = ttools.Checkpointer(
chkpt, vect_discrim,
prefix="vect_d_")
extras, _ = checkpointer_vect.load_latest()
checkpointer_d_vect.load_latest()
epoch = extras["epoch"] if extras and "epoch" in extras.keys() else 0
# if meta is not None and meta["model_parameters"] != model_params:
# LOG.info("Checkpoint's metaparams differ "
# "from CLI, aborting: %s and %s", meta, model_params)
trainer = ttools.Trainer(interface)
# Add callbacks
losses = ["loss_g", "loss_d", "loss_g_vect", "loss_d_vect", "gp",
"gp_vect"]
training_debug = ["lr"]
trainer.add_callback(Callback(
env=env_name, win="samples", port=args.port, frequency=args.freq))
trainer.add_callback(ttools.callbacks.ProgressBarCallback(
keys=losses, val_keys=None))
trainer.add_callback(ttools.callbacks.MultiPlotCallback(
keys=losses, val_keys=None, env=env_name, port=args.port,
server=args.server, base_url=args.base_url,
win="losses", frequency=args.freq))
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
keys=training_debug, smoothing=0, val_keys=None, env=env_name,
server=args.server, base_url=args.base_url,
port=args.port))
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
checkpointer, max_files=2, interval=600, max_epochs=10))
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
checkpointer_d, max_files=2, interval=600, max_epochs=10))
if not args.raster_only:
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
checkpointer_vect, max_files=2, interval=600, max_epochs=10))
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
checkpointer_d_vect, max_files=2, interval=600, max_epochs=10))
trainer.add_callback(
ttools.callbacks.LRSchedulerCallback(interface.schedulers))
# Start training
trainer.train(dataloader, starting_epoch=epoch,
val_dataloader=val_dataloader,
num_epochs=args.num_epochs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--task",
default="mnist",
choices=["mnist", "quickdraw"])
parser.add_argument("--generator",
default="bezier_fc",
choices=["bezier_fc", "fc", "rnn", "chain_rnn"],
help="model to use as generator")
parser.add_argument("--raster_only", action="store_true", default=False,
help="if true only train the raster baseline")
parser.add_argument("--standard_gan", dest="wgan_gp", action="store_false",
default=True,
help="if true, use regular GAN instead of WGAN")
# Training params
parser.add_argument("--bs", type=int, default=4, help="batch size")
parser.add_argument("--workers", type=int, default=4,
help="number of dataloader threads")
parser.add_argument("--num_epochs", type=int, default=200,
help="number of epochs to train for")
parser.add_argument("--lr", type=float, default=1e-4,
help="learning rate")
parser.add_argument("--lr_decay", type=float, default=0.9999,
help="exponential learning rate decay rate")
# Model configuration
parser.add_argument("--zdim", type=int, default=32,
help="latent space dimension")
parser.add_argument("--stroke_width", type=float, nargs=2,
default=(0.5, 1.5),
help="min and max stroke width")
parser.add_argument("--num_strokes", type=int, default=16,
help="number of strokes to generate")
parser.add_argument("--raster_resolution", type=int, default=32,
help="raster canvas resolution on each side")
# Viz params
parser.add_argument("--freq", type=int, default=10,
help="visualization frequency")
parser.add_argument("--port", type=int, default=8097,
help="visdom port")
parser.add_argument("--server", default=None,
help="visdom server if not local.")
parser.add_argument("--base_url", default="", help="visdom entrypoint URL")
args = parser.parse_args()
pydiffvg.set_use_gpu(False)
ttools.set_logger(False)
train(args)