fixes mnist_vae

This commit is contained in:
Michael Gharbi
2020-09-30 09:02:51 -07:00
committed by Tzu-Mao Li
parent 2ec688ebb7
commit afe2e674d6

View File

@@ -3,11 +3,11 @@
Usage:
* Train a model:
* Train a model:
`python mnist_vae.py train`
* Generate samples from a trained model:
* Generate samples from a trained model:
`python mnist_vae.py sample`
@@ -26,7 +26,6 @@ import torchvision.transforms as transforms
import ttools
import ttools.interfaces
from ttools.modules import networks
from modules import Flatten
@@ -50,20 +49,21 @@ def _onehot(label):
def render(canvas_width, canvas_height, shapes, shape_groups, samples=2):
_render = pydiffvg.RenderFunction.apply
scene_args = pydiffvg.RenderFunction.serialize_scene(\
scene_args = pydiffvg.RenderFunction.serialize_scene(
canvas_width, canvas_height, shapes, shape_groups)
img = _render(canvas_width, # width
canvas_height, # height
samples, # num_samples_x
samples, # num_samples_y
0, # seed
None, # background
*scene_args)
img = _render(canvas_width,
canvas_height,
samples,
samples,
0,
None,
*scene_args)
return img
class MNISTCallback(ttools.callbacks.ImageDisplayCallback):
"""Simple callback that visualize generated images during training."""
def visualized_image(self, batch, step_data, is_val=False):
im = step_data["rendering"].detach().cpu()
im = 0.5 + 0.5*im
@@ -142,87 +142,17 @@ class VAEInterface(ttools.ModelInterface):
ret["data_loss"] = data_loss.item()
ret["auxdata"] = auxdata
ret["rendering"] = rendering
ret["logvar"] = logvar.abs().max().item()
return ret
# def init_validation(self):
# return {"count": 0, "loss": 0}
#
# def update_validation(self, batch, fwd, running_data):
# with th.no_grad():
# ref = batch[1].to(self.device)
# loss = th.nn.functional.mse_loss(fwd, ref)
# n = ref.shape[0]
#
# return {
# "loss": running_data["loss"] + loss.item()*n,
# "count": running_data["count"] + n
# }
#
# def finalize_validation(self, running_data):
# return {
# "loss": running_data["loss"] / running_data["count"]
# }
class MNISTGenerator(th.nn.Module):
def __init__(self, imsize=28):
super(MNISTGenerator, self).__init__()
if imsize != 28:
raise NotImplementedError()
mul = 2
self.convnet = th.nn.Sequential(
# 4x4
th.nn.ConvTranspose2d(16 + 1, mul*32, 4, padding=1, stride=2),
th.nn.LeakyReLU(inplace=True),
th.nn.Conv2d(mul*32, mul*32, 3, padding=1),
th.nn.LeakyReLU(inplace=True),
# 8x8
th.nn.ConvTranspose2d(mul*32, mul*64, 4, padding=1, stride=2),
th.nn.LeakyReLU(inplace=True),
th.nn.Conv2d(mul*64, mul*64, 3, padding=1),
th.nn.LeakyReLU(inplace=True),
# 16x16
th.nn.ConvTranspose2d(mul*64, mul*128, 4, padding=1, stride=2),
th.nn.LeakyReLU(inplace=True),
th.nn.Conv2d(mul*128, mul*128, 3, padding=1),
th.nn.LeakyReLU(inplace=True),
# 32x32
th.nn.Conv2d(mul*128, mul*128, 3, padding=1),
th.nn.LeakyReLU(inplace=True),
th.nn.Conv2d(mul*128, mul*128, 3, padding=1),
th.nn.LeakyReLU(inplace=True),
th.nn.Conv2d(mul*128, 1, 1),
# th.nn.Tanh(),
)
def forward(self, im, label):
bs = im.shape[0]
# sample a hidden vector
z = th.randn(bs, 16, 4, 4).to(im.device)
# make the model conditional
in_ = th.cat([z, label.float().view(bs, 1, 1, 1).repeat(1, 1, 4, 4)], 1)
out = self.convnet(in_)
return out, None
class VectorMNISTVAE(th.nn.Module):
def __init__(self, imsize=28, paths=4, segments=5, samples=2, zdim=128,
conditional=False, variational=True, raster=False, fc=False):
conditional=False, variational=True, raster=False, fc=False,
stroke_width=None):
super(VectorMNISTVAE, self).__init__()
# if imsize != 28:
# raise NotImplementedError()
self.samples = samples
self.imsize = imsize
self.paths = paths
@@ -231,6 +161,12 @@ class VectorMNISTVAE(th.nn.Module):
self.conditional = conditional
self.variational = variational
if stroke_width is None:
self.stroke_width = (1.0, 3.0)
LOG.warning("Setting default stroke with %s", self.stroke_width)
else:
self.stroke_width = stroke_width
ncond = 0
if self.conditional: # one hot encoded input for conditional model
ncond = 10
@@ -278,8 +214,8 @@ class VectorMNISTVAE(th.nn.Module):
th.nn.SELU(inplace=True),
)
self.raster = raster
if self.raster:
self.raster_decoder = th.nn.Sequential(
th.nn.Linear(nc, imsize*imsize),
@@ -293,59 +229,21 @@ class VectorMNISTVAE(th.nn.Module):
self.width_predictor = th.nn.Sequential(
th.nn.Linear(nc, self.paths),
th.nn.Tanh()
th.nn.Sigmoid()
)
self.alpha_predictor = th.nn.Sequential(
th.nn.Linear(nc, self.paths),
th.nn.Tanh()
th.nn.Sigmoid()
)
self._reset_weights()
def _reset_weights(self):
for n, p in self.encoder.named_parameters():
if 'bias' in n:
p.data.zero_()
elif 'weight' in n:
th.nn.init.kaiming_normal_(p.data, nonlinearity="leaky_relu")
th.nn.init.kaiming_normal_(self.mu_predictor.weight.data, nonlinearity="linear")
if self.variational:
th.nn.init.kaiming_normal_(self.logvar_predictor.weight.data, nonlinearity="linear")
for n, p in self.decoder.named_parameters():
if 'bias' in n:
p.data.zero_()
elif 'weight' in n:
th.nn.init.kaiming_normal_(p, nonlinearity="linear")
if not self.raster:
for n, p in self.point_predictor.named_parameters():
pass
# if 'bias' in n:
# p.data.zero_()
# if 'weight' in n:
# th.nn.init.orthogonal_(p)
for n, p in self.width_predictor.named_parameters():
if 'bias' in n:
p.data.zero_()
elif 'weight' in n:
th.nn.init.orthogonal_(p)
for n, p in self.alpha_predictor.named_parameters():
if 'bias' in n:
p.data.zero_()
elif 'weight' in n:
th.nn.init.orthogonal_(p)
def encode(self, im, label):
bs, _, h, w = im.shape
if self.conditional:
label_onehot = _onehot(label)
if not self.fc:
label_onehot = label_onehot.view(bs, 10, 1, 1).repeat(1, 1, h, w)
label_onehot = label_onehot.view(
bs, 10, 1, 1).repeat(1, 1, h, w)
out = self.encoder(th.cat([im, label_onehot], 1))
else:
out = self.encoder(th.cat([im.view(bs, -1), label_onehot], 1))
@@ -365,7 +263,9 @@ class VectorMNISTVAE(th.nn.Module):
def _decode_features(self, z, label):
if label is not None:
assert self.conditional, "decoding with an input label requires a conditional AE"
if not self.conditional:
raise ValueError("decoding with an input label "
"requires a conditional AE")
label_onehot = _onehot(label)
z = th.cat([z, label_onehot], 1)
@@ -378,7 +278,8 @@ class VectorMNISTVAE(th.nn.Module):
feats = self._decode_features(z, label)
if self.raster:
out = self.raster_decoder(feats).view(bs, 1, self.imsize, self.imsize)
out = self.raster_decoder(feats).view(
bs, 1, self.imsize, self.imsize)
return out, {}
all_points = self.point_predictor(feats)
@@ -389,7 +290,10 @@ class VectorMNISTVAE(th.nn.Module):
if False:
all_widths = th.ones(bs, self.paths) * 0.5
else:
all_widths = self.width_predictor(feats) * 1.5 + .25
all_widths = self.width_predictor(feats)
min_width = self.stroke_width[0]
max_width = self.stroke_width[1]
all_widths = (max_width - min_width) * all_widths + min_width
if False:
all_alphas = th.ones(bs, self.paths)
@@ -426,14 +330,16 @@ class VectorMNISTVAE(th.nn.Module):
[shapes, shape_groups, (self.imsize, self.imsize)])
# Rasterize
out = render(self.imsize, self.imsize, shapes, shape_groups, samples=self.samples)
out = render(self.imsize, self.imsize, shapes, shape_groups,
samples=self.samples)
# Torch format, discard alpha, make gray
out = out.permute(2, 0, 1).view(4, self.imsize, self.imsize)[:3].mean(0, keepdim=True)
out = out.permute(2, 0, 1).view(
4, self.imsize, self.imsize)[:3].mean(0, keepdim=True)
outputs.append(out)
output = th.stack(outputs).to(z.device)
output = th.stack(outputs).to(z.device)
auxdata = {
"points": all_points,
@@ -446,8 +352,6 @@ class VectorMNISTVAE(th.nn.Module):
return output, auxdata
def forward(self, im, label):
bs = im.shape[0]
if self.variational:
mu, logvar = self.encode(im, label)
z = self.reparameterize(mu, logvar)
@@ -457,9 +361,9 @@ class VectorMNISTVAE(th.nn.Module):
logvar = None
if self.conditional:
output, aux = self.decode(z, label=label)
output, aux = self.decode(z, label=label)
else:
output, aux = self.decode(z)
output, aux = self.decode(z)
aux["logvar"] = logvar
aux["mu"] = mu
@@ -467,250 +371,6 @@ class VectorMNISTVAE(th.nn.Module):
return output, aux
class VectorMNISTGenerator(th.nn.Module):
def __init__(self, imsize=28, paths=4, segments=5, samples=2, conditional=False,
zdim=20, fc=False):
super(VectorMNISTGenerator, self).__init__()
if imsize != 28:
raise NotImplementedError()
self.samples = samples
self.imsize = imsize
self.paths = paths
self.segments = segments
self.conditional = conditional
self.zdim = zdim
self.fc = fc
ncond = 0
if self.conditional: # one hot encoded input for conditional model
ncond = 10
nc = 1024
self.trunk = th.nn.Sequential(
th.nn.Linear(zdim + ncond, nc), # noise + one-hot
th.nn.SELU(inplace=True),
# th.nn.Linear(nc, nc),
# th.nn.SELU(inplace=True),
th.nn.Linear(nc, nc),
th.nn.SELU(inplace=True),
# th.nn.Linear(nc, nc),
# th.nn.SELU(inplace=True),
)
# 4 points bezier so n_segments -> 3*n_segments + 1 points
self.point_predictor = th.nn.Sequential(
th.nn.Linear(nc, 2*self.paths*(self.segments*3+1)),
# th.nn.Linear(nc, 2*self.paths*(self.segments*1+1)),
th.nn.Tanh() # bound spatial extent
)
self.width_predictor = th.nn.Sequential(
th.nn.Linear(nc, self.paths),
th.nn.Tanh()
)
self.alpha_predictor = th.nn.Sequential(
th.nn.Linear(nc, self.paths),
th.nn.Tanh()
)
# self.postprocessor = th.nn.Sequential(
# th.nn.Conv2d(1, 32, 3, padding=1),
# th.nn.LeakyReLU(inplace=True),
# th.nn.Conv2d(32, 1, 1),
# )
self._reset_weights()
def _reset_weights(self):
for n, p in self.trunk.named_parameters():
if 'bias' in n:
p.data.zero_()
elif 'weight' in n:
th.nn.init.kaiming_normal_(p)
p.data.mul_(0.7)
# th.nn.init.kaiming_normal_(p, nonlinearity="leaky_relu")
for n, p in self.point_predictor.named_parameters():
# if 'bias' in n:
# p.data.zero_()
if 'weight' in n:
th.nn.init.orthogonal_(p)
# th.nn.init.kaiming_normal_(p, nonlinearity="tanh")
for n, p in self.width_predictor.named_parameters():
if 'bias' in n:
p.data.zero_()
elif 'weight' in n:
# th.nn.init.orthogonal_(p)
th.nn.init.kaiming_normal_(p, nonlinearity="tanh")
for n, p in self.alpha_predictor.named_parameters():
if 'bias' in n:
p.data.zero_()
elif 'weight' in n:
th.nn.init.kaiming_normal_(p, nonlinearity="tanh")
# th.nn.init.orthogonal_(p)
def sample_z(self, bs):
return th.randn(bs, self.zdim)
def gen_sample(self, z, label=None):
bs = z.shape[0]
if self.conditional:
if label is None:
raise ValueError("GAN is conditional, please provide a label")
# One-hot encoding of the image label
label_onehot = _onehot(label)
# get some embedding
in_ = th.cat([z, label_onehot.float()], 1)
else:
in_ = z
feats = self.trunk(in_)
all_points = self.point_predictor(feats)
all_points = all_points.view(bs, self.paths, -1, 2)
if False:
all_alphas = th.ones(bs, self.paths)
else:
all_alphas = self.alpha_predictor(feats)
# stroke size between 0.5 and 3.5 px
if False:
all_widths = th.ones(bs, self.paths) * 1
else:
all_widths = self.width_predictor(feats)
all_widths = 1.5*all_widths + 0.5
all_points = all_points*(self.imsize//2) + self.imsize//2
# Process the batch sequentially
outputs = []
for k in range(bs):
# Get point parameters from network
shapes = []
shape_groups = []
for p in range(self.paths):
points = all_points[k, p].contiguous().cpu()
# num_ctrl_pts = th.zeros(self.segments, dtype=th.int32)+0
num_ctrl_pts = th.zeros(self.segments, dtype=th.int32)+2
width = all_widths[k, p].cpu()
alpha = all_alphas[k, p].cpu()
color = th.cat([th.ones(3), alpha.view(1,)])
path = pydiffvg.Path(
num_control_points=num_ctrl_pts, points=points,
stroke_width=width, is_closed=False)
shapes.append(path)
path_group = pydiffvg.ShapeGroup(
shape_ids=th.tensor([len(shapes) - 1]),
fill_color=None,
stroke_color=color)
shape_groups.append(path_group)
# Rasterize
out = render(self.imsize, self.imsize, shapes, shape_groups, samples=self.samples)
# Torch format, discard alpha, make gray
out = out.permute(2, 0, 1).view(4, self.imsize, self.imsize)[:3].mean(0, keepdim=True)
outputs.append(out)
output = th.stack(outputs).to(z.device)
aux_data = {
"points": all_points,
"raw_vector": output,
}
# output = self.postprocessor(output)
# map to [-1, 1]
output = output*2.0 - 1.0
return output, aux_data
def forward(self, im, label):
bs = label.shape[0]
# sample a hidden vector (same dim as the raster version)
z = self.sample_z(bs).to(im.device)
if args.conditional:
return self.gen_sample(z, label=label)
else:
return self.gen_sample(z)
class Discriminator(th.nn.Module):
def __init__(self, conditional=False, fc=False):
super(Discriminator, self).__init__()
self.conditional = conditional
ncond = 0
if self.conditional: # one hot encoded input for conditional model
ncond = 10
sn = th.nn.utils.spectral_norm
# sn = lambda x: x
self.fc = fc
mult = 2
if self.fc:
self.net = th.nn.Sequential(
Flatten(),
th.nn.Linear(28*28 + ncond, mult*256),
th.nn.LeakyReLU(0.2, inplace=True),
# th.nn.Linear(mult*256, mult*256, 4),
# th.nn.LeakyReLU(0.2, inplace=True),
# th.nn.Dropout(0.5),
th.nn.Linear(mult*256, mult*256, 4),
th.nn.LeakyReLU(0.2, inplace=True),
th.nn.Linear(mult*256*1*1, 1),
)
else:
self.net = th.nn.Sequential(
th.nn.Conv2d(1 + ncond, mult*64, 4, padding=0, stride=2),
th.nn.LeakyReLU(0.2, inplace=True),
# 16x16
sn(th.nn.Conv2d(mult*64, mult*128, 4, padding=0, stride=2)),
th.nn.LeakyReLU(0.2, inplace=True),
# 8x8
sn(th.nn.Conv2d(mult*128, mult*256, 4, padding=0, stride=2)),
th.nn.LeakyReLU(0.2, inplace=True),
# 4x4
Flatten(),
th.nn.Linear(mult*256*1*1, 1),
)
self._reset_weights()
def _reset_weights(self):
for n, p in self.net.named_parameters():
if 'bias' in n:
p.data.zero_()
if 'weight' in n:
th.nn.init.kaiming_normal_(p, nonlinearity="leaky_relu")
def forward(self, x):
out = self.net(x)
return out
class Dataset(th.utils.data.Dataset):
def __init__(self, data_dir, imsize):
super(Dataset, self).__init__()
@@ -783,8 +443,8 @@ def train(args):
extras, meta = checkpointer.load_latest()
if meta is not None and meta != model_params:
LOG.info("Checkpoint's metaparams differ from CLI, aborting: %s and %s",
meta, model_params)
LOG.info(f"Checkpoint's metaparams differ from CLI, "
f"aborting: {meta} and {model_params}")
# Hook interface
if args.generator in ["vae", "ae"]:
@@ -794,17 +454,18 @@ def train(args):
else:
LOG.info("Using an AE")
interface = VAEInterface(model, lr=args.lr, cuda=args.cuda,
variational=variational, w_kld=args.kld_weight)
variational=variational,
w_kld=args.kld_weight)
trainer = ttools.Trainer(interface)
# Add callbacks
keys = ["loss_g", "loss_d"]
keys = []
if args.generator == "vae":
keys = ["kld", "data_loss", "loss"]
keys = ["kld", "data_loss", "loss", "logvar"]
elif args.generator == "ae":
keys = ["data_loss", "loss"]
port = 8097
port = 8080
trainer.add_callback(ttools.callbacks.ProgressBarCallback(
keys=keys, val_keys=keys))
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
@@ -979,8 +640,10 @@ if __name__ == "__main__":
parser.add_argument("--cpu", dest="cuda", action="store_false",
default=th.cuda.is_available(),
help="if true, use CPU instead of GPU.")
parser.add_argument("--conditional", action="store_true", default=False)
parser.add_argument("--fc", action="store_true", default=False)
parser.add_argument("--no-conditional", dest="conditional",
action="store_false", default=True)
parser.add_argument("--no-fc", dest="fc", action="store_false",
default=True)
parser.add_argument("--data_dir", default="mnist",
help="path to download and store the data.")
@@ -992,19 +655,19 @@ if __name__ == "__main__":
"autoencoder")
parser_train.add_argument("--freq", type=int, default=100,
help="number of steps between visualizations")
parser_train.add_argument("--lr", type=float, default=1e-4,
parser_train.add_argument("--lr", type=float, default=5e-5,
help="learning rate")
parser_train.add_argument("--kld_weight", type=float, default=1.0,
help="scalar weight for the KL divergence term.")
parser_train.add_argument("--bs", type=int, default=8, help="batch size")
parser_train.add_argument("--num_epochs", type=int,
parser_train.add_argument("--num_epochs", default=50, type=int,
help="max number of epochs")
# Vector configs
parser_train.add_argument("--paths", type=int, default=1,
help="number of unique vector paths to generate.")
help="number of vector paths to generate.")
parser_train.add_argument("--segments", type=int, default=3,
help="number of segments per vector path")
parser_train.add_argument("--samples", type=int, default=2,
parser_train.add_argument("--samples", type=int, default=4,
help="number of samples in the MC rasterizer")
parser_train.add_argument("--zdim", type=int, default=20,
help="dimension of the latent space")