initial commit
This commit is contained in:
1
apps/generative_models/.gitignore
vendored
Normal file
1
apps/generative_models/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
.gdb_history
|
5
apps/generative_models/README.md
Normal file
5
apps/generative_models/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# Usage
|
||||
|
||||
For the GAN models, see `train_gan.py`. Generate samples from a pretrained using `eval_gan.py`
|
||||
|
||||
For the VAE models, see `mnist_vae.py`.
|
0
apps/generative_models/__init__.py
Normal file
0
apps/generative_models/__init__.py
Normal file
229
apps/generative_models/data.py
Normal file
229
apps/generative_models/data.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import os
|
||||
import time
|
||||
import torch as th
|
||||
import numpy as np
|
||||
import torchvision.datasets as dset
|
||||
import torchvision.transforms as transforms
|
||||
import imageio
|
||||
|
||||
import ttools
|
||||
import rendering
|
||||
|
||||
BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir)
|
||||
DATA = os.path.join(BASE_DIR, "data")
|
||||
|
||||
LOG = ttools.get_logger(__name__)
|
||||
|
||||
|
||||
class QuickDrawImageDataset(th.utils.data.Dataset):
|
||||
BASE_DATA_URL = \
|
||||
"https://console.cloud.google.com/storage/browser/_details/quickdraw_dataset/full/numpy_bitmap/cat.npy"
|
||||
"""
|
||||
Args:
|
||||
spatial_limit(int): maximum spatial extent in pixels.
|
||||
"""
|
||||
def __init__(self, imsize, train=True):
|
||||
super(QuickDrawImageDataset, self).__init__()
|
||||
file = os.path.join(DATA, "cat.npy")
|
||||
|
||||
self.imsize = imsize
|
||||
|
||||
if not os.path.exists(file):
|
||||
msg = "Dataset file %s does not exist, please download"
|
||||
" it from %s" % (file, QuickDrawImageDataset.BASE_DATA_URL)
|
||||
LOG.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
self.data = np.load(file, allow_pickle=True, encoding="latin1")
|
||||
|
||||
def __len__(self):
|
||||
return self.data.shape[0]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
im = np.reshape(self.data[idx], (1, 1, 28, 28))
|
||||
im = th.from_numpy(im).float() / 255.0
|
||||
im = th.nn.functional.interpolate(im, size=(self.imsize, self.imsize))
|
||||
|
||||
# Bring it to [-1, 1]
|
||||
im = th.clamp(im, 0, 1)
|
||||
im -= 0.5
|
||||
im /= 0.5
|
||||
|
||||
return im.squeeze(0)
|
||||
|
||||
|
||||
class QuickDrawDataset(th.utils.data.Dataset):
|
||||
BASE_DATA_URL = \
|
||||
"https://storage.cloud.google.com/quickdraw_dataset/sketchrnn"
|
||||
|
||||
"""
|
||||
Args:
|
||||
spatial_limit(int): maximum spatial extent in pixels.
|
||||
"""
|
||||
def __init__(self, dataset, mode="train",
|
||||
max_seq_length=250,
|
||||
spatial_limit=1000):
|
||||
super(QuickDrawDataset, self).__init__()
|
||||
file = os.path.join(DATA, "sketchrnn_"+dataset)
|
||||
remote = os.path.join(QuickDrawDataset.BASE_DATA_URL, dataset)
|
||||
|
||||
self.max_seq_length = max_seq_length
|
||||
self.spatial_limit = spatial_limit
|
||||
|
||||
if mode not in ["train", "test", "valid"]:
|
||||
return ValueError("Only allowed data mode are 'train' and 'test',"
|
||||
" 'valid'.")
|
||||
|
||||
if not os.path.exists(file):
|
||||
msg = "Dataset file %s does not exist, please download"
|
||||
" it from %s" % (file, remote)
|
||||
LOG.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
data = np.load(file, allow_pickle=True, encoding="latin1")[mode]
|
||||
data = self.purify(data)
|
||||
data = self.normalize(data)
|
||||
|
||||
# Length of longest sequence in the dataset
|
||||
self.nmax = max([len(seq) for seq in data])
|
||||
self.sketches = data
|
||||
|
||||
def __repr__(self):
|
||||
return "Dataset with %d sequences of max length %d" % \
|
||||
(len(self.sketches), self.nmax)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sketches)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Return the idx-th stroke in 5-D format, padded to length (Nmax+2).
|
||||
|
||||
The first and last element of the sequence are fixed to "start-" and
|
||||
"end-of-sequence" token.
|
||||
|
||||
dx, dy, + 3 numbers for one-hot encoding of state:
|
||||
1 0 0: pen touching paper till next point
|
||||
0 1 0: pen lifted from paper after current point
|
||||
0 0 1: drawing has ended, next points (including current will not be
|
||||
drawn)
|
||||
"""
|
||||
sample_data = self.sketches[idx]
|
||||
|
||||
# Allow two extra slots for start/end of sequence tokens
|
||||
sample = np.zeros((self.nmax+2, 5), dtype=np.float32)
|
||||
|
||||
n = sample_data.shape[0]
|
||||
|
||||
# normalize dx, dy
|
||||
deltas = sample_data[:, :2]
|
||||
# Absolute coordinates
|
||||
positions = deltas[..., :2].cumsum(0)
|
||||
maxi = np.abs(positions).max() + 1e-8
|
||||
deltas = deltas / (1.1 * maxi) # leave some margin on edges
|
||||
|
||||
# fill in dx, dy coordinates
|
||||
sample[1:n+1, :2] = deltas
|
||||
|
||||
# on paper indicator: 0 means touching paper in the 3d format, flip it
|
||||
sample[1:n+1, 2] = 1 - sample_data[:, 2]
|
||||
|
||||
# off-paper indicator, complement of previous flag
|
||||
sample[1:n+1, 3] = 1 - sample[1:n+1, 2]
|
||||
|
||||
# fill with end of sequence tokens for the remainder
|
||||
sample[n+1:, 4] = 1
|
||||
|
||||
# Start of sequence token
|
||||
sample[0] = [0, 0, 1, 0, 0]
|
||||
|
||||
return sample
|
||||
|
||||
def purify(self, strokes):
|
||||
"""removes to small or too long sequences + removes large gaps"""
|
||||
data = []
|
||||
for seq in strokes:
|
||||
if seq.shape[0] <= self.max_seq_length:
|
||||
# and seq.shape[0] > 10:
|
||||
|
||||
# Limit large spatial gaps
|
||||
seq = np.minimum(seq, self.spatial_limit)
|
||||
seq = np.maximum(seq, -self.spatial_limit)
|
||||
seq = np.array(seq, dtype=np.float32)
|
||||
data.append(seq)
|
||||
return data
|
||||
|
||||
def calculate_normalizing_scale_factor(self, strokes):
|
||||
"""Calculate the normalizing factor explained in appendix of
|
||||
sketch-rnn."""
|
||||
data = []
|
||||
for i, stroke_i in enumerate(strokes):
|
||||
for j, pt in enumerate(strokes[i]):
|
||||
data.append(pt[0])
|
||||
data.append(pt[1])
|
||||
data = np.array(data)
|
||||
return np.std(data)
|
||||
|
||||
def normalize(self, strokes):
|
||||
"""Normalize entire dataset (delta_x, delta_y) by the scaling
|
||||
factor."""
|
||||
data = []
|
||||
scale_factor = self.calculate_normalizing_scale_factor(strokes)
|
||||
for seq in strokes:
|
||||
seq[:, 0:2] /= scale_factor
|
||||
data.append(seq)
|
||||
return data
|
||||
|
||||
|
||||
class FixedLengthQuickDrawDataset(QuickDrawDataset):
|
||||
"""A variant of the QuickDraw dataset where the strokes are represented as
|
||||
a fixed-length sequence of triplets (dx, dy, opacity), where opacity = 0, 1.
|
||||
"""
|
||||
def __init__(self, *args, canvas_size=64, **kwargs):
|
||||
super(FixedLengthQuickDrawDataset, self).__init__(*args, **kwargs)
|
||||
self.canvas_size = canvas_size
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = super(FixedLengthQuickDrawDataset, self).__getitem__(idx)
|
||||
|
||||
# We construct a stroke opacity variable from the pen down state, dx, dy remain unchanged
|
||||
strokes = sample[:, :3]
|
||||
|
||||
im = np.zeros((1, 1))
|
||||
|
||||
# render image
|
||||
# start = time.time()
|
||||
im = rendering.opacityStroke2diffvg(
|
||||
th.from_numpy(strokes).unsqueeze(0), canvas_size=self.canvas_size,
|
||||
relative=True, debug=False)
|
||||
im = im.squeeze(0).numpy()
|
||||
# elapsed = (time.time() - start)*1000
|
||||
# print("item %d pipeline gt rendering took %.2fms" % (idx, elapsed))
|
||||
|
||||
return strokes, im
|
||||
|
||||
|
||||
class MNISTDataset(th.utils.data.Dataset):
|
||||
def __init__(self, imsize, train=True):
|
||||
super(MNISTDataset, self).__init__()
|
||||
self.mnist = dset.MNIST(root=os.path.join(DATA, "mnist"),
|
||||
train=train,
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.Resize((imsize, imsize)),
|
||||
transforms.ToTensor(),
|
||||
]))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.mnist)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
im, label = self.mnist[idx]
|
||||
|
||||
# make sure data uses [0, 1] range
|
||||
im -= im.min()
|
||||
im /= im.max() + 1e-8
|
||||
|
||||
# Bring it to [-1, 1]
|
||||
im -= 0.5
|
||||
im /= 0.5
|
||||
return im
|
182
apps/generative_models/eval_gan.py
Normal file
182
apps/generative_models/eval_gan.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""Evaluate a pretrained GAN model.
|
||||
Usage:
|
||||
|
||||
`python eval_gan.py <path/to/model/folder>`, e.g.
|
||||
`../results/quickdraw_gan_vector_bezier_fc_wgan`.
|
||||
|
||||
"""
|
||||
import os
|
||||
import argparse
|
||||
import torch as th
|
||||
import numpy as np
|
||||
import ttools
|
||||
import imageio
|
||||
from subprocess import call
|
||||
|
||||
import pydiffvg
|
||||
|
||||
import models
|
||||
|
||||
|
||||
LOG = ttools.get_logger(__name__)
|
||||
|
||||
|
||||
def postprocess(im, invert=False):
|
||||
im = th.clamp((im + 1.0) / 2.0, 0, 1)
|
||||
if invert:
|
||||
im = (1.0 - im)
|
||||
im = ttools.tensor2image(im)
|
||||
return im
|
||||
|
||||
|
||||
def imsave(im, path):
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
imageio.imwrite(path, im)
|
||||
|
||||
|
||||
def save_scene(scn, path):
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
pydiffvg.save_svg(path, *scn, use_gamma=False)
|
||||
|
||||
|
||||
def run(args):
|
||||
th.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
meta = ttools.Checkpointer.load_meta(args.model, "vect_g_")
|
||||
|
||||
if meta is None:
|
||||
LOG.warning("Could not load metadata at %s, aborting.", args.model)
|
||||
return
|
||||
|
||||
LOG.info("Loaded model %s with metadata:\n %s", args.model, meta)
|
||||
|
||||
if args.output_dir is None:
|
||||
outdir = os.path.join(args.model, "eval")
|
||||
else:
|
||||
outdir = args.output_dir
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
|
||||
model_params = meta["model_params"]
|
||||
if args.imsize is not None:
|
||||
LOG.info("Overriding output image size to: %dx%d", args.imsize,
|
||||
args.imsize)
|
||||
old_size = model_params["imsize"]
|
||||
scale = args.imsize * 1.0 / old_size
|
||||
model_params["imsize"] = args.imsize
|
||||
model_params["stroke_width"] = [w*scale for w in
|
||||
model_params["stroke_width"]]
|
||||
LOG.info("Overriding width to: %s", model_params["stroke_width"])
|
||||
|
||||
# task = meta["task"]
|
||||
generator = meta["generator"]
|
||||
if generator == "fc":
|
||||
model = models.VectorGenerator(**model_params)
|
||||
elif generator == "bezier_fc":
|
||||
model = models.BezierVectorGenerator(**model_params)
|
||||
elif generator in ["rnn"]:
|
||||
model = models.RNNVectorGenerator(**model_params)
|
||||
elif generator in ["chain_rnn"]:
|
||||
model = models.ChainRNNVectorGenerator(**model_params)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
model.eval()
|
||||
|
||||
device = "cpu"
|
||||
if th.cuda.is_available():
|
||||
device = "cuda"
|
||||
|
||||
model.to(device)
|
||||
|
||||
checkpointer = ttools.Checkpointer(
|
||||
args.model, model, meta=meta, prefix="vect_g_")
|
||||
checkpointer.load_latest()
|
||||
|
||||
LOG.info("Computing latent space interpolation")
|
||||
for i in range(args.nsamples):
|
||||
z0 = model.sample_z(1)
|
||||
z1 = model.sample_z(1)
|
||||
|
||||
# interpolation
|
||||
alpha = th.linspace(0, 1, args.nsteps).view(args.nsteps, 1).to(device)
|
||||
alpha_video = th.linspace(0, 1, args.nframes).view(args.nframes, 1)
|
||||
alpha_video = alpha_video.to(device)
|
||||
|
||||
length = [args.nsteps, args.nframes]
|
||||
for idx, a in enumerate([alpha, alpha_video]):
|
||||
_z0 = z0.repeat(length[idx], 1).to(device)
|
||||
_z1 = z1.repeat(length[idx], 1).to(device)
|
||||
batch = _z0*(1-a) + _z1*a
|
||||
out = model(batch)
|
||||
if idx == 0: # image viz
|
||||
n, c, h, w = out.shape
|
||||
out = out.permute(1, 2, 0, 3)
|
||||
out = out.contiguous().view(1, c, h, w*n)
|
||||
out = postprocess(out, invert=args.invert)
|
||||
imsave(out, os.path.join(outdir,
|
||||
"latent_interp", "%03d.png" % i))
|
||||
|
||||
scenes = model.get_vector(batch)
|
||||
for scn_idx, scn in enumerate(scenes):
|
||||
save_scene(scn, os.path.join(outdir, "latent_interp_svg",
|
||||
"%03d" % i, "%03d.svg" %
|
||||
scn_idx))
|
||||
else: # video viz
|
||||
anim_root = os.path.join(outdir,
|
||||
"latent_interp_video", "%03d" % i)
|
||||
LOG.info("Rendering animation %d", i)
|
||||
for frame_idx, frame in enumerate(out):
|
||||
LOG.info("frame %d", frame_idx)
|
||||
frame = frame.unsqueeze(0)
|
||||
frame = postprocess(frame, invert=args.invert)
|
||||
imsave(frame, os.path.join(anim_root,
|
||||
"frame%04d.png" % frame_idx))
|
||||
call(["ffmpeg", "-framerate", "30", "-i",
|
||||
os.path.join(anim_root, "frame%04d.png"), "-vb", "20M",
|
||||
os.path.join(outdir,
|
||||
"latent_interp_video", "%03d.mp4" % i)])
|
||||
LOG.info(" saved %d", i)
|
||||
|
||||
LOG.info("Sampling latent space")
|
||||
|
||||
for i in range(args.nsamples):
|
||||
n = 8
|
||||
bs = n*n
|
||||
z = model.sample_z(bs).to(device)
|
||||
out = model(z)
|
||||
_, c, h, w = out.shape
|
||||
out = out.view(n, n, c, h, w).permute(2, 0, 3, 1, 4)
|
||||
out = out.contiguous().view(1, c, h*n, w*n)
|
||||
out = postprocess(out)
|
||||
imsave(out, os.path.join(outdir, "samples_%03d.png" % i))
|
||||
LOG.info(" saved %d", i)
|
||||
|
||||
LOG.info("output images saved to %s", outdir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("model")
|
||||
parser.add_argument("--output_dir", help="output directory for "
|
||||
" the samples. Defaults to the model's path")
|
||||
parser.add_argument("--nsamples", default=16, type=int,
|
||||
help="number of output to compute")
|
||||
parser.add_argument("--imsize", type=int,
|
||||
help="if provided, override the raster output "
|
||||
"resolution")
|
||||
parser.add_argument("--nsteps", default=9, type=int, help="number of "
|
||||
"interpolation steps for the interpolation")
|
||||
parser.add_argument("--nframes", default=120, type=int, help="number of "
|
||||
"frames for the interpolation video")
|
||||
parser.add_argument("--invert", default=False, action="store_true",
|
||||
help="if True, render black on white rather than the"
|
||||
" opposite")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
pydiffvg.set_use_gpu(False)
|
||||
|
||||
ttools.set_logger(False)
|
||||
|
||||
run(args)
|
99
apps/generative_models/losses.py
Normal file
99
apps/generative_models/losses.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Losses for the generative models and baselines."""
|
||||
import torch as th
|
||||
import numpy as np
|
||||
|
||||
import ttools.modules.image_operators as imops
|
||||
|
||||
|
||||
class KLDivergence(th.nn.Module):
|
||||
"""
|
||||
Args:
|
||||
min_value(float): the loss is clipped so that value below this
|
||||
number don't affect the optimization.
|
||||
"""
|
||||
def __init__(self, min_value=0.2):
|
||||
super(KLDivergence, self).__init__()
|
||||
self.min_value = min_value
|
||||
|
||||
def forward(self, mu, log_sigma):
|
||||
loss = -0.5 * (1.0 + log_sigma - mu.pow(2) - log_sigma.exp())
|
||||
loss = loss.mean()
|
||||
loss = th.max(loss, self.min_value*th.ones_like(loss))
|
||||
return loss
|
||||
|
||||
|
||||
class MultiscaleMSELoss(th.nn.Module):
|
||||
def __init__(self, channels=3):
|
||||
super(MultiscaleMSELoss, self).__init__()
|
||||
self.blur = imops.GaussianBlur(1, channels=channels)
|
||||
|
||||
def forward(self, im, target):
|
||||
bs, c, h, w = im.shape
|
||||
num_levels = max(int(np.ceil(np.log2(h))) - 2, 1)
|
||||
|
||||
losses = []
|
||||
for lvl in range(num_levels):
|
||||
loss = th.nn.functional.mse_loss(im, target)
|
||||
losses.append(loss)
|
||||
im = th.nn.functional.interpolate(self.blur(im),
|
||||
scale_factor=0.5,
|
||||
mode="nearest")
|
||||
target = th.nn.functional.interpolate(self.blur(target),
|
||||
scale_factor=0.5,
|
||||
mode="nearest")
|
||||
|
||||
losses = th.stack(losses)
|
||||
return losses.sum()
|
||||
|
||||
|
||||
def gaussian_pdfs(dx, dy, params):
|
||||
"""Returns the pdf at (dx, dy) for each Gaussian in the mixture.
|
||||
"""
|
||||
dx = dx.unsqueeze(-1) # replicate dx, dy to evaluate all pdfs at once
|
||||
dy = dy.unsqueeze(-1)
|
||||
|
||||
mu_x = params[..., 0]
|
||||
mu_y = params[..., 1]
|
||||
sigma_x = params[..., 2].exp()
|
||||
sigma_y = params[..., 3].exp()
|
||||
rho_xy = th.tanh(params[..., 4])
|
||||
|
||||
x = ((dx-mu_x) / sigma_x).pow(2)
|
||||
y = ((dy-mu_y) / sigma_y).pow(2)
|
||||
|
||||
xy = (dx-mu_x)*(dy-mu_y) / (sigma_x * sigma_y)
|
||||
arg = x + y - 2.0*rho_xy*xy
|
||||
pdf = th.exp(-arg / (2*(1.0 - rho_xy.pow(2))))
|
||||
norm = 2.0 * np.pi * sigma_x * sigma_y * (1.0 - rho_xy.pow(2)).sqrt()
|
||||
|
||||
return pdf / norm
|
||||
|
||||
|
||||
class GaussianMixtureReconstructionLoss(th.nn.Module):
|
||||
"""
|
||||
Args:
|
||||
"""
|
||||
def __init__(self, eps=1e-5):
|
||||
super(GaussianMixtureReconstructionLoss, self).__init__()
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, pen_logits, mixture_logits, gaussian_params, targets):
|
||||
dx = targets[..., 0]
|
||||
dy = targets[..., 1]
|
||||
pen_state = targets[..., 2:].argmax(-1) # target index
|
||||
|
||||
# Likelihood loss on the stroke position
|
||||
# No need to predict accurate pen position for end-of-sequence tokens
|
||||
valid_stroke = (targets[..., -1] != 1.0).float()
|
||||
mixture_weights = th.nn.functional.softmax(mixture_logits, -1)
|
||||
pdfs = gaussian_pdfs(dx, dy, gaussian_params)
|
||||
position_loss = - th.log(self.eps + (pdfs * mixture_weights).sum(-1))
|
||||
|
||||
# by actual non-empty count
|
||||
position_loss = (position_loss*valid_stroke).sum() / valid_stroke.sum()
|
||||
|
||||
# Classification loss for the stroke mode
|
||||
pen_loss = th.nn.functional.cross_entropy(pen_logits.view(-1, 3),
|
||||
pen_state.view(-1))
|
||||
|
||||
return position_loss + pen_loss
|
1026
apps/generative_models/mnist_vae.py
Normal file
1026
apps/generative_models/mnist_vae.py
Normal file
File diff suppressed because it is too large
Load Diff
484
apps/generative_models/models.py
Normal file
484
apps/generative_models/models.py
Normal file
@@ -0,0 +1,484 @@
|
||||
"""Collection of generative models."""
|
||||
|
||||
import torch as th
|
||||
import ttools
|
||||
|
||||
import rendering
|
||||
import modules
|
||||
|
||||
LOG = ttools.get_logger(__name__)
|
||||
|
||||
|
||||
class BaseModel(th.nn.Module):
|
||||
def sample_z(self, bs, device="cpu"):
|
||||
return th.randn(bs, self.zdim).to(device)
|
||||
|
||||
|
||||
class BaseVectorModel(BaseModel):
|
||||
def get_vector(self, z):
|
||||
_, scenes = self._forward(z)
|
||||
return scenes
|
||||
|
||||
def _forward(self, x):
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, z):
|
||||
# Only return the raster
|
||||
return self._forward(z)[0]
|
||||
|
||||
|
||||
class BezierVectorGenerator(BaseVectorModel):
|
||||
NUM_SEGMENTS = 2
|
||||
def __init__(self, num_strokes=4,
|
||||
zdim=128, width=32, imsize=32,
|
||||
color_output=False,
|
||||
stroke_width=None):
|
||||
super(BezierVectorGenerator, self).__init__()
|
||||
|
||||
if stroke_width is None:
|
||||
self.stroke_width = (0.5, 3.0)
|
||||
LOG.warning("Setting default stroke with %s", self.stroke_width)
|
||||
else:
|
||||
self.stroke_width = stroke_width
|
||||
|
||||
self.imsize = imsize
|
||||
self.num_strokes = num_strokes
|
||||
self.zdim = zdim
|
||||
|
||||
self.trunk = th.nn.Sequential(
|
||||
th.nn.Linear(zdim, width),
|
||||
th.nn.SELU(inplace=True),
|
||||
|
||||
th.nn.Linear(width, 2*width),
|
||||
th.nn.SELU(inplace=True),
|
||||
|
||||
th.nn.Linear(2*width, 4*width),
|
||||
th.nn.SELU(inplace=True),
|
||||
|
||||
th.nn.Linear(4*width, 8*width),
|
||||
th.nn.SELU(inplace=True),
|
||||
)
|
||||
|
||||
# 4 points bezier with n_segments -> 3*n_segments + 1 points
|
||||
self.point_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(8*width,
|
||||
2*self.num_strokes*(
|
||||
BezierVectorGenerator.NUM_SEGMENTS*3 + 1)),
|
||||
th.nn.Tanh() # bound spatial extent
|
||||
)
|
||||
|
||||
self.width_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(8*width, self.num_strokes),
|
||||
th.nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.alpha_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(8*width, self.num_strokes),
|
||||
th.nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.color_predictor = None
|
||||
if color_output:
|
||||
self.color_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(8*width, 3*self.num_strokes),
|
||||
th.nn.Sigmoid()
|
||||
)
|
||||
|
||||
def _forward(self, z):
|
||||
bs = z.shape[0]
|
||||
|
||||
feats = self.trunk(z)
|
||||
all_points = self.point_predictor(feats)
|
||||
all_alphas = self.alpha_predictor(feats)
|
||||
|
||||
if self.color_predictor:
|
||||
all_colors = self.color_predictor(feats)
|
||||
all_colors = all_colors.view(bs, self.num_strokes, 3)
|
||||
else:
|
||||
all_colors = None
|
||||
|
||||
all_widths = self.width_predictor(feats)
|
||||
min_width = self.stroke_width[0]
|
||||
max_width = self.stroke_width[1]
|
||||
all_widths = (max_width - min_width) * all_widths + min_width
|
||||
|
||||
all_points = all_points.view(
|
||||
bs, self.num_strokes, BezierVectorGenerator.NUM_SEGMENTS*3+1, 2)
|
||||
|
||||
output, scenes = rendering.bezier_render(all_points, all_widths, all_alphas,
|
||||
colors=all_colors,
|
||||
canvas_size=self.imsize)
|
||||
|
||||
# map to [-1, 1]
|
||||
output = output*2.0 - 1.0
|
||||
|
||||
return output, scenes
|
||||
|
||||
|
||||
class VectorGenerator(BaseVectorModel):
|
||||
def __init__(self, num_strokes=4,
|
||||
zdim=128, width=32, imsize=32,
|
||||
color_output=False,
|
||||
stroke_width=None):
|
||||
super(VectorGenerator, self).__init__()
|
||||
|
||||
if stroke_width is None:
|
||||
self.stroke_width = (0.5, 3.0)
|
||||
LOG.warning("Setting default stroke with %s", self.stroke_width)
|
||||
else:
|
||||
self.stroke_width = stroke_width
|
||||
|
||||
self.imsize = imsize
|
||||
self.num_strokes = num_strokes
|
||||
self.zdim = zdim
|
||||
|
||||
self.trunk = th.nn.Sequential(
|
||||
th.nn.Linear(zdim, width),
|
||||
th.nn.SELU(inplace=True),
|
||||
|
||||
th.nn.Linear(width, 2*width),
|
||||
th.nn.SELU(inplace=True),
|
||||
|
||||
th.nn.Linear(2*width, 4*width),
|
||||
th.nn.SELU(inplace=True),
|
||||
|
||||
th.nn.Linear(4*width, 8*width),
|
||||
th.nn.SELU(inplace=True),
|
||||
)
|
||||
|
||||
# straight lines so n_segments -> n_segments - 1 points
|
||||
self.point_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(8*width, 2*(self.num_strokes*2)),
|
||||
th.nn.Tanh() # bound spatial extent
|
||||
)
|
||||
|
||||
self.width_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(8*width, self.num_strokes),
|
||||
th.nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.alpha_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(8*width, self.num_strokes),
|
||||
th.nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.color_predictor = None
|
||||
if color_output:
|
||||
self.color_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(8*width, 3*self.num_strokes),
|
||||
th.nn.Sigmoid()
|
||||
)
|
||||
|
||||
def _forward(self, z):
|
||||
bs = z.shape[0]
|
||||
|
||||
feats = self.trunk(z)
|
||||
|
||||
all_points = self.point_predictor(feats)
|
||||
|
||||
all_alphas = self.alpha_predictor(feats)
|
||||
|
||||
if self.color_predictor:
|
||||
all_colors = self.color_predictor(feats)
|
||||
all_colors = all_colors.view(bs, self.num_strokes, 3)
|
||||
else:
|
||||
all_colors = None
|
||||
|
||||
all_widths = self.width_predictor(feats)
|
||||
min_width = self.stroke_width[0]
|
||||
max_width = self.stroke_width[1]
|
||||
all_widths = (max_width - min_width) * all_widths + min_width
|
||||
|
||||
all_points = all_points.view(bs, self.num_strokes, 2, 2)
|
||||
output, scenes = rendering.line_render(all_points, all_widths, all_alphas,
|
||||
colors=all_colors,
|
||||
canvas_size=self.imsize)
|
||||
|
||||
# map to [-1, 1]
|
||||
output = output*2.0 - 1.0
|
||||
|
||||
return output, scenes
|
||||
|
||||
|
||||
class RNNVectorGenerator(BaseVectorModel):
|
||||
def __init__(self, num_strokes=64,
|
||||
zdim=128, width=32, imsize=32,
|
||||
hidden_size=512, dropout=0.9,
|
||||
color_output=False,
|
||||
num_layers=3, stroke_width=None):
|
||||
super(RNNVectorGenerator, self).__init__()
|
||||
|
||||
|
||||
if stroke_width is None:
|
||||
self.stroke_width = (0.5, 3.0)
|
||||
LOG.warning("Setting default stroke with %s", self.stroke_width)
|
||||
else:
|
||||
self.stroke_width = stroke_width
|
||||
|
||||
self.num_layers = num_layers
|
||||
self.imsize = imsize
|
||||
self.num_strokes = num_strokes
|
||||
self.hidden_size = hidden_size
|
||||
self.zdim = zdim
|
||||
|
||||
self.hidden_cell_predictor = th.nn.Linear(
|
||||
zdim, 2*hidden_size*num_layers)
|
||||
|
||||
self.lstm = th.nn.LSTM(
|
||||
zdim, hidden_size,
|
||||
num_layers=self.num_layers, dropout=dropout,
|
||||
batch_first=True)
|
||||
|
||||
# straight lines so n_segments -> n_segments - 1 points
|
||||
self.point_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(hidden_size, 2*2), # 2 points, (x,y)
|
||||
th.nn.Tanh() # bound spatial extent
|
||||
)
|
||||
|
||||
self.width_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(hidden_size, 1),
|
||||
th.nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.alpha_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(hidden_size, 1),
|
||||
th.nn.Sigmoid()
|
||||
)
|
||||
|
||||
def _forward(self, z, hidden_and_cell=None):
|
||||
steps = self.num_strokes
|
||||
|
||||
# z is passed at each step, duplicate it
|
||||
bs = z.shape[0]
|
||||
expanded_z = z.unsqueeze(1).repeat(1, steps, 1)
|
||||
|
||||
# First step in the RNN
|
||||
if hidden_and_cell is None:
|
||||
# Initialize from latent vector
|
||||
hidden_and_cell = self.hidden_cell_predictor(th.tanh(z))
|
||||
hidden = hidden_and_cell[:, :self.hidden_size*self.num_layers]
|
||||
hidden = hidden.view(-1, self.num_layers, self.hidden_size)
|
||||
hidden = hidden.permute(1, 0, 2).contiguous()
|
||||
cell = hidden_and_cell[:, self.hidden_size*self.num_layers:]
|
||||
cell = cell.view(-1, self.num_layers, self.hidden_size)
|
||||
cell = cell.permute(1, 0, 2).contiguous()
|
||||
hidden_and_cell = (hidden, cell)
|
||||
|
||||
feats, hidden_and_cell = self.lstm(expanded_z, hidden_and_cell)
|
||||
hidden, cell = hidden_and_cell
|
||||
|
||||
feats = feats.reshape(bs*steps, self.hidden_size)
|
||||
|
||||
all_points = self.point_predictor(feats).view(bs, steps, 2, 2)
|
||||
all_alphas = self.alpha_predictor(feats).view(bs, steps)
|
||||
all_widths = self.width_predictor(feats).view(bs, steps)
|
||||
|
||||
min_width = self.stroke_width[0]
|
||||
max_width = self.stroke_width[1]
|
||||
all_widths = (max_width - min_width) * all_widths + min_width
|
||||
|
||||
output, scenes = rendering.line_render(all_points, all_widths, all_alphas,
|
||||
canvas_size=self.imsize)
|
||||
|
||||
# map to [-1, 1]
|
||||
output = output*2.0 - 1.0
|
||||
|
||||
return output, scenes
|
||||
|
||||
|
||||
class ChainRNNVectorGenerator(BaseVectorModel):
|
||||
"""Strokes form a single long chain."""
|
||||
def __init__(self, num_strokes=64,
|
||||
zdim=128, width=32, imsize=32,
|
||||
hidden_size=512, dropout=0.9,
|
||||
color_output=False,
|
||||
num_layers=3, stroke_width=None):
|
||||
super(ChainRNNVectorGenerator, self).__init__()
|
||||
|
||||
if stroke_width is None:
|
||||
self.stroke_width = (0.5, 3.0)
|
||||
LOG.warning("Setting default stroke with %s", self.stroke_width)
|
||||
else:
|
||||
self.stroke_width = stroke_width
|
||||
|
||||
self.num_layers = num_layers
|
||||
self.imsize = imsize
|
||||
self.num_strokes = num_strokes
|
||||
self.hidden_size = hidden_size
|
||||
self.zdim = zdim
|
||||
|
||||
self.hidden_cell_predictor = th.nn.Linear(
|
||||
zdim, 2*hidden_size*num_layers)
|
||||
|
||||
self.lstm = th.nn.LSTM(
|
||||
zdim, hidden_size,
|
||||
num_layers=self.num_layers, dropout=dropout,
|
||||
batch_first=True)
|
||||
|
||||
# straight lines so n_segments -> n_segments - 1 points
|
||||
self.point_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(hidden_size, 2), # 1 point, (x,y)
|
||||
th.nn.Tanh() # bound spatial extent
|
||||
)
|
||||
|
||||
self.width_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(hidden_size, 1),
|
||||
th.nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.alpha_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(hidden_size, 1),
|
||||
th.nn.Sigmoid()
|
||||
)
|
||||
|
||||
def _forward(self, z, hidden_and_cell=None):
|
||||
steps = self.num_strokes
|
||||
|
||||
# z is passed at each step, duplicate it
|
||||
bs = z.shape[0]
|
||||
expanded_z = z.unsqueeze(1).repeat(1, steps, 1)
|
||||
|
||||
# First step in the RNN
|
||||
if hidden_and_cell is None:
|
||||
# Initialize from latent vector
|
||||
hidden_and_cell = self.hidden_cell_predictor(th.tanh(z))
|
||||
hidden = hidden_and_cell[:, :self.hidden_size*self.num_layers]
|
||||
hidden = hidden.view(-1, self.num_layers, self.hidden_size)
|
||||
hidden = hidden.permute(1, 0, 2).contiguous()
|
||||
cell = hidden_and_cell[:, self.hidden_size*self.num_layers:]
|
||||
cell = cell.view(-1, self.num_layers, self.hidden_size)
|
||||
cell = cell.permute(1, 0, 2).contiguous()
|
||||
hidden_and_cell = (hidden, cell)
|
||||
|
||||
feats, hidden_and_cell = self.lstm(expanded_z, hidden_and_cell)
|
||||
hidden, cell = hidden_and_cell
|
||||
|
||||
feats = feats.reshape(bs*steps, self.hidden_size)
|
||||
|
||||
# Construct the chain
|
||||
end_points = self.point_predictor(feats).view(bs, steps, 1, 2)
|
||||
start_points = th.cat([
|
||||
# first point is canvas center
|
||||
th.zeros(bs, 1, 1, 2, device=feats.device),
|
||||
end_points[:, 1:, :, :]], 1)
|
||||
all_points = th.cat([start_points, end_points], 2)
|
||||
|
||||
all_alphas = self.alpha_predictor(feats).view(bs, steps)
|
||||
all_widths = self.width_predictor(feats).view(bs, steps)
|
||||
|
||||
min_width = self.stroke_width[0]
|
||||
max_width = self.stroke_width[1]
|
||||
all_widths = (max_width - min_width) * all_widths + min_width
|
||||
|
||||
output, scenes = rendering.line_render(all_points, all_widths, all_alphas,
|
||||
canvas_size=self.imsize)
|
||||
|
||||
# map to [-1, 1]
|
||||
output = output*2.0 - 1.0
|
||||
|
||||
return output, scenes
|
||||
|
||||
|
||||
class Generator(BaseModel):
|
||||
def __init__(self, width=64, imsize=32, zdim=128,
|
||||
stroke_width=None,
|
||||
color_output=False,
|
||||
num_strokes=4):
|
||||
super(Generator, self).__init__()
|
||||
assert imsize == 32
|
||||
|
||||
self.imsize = imsize
|
||||
self.zdim = zdim
|
||||
|
||||
num_in_chans = self.zdim // (2*2)
|
||||
num_out_chans = 3 if color_output else 1
|
||||
|
||||
self.net = th.nn.Sequential(
|
||||
th.nn.ConvTranspose2d(num_in_chans, width*8, 4, padding=1,
|
||||
stride=2),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
th.nn.Conv2d(width*8, width*8, 3, padding=1),
|
||||
th.nn.BatchNorm2d(width*8),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
# 4x4
|
||||
|
||||
th.nn.ConvTranspose2d(8*width, 4*width, 4, padding=1, stride=2),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
th.nn.Conv2d(4*width, 4*width, 3, padding=1),
|
||||
th.nn.BatchNorm2d(width*4),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
# 8x8
|
||||
|
||||
th.nn.ConvTranspose2d(4*width, 2*width, 4, padding=1, stride=2),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
th.nn.Conv2d(2*width, 2*width, 3, padding=1),
|
||||
th.nn.BatchNorm2d(width*2),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
# 16x16
|
||||
|
||||
th.nn.ConvTranspose2d(2*width, width, 4, padding=1, stride=2),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
th.nn.Conv2d(width, width, 3, padding=1),
|
||||
th.nn.BatchNorm2d(width),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
# 32x32
|
||||
|
||||
th.nn.Conv2d(width, width, 3, padding=1),
|
||||
th.nn.BatchNorm2d(width),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
th.nn.Conv2d(width, width, 3, padding=1),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
th.nn.Conv2d(width, num_out_chans, 1),
|
||||
|
||||
th.nn.Tanh(),
|
||||
)
|
||||
|
||||
def forward(self, z):
|
||||
bs = z.shape[0]
|
||||
num_in_chans = self.zdim // (2*2)
|
||||
raster = self.net(z.view(bs, num_in_chans, 2, 2))
|
||||
return raster
|
||||
|
||||
|
||||
class Discriminator(th.nn.Module):
|
||||
def __init__(self, conditional=False, width=64, color_output=False):
|
||||
super(Discriminator, self).__init__()
|
||||
|
||||
self.conditional = conditional
|
||||
|
||||
sn = th.nn.utils.spectral_norm
|
||||
|
||||
num_chan_in = 3 if color_output else 1
|
||||
|
||||
self.net = th.nn.Sequential(
|
||||
th.nn.Conv2d(num_chan_in, width, 3, padding=1),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
th.nn.Conv2d(width, 2*width, 4, padding=1, stride=2),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
# 16x16
|
||||
|
||||
sn(th.nn.Conv2d(2*width, 2*width, 3, padding=1)),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
sn(th.nn.Conv2d(2*width, 4*width, 4, padding=1, stride=2)),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
# 8x8
|
||||
|
||||
sn(th.nn.Conv2d(4*width, 4*width, 3, padding=1)),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
sn(th.nn.Conv2d(4*width, width*4, 4, padding=1, stride=2)),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
# 4x4
|
||||
|
||||
sn(th.nn.Conv2d(4*width, 4*width, 3, padding=1)),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
sn(th.nn.Conv2d(4*width, width*4, 4, padding=1, stride=2)),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
# 2x2
|
||||
|
||||
modules.Flatten(),
|
||||
th.nn.Linear(width*4*2*2, 1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.net(x)
|
||||
return out
|
11
apps/generative_models/modules.py
Normal file
11
apps/generative_models/modules.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Helper modules to build our networks."""
|
||||
import torch as th
|
||||
|
||||
|
||||
class Flatten(th.nn.Module):
|
||||
def __init__(self):
|
||||
super(Flatten, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
bs = x.shape[0]
|
||||
return x.view(bs, -1)
|
307
apps/generative_models/rendering.py
Normal file
307
apps/generative_models/rendering.py
Normal file
@@ -0,0 +1,307 @@
|
||||
import os
|
||||
import torch as th
|
||||
import torch.multiprocessing as mp
|
||||
import threading as mt
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
import ttools
|
||||
|
||||
import pydiffvg
|
||||
import time
|
||||
|
||||
|
||||
def render(canvas_width, canvas_height, shapes, shape_groups, samples=2,
|
||||
seed=None):
|
||||
if seed is None:
|
||||
seed = random.randint(0, 1000000)
|
||||
_render = pydiffvg.RenderFunction.apply
|
||||
scene_args = pydiffvg.RenderFunction.serialize_scene(
|
||||
canvas_width, canvas_height, shapes, shape_groups)
|
||||
img = _render(canvas_width, canvas_height, samples, samples,
|
||||
seed, # seed
|
||||
None, # background image
|
||||
*scene_args)
|
||||
return img
|
||||
|
||||
|
||||
def opacityStroke2diffvg(strokes, canvas_size=128, debug=False, relative=True,
|
||||
force_cpu=True):
|
||||
|
||||
dev = strokes.device
|
||||
if force_cpu:
|
||||
strokes = strokes.to("cpu")
|
||||
|
||||
|
||||
# pydiffvg.set_use_gpu(False)
|
||||
# if strokes.is_cuda:
|
||||
# pydiffvg.set_use_gpu(True)
|
||||
|
||||
"""Rasterize strokes given in (dx, dy, opacity) sequence format."""
|
||||
bs, nsegs, dims = strokes.shape
|
||||
out = []
|
||||
|
||||
start = time.time()
|
||||
for batch_idx, stroke in enumerate(strokes):
|
||||
|
||||
if relative: # Absolute coordinates
|
||||
all_points = stroke[..., :2].cumsum(0)
|
||||
else:
|
||||
all_points = stroke[..., :2]
|
||||
|
||||
all_opacities = stroke[..., 2]
|
||||
|
||||
# Transform from [-1, 1] to canvas coordinates
|
||||
# Make sure points are in canvas
|
||||
all_points = 0.5*(all_points + 1.0) * canvas_size
|
||||
# all_points = th.clamp(0.5*(all_points + 1.0), 0, 1) * canvas_size
|
||||
|
||||
# Avoid overlapping points
|
||||
eps = 1e-4
|
||||
all_points = all_points + eps*th.randn_like(all_points)
|
||||
|
||||
shapes = []
|
||||
shape_groups = []
|
||||
|
||||
for start_idx in range(0, nsegs-1):
|
||||
points = all_points[start_idx:start_idx+2].contiguous().float()
|
||||
opacity = all_opacities[start_idx]
|
||||
|
||||
num_ctrl_pts = th.zeros(points.shape[0] - 1, dtype=th.int32)
|
||||
width = th.ones(1)
|
||||
|
||||
path = pydiffvg.Path(
|
||||
num_control_points=num_ctrl_pts, points=points,
|
||||
stroke_width=width, is_closed=False)
|
||||
|
||||
shapes.append(path)
|
||||
|
||||
color = th.cat([th.ones(3, device=opacity.device),
|
||||
opacity.unsqueeze(0)], 0)
|
||||
path_group = pydiffvg.ShapeGroup(
|
||||
shape_ids=th.tensor([len(shapes) - 1]),
|
||||
fill_color=None,
|
||||
stroke_color=color)
|
||||
shape_groups.append(path_group)
|
||||
|
||||
# Rasterize only if there are shapes
|
||||
if shapes:
|
||||
inner_start = time.time()
|
||||
out.append(render(canvas_size, canvas_size, shapes, shape_groups,
|
||||
samples=4))
|
||||
if debug:
|
||||
inner_elapsed = time.time() - inner_start
|
||||
print("diffvg call took %.2fms" % inner_elapsed)
|
||||
else:
|
||||
out.append(th.zeros(canvas_size, canvas_size, 4,
|
||||
device=strokes.device))
|
||||
|
||||
if debug:
|
||||
elapsed = (time.time() - start)*1000
|
||||
print("rendering took %.2fms" % elapsed)
|
||||
images = th.stack(out, 0).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
# Return data on the same device as input
|
||||
return images.to(dev)
|
||||
|
||||
|
||||
def stroke2diffvg(strokes, canvas_size=128):
|
||||
"""Rasterize strokes given some sequential data."""
|
||||
bs, nsegs, dims = strokes.shape
|
||||
out = []
|
||||
for stroke_idx, stroke in enumerate(strokes):
|
||||
end_of_stroke = stroke[:, 4] == 1
|
||||
last = end_of_stroke.cpu().numpy().argmax()
|
||||
stroke = stroke[:last+1, :]
|
||||
# stroke = stroke[~end_of_stroke]
|
||||
# TODO: stop at the first end of stroke
|
||||
# import ipdb; ipdb.set_trace()
|
||||
split_idx = stroke[:, 3].nonzero().squeeze(1)
|
||||
|
||||
# Absolute coordinates
|
||||
all_points = stroke[..., :2].cumsum(0)
|
||||
|
||||
# Transform to canvas coordinates
|
||||
all_points[..., 0] += 0.5
|
||||
all_points[..., 0] *= canvas_size
|
||||
all_points[..., 1] += 0.5
|
||||
all_points[..., 1] *= canvas_size
|
||||
|
||||
# Make sure points are in canvas
|
||||
all_points[..., :2] = th.clamp(all_points[..., :2], 0, canvas_size)
|
||||
|
||||
shape_groups = []
|
||||
shapes = []
|
||||
start_idx = 0
|
||||
|
||||
for count, end_idx in enumerate(split_idx):
|
||||
points = all_points[start_idx:end_idx+1].contiguous().float()
|
||||
|
||||
if points.shape[0] <= 2: # we need at least 2 points for a line
|
||||
continue
|
||||
|
||||
num_ctrl_pts = th.zeros(points.shape[0] - 1, dtype=th.int32)
|
||||
width = th.ones(1)
|
||||
path = pydiffvg.Path(
|
||||
num_control_points=num_ctrl_pts, points=points,
|
||||
stroke_width=width, is_closed=False)
|
||||
|
||||
start_idx = end_idx+1
|
||||
shapes.append(path)
|
||||
|
||||
color = th.ones(4, 1)
|
||||
path_group = pydiffvg.ShapeGroup(
|
||||
shape_ids=th.tensor([len(shapes) - 1]),
|
||||
fill_color=None,
|
||||
stroke_color=color)
|
||||
shape_groups.append(path_group)
|
||||
|
||||
# Rasterize
|
||||
if shapes:
|
||||
# draw only if there are shapes
|
||||
out.append(render(canvas_size, canvas_size, shapes, shape_groups, samples=2))
|
||||
else:
|
||||
out.append(th.zeros(canvas_size, canvas_size, 4,
|
||||
device=strokes.device))
|
||||
|
||||
return th.stack(out, 0).permute(0, 3, 1, 2)[:, :3].contiguous()
|
||||
|
||||
|
||||
def line_render(all_points, all_widths, all_alphas, force_cpu=True,
|
||||
canvas_size=32, colors=None):
|
||||
dev = all_points.device
|
||||
if force_cpu:
|
||||
all_points = all_points.to("cpu")
|
||||
all_widths = all_widths.to("cpu")
|
||||
all_alphas = all_alphas.to("cpu")
|
||||
|
||||
if colors is not None:
|
||||
colors = colors.to("cpu")
|
||||
|
||||
all_points = 0.5*(all_points + 1.0) * canvas_size
|
||||
|
||||
eps = 1e-4
|
||||
all_points = all_points + eps*th.randn_like(all_points)
|
||||
|
||||
bs, num_segments, _, _ = all_points.shape
|
||||
n_out = 3 if colors is not None else 1
|
||||
output = th.zeros(bs, n_out, canvas_size, canvas_size,
|
||||
device=all_points.device)
|
||||
|
||||
scenes = []
|
||||
for k in range(bs):
|
||||
shapes = []
|
||||
shape_groups = []
|
||||
for p in range(num_segments):
|
||||
points = all_points[k, p].contiguous().cpu()
|
||||
num_ctrl_pts = th.zeros(1, dtype=th.int32)
|
||||
width = all_widths[k, p].cpu()
|
||||
alpha = all_alphas[k, p].cpu()
|
||||
if colors is not None:
|
||||
color = colors[k, p]
|
||||
else:
|
||||
color = th.ones(3, device=alpha.device)
|
||||
|
||||
color = th.cat([color, alpha.view(1,)])
|
||||
|
||||
path = pydiffvg.Path(
|
||||
num_control_points=num_ctrl_pts, points=points,
|
||||
stroke_width=width, is_closed=False)
|
||||
shapes.append(path)
|
||||
path_group = pydiffvg.ShapeGroup(
|
||||
shape_ids=th.tensor([len(shapes) - 1]),
|
||||
fill_color=None,
|
||||
stroke_color=color)
|
||||
shape_groups.append(path_group)
|
||||
|
||||
# Rasterize
|
||||
scenes.append((canvas_size, canvas_size, shapes, shape_groups))
|
||||
raster = render(canvas_size, canvas_size, shapes, shape_groups,
|
||||
samples=2)
|
||||
raster = raster.permute(2, 0, 1).view(4, canvas_size, canvas_size)
|
||||
|
||||
alpha = raster[3:4]
|
||||
if colors is not None: # color output
|
||||
image = raster[:3]
|
||||
alpha = alpha.repeat(3, 1, 1)
|
||||
else:
|
||||
image = raster[:1]
|
||||
|
||||
# alpha compositing
|
||||
image = image*alpha
|
||||
output[k] = image
|
||||
|
||||
output = output.to(dev)
|
||||
|
||||
return output, scenes
|
||||
|
||||
|
||||
def bezier_render(all_points, all_widths, all_alphas, force_cpu=True,
|
||||
canvas_size=32, colors=None):
|
||||
dev = all_points.device
|
||||
if force_cpu:
|
||||
all_points = all_points.to("cpu")
|
||||
all_widths = all_widths.to("cpu")
|
||||
all_alphas = all_alphas.to("cpu")
|
||||
|
||||
if colors is not None:
|
||||
colors = colors.to("cpu")
|
||||
|
||||
all_points = 0.5*(all_points + 1.0) * canvas_size
|
||||
|
||||
eps = 1e-4
|
||||
all_points = all_points + eps*th.randn_like(all_points)
|
||||
|
||||
bs, num_strokes, num_pts, _ = all_points.shape
|
||||
num_segments = (num_pts - 1) // 3
|
||||
n_out = 3 if colors is not None else 1
|
||||
output = th.zeros(bs, n_out, canvas_size, canvas_size,
|
||||
device=all_points.device)
|
||||
|
||||
scenes = []
|
||||
for k in range(bs):
|
||||
shapes = []
|
||||
shape_groups = []
|
||||
for p in range(num_strokes):
|
||||
points = all_points[k, p].contiguous().cpu()
|
||||
# bezier
|
||||
num_ctrl_pts = th.zeros(num_segments, dtype=th.int32) + 2
|
||||
width = all_widths[k, p].cpu()
|
||||
alpha = all_alphas[k, p].cpu()
|
||||
if colors is not None:
|
||||
color = colors[k, p]
|
||||
else:
|
||||
color = th.ones(3, device=alpha.device)
|
||||
|
||||
color = th.cat([color, alpha.view(1,)])
|
||||
|
||||
path = pydiffvg.Path(
|
||||
num_control_points=num_ctrl_pts, points=points,
|
||||
stroke_width=width, is_closed=False)
|
||||
shapes.append(path)
|
||||
path_group = pydiffvg.ShapeGroup(
|
||||
shape_ids=th.tensor([len(shapes) - 1]),
|
||||
fill_color=None,
|
||||
stroke_color=color)
|
||||
shape_groups.append(path_group)
|
||||
|
||||
# Rasterize
|
||||
scenes.append((canvas_size, canvas_size, shapes, shape_groups))
|
||||
raster = render(canvas_size, canvas_size, shapes, shape_groups,
|
||||
samples=2)
|
||||
raster = raster.permute(2, 0, 1).view(4, canvas_size, canvas_size)
|
||||
|
||||
alpha = raster[3:4]
|
||||
if colors is not None: # color output
|
||||
image = raster[:3]
|
||||
alpha = alpha.repeat(3, 1, 1)
|
||||
else:
|
||||
image = raster[:1]
|
||||
|
||||
# alpha compositing
|
||||
image = image*alpha
|
||||
output[k] = image
|
||||
|
||||
output = output.to(dev)
|
||||
|
||||
return output, scenes
|
461
apps/generative_models/sketch_rnn.py
Executable file
461
apps/generative_models/sketch_rnn.py
Executable file
@@ -0,0 +1,461 @@
|
||||
#!/bin/env python
|
||||
"""Train a Sketch-RNN."""
|
||||
import argparse
|
||||
from enum import Enum
|
||||
import os
|
||||
import wget
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from torch.utils.data import DataLoader
|
||||
import torchvision.datasets as dset
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
import ttools
|
||||
import ttools.interfaces
|
||||
from ttools.modules import networks
|
||||
|
||||
import pydiffvg
|
||||
|
||||
import rendering
|
||||
import losses
|
||||
import data
|
||||
|
||||
LOG = ttools.get_logger(__name__)
|
||||
|
||||
|
||||
BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir)
|
||||
OUTPUT = os.path.join(BASE_DIR, "results", "sketch_rnn_diffvg")
|
||||
OUTPUT_BASELINE = os.path.join(BASE_DIR, "results", "sketch_rnn")
|
||||
|
||||
|
||||
class SketchRNN(th.nn.Module):
|
||||
class Encoder(th.nn.Module):
|
||||
def __init__(self, hidden_size=512, dropout=0.9, zdim=128,
|
||||
num_layers=1):
|
||||
super(SketchRNN.Encoder, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.zdim = zdim
|
||||
|
||||
self.lstm = th.nn.LSTM(5, hidden_size, num_layers=self.num_layers,
|
||||
dropout=dropout, bidirectional=True,
|
||||
batch_first=True)
|
||||
|
||||
# bidirectional model -> *2
|
||||
self.mu_predictor = th.nn.Linear(2*hidden_size, zdim)
|
||||
self.sigma_predictor = th.nn.Linear(2*hidden_size, zdim)
|
||||
|
||||
def forward(self, sequences, hidden_and_cell=None):
|
||||
bs = sequences.shape[0]
|
||||
if hidden_and_cell is None:
|
||||
hidden = th.zeros(self.num_layers*2, bs, self.hidden_size).to(
|
||||
sequences.device)
|
||||
cell = th.zeros(self.num_layers*2, bs, self.hidden_size).to(
|
||||
sequences.device)
|
||||
hidden_and_cell = (hidden, cell)
|
||||
|
||||
out, hidden_and_cell = self.lstm(sequences, hidden_and_cell)
|
||||
hidden = hidden_and_cell[0]
|
||||
|
||||
# Concat the forward/backward states
|
||||
fc_input = th.cat([hidden[0], hidden[1]], 1)
|
||||
|
||||
# VAE params
|
||||
mu = self.mu_predictor(fc_input)
|
||||
log_sigma = self.sigma_predictor(fc_input)
|
||||
|
||||
# Sample a latent vector
|
||||
sigma = th.exp(log_sigma/2.0)
|
||||
z0 = th.randn(self.zdim, device=mu.device)
|
||||
z = mu + sigma*z0
|
||||
|
||||
# KL divergence needs mu/sigma
|
||||
return z, mu, log_sigma
|
||||
|
||||
class Decoder(th.nn.Module):
|
||||
"""
|
||||
The decoder outputs a sequence where each time step models (dx, dy) as
|
||||
a mixture of `num_gaussians` 2D Gaussians and the state triplet is a
|
||||
categorical distribution.
|
||||
|
||||
The model outputs at each time step:
|
||||
- 5 parameters for each Gaussian: mu_x, mu_y, sigma_x, sigma_y,
|
||||
rho_xy
|
||||
- 1 logit for each Gaussian (the mixture weight)
|
||||
- 3 logits for the state triplet probabilities
|
||||
"""
|
||||
def __init__(self, hidden_size=512, dropout=0.9, zdim=128,
|
||||
num_layers=1, num_gaussians=20):
|
||||
super(SketchRNN.Decoder, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.zdim = zdim
|
||||
self.num_gaussians = num_gaussians
|
||||
|
||||
# Maps the latent vector to an initial cell/hidden vector
|
||||
self.hidden_cell_predictor = th.nn.Linear(zdim, 2*hidden_size)
|
||||
|
||||
self.lstm = th.nn.LSTM(
|
||||
5 + zdim, hidden_size,
|
||||
num_layers=self.num_layers, dropout=dropout,
|
||||
batch_first=True)
|
||||
|
||||
self.parameters_predictor = th.nn.Linear(
|
||||
hidden_size, num_gaussians + 5*num_gaussians + 3)
|
||||
|
||||
def forward(self, inputs, z, hidden_and_cell=None):
|
||||
# Every step in the sequence takes the latent vector as input so we
|
||||
# replicate it here
|
||||
expanded_z = z.unsqueeze(1).repeat(1, inputs.shape[1], 1)
|
||||
inputs = th.cat([inputs, expanded_z], 2)
|
||||
|
||||
bs, steps = inputs.shape[:2]
|
||||
if hidden_and_cell is None:
|
||||
# Initialize from latent vector
|
||||
hidden_and_cell = self.hidden_cell_predictor(th.tanh(z))
|
||||
hidden = hidden_and_cell[:, :self.hidden_size]
|
||||
hidden = hidden.unsqueeze(0).contiguous()
|
||||
cell = hidden_and_cell[:, self.hidden_size:]
|
||||
cell = cell.unsqueeze(0).contiguous()
|
||||
hidden_and_cell = (hidden, cell)
|
||||
|
||||
outputs, hidden_and_cell = self.lstm(inputs, hidden_and_cell)
|
||||
hidden, cell = hidden_and_cell
|
||||
|
||||
# if self.training:
|
||||
# At train time we want parameters for each time step
|
||||
outputs = outputs.reshape(bs*steps, self.hidden_size)
|
||||
params = self.parameters_predictor(outputs).view(bs, steps, -1)
|
||||
|
||||
pen_logits = params[..., -3:]
|
||||
gaussian_params = params[..., :-3]
|
||||
mixture_logits = gaussian_params[..., :self.num_gaussians]
|
||||
gaussian_params = gaussian_params[..., self.num_gaussians:].view(
|
||||
bs, steps, self.num_gaussians, -1)
|
||||
|
||||
return pen_logits, mixture_logits, gaussian_params, hidden_and_cell
|
||||
|
||||
def __init__(self, zdim=128, num_gaussians=20, encoder_dim=256,
|
||||
decoder_dim=512):
|
||||
super(SketchRNN, self).__init__()
|
||||
self.encoder = SketchRNN.Encoder(zdim=zdim, hidden_size=encoder_dim)
|
||||
self.decoder = SketchRNN.Decoder(zdim=zdim, hidden_size=decoder_dim,
|
||||
num_gaussians=num_gaussians)
|
||||
|
||||
def forward(self, sequences):
|
||||
# Encode the sequences as latent vectors
|
||||
# We skip the first time step since it is the same for all sequences:
|
||||
# (0, 0, 1, 0, 0)
|
||||
z, mu, log_sigma = self.encoder(sequences[:, 1:])
|
||||
|
||||
# Decode the latent vector into a model sequence
|
||||
# Do not process the last time step (it is an end-of-sequence token)
|
||||
pen_logits, mixture_logits, gaussian_params, hidden_and_cell = \
|
||||
self.decoder(sequences[:, :-1], z)
|
||||
|
||||
return {
|
||||
"pen_logits": pen_logits,
|
||||
"mixture_logits": mixture_logits,
|
||||
"gaussian_params": gaussian_params,
|
||||
"z": z,
|
||||
"mu": mu,
|
||||
"log_sigma": log_sigma,
|
||||
"hidden_and_cell": hidden_and_cell,
|
||||
}
|
||||
|
||||
def sample(self, sequences, temperature=1.0):
|
||||
# Compute a latent vector conditionned based on a real sequence
|
||||
z, _, _ = self.encoder(sequences[:, 1:])
|
||||
|
||||
start_of_seq = sequences[:, :1]
|
||||
|
||||
max_steps = sequences.shape[1] - 1 # last step is an end-of-seq token
|
||||
|
||||
output_sequences = th.zeros_like(sequences)
|
||||
output_sequences[:, 0] = start_of_seq.squeeze(1)
|
||||
|
||||
current_input = start_of_seq
|
||||
hidden_and_cell = None
|
||||
for step in range(max_steps):
|
||||
pen_logits, mixture_logits, gaussian_params, hidden_and_cell = \
|
||||
self.decoder(current_input, z, hidden_and_cell=hidden_and_cell)
|
||||
|
||||
# Pen and displacement state for the next step
|
||||
next_state = th.zeros_like(current_input)
|
||||
|
||||
# Adjust temperature to control randomness
|
||||
mixture_logits = mixture_logits*temperature
|
||||
pen_logits = pen_logits*temperature
|
||||
|
||||
# Select one of 3 pen states
|
||||
pen_distrib = \
|
||||
th.distributions.categorical.Categorical(logits=pen_logits)
|
||||
pen_state = pen_distrib.sample()
|
||||
|
||||
# One-hot encoding of the state
|
||||
next_state[:, :, 2:].scatter_(2, pen_state.unsqueeze(-1),
|
||||
th.ones_like(next_state[:, :, 2:]))
|
||||
|
||||
# Select one of the Gaussians from the mixture
|
||||
mixture_distrib = \
|
||||
th.distributions.categorical.Categorical(logits=mixture_logits)
|
||||
mixture_idx = mixture_distrib.sample()
|
||||
|
||||
# select the Gaussian parameter
|
||||
mixture_idx = mixture_idx.unsqueeze(-1).unsqueeze(-1)
|
||||
mixture_idx = mixture_idx.repeat(1, 1, 1, 5)
|
||||
params = th.gather(gaussian_params, 2, mixture_idx).squeeze(2)
|
||||
|
||||
# Sample a Gaussian from the corresponding Gaussian
|
||||
mu = params[..., :2]
|
||||
sigma_x = params[..., 2].exp()
|
||||
sigma_y = params[..., 3].exp()
|
||||
rho_xy = th.tanh(params[..., 4])
|
||||
cov = th.zeros(params.shape[0], params.shape[1], 2, 2,
|
||||
device=params.device)
|
||||
cov[..., 0, 0] = sigma_x.pow(2)*temperature
|
||||
cov[..., 1, 1] = sigma_x.pow(2)*temperature
|
||||
cov[..., 1, 0] = sigma_x*sigma_y*rho_xy*temperature
|
||||
point_distrib = \
|
||||
th.distributions.multivariate_normal.MultivariateNormal(
|
||||
mu, scale_tril=cov)
|
||||
point = point_distrib.sample()
|
||||
next_state[:, :, :2] = point
|
||||
|
||||
# Commit step to output
|
||||
output_sequences[:, step + 1] = next_state.squeeze(1)
|
||||
|
||||
# Prepare next recurrent step
|
||||
current_input = next_state
|
||||
|
||||
return output_sequences
|
||||
|
||||
|
||||
class SketchRNNCallback(ttools.callbacks.ImageDisplayCallback):
|
||||
"""Simple callback that visualize images."""
|
||||
def visualized_image(self, batch, step_data, is_val=False):
|
||||
if not is_val:
|
||||
# No need to render training data
|
||||
return None
|
||||
|
||||
with th.no_grad():
|
||||
# only display the first n drawings
|
||||
n = 8
|
||||
batch = batch[:n]
|
||||
|
||||
out_im = rendering.stroke2diffvg(step_data["sample"][:n])
|
||||
im = rendering.stroke2diffvg(batch)
|
||||
im = th.cat([im, out_im], 2)
|
||||
|
||||
return im
|
||||
|
||||
def caption(self, batch, step_data, is_val=False):
|
||||
if is_val:
|
||||
return "top: truth, bottom: sample"
|
||||
else:
|
||||
return "top: truth, bottom: sample"
|
||||
|
||||
|
||||
class Interface(ttools.ModelInterface):
|
||||
def __init__(self, model, lr=1e-3, lr_decay=0.9999,
|
||||
kl_weight=0.5, kl_min_weight=0.01, kl_decay=0.99995,
|
||||
device="cpu", grad_clip=1.0, sampling_temperature=0.4):
|
||||
super(Interface, self).__init__()
|
||||
self.grad_clip = grad_clip
|
||||
self.sampling_temperature = sampling_temperature
|
||||
|
||||
self.model = model
|
||||
self.device = device
|
||||
self.model.to(self.device)
|
||||
self.enc_opt = th.optim.Adam(self.model.encoder.parameters(), lr=lr)
|
||||
self.dec_opt = th.optim.Adam(self.model.decoder.parameters(), lr=lr)
|
||||
|
||||
self.kl_weight = kl_weight
|
||||
self.kl_min_weight = kl_min_weight
|
||||
self.kl_decay = kl_decay
|
||||
self.kl_loss = losses.KLDivergence()
|
||||
|
||||
self.schedulers = [
|
||||
th.optim.lr_scheduler.ExponentialLR(self.enc_opt, lr_decay),
|
||||
th.optim.lr_scheduler.ExponentialLR(self.dec_opt, lr_decay),
|
||||
]
|
||||
|
||||
self.reconstruction_loss = losses.GaussianMixtureReconstructionLoss()
|
||||
|
||||
def optimizers(self):
|
||||
return [self.enc_opt, self.dec_opt]
|
||||
|
||||
def training_step(self, batch):
|
||||
batch = batch.to(self.device)
|
||||
out = self.model(batch)
|
||||
|
||||
kl_loss = self.kl_loss(
|
||||
out["mu"], out["log_sigma"])
|
||||
|
||||
# The target to predict is the next sequence step
|
||||
targets = batch[:, 1:].to(self.device)
|
||||
|
||||
# Scale the KL divergence weight
|
||||
try:
|
||||
state = self.enc_opt.state_dict()["param_groups"][0]["params"][0]
|
||||
optim_step = self.enc_opt.state_dict()["state"][state]["step"]
|
||||
except KeyError:
|
||||
optim_step = 0 # no step taken yet
|
||||
kl_scaling = 1.0 - (1.0 -
|
||||
self.kl_min_weight)*(self.kl_decay**optim_step)
|
||||
kl_weight = self.kl_weight * kl_scaling
|
||||
|
||||
reconstruction_loss = self.reconstruction_loss(
|
||||
out["pen_logits"], out["mixture_logits"],
|
||||
out["gaussian_params"], targets)
|
||||
loss = kl_loss*self.kl_weight + reconstruction_loss
|
||||
|
||||
self.enc_opt.zero_grad()
|
||||
self.dec_opt.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
# clip gradients
|
||||
enc_nrm = th.nn.utils.clip_grad_norm_(
|
||||
self.model.encoder.parameters(), self.grad_clip)
|
||||
dec_nrm = th.nn.utils.clip_grad_norm_(
|
||||
self.model.decoder.parameters(), self.grad_clip)
|
||||
|
||||
if enc_nrm > self.grad_clip:
|
||||
LOG.debug("Clipped encoder gradient (%.5f) to %.2f",
|
||||
enc_nrm, self.grad_clip)
|
||||
|
||||
if dec_nrm > self.grad_clip:
|
||||
LOG.debug("Clipped decoder gradient (%.5f) to %.2f",
|
||||
dec_nrm, self.grad_clip)
|
||||
|
||||
self.enc_opt.step()
|
||||
self.dec_opt.step()
|
||||
|
||||
return {
|
||||
"loss": loss.item(),
|
||||
"kl_loss": kl_loss.item(),
|
||||
"kl_weight": kl_weight,
|
||||
"recons_loss": reconstruction_loss.item(),
|
||||
"lr": self.enc_opt.param_groups[0]["lr"],
|
||||
}
|
||||
|
||||
def init_validation(self):
|
||||
return dict(sample=None)
|
||||
|
||||
def validation_step(self, batch, running_data):
|
||||
# Switch to eval mode for dropout, batchnorm, etc
|
||||
self.model.eval()
|
||||
with th.no_grad():
|
||||
sample = self.model.sample(
|
||||
batch.to(self.device), temperature=self.sampling_temperature)
|
||||
running_data["sample"] = sample
|
||||
self.model.train()
|
||||
return running_data
|
||||
|
||||
|
||||
def train(args):
|
||||
th.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
dataset = data.QuickDrawDataset(args.dataset)
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=args.bs, num_workers=4, shuffle=True,
|
||||
pin_memory=False)
|
||||
|
||||
val_dataset = [s for idx, s in enumerate(dataset) if idx < 8]
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset, batch_size=8, num_workers=4, shuffle=False,
|
||||
pin_memory=False)
|
||||
|
||||
model_params = {
|
||||
"zdim": args.zdim,
|
||||
"num_gaussians": args.num_gaussians,
|
||||
"encoder_dim": args.encoder_dim,
|
||||
"decoder_dim": args.decoder_dim,
|
||||
}
|
||||
model = SketchRNN(**model_params)
|
||||
model.train()
|
||||
|
||||
device = "cpu"
|
||||
if th.cuda.is_available():
|
||||
device = "cuda"
|
||||
LOG.info("Using CUDA")
|
||||
|
||||
interface = Interface(model, lr=args.lr, lr_decay=args.lr_decay,
|
||||
kl_decay=args.kl_decay, kl_weight=args.kl_weight,
|
||||
sampling_temperature=args.sampling_temperature,
|
||||
device=device)
|
||||
|
||||
chkpt = OUTPUT_BASELINE
|
||||
env_name = "sketch_rnn"
|
||||
|
||||
# Resume from checkpoint, if any
|
||||
checkpointer = ttools.Checkpointer(
|
||||
chkpt, model, meta=model_params,
|
||||
optimizers=interface.optimizers(),
|
||||
schedulers=interface.schedulers)
|
||||
extras, meta = checkpointer.load_latest()
|
||||
epoch = extras["epoch"] if extras and "epoch" in extras.keys() else 0
|
||||
|
||||
if meta is not None and meta != model_params:
|
||||
LOG.info("Checkpoint's metaparams differ "
|
||||
"from CLI, aborting: %s and %s", meta, model_params)
|
||||
|
||||
trainer = ttools.Trainer(interface)
|
||||
|
||||
# Add callbacks
|
||||
losses = ["loss", "kl_loss", "recons_loss"]
|
||||
training_debug = ["lr", "kl_weight"]
|
||||
trainer.add_callback(ttools.callbacks.ProgressBarCallback(
|
||||
keys=losses, val_keys=None))
|
||||
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
|
||||
keys=losses, val_keys=None, env=env_name, port=args.port))
|
||||
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
|
||||
keys=training_debug, smoothing=0, val_keys=None, env=env_name,
|
||||
port=args.port))
|
||||
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
|
||||
checkpointer, max_files=2, interval=600, max_epochs=10))
|
||||
trainer.add_callback(
|
||||
ttools.callbacks.LRSchedulerCallback(interface.schedulers))
|
||||
|
||||
trainer.add_callback(SketchRNNCallback(
|
||||
env=env_name, win="samples", port=args.port, frequency=args.freq))
|
||||
|
||||
# Start training
|
||||
trainer.train(dataloader, starting_epoch=epoch,
|
||||
val_dataloader=val_dataloader,
|
||||
num_epochs=args.num_epochs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dataset", default="cat.npz")
|
||||
|
||||
# Training params
|
||||
parser.add_argument("--bs", type=int, default=100)
|
||||
parser.add_argument("--num_epochs", type=int, default=10000)
|
||||
parser.add_argument("--lr", type=float, default=1e-4)
|
||||
parser.add_argument("--lr_decay", type=float, default=0.9999)
|
||||
parser.add_argument("--kl_weight", type=float, default=0.5)
|
||||
parser.add_argument("--kl_decay", type=float, default=0.99995)
|
||||
|
||||
# Model configuration
|
||||
parser.add_argument("--zdim", type=int, default=128)
|
||||
parser.add_argument("--num_gaussians", type=int, default=20)
|
||||
parser.add_argument("--encoder_dim", type=int, default=256)
|
||||
parser.add_argument("--decoder_dim", type=int, default=512)
|
||||
|
||||
parser.add_argument("--sampling_temperature", type=float, default=0.4,
|
||||
help="controls sampling randomness. "
|
||||
"0.0: deterministic, 1.0: unchanged")
|
||||
|
||||
# Viz params
|
||||
parser.add_argument("--freq", type=int, default=100)
|
||||
parser.add_argument("--port", type=int, default=5000)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
pydiffvg.set_use_gpu(th.cuda.is_available())
|
||||
|
||||
train(args)
|
524
apps/generative_models/sketch_vae.py
Executable file
524
apps/generative_models/sketch_vae.py
Executable file
@@ -0,0 +1,524 @@
|
||||
#!/bin/env python
|
||||
"""Train a Sketch-VAE."""
|
||||
import argparse
|
||||
from enum import Enum
|
||||
import os
|
||||
import wget
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from torch.utils.data import DataLoader
|
||||
import torchvision.datasets as dset
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
import ttools
|
||||
import ttools.interfaces
|
||||
from ttools.modules import networks
|
||||
|
||||
import rendering
|
||||
import losses
|
||||
import modules
|
||||
import data
|
||||
|
||||
import pydiffvg
|
||||
|
||||
LOG = ttools.get_logger(__name__)
|
||||
|
||||
|
||||
BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir)
|
||||
OUTPUT = os.path.join(BASE_DIR, "results")
|
||||
|
||||
|
||||
class SketchVAE(th.nn.Module):
|
||||
class ImageEncoder(th.nn.Module):
|
||||
def __init__(self, image_size=64, width=64, zdim=128):
|
||||
super(SketchVAE.ImageEncoder, self).__init__()
|
||||
self.zdim = zdim
|
||||
|
||||
self.net = th.nn.Sequential(
|
||||
th.nn.Conv2d(4, width, 5, padding=2),
|
||||
th.nn.InstanceNorm2d(width),
|
||||
th.nn.ReLU(inplace=True),
|
||||
# 64x64
|
||||
|
||||
th.nn.Conv2d(width, width, 5, padding=2),
|
||||
th.nn.InstanceNorm2d(width),
|
||||
th.nn.ReLU( inplace=True),
|
||||
# 64x64
|
||||
|
||||
th.nn.Conv2d(width, 2*width, 5, stride=1, padding=2),
|
||||
th.nn.InstanceNorm2d(2*width),
|
||||
th.nn.ReLU( inplace=True),
|
||||
# 32x32
|
||||
|
||||
th.nn.Conv2d(2*width, 2*width, 5, stride=2, padding=2),
|
||||
th.nn.InstanceNorm2d(2*width),
|
||||
th.nn.ReLU( inplace=True),
|
||||
# 16x16
|
||||
|
||||
th.nn.Conv2d(2*width, 2*width, 5, stride=2, padding=2),
|
||||
th.nn.InstanceNorm2d(2*width),
|
||||
th.nn.ReLU( inplace=True),
|
||||
# 16x16
|
||||
|
||||
th.nn.Conv2d(2*width, 2*width, 5, stride=2, padding=2),
|
||||
th.nn.InstanceNorm2d(2*width),
|
||||
th.nn.ReLU( inplace=True),
|
||||
# 8x8
|
||||
|
||||
th.nn.Conv2d(2*width, 2*width, 5, stride=2, padding=2),
|
||||
th.nn.InstanceNorm2d(2*width),
|
||||
th.nn.ReLU( inplace=True),
|
||||
# 4x4
|
||||
|
||||
modules.Flatten(),
|
||||
th.nn.Linear(4*4*2*width, 2*zdim)
|
||||
)
|
||||
|
||||
def forward(self, images):
|
||||
features = self.net(images)
|
||||
|
||||
# VAE params
|
||||
mu = features[:, :self.zdim]
|
||||
log_sigma = features[:, self.zdim:]
|
||||
|
||||
# Sample a latent vector
|
||||
sigma = th.exp(log_sigma/2.0)
|
||||
z0 = th.randn(self.zdim, device=mu.device)
|
||||
z = mu + sigma*z0
|
||||
|
||||
# KL divergence needs mu/sigma
|
||||
return z, mu, log_sigma
|
||||
|
||||
class ImageDecoder(th.nn.Module):
|
||||
""""""
|
||||
def __init__(self, zdim=128, image_size=64, width=64):
|
||||
super(SketchVAE.ImageDecoder, self).__init__()
|
||||
self.zdim = zdim
|
||||
self.width = width
|
||||
|
||||
self.embedding = th.nn.Linear(zdim, 4*4*2*width)
|
||||
|
||||
self.net = th.nn.Sequential(
|
||||
th.nn.ConvTranspose2d(2*width, 2*width, 4, padding=1, stride=2),
|
||||
th.nn.InstanceNorm2d(2*width),
|
||||
th.nn.ReLU( inplace=True),
|
||||
# 8x8
|
||||
|
||||
th.nn.ConvTranspose2d(2*width, 2*width, 4, padding=1, stride=2),
|
||||
th.nn.InstanceNorm2d(2*width),
|
||||
th.nn.ReLU( inplace=True),
|
||||
# 16x16
|
||||
|
||||
th.nn.ConvTranspose2d(2*width, 2*width, 4, padding=1, stride=2),
|
||||
th.nn.InstanceNorm2d(2*width),
|
||||
th.nn.ReLU( inplace=True),
|
||||
# 16x16
|
||||
|
||||
th.nn.Conv2d(2*width, 2*width, 5, padding=2, stride=1),
|
||||
th.nn.InstanceNorm2d(2*width),
|
||||
th.nn.ReLU( inplace=True),
|
||||
# 16x16
|
||||
|
||||
th.nn.ConvTranspose2d(2*width, 2*width, 4, padding=1, stride=2),
|
||||
th.nn.InstanceNorm2d(2*width),
|
||||
th.nn.ReLU( inplace=True),
|
||||
# 32x32
|
||||
|
||||
th.nn.Conv2d(2*width, width, 5, padding=2, stride=1),
|
||||
th.nn.InstanceNorm2d(width),
|
||||
th.nn.ReLU( inplace=True),
|
||||
# 32x32
|
||||
|
||||
th.nn.ConvTranspose2d(width, width, 5, padding=2, stride=1),
|
||||
th.nn.InstanceNorm2d(width),
|
||||
th.nn.ReLU( inplace=True),
|
||||
# 64x64
|
||||
|
||||
th.nn.Conv2d(width, width, 5, padding=2, stride=1),
|
||||
th.nn.InstanceNorm2d(width),
|
||||
th.nn.ReLU( inplace=True),
|
||||
# 64x64
|
||||
|
||||
th.nn.Conv2d(width, 4, 5, padding=2, stride=1),
|
||||
)
|
||||
|
||||
def forward(self, z):
|
||||
bs = z.shape[0]
|
||||
im = self.embedding(z).view(bs, 2*self.width, 4, 4)
|
||||
out = self.net(im)
|
||||
return out
|
||||
|
||||
class SketchDecoder(th.nn.Module):
|
||||
"""
|
||||
The decoder outputs a sequence where each time step models (dx, dy,
|
||||
opacity).
|
||||
"""
|
||||
def __init__(self, sequence_length, hidden_size=512, dropout=0.9,
|
||||
zdim=128, num_layers=3):
|
||||
super(SketchVAE.SketchDecoder, self).__init__()
|
||||
self.sequence_length = sequence_length
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.zdim = zdim
|
||||
|
||||
# Maps the latent vector to an initial cell/hidden vector
|
||||
self.hidden_cell_predictor = th.nn.Linear(zdim, 2*hidden_size*num_layers)
|
||||
|
||||
self.lstm = th.nn.LSTM(
|
||||
zdim, hidden_size,
|
||||
num_layers=self.num_layers, dropout=dropout,
|
||||
batch_first=True)
|
||||
|
||||
self.dxdy_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(hidden_size, 2),
|
||||
th.nn.Tanh(),
|
||||
)
|
||||
self.opacity_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(hidden_size, 1),
|
||||
th.nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, z, hidden_and_cell=None):
|
||||
# Every step in the sequence takes the latent vector as input so we
|
||||
# replicate it here
|
||||
bs = z.shape[0]
|
||||
steps = self.sequence_length - 1 # no need to predict the start of sequence
|
||||
expanded_z = z.unsqueeze(1).repeat(1, steps, 1)
|
||||
|
||||
if hidden_and_cell is None:
|
||||
# Initialize from latent vector
|
||||
hidden_and_cell = self.hidden_cell_predictor(
|
||||
th.tanh(z))
|
||||
hidden = hidden_and_cell[:, :self.hidden_size*self.num_layers]
|
||||
hidden = hidden.view(-1, self.num_layers, self.hidden_size)
|
||||
hidden = hidden.permute(1, 0, 2).contiguous()
|
||||
# hidden = hidden.unsqueeze(1).contiguous()
|
||||
cell = hidden_and_cell[:, self.hidden_size*self.num_layers:]
|
||||
cell = cell.view(-1, self.num_layers, self.hidden_size)
|
||||
cell = cell.permute(1, 0, 2).contiguous()
|
||||
# cell = cell.unsqueeze(1).contiguous()
|
||||
hidden_and_cell = (hidden, cell)
|
||||
|
||||
outputs, hidden_and_cell = self.lstm(expanded_z, hidden_and_cell)
|
||||
hidden, cell = hidden_and_cell
|
||||
|
||||
dxdy = self.dxdy_predictor(
|
||||
outputs.reshape(bs*steps, self.hidden_size)).view(bs, steps, -1)
|
||||
|
||||
opacity = self.opacity_predictor(
|
||||
outputs.reshape(bs*steps, self.hidden_size)).view(bs, steps, -1)
|
||||
|
||||
strokes = th.cat([dxdy, opacity], -1)
|
||||
|
||||
return strokes
|
||||
|
||||
def __init__(self, sequence_length, zdim=128, image_size=64):
|
||||
super(SketchVAE, self).__init__()
|
||||
self.im_encoder = SketchVAE.ImageEncoder(
|
||||
zdim=zdim, image_size=image_size)
|
||||
self.im_decoder = SketchVAE.ImageDecoder(
|
||||
zdim=zdim, image_size=image_size)
|
||||
self.sketch_decoder = SketchVAE.SketchDecoder(
|
||||
sequence_length, zdim=zdim)
|
||||
|
||||
def forward(self, images):
|
||||
# Encode the images as latent vectors
|
||||
z, mu, log_sigma = self.im_encoder(images)
|
||||
decoded_im = self.im_decoder(z)
|
||||
decoded_sketch = self.sketch_decoder(z)
|
||||
|
||||
return {
|
||||
"decoded_im": decoded_im,
|
||||
"decoded_sketch": decoded_sketch,
|
||||
"z": z,
|
||||
"mu": mu,
|
||||
"log_sigma": log_sigma,
|
||||
}
|
||||
|
||||
|
||||
class SketchVAECallback(ttools.callbacks.ImageDisplayCallback):
|
||||
"""Simple callback that visualize images."""
|
||||
def visualized_image(self, batch, step_data, is_val=False):
|
||||
if is_val:
|
||||
return None
|
||||
|
||||
# only display the first n drawings
|
||||
n = 8
|
||||
gt = step_data["gt_image"][:n].detach()
|
||||
vae_im = step_data["vae_image"][:n].detach()
|
||||
sketch_im = step_data["sketch_image"][:n].detach()
|
||||
|
||||
rendering = th.cat([gt, vae_im, sketch_im], 2)
|
||||
rendering = th.clamp(rendering, 0, 1)
|
||||
alpha = rendering[:, 3:4]
|
||||
rendering = rendering[:, :3] * alpha
|
||||
|
||||
return rendering
|
||||
|
||||
def caption(self, batch, step_data, is_val=False):
|
||||
if is_val:
|
||||
return ""
|
||||
else:
|
||||
return "top: truth, middle: vae sample, output: rnn-output"
|
||||
|
||||
|
||||
|
||||
|
||||
class Interface(ttools.ModelInterface):
|
||||
def __init__(self, model, lr=1e-4, lr_decay=0.9999,
|
||||
kl_weight=0.5, kl_min_weight=0.01, kl_decay=0.99995,
|
||||
raster_resolution=64, absolute_coords=False,
|
||||
device="cpu", grad_clip=1.0):
|
||||
super(Interface, self).__init__()
|
||||
|
||||
self.grad_clip = grad_clip
|
||||
self.raster_resolution = raster_resolution
|
||||
self.absolute_coords = absolute_coords
|
||||
|
||||
self.model = model
|
||||
self.device = device
|
||||
self.model.to(self.device)
|
||||
self.im_enc_opt = th.optim.Adam(
|
||||
self.model.im_encoder.parameters(), lr=lr)
|
||||
self.im_dec_opt = th.optim.Adam(
|
||||
self.model.im_decoder.parameters(), lr=lr)
|
||||
self.sketch_dec_opt = th.optim.Adam(
|
||||
self.model.sketch_decoder.parameters(), lr=lr)
|
||||
|
||||
self.kl_weight = kl_weight
|
||||
self.kl_min_weight = kl_min_weight
|
||||
self.kl_decay = kl_decay
|
||||
self.kl_loss = losses.KLDivergence()
|
||||
|
||||
self.schedulers = [
|
||||
th.optim.lr_scheduler.ExponentialLR(self.im_enc_opt, lr_decay),
|
||||
th.optim.lr_scheduler.ExponentialLR(self.im_dec_opt, lr_decay),
|
||||
th.optim.lr_scheduler.ExponentialLR(self.sketch_dec_opt, lr_decay),
|
||||
]
|
||||
|
||||
# include loss on alpha
|
||||
self.im_loss = losses.MultiscaleMSELoss(channels=4).to(self.device)
|
||||
|
||||
def optimizers(self):
|
||||
return [self.im_enc_opt, self.im_dec_opt, self.sketch_dec_opt]
|
||||
|
||||
def kl_scaling(self):
|
||||
# Scale the KL divergence weight
|
||||
try:
|
||||
state = self.im_enc_opt.state_dict()["param_groups"][0]["params"][0]
|
||||
optim_step = self.im_enc_opt.state_dict()["state"][state]["step"]
|
||||
except KeyError:
|
||||
optim_step = 0 # no step taken yet
|
||||
kl_scaling = 1.0 - (1.0 -
|
||||
self.kl_min_weight)*(self.kl_decay**optim_step)
|
||||
return kl_scaling
|
||||
|
||||
def training_step(self, batch):
|
||||
gt_strokes, gt_im = batch
|
||||
gt_strokes = gt_strokes.to(self.device)
|
||||
gt_im = gt_im.to(self.device)
|
||||
|
||||
out = self.model(gt_im)
|
||||
|
||||
kl_loss = self.kl_loss(
|
||||
out["mu"], out["log_sigma"])
|
||||
kl_weight = self.kl_weight * self.kl_scaling()
|
||||
|
||||
# add start of sequence
|
||||
sos = gt_strokes[:, :1]
|
||||
sketch = th.cat([sos, out["decoded_sketch"]], 1)
|
||||
|
||||
vae_im = out["decoded_im"]
|
||||
|
||||
# start = time.time()
|
||||
sketch_im = rendering.opacityStroke2diffvg(
|
||||
sketch, canvas_size=self.raster_resolution, debug=False,
|
||||
force_cpu=True, relative=not self.absolute_coords)
|
||||
# elapsed = (time.time() - start)*1000
|
||||
# print("out rendering took %.2fms" % elapsed)
|
||||
|
||||
vae_im_loss = self.im_loss(vae_im, gt_im)
|
||||
sketch_im_loss = self.im_loss(sketch_im, gt_im)
|
||||
|
||||
# vae_im_loss = th.nn.functional.mse_loss(vae_im, gt_im)
|
||||
# sketch_im_loss = th.nn.functional.mse_loss(sketch_im, gt_im)
|
||||
|
||||
loss = vae_im_loss + kl_loss*kl_weight + sketch_im_loss
|
||||
|
||||
self.im_enc_opt.zero_grad()
|
||||
self.im_dec_opt.zero_grad()
|
||||
self.sketch_dec_opt.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
# clip gradients
|
||||
enc_nrm = th.nn.utils.clip_grad_norm_(
|
||||
self.model.im_encoder.parameters(), self.grad_clip)
|
||||
dec_nrm = th.nn.utils.clip_grad_norm_(
|
||||
self.model.im_decoder.parameters(), self.grad_clip)
|
||||
sketch_dec_nrm = th.nn.utils.clip_grad_norm_(
|
||||
self.model.sketch_decoder.parameters(), self.grad_clip)
|
||||
|
||||
if enc_nrm > self.grad_clip:
|
||||
LOG.debug("Clipped encoder gradient (%.5f) to %.2f",
|
||||
enc_nrm, self.grad_clip)
|
||||
|
||||
if dec_nrm > self.grad_clip:
|
||||
LOG.debug("Clipped decoder gradient (%.5f) to %.2f",
|
||||
dec_nrm, self.grad_clip)
|
||||
|
||||
if sketch_dec_nrm > self.grad_clip:
|
||||
LOG.debug("Clipped sketch decoder gradient (%.5f) to %.2f",
|
||||
sketch_dec_nrm, self.grad_clip)
|
||||
|
||||
self.im_enc_opt.step()
|
||||
self.im_dec_opt.step()
|
||||
self.sketch_dec_opt.step()
|
||||
|
||||
return {
|
||||
"vae_image": vae_im,
|
||||
"sketch_image": sketch_im,
|
||||
"gt_image": gt_im,
|
||||
"loss": loss.item(),
|
||||
"vae_im_loss": vae_im_loss.item(),
|
||||
"sketch_im_loss": sketch_im_loss.item(),
|
||||
"kl_loss": kl_loss.item(),
|
||||
"kl_weight": kl_weight,
|
||||
"lr": self.im_enc_opt.param_groups[0]["lr"],
|
||||
}
|
||||
|
||||
def init_validation(self):
|
||||
return dict(sample=None)
|
||||
|
||||
def validation_step(self, batch, running_data):
|
||||
# Switch to eval mode for dropout, batchnorm, etc
|
||||
# self.model.eval()
|
||||
# with th.no_grad():
|
||||
# # sample = self.model.sample(
|
||||
# # batch.to(self.device), temperature=self.sampling_temperature)
|
||||
# # running_data["sample"] = sample
|
||||
# self.model.train()
|
||||
return running_data
|
||||
|
||||
|
||||
def train(args):
|
||||
th.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
dataset = data.FixedLengthQuickDrawDataset(
|
||||
args.dataset, max_seq_length=args.sequence_length,
|
||||
canvas_size=args.raster_resolution)
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=args.bs, num_workers=args.workers, shuffle=True)
|
||||
|
||||
# val_dataset = [s for idx, s in enumerate(dataset) if idx < 8]
|
||||
# val_dataloader = DataLoader(
|
||||
# val_dataset, batch_size=8, num_workers=4, shuffle=False)
|
||||
|
||||
val_dataloader = None
|
||||
|
||||
model_params = {
|
||||
"zdim": args.zdim,
|
||||
"sequence_length": args.sequence_length,
|
||||
"image_size": args.raster_resolution,
|
||||
# "encoder_dim": args.encoder_dim,
|
||||
# "decoder_dim": args.decoder_dim,
|
||||
}
|
||||
model = SketchVAE(**model_params)
|
||||
model.train()
|
||||
|
||||
LOG.info("Model parameters:\n%s", model_params)
|
||||
|
||||
device = "cpu"
|
||||
if th.cuda.is_available():
|
||||
device = "cuda"
|
||||
LOG.info("Using CUDA")
|
||||
|
||||
interface = Interface(model, raster_resolution=args.raster_resolution,
|
||||
lr=args.lr, lr_decay=args.lr_decay,
|
||||
kl_decay=args.kl_decay, kl_weight=args.kl_weight,
|
||||
absolute_coords=args.absolute_coordinates,
|
||||
device=device)
|
||||
|
||||
env_name = "sketch_vae"
|
||||
if args.custom_name is not None:
|
||||
env_name += "_" + args.custom_name
|
||||
|
||||
if args.absolute_coordinates:
|
||||
env_name += "_abs_coords"
|
||||
|
||||
chkpt = os.path.join(OUTPUT, env_name)
|
||||
|
||||
# Resume from checkpoint, if any
|
||||
checkpointer = ttools.Checkpointer(
|
||||
chkpt, model, meta=model_params,
|
||||
optimizers=interface.optimizers(),
|
||||
schedulers=interface.schedulers)
|
||||
extras, meta = checkpointer.load_latest()
|
||||
epoch = extras["epoch"] if extras and "epoch" in extras.keys() else 0
|
||||
|
||||
if meta is not None and meta != model_params:
|
||||
LOG.info("Checkpoint's metaparams differ "
|
||||
"from CLI, aborting: %s and %s", meta, model_params)
|
||||
|
||||
trainer = ttools.Trainer(interface)
|
||||
|
||||
# Add callbacks
|
||||
losses = ["loss", "kl_loss", "vae_im_loss", "sketch_im_loss"]
|
||||
training_debug = ["lr", "kl_weight"]
|
||||
trainer.add_callback(ttools.callbacks.ProgressBarCallback(
|
||||
keys=losses, val_keys=None))
|
||||
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
|
||||
keys=losses, val_keys=None, env=env_name, port=args.port))
|
||||
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
|
||||
keys=training_debug, smoothing=0, val_keys=None, env=env_name,
|
||||
port=args.port))
|
||||
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
|
||||
checkpointer, max_files=2, interval=600, max_epochs=10))
|
||||
trainer.add_callback(
|
||||
ttools.callbacks.LRSchedulerCallback(interface.schedulers))
|
||||
|
||||
trainer.add_callback(SketchVAECallback(
|
||||
env=env_name, win="samples", port=args.port, frequency=args.freq))
|
||||
|
||||
# Start training
|
||||
trainer.train(dataloader, starting_epoch=epoch,
|
||||
val_dataloader=val_dataloader,
|
||||
num_epochs=args.num_epochs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dataset", default="cat.npz")
|
||||
|
||||
parser.add_argument("--absolute_coordinates", action="store_true",
|
||||
default=False)
|
||||
|
||||
parser.add_argument("--custom_name")
|
||||
|
||||
# Training params
|
||||
parser.add_argument("--bs", type=int, default=1)
|
||||
parser.add_argument("--workers", type=int, default=0)
|
||||
parser.add_argument("--num_epochs", type=int, default=10000)
|
||||
parser.add_argument("--lr", type=float, default=1e-4)
|
||||
parser.add_argument("--lr_decay", type=float, default=0.9999)
|
||||
parser.add_argument("--kl_weight", type=float, default=0.5)
|
||||
parser.add_argument("--kl_decay", type=float, default=0.99995)
|
||||
|
||||
# Model configuration
|
||||
parser.add_argument("--zdim", type=int, default=128)
|
||||
parser.add_argument("--sequence_length", type=int, default=50)
|
||||
parser.add_argument("--raster_resolution", type=int, default=64)
|
||||
# parser.add_argument("--encoder_dim", type=int, default=256)
|
||||
# parser.add_argument("--decoder_dim", type=int, default=512)
|
||||
|
||||
# Viz params
|
||||
parser.add_argument("--freq", type=int, default=10)
|
||||
parser.add_argument("--port", type=int, default=5000)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
pydiffvg.set_use_gpu(False)
|
||||
|
||||
train(args)
|
489
apps/generative_models/train_gan.py
Executable file
489
apps/generative_models/train_gan.py
Executable file
@@ -0,0 +1,489 @@
|
||||
#!/bin/env python
|
||||
"""Train a GAN.
|
||||
|
||||
Usage:
|
||||
|
||||
* Train a MNIST model:
|
||||
|
||||
`python train_gan.py`
|
||||
|
||||
* Train a Quickdraw model:
|
||||
|
||||
`python train_gan.py --task quickdraw`
|
||||
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import ttools
|
||||
import ttools.interfaces
|
||||
|
||||
import losses
|
||||
import data
|
||||
import models
|
||||
|
||||
import pydiffvg
|
||||
|
||||
LOG = ttools.get_logger(__name__)
|
||||
|
||||
|
||||
BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir)
|
||||
OUTPUT = os.path.join(BASE_DIR, "results")
|
||||
|
||||
|
||||
class Callback(ttools.callbacks.ImageDisplayCallback):
|
||||
"""Simple callback that visualize images."""
|
||||
def visualized_image(self, batch, step_data, is_val=False):
|
||||
if is_val:
|
||||
return
|
||||
|
||||
gen = step_data["gen_image"][:16].detach()
|
||||
ref = step_data["gt_image"][:16].detach()
|
||||
|
||||
# tensor to visualize, concatenate images
|
||||
vizdata = th.cat([ref, gen], 2)
|
||||
|
||||
vector = step_data["vector_image"]
|
||||
if vector is not None:
|
||||
vector = vector[:16].detach()
|
||||
vizdata = th.cat([vizdata, vector], 2)
|
||||
|
||||
vizdata = (vizdata + 1.0 ) * 0.5
|
||||
viz = th.clamp(vizdata, 0, 1)
|
||||
return viz
|
||||
|
||||
def caption(self, batch, step_data, is_val=False):
|
||||
if step_data["vector_image"] is not None:
|
||||
s = "top: real, middle: raster, bottom: vector"
|
||||
else:
|
||||
s = "top: real, bottom: fake"
|
||||
return s
|
||||
|
||||
|
||||
class Interface(ttools.ModelInterface):
|
||||
def __init__(self, generator, vect_generator,
|
||||
discriminator, vect_discriminator,
|
||||
lr=1e-4, lr_decay=0.9999,
|
||||
gradient_penalty=10,
|
||||
wgan_gp=False,
|
||||
raster_resolution=32, device="cpu", grad_clip=1.0):
|
||||
super(Interface, self).__init__()
|
||||
|
||||
self.wgan_gp = wgan_gp
|
||||
self.w_gradient_penalty = gradient_penalty
|
||||
|
||||
self.n_critic = 1
|
||||
if self.wgan_gp:
|
||||
self.n_critic = 5
|
||||
|
||||
self.grad_clip = grad_clip
|
||||
self.raster_resolution = raster_resolution
|
||||
|
||||
self.gen = generator
|
||||
self.vect_gen = vect_generator
|
||||
self.discrim = discriminator
|
||||
self.vect_discrim = vect_discriminator
|
||||
|
||||
self.device = device
|
||||
self.gen.to(self.device)
|
||||
self.discrim.to(self.device)
|
||||
|
||||
beta1 = 0.5
|
||||
beta2 = 0.9
|
||||
|
||||
self.gen_opt = th.optim.Adam(
|
||||
self.gen.parameters(), lr=lr, betas=(beta1, beta2))
|
||||
self.discrim_opt = th.optim.Adam(
|
||||
self.discrim.parameters(), lr=lr, betas=(beta1, beta2))
|
||||
|
||||
self.schedulers = [
|
||||
th.optim.lr_scheduler.ExponentialLR(self.gen_opt, lr_decay),
|
||||
th.optim.lr_scheduler.ExponentialLR(self.discrim_opt, lr_decay),
|
||||
]
|
||||
|
||||
self.optimizers = [self.gen_opt, self.discrim_opt]
|
||||
|
||||
if self.vect_gen is not None:
|
||||
assert self.vect_discrim is not None
|
||||
|
||||
self.vect_gen.to(self.device)
|
||||
self.vect_discrim.to(self.device)
|
||||
|
||||
self.vect_gen_opt = th.optim.Adam(
|
||||
self.vect_gen.parameters(), lr=lr, betas=(beta1, beta2))
|
||||
self.vect_discrim_opt = th.optim.Adam(
|
||||
self.vect_discrim.parameters(), lr=lr, betas=(beta1, beta2))
|
||||
|
||||
self.schedulers += [
|
||||
th.optim.lr_scheduler.ExponentialLR(self.vect_gen_opt,
|
||||
lr_decay),
|
||||
th.optim.lr_scheduler.ExponentialLR(self.vect_discrim_opt,
|
||||
lr_decay),
|
||||
]
|
||||
|
||||
self.optimizers += [self.vect_gen_opt, self.vect_discrim_opt]
|
||||
|
||||
# include loss on alpha
|
||||
self.im_loss = losses.MultiscaleMSELoss(channels=4).to(self.device)
|
||||
|
||||
self.iter = 0
|
||||
|
||||
self.cross_entropy = th.nn.BCEWithLogitsLoss()
|
||||
self.mse = th.nn.MSELoss()
|
||||
|
||||
def _gradient_penalty(self, discrim, fake, real):
|
||||
bs = real.size(0)
|
||||
epsilon = th.rand(bs, 1, 1, 1, device=real.device)
|
||||
epsilon = epsilon.expand_as(real)
|
||||
|
||||
interpolation = epsilon * real.data + (1 - epsilon) * fake.data
|
||||
interpolation = th.autograd.Variable(interpolation, requires_grad=True)
|
||||
|
||||
interpolation_logits = discrim(interpolation)
|
||||
grad_outputs = th.ones(interpolation_logits.size(), device=real.device)
|
||||
|
||||
gradients = th.autograd.grad(outputs=interpolation_logits,
|
||||
inputs=interpolation,
|
||||
grad_outputs=grad_outputs,
|
||||
create_graph=True, retain_graph=True)[0]
|
||||
|
||||
gradients = gradients.view(bs, -1)
|
||||
gradients_norm = th.sqrt(th.sum(gradients ** 2, dim=1) + 1e-12)
|
||||
|
||||
# [Tanh-Tung 2019] https://openreview.net/pdf?id=ByxPYjC5KQ
|
||||
return self.w_gradient_penalty * ((gradients_norm - 0) ** 2).mean()
|
||||
|
||||
# return self.w_gradient_penalty * ((gradients_norm - 1) ** 2).mean()
|
||||
|
||||
def _discriminator_step(self, discrim, opt, fake, real):
|
||||
"""Try to classify fake as 0 and real as 1."""
|
||||
|
||||
opt.zero_grad()
|
||||
|
||||
# no backprop to gen
|
||||
fake = fake.detach()
|
||||
|
||||
fake_pred = discrim(fake)
|
||||
real_pred = discrim(real)
|
||||
|
||||
if self.wgan_gp:
|
||||
gradient_penalty = self._gradient_penalty(discrim, fake, real)
|
||||
loss_d = fake_pred.mean() - real_pred.mean() + gradient_penalty
|
||||
gradient_penalty = gradient_penalty.item()
|
||||
else:
|
||||
fake_loss = self.cross_entropy(fake_pred, th.zeros_like(fake_pred))
|
||||
real_loss = self.cross_entropy(real_pred, th.ones_like(real_pred))
|
||||
# fake_loss = self.mse(fake_pred, th.zeros_like(fake_pred))
|
||||
# real_loss = self.mse(real_pred, th.ones_like(real_pred))
|
||||
loss_d = 0.5*(fake_loss + real_loss)
|
||||
gradient_penalty = None
|
||||
|
||||
loss_d.backward()
|
||||
nrm = th.nn.utils.clip_grad_norm_(
|
||||
discrim.parameters(), self.grad_clip)
|
||||
if nrm > self.grad_clip:
|
||||
LOG.debug("Clipped discriminator gradient (%.5f) to %.2f",
|
||||
nrm, self.grad_clip)
|
||||
|
||||
opt.step()
|
||||
|
||||
return loss_d.item(), gradient_penalty
|
||||
|
||||
def _generator_step(self, gen, discrim, opt, fake):
|
||||
"""Try to classify fake as 1."""
|
||||
|
||||
opt.zero_grad()
|
||||
|
||||
fake_pred = discrim(fake)
|
||||
|
||||
if self.wgan_gp:
|
||||
loss_g = -fake_pred.mean()
|
||||
else:
|
||||
loss_g = self.cross_entropy(fake_pred, th.ones_like(fake_pred))
|
||||
# loss_g = self.mse(fake_pred, th.ones_like(fake_pred))
|
||||
|
||||
loss_g.backward()
|
||||
|
||||
# clip gradients
|
||||
nrm = th.nn.utils.clip_grad_norm_(
|
||||
gen.parameters(), self.grad_clip)
|
||||
if nrm > self.grad_clip:
|
||||
LOG.debug("Clipped generator gradient (%.5f) to %.2f",
|
||||
nrm, self.grad_clip)
|
||||
|
||||
opt.step()
|
||||
|
||||
return loss_g.item()
|
||||
|
||||
def training_step(self, batch):
|
||||
im = batch
|
||||
im = im.to(self.device)
|
||||
|
||||
z = self.gen.sample_z(im.shape[0], device=self.device)
|
||||
|
||||
generated = self.gen(z)
|
||||
|
||||
vect_generated = None
|
||||
if self.vect_gen is not None:
|
||||
vect_generated = self.vect_gen(z)
|
||||
|
||||
loss_g = None
|
||||
loss_d = None
|
||||
loss_g_vect = None
|
||||
loss_d_vect = None
|
||||
|
||||
gp = None
|
||||
gp_vect = None
|
||||
|
||||
if self.iter < self.n_critic: # Discriminator update
|
||||
self.iter += 1
|
||||
|
||||
loss_d, gp = self._discriminator_step(
|
||||
self.discrim, self.discrim_opt, generated, im)
|
||||
|
||||
if vect_generated is not None:
|
||||
loss_d_vect, gp_vect = self._discriminator_step(
|
||||
self.vect_discrim, self.vect_discrim_opt, vect_generated, im)
|
||||
|
||||
else: # Generator update
|
||||
self.iter = 0
|
||||
|
||||
loss_g = self._generator_step(
|
||||
self.gen, self.discrim, self.gen_opt, generated)
|
||||
|
||||
if vect_generated is not None:
|
||||
loss_g_vect = self._generator_step(
|
||||
self.vect_gen, self.vect_discrim, self.vect_gen_opt, vect_generated)
|
||||
|
||||
return {
|
||||
"loss_g": loss_g,
|
||||
"loss_d": loss_d,
|
||||
"loss_g_vect": loss_g_vect,
|
||||
"loss_d_vect": loss_d_vect,
|
||||
"gp": gp,
|
||||
"gp_vect": gp_vect,
|
||||
"gt_image": im,
|
||||
"gen_image": generated,
|
||||
"vector_image": vect_generated,
|
||||
"lr": self.gen_opt.param_groups[0]["lr"],
|
||||
}
|
||||
|
||||
def init_validation(self):
|
||||
return dict(sample=None)
|
||||
|
||||
def validation_step(self, batch, running_data):
|
||||
# Switch to eval mode for dropout, batchnorm, etc
|
||||
self.model.eval()
|
||||
return running_data
|
||||
|
||||
|
||||
def train(args):
|
||||
th.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
color_output = False
|
||||
if args.task == "mnist":
|
||||
dataset = data.MNISTDataset(args.raster_resolution, train=True)
|
||||
elif args.task == "quickdraw":
|
||||
dataset = data.QuickDrawImageDataset(
|
||||
args.raster_resolution, train=True)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=args.bs, num_workers=args.workers, shuffle=True)
|
||||
|
||||
val_dataloader = None
|
||||
|
||||
model_params = {
|
||||
"zdim": args.zdim,
|
||||
"num_strokes": args.num_strokes,
|
||||
"imsize": args.raster_resolution,
|
||||
"stroke_width": args.stroke_width,
|
||||
"color_output": color_output,
|
||||
}
|
||||
gen = models.Generator(**model_params)
|
||||
gen.train()
|
||||
|
||||
discrim = models.Discriminator(color_output=color_output)
|
||||
discrim.train()
|
||||
|
||||
if args.raster_only:
|
||||
vect_gen = None
|
||||
vect_discrim = None
|
||||
else:
|
||||
if args.generator == "fc":
|
||||
vect_gen = models.VectorGenerator(**model_params)
|
||||
elif args.generator == "bezier_fc":
|
||||
vect_gen = models.BezierVectorGenerator(**model_params)
|
||||
elif args.generator in ["rnn"]:
|
||||
vect_gen = models.RNNVectorGenerator(**model_params)
|
||||
elif args.generator in ["chain_rnn"]:
|
||||
vect_gen = models.ChainRNNVectorGenerator(**model_params)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
vect_gen.train()
|
||||
|
||||
vect_discrim = models.Discriminator(color_output=color_output)
|
||||
vect_discrim.train()
|
||||
|
||||
LOG.info("Model parameters:\n%s", model_params)
|
||||
|
||||
device = "cpu"
|
||||
if th.cuda.is_available():
|
||||
device = "cuda"
|
||||
LOG.info("Using CUDA")
|
||||
|
||||
interface = Interface(gen, vect_gen, discrim, vect_discrim,
|
||||
raster_resolution=args.raster_resolution, lr=args.lr,
|
||||
wgan_gp=args.wgan_gp,
|
||||
lr_decay=args.lr_decay, device=device)
|
||||
|
||||
env_name = args.task + "_gan"
|
||||
|
||||
if args.raster_only:
|
||||
env_name += "_raster"
|
||||
else:
|
||||
env_name += "_vector"
|
||||
|
||||
env_name += "_" + args.generator
|
||||
|
||||
if args.wgan_gp:
|
||||
env_name += "_wgan"
|
||||
|
||||
chkpt = os.path.join(OUTPUT, env_name)
|
||||
|
||||
meta = {
|
||||
"model_params": model_params,
|
||||
"task": args.task,
|
||||
"generator": args.generator,
|
||||
}
|
||||
checkpointer = ttools.Checkpointer(
|
||||
chkpt, gen, meta=meta,
|
||||
optimizers=interface.optimizers,
|
||||
schedulers=interface.schedulers,
|
||||
prefix="g_")
|
||||
checkpointer_d = ttools.Checkpointer(
|
||||
chkpt, discrim,
|
||||
prefix="d_")
|
||||
|
||||
# Resume from checkpoint, if any
|
||||
extras, _ = checkpointer.load_latest()
|
||||
checkpointer_d.load_latest()
|
||||
|
||||
if not args.raster_only:
|
||||
checkpointer_vect = ttools.Checkpointer(
|
||||
chkpt, vect_gen, meta=meta,
|
||||
optimizers=interface.optimizers,
|
||||
schedulers=interface.schedulers,
|
||||
prefix="vect_g_")
|
||||
checkpointer_d_vect = ttools.Checkpointer(
|
||||
chkpt, vect_discrim,
|
||||
prefix="vect_d_")
|
||||
extras, _ = checkpointer_vect.load_latest()
|
||||
checkpointer_d_vect.load_latest()
|
||||
|
||||
epoch = extras["epoch"] if extras and "epoch" in extras.keys() else 0
|
||||
|
||||
# if meta is not None and meta["model_parameters"] != model_params:
|
||||
# LOG.info("Checkpoint's metaparams differ "
|
||||
# "from CLI, aborting: %s and %s", meta, model_params)
|
||||
|
||||
trainer = ttools.Trainer(interface)
|
||||
|
||||
# Add callbacks
|
||||
losses = ["loss_g", "loss_d", "loss_g_vect", "loss_d_vect", "gp",
|
||||
"gp_vect"]
|
||||
training_debug = ["lr"]
|
||||
|
||||
trainer.add_callback(Callback(
|
||||
env=env_name, win="samples", port=args.port, frequency=args.freq))
|
||||
trainer.add_callback(ttools.callbacks.ProgressBarCallback(
|
||||
keys=losses, val_keys=None))
|
||||
trainer.add_callback(ttools.callbacks.MultiPlotCallback(
|
||||
keys=losses, val_keys=None, env=env_name, port=args.port,
|
||||
server=args.server, base_url=args.base_url,
|
||||
win="losses", frequency=args.freq))
|
||||
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
|
||||
keys=training_debug, smoothing=0, val_keys=None, env=env_name,
|
||||
server=args.server, base_url=args.base_url,
|
||||
port=args.port))
|
||||
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
|
||||
checkpointer, max_files=2, interval=600, max_epochs=10))
|
||||
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
|
||||
checkpointer_d, max_files=2, interval=600, max_epochs=10))
|
||||
|
||||
if not args.raster_only:
|
||||
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
|
||||
checkpointer_vect, max_files=2, interval=600, max_epochs=10))
|
||||
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
|
||||
checkpointer_d_vect, max_files=2, interval=600, max_epochs=10))
|
||||
|
||||
trainer.add_callback(
|
||||
ttools.callbacks.LRSchedulerCallback(interface.schedulers))
|
||||
|
||||
# Start training
|
||||
trainer.train(dataloader, starting_epoch=epoch,
|
||||
val_dataloader=val_dataloader,
|
||||
num_epochs=args.num_epochs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task",
|
||||
default="mnist",
|
||||
choices=["mnist", "quickdraw"])
|
||||
parser.add_argument("--generator",
|
||||
default="bezier_fc",
|
||||
choices=["bezier_fc", "fc", "rnn", "chain_rnn"],
|
||||
help="model to use as generator")
|
||||
|
||||
parser.add_argument("--raster_only", action="store_true", default=False,
|
||||
help="if true only train the raster baseline")
|
||||
|
||||
parser.add_argument("--standard_gan", dest="wgan_gp", action="store_false",
|
||||
default=True,
|
||||
help="if true, use regular GAN instead of WGAN")
|
||||
|
||||
# Training params
|
||||
parser.add_argument("--bs", type=int, default=4, help="batch size")
|
||||
parser.add_argument("--workers", type=int, default=4,
|
||||
help="number of dataloader threads")
|
||||
parser.add_argument("--num_epochs", type=int, default=200,
|
||||
help="number of epochs to train for")
|
||||
parser.add_argument("--lr", type=float, default=1e-4,
|
||||
help="learning rate")
|
||||
parser.add_argument("--lr_decay", type=float, default=0.9999,
|
||||
help="exponential learning rate decay rate")
|
||||
|
||||
# Model configuration
|
||||
parser.add_argument("--zdim", type=int, default=32,
|
||||
help="latent space dimension")
|
||||
parser.add_argument("--stroke_width", type=float, nargs=2,
|
||||
default=(0.5, 1.5),
|
||||
help="min and max stroke width")
|
||||
parser.add_argument("--num_strokes", type=int, default=16,
|
||||
help="number of strokes to generate")
|
||||
parser.add_argument("--raster_resolution", type=int, default=32,
|
||||
help="raster canvas resolution on each side")
|
||||
|
||||
# Viz params
|
||||
parser.add_argument("--freq", type=int, default=10,
|
||||
help="visualization frequency")
|
||||
parser.add_argument("--port", type=int, default=8097,
|
||||
help="visdom port")
|
||||
parser.add_argument("--server", default=None,
|
||||
help="visdom server if not local.")
|
||||
parser.add_argument("--base_url", default="", help="visdom entrypoint URL")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
pydiffvg.set_use_gpu(False)
|
||||
|
||||
ttools.set_logger(False)
|
||||
|
||||
train(args)
|
Reference in New Issue
Block a user