183 lines
6.1 KiB
Python
183 lines
6.1 KiB
Python
"""Evaluate a pretrained GAN model.
|
|
Usage:
|
|
|
|
`python eval_gan.py <path/to/model/folder>`, e.g.
|
|
`../results/quickdraw_gan_vector_bezier_fc_wgan`.
|
|
|
|
"""
|
|
import os
|
|
import argparse
|
|
import torch as th
|
|
import numpy as np
|
|
import ttools
|
|
import imageio
|
|
from subprocess import call
|
|
|
|
import pydiffvg
|
|
|
|
import models
|
|
|
|
|
|
LOG = ttools.get_logger(__name__)
|
|
|
|
|
|
def postprocess(im, invert=False):
|
|
im = th.clamp((im + 1.0) / 2.0, 0, 1)
|
|
if invert:
|
|
im = (1.0 - im)
|
|
im = ttools.tensor2image(im)
|
|
return im
|
|
|
|
|
|
def imsave(im, path):
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
imageio.imwrite(path, im)
|
|
|
|
|
|
def save_scene(scn, path):
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
pydiffvg.save_svg(path, *scn, use_gamma=False)
|
|
|
|
|
|
def run(args):
|
|
th.manual_seed(0)
|
|
np.random.seed(0)
|
|
|
|
meta = ttools.Checkpointer.load_meta(args.model, "vect_g_")
|
|
|
|
if meta is None:
|
|
LOG.warning("Could not load metadata at %s, aborting.", args.model)
|
|
return
|
|
|
|
LOG.info("Loaded model %s with metadata:\n %s", args.model, meta)
|
|
|
|
if args.output_dir is None:
|
|
outdir = os.path.join(args.model, "eval")
|
|
else:
|
|
outdir = args.output_dir
|
|
os.makedirs(outdir, exist_ok=True)
|
|
|
|
model_params = meta["model_params"]
|
|
if args.imsize is not None:
|
|
LOG.info("Overriding output image size to: %dx%d", args.imsize,
|
|
args.imsize)
|
|
old_size = model_params["imsize"]
|
|
scale = args.imsize * 1.0 / old_size
|
|
model_params["imsize"] = args.imsize
|
|
model_params["stroke_width"] = [w*scale for w in
|
|
model_params["stroke_width"]]
|
|
LOG.info("Overriding width to: %s", model_params["stroke_width"])
|
|
|
|
# task = meta["task"]
|
|
generator = meta["generator"]
|
|
if generator == "fc":
|
|
model = models.VectorGenerator(**model_params)
|
|
elif generator == "bezier_fc":
|
|
model = models.BezierVectorGenerator(**model_params)
|
|
elif generator in ["rnn"]:
|
|
model = models.RNNVectorGenerator(**model_params)
|
|
elif generator in ["chain_rnn"]:
|
|
model = models.ChainRNNVectorGenerator(**model_params)
|
|
else:
|
|
raise NotImplementedError()
|
|
model.eval()
|
|
|
|
device = "cpu"
|
|
if th.cuda.is_available():
|
|
device = "cuda"
|
|
|
|
model.to(device)
|
|
|
|
checkpointer = ttools.Checkpointer(
|
|
args.model, model, meta=meta, prefix="vect_g_")
|
|
checkpointer.load_latest()
|
|
|
|
LOG.info("Computing latent space interpolation")
|
|
for i in range(args.nsamples):
|
|
z0 = model.sample_z(1)
|
|
z1 = model.sample_z(1)
|
|
|
|
# interpolation
|
|
alpha = th.linspace(0, 1, args.nsteps).view(args.nsteps, 1).to(device)
|
|
alpha_video = th.linspace(0, 1, args.nframes).view(args.nframes, 1)
|
|
alpha_video = alpha_video.to(device)
|
|
|
|
length = [args.nsteps, args.nframes]
|
|
for idx, a in enumerate([alpha, alpha_video]):
|
|
_z0 = z0.repeat(length[idx], 1).to(device)
|
|
_z1 = z1.repeat(length[idx], 1).to(device)
|
|
batch = _z0*(1-a) + _z1*a
|
|
out = model(batch)
|
|
if idx == 0: # image viz
|
|
n, c, h, w = out.shape
|
|
out = out.permute(1, 2, 0, 3)
|
|
out = out.contiguous().view(1, c, h, w*n)
|
|
out = postprocess(out, invert=args.invert)
|
|
imsave(out, os.path.join(outdir,
|
|
"latent_interp", "%03d.png" % i))
|
|
|
|
scenes = model.get_vector(batch)
|
|
for scn_idx, scn in enumerate(scenes):
|
|
save_scene(scn, os.path.join(outdir, "latent_interp_svg",
|
|
"%03d" % i, "%03d.svg" %
|
|
scn_idx))
|
|
else: # video viz
|
|
anim_root = os.path.join(outdir,
|
|
"latent_interp_video", "%03d" % i)
|
|
LOG.info("Rendering animation %d", i)
|
|
for frame_idx, frame in enumerate(out):
|
|
LOG.info("frame %d", frame_idx)
|
|
frame = frame.unsqueeze(0)
|
|
frame = postprocess(frame, invert=args.invert)
|
|
imsave(frame, os.path.join(anim_root,
|
|
"frame%04d.png" % frame_idx))
|
|
call(["ffmpeg", "-framerate", "30", "-i",
|
|
os.path.join(anim_root, "frame%04d.png"), "-vb", "20M",
|
|
os.path.join(outdir,
|
|
"latent_interp_video", "%03d.mp4" % i)])
|
|
LOG.info(" saved %d", i)
|
|
|
|
LOG.info("Sampling latent space")
|
|
|
|
for i in range(args.nsamples):
|
|
n = 8
|
|
bs = n*n
|
|
z = model.sample_z(bs).to(device)
|
|
out = model(z)
|
|
_, c, h, w = out.shape
|
|
out = out.view(n, n, c, h, w).permute(2, 0, 3, 1, 4)
|
|
out = out.contiguous().view(1, c, h*n, w*n)
|
|
out = postprocess(out)
|
|
imsave(out, os.path.join(outdir, "samples_%03d.png" % i))
|
|
LOG.info(" saved %d", i)
|
|
|
|
LOG.info("output images saved to %s", outdir)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("model")
|
|
parser.add_argument("--output_dir", help="output directory for "
|
|
" the samples. Defaults to the model's path")
|
|
parser.add_argument("--nsamples", default=16, type=int,
|
|
help="number of output to compute")
|
|
parser.add_argument("--imsize", type=int,
|
|
help="if provided, override the raster output "
|
|
"resolution")
|
|
parser.add_argument("--nsteps", default=9, type=int, help="number of "
|
|
"interpolation steps for the interpolation")
|
|
parser.add_argument("--nframes", default=120, type=int, help="number of "
|
|
"frames for the interpolation video")
|
|
parser.add_argument("--invert", default=False, action="store_true",
|
|
help="if True, render black on white rather than the"
|
|
" opposite")
|
|
|
|
args = parser.parse_args()
|
|
|
|
pydiffvg.set_use_gpu(False)
|
|
|
|
ttools.set_logger(False)
|
|
|
|
run(args)
|