initial commit
This commit is contained in:
182
apps/generative_models/eval_gan.py
Normal file
182
apps/generative_models/eval_gan.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""Evaluate a pretrained GAN model.
|
||||
Usage:
|
||||
|
||||
`python eval_gan.py <path/to/model/folder>`, e.g.
|
||||
`../results/quickdraw_gan_vector_bezier_fc_wgan`.
|
||||
|
||||
"""
|
||||
import os
|
||||
import argparse
|
||||
import torch as th
|
||||
import numpy as np
|
||||
import ttools
|
||||
import imageio
|
||||
from subprocess import call
|
||||
|
||||
import pydiffvg
|
||||
|
||||
import models
|
||||
|
||||
|
||||
LOG = ttools.get_logger(__name__)
|
||||
|
||||
|
||||
def postprocess(im, invert=False):
|
||||
im = th.clamp((im + 1.0) / 2.0, 0, 1)
|
||||
if invert:
|
||||
im = (1.0 - im)
|
||||
im = ttools.tensor2image(im)
|
||||
return im
|
||||
|
||||
|
||||
def imsave(im, path):
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
imageio.imwrite(path, im)
|
||||
|
||||
|
||||
def save_scene(scn, path):
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
pydiffvg.save_svg(path, *scn, use_gamma=False)
|
||||
|
||||
|
||||
def run(args):
|
||||
th.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
meta = ttools.Checkpointer.load_meta(args.model, "vect_g_")
|
||||
|
||||
if meta is None:
|
||||
LOG.warning("Could not load metadata at %s, aborting.", args.model)
|
||||
return
|
||||
|
||||
LOG.info("Loaded model %s with metadata:\n %s", args.model, meta)
|
||||
|
||||
if args.output_dir is None:
|
||||
outdir = os.path.join(args.model, "eval")
|
||||
else:
|
||||
outdir = args.output_dir
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
|
||||
model_params = meta["model_params"]
|
||||
if args.imsize is not None:
|
||||
LOG.info("Overriding output image size to: %dx%d", args.imsize,
|
||||
args.imsize)
|
||||
old_size = model_params["imsize"]
|
||||
scale = args.imsize * 1.0 / old_size
|
||||
model_params["imsize"] = args.imsize
|
||||
model_params["stroke_width"] = [w*scale for w in
|
||||
model_params["stroke_width"]]
|
||||
LOG.info("Overriding width to: %s", model_params["stroke_width"])
|
||||
|
||||
# task = meta["task"]
|
||||
generator = meta["generator"]
|
||||
if generator == "fc":
|
||||
model = models.VectorGenerator(**model_params)
|
||||
elif generator == "bezier_fc":
|
||||
model = models.BezierVectorGenerator(**model_params)
|
||||
elif generator in ["rnn"]:
|
||||
model = models.RNNVectorGenerator(**model_params)
|
||||
elif generator in ["chain_rnn"]:
|
||||
model = models.ChainRNNVectorGenerator(**model_params)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
model.eval()
|
||||
|
||||
device = "cpu"
|
||||
if th.cuda.is_available():
|
||||
device = "cuda"
|
||||
|
||||
model.to(device)
|
||||
|
||||
checkpointer = ttools.Checkpointer(
|
||||
args.model, model, meta=meta, prefix="vect_g_")
|
||||
checkpointer.load_latest()
|
||||
|
||||
LOG.info("Computing latent space interpolation")
|
||||
for i in range(args.nsamples):
|
||||
z0 = model.sample_z(1)
|
||||
z1 = model.sample_z(1)
|
||||
|
||||
# interpolation
|
||||
alpha = th.linspace(0, 1, args.nsteps).view(args.nsteps, 1).to(device)
|
||||
alpha_video = th.linspace(0, 1, args.nframes).view(args.nframes, 1)
|
||||
alpha_video = alpha_video.to(device)
|
||||
|
||||
length = [args.nsteps, args.nframes]
|
||||
for idx, a in enumerate([alpha, alpha_video]):
|
||||
_z0 = z0.repeat(length[idx], 1).to(device)
|
||||
_z1 = z1.repeat(length[idx], 1).to(device)
|
||||
batch = _z0*(1-a) + _z1*a
|
||||
out = model(batch)
|
||||
if idx == 0: # image viz
|
||||
n, c, h, w = out.shape
|
||||
out = out.permute(1, 2, 0, 3)
|
||||
out = out.contiguous().view(1, c, h, w*n)
|
||||
out = postprocess(out, invert=args.invert)
|
||||
imsave(out, os.path.join(outdir,
|
||||
"latent_interp", "%03d.png" % i))
|
||||
|
||||
scenes = model.get_vector(batch)
|
||||
for scn_idx, scn in enumerate(scenes):
|
||||
save_scene(scn, os.path.join(outdir, "latent_interp_svg",
|
||||
"%03d" % i, "%03d.svg" %
|
||||
scn_idx))
|
||||
else: # video viz
|
||||
anim_root = os.path.join(outdir,
|
||||
"latent_interp_video", "%03d" % i)
|
||||
LOG.info("Rendering animation %d", i)
|
||||
for frame_idx, frame in enumerate(out):
|
||||
LOG.info("frame %d", frame_idx)
|
||||
frame = frame.unsqueeze(0)
|
||||
frame = postprocess(frame, invert=args.invert)
|
||||
imsave(frame, os.path.join(anim_root,
|
||||
"frame%04d.png" % frame_idx))
|
||||
call(["ffmpeg", "-framerate", "30", "-i",
|
||||
os.path.join(anim_root, "frame%04d.png"), "-vb", "20M",
|
||||
os.path.join(outdir,
|
||||
"latent_interp_video", "%03d.mp4" % i)])
|
||||
LOG.info(" saved %d", i)
|
||||
|
||||
LOG.info("Sampling latent space")
|
||||
|
||||
for i in range(args.nsamples):
|
||||
n = 8
|
||||
bs = n*n
|
||||
z = model.sample_z(bs).to(device)
|
||||
out = model(z)
|
||||
_, c, h, w = out.shape
|
||||
out = out.view(n, n, c, h, w).permute(2, 0, 3, 1, 4)
|
||||
out = out.contiguous().view(1, c, h*n, w*n)
|
||||
out = postprocess(out)
|
||||
imsave(out, os.path.join(outdir, "samples_%03d.png" % i))
|
||||
LOG.info(" saved %d", i)
|
||||
|
||||
LOG.info("output images saved to %s", outdir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("model")
|
||||
parser.add_argument("--output_dir", help="output directory for "
|
||||
" the samples. Defaults to the model's path")
|
||||
parser.add_argument("--nsamples", default=16, type=int,
|
||||
help="number of output to compute")
|
||||
parser.add_argument("--imsize", type=int,
|
||||
help="if provided, override the raster output "
|
||||
"resolution")
|
||||
parser.add_argument("--nsteps", default=9, type=int, help="number of "
|
||||
"interpolation steps for the interpolation")
|
||||
parser.add_argument("--nframes", default=120, type=int, help="number of "
|
||||
"frames for the interpolation video")
|
||||
parser.add_argument("--invert", default=False, action="store_true",
|
||||
help="if True, render black on white rather than the"
|
||||
" opposite")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
pydiffvg.set_use_gpu(False)
|
||||
|
||||
ttools.set_logger(False)
|
||||
|
||||
run(args)
|
Reference in New Issue
Block a user