diff --git a/apps/generative_models/mnist_vae.py b/apps/generative_models/mnist_vae.py index 884c147..c0da626 100644 --- a/apps/generative_models/mnist_vae.py +++ b/apps/generative_models/mnist_vae.py @@ -3,11 +3,11 @@ Usage: -* Train a model: +* Train a model: `python mnist_vae.py train` -* Generate samples from a trained model: +* Generate samples from a trained model: `python mnist_vae.py sample` @@ -26,7 +26,6 @@ import torchvision.transforms as transforms import ttools import ttools.interfaces -from ttools.modules import networks from modules import Flatten @@ -50,20 +49,21 @@ def _onehot(label): def render(canvas_width, canvas_height, shapes, shape_groups, samples=2): _render = pydiffvg.RenderFunction.apply - scene_args = pydiffvg.RenderFunction.serialize_scene(\ + scene_args = pydiffvg.RenderFunction.serialize_scene( canvas_width, canvas_height, shapes, shape_groups) - img = _render(canvas_width, # width - canvas_height, # height - samples, # num_samples_x - samples, # num_samples_y - 0, # seed - None, # background - *scene_args) + img = _render(canvas_width, + canvas_height, + samples, + samples, + 0, + None, + *scene_args) return img class MNISTCallback(ttools.callbacks.ImageDisplayCallback): """Simple callback that visualize generated images during training.""" + def visualized_image(self, batch, step_data, is_val=False): im = step_data["rendering"].detach().cpu() im = 0.5 + 0.5*im @@ -142,87 +142,17 @@ class VAEInterface(ttools.ModelInterface): ret["data_loss"] = data_loss.item() ret["auxdata"] = auxdata ret["rendering"] = rendering + ret["logvar"] = logvar.abs().max().item() return ret - # def init_validation(self): - # return {"count": 0, "loss": 0} - # - # def update_validation(self, batch, fwd, running_data): - # with th.no_grad(): - # ref = batch[1].to(self.device) - # loss = th.nn.functional.mse_loss(fwd, ref) - # n = ref.shape[0] - # - # return { - # "loss": running_data["loss"] + loss.item()*n, - # "count": running_data["count"] + n - # } - # - # def finalize_validation(self, running_data): - # return { - # "loss": running_data["loss"] / running_data["count"] - # } - - -class MNISTGenerator(th.nn.Module): - def __init__(self, imsize=28): - super(MNISTGenerator, self).__init__() - if imsize != 28: - raise NotImplementedError() - - mul = 2 - self.convnet = th.nn.Sequential( - # 4x4 - th.nn.ConvTranspose2d(16 + 1, mul*32, 4, padding=1, stride=2), - th.nn.LeakyReLU(inplace=True), - th.nn.Conv2d(mul*32, mul*32, 3, padding=1), - th.nn.LeakyReLU(inplace=True), - - # 8x8 - th.nn.ConvTranspose2d(mul*32, mul*64, 4, padding=1, stride=2), - th.nn.LeakyReLU(inplace=True), - th.nn.Conv2d(mul*64, mul*64, 3, padding=1), - th.nn.LeakyReLU(inplace=True), - - # 16x16 - th.nn.ConvTranspose2d(mul*64, mul*128, 4, padding=1, stride=2), - th.nn.LeakyReLU(inplace=True), - th.nn.Conv2d(mul*128, mul*128, 3, padding=1), - th.nn.LeakyReLU(inplace=True), - - # 32x32 - th.nn.Conv2d(mul*128, mul*128, 3, padding=1), - th.nn.LeakyReLU(inplace=True), - - th.nn.Conv2d(mul*128, mul*128, 3, padding=1), - th.nn.LeakyReLU(inplace=True), - - th.nn.Conv2d(mul*128, 1, 1), - # th.nn.Tanh(), - ) - - def forward(self, im, label): - bs = im.shape[0] - - # sample a hidden vector - z = th.randn(bs, 16, 4, 4).to(im.device) - - # make the model conditional - in_ = th.cat([z, label.float().view(bs, 1, 1, 1).repeat(1, 1, 4, 4)], 1) - - out = self.convnet(in_) - return out, None - class VectorMNISTVAE(th.nn.Module): def __init__(self, imsize=28, paths=4, segments=5, samples=2, zdim=128, - conditional=False, variational=True, raster=False, fc=False): + conditional=False, variational=True, raster=False, fc=False, + stroke_width=None): super(VectorMNISTVAE, self).__init__() - # if imsize != 28: - # raise NotImplementedError() - self.samples = samples self.imsize = imsize self.paths = paths @@ -231,6 +161,12 @@ class VectorMNISTVAE(th.nn.Module): self.conditional = conditional self.variational = variational + if stroke_width is None: + self.stroke_width = (1.0, 3.0) + LOG.warning("Setting default stroke with %s", self.stroke_width) + else: + self.stroke_width = stroke_width + ncond = 0 if self.conditional: # one hot encoded input for conditional model ncond = 10 @@ -278,8 +214,8 @@ class VectorMNISTVAE(th.nn.Module): th.nn.SELU(inplace=True), ) - self.raster = raster + if self.raster: self.raster_decoder = th.nn.Sequential( th.nn.Linear(nc, imsize*imsize), @@ -293,59 +229,21 @@ class VectorMNISTVAE(th.nn.Module): self.width_predictor = th.nn.Sequential( th.nn.Linear(nc, self.paths), - th.nn.Tanh() + th.nn.Sigmoid() ) self.alpha_predictor = th.nn.Sequential( th.nn.Linear(nc, self.paths), - th.nn.Tanh() + th.nn.Sigmoid() ) - self._reset_weights() - - def _reset_weights(self): - for n, p in self.encoder.named_parameters(): - if 'bias' in n: - p.data.zero_() - elif 'weight' in n: - th.nn.init.kaiming_normal_(p.data, nonlinearity="leaky_relu") - - th.nn.init.kaiming_normal_(self.mu_predictor.weight.data, nonlinearity="linear") - if self.variational: - th.nn.init.kaiming_normal_(self.logvar_predictor.weight.data, nonlinearity="linear") - - for n, p in self.decoder.named_parameters(): - if 'bias' in n: - p.data.zero_() - elif 'weight' in n: - th.nn.init.kaiming_normal_(p, nonlinearity="linear") - - if not self.raster: - for n, p in self.point_predictor.named_parameters(): - pass - # if 'bias' in n: - # p.data.zero_() - # if 'weight' in n: - # th.nn.init.orthogonal_(p) - - for n, p in self.width_predictor.named_parameters(): - if 'bias' in n: - p.data.zero_() - elif 'weight' in n: - th.nn.init.orthogonal_(p) - - for n, p in self.alpha_predictor.named_parameters(): - if 'bias' in n: - p.data.zero_() - elif 'weight' in n: - th.nn.init.orthogonal_(p) - def encode(self, im, label): bs, _, h, w = im.shape if self.conditional: label_onehot = _onehot(label) if not self.fc: - label_onehot = label_onehot.view(bs, 10, 1, 1).repeat(1, 1, h, w) + label_onehot = label_onehot.view( + bs, 10, 1, 1).repeat(1, 1, h, w) out = self.encoder(th.cat([im, label_onehot], 1)) else: out = self.encoder(th.cat([im.view(bs, -1), label_onehot], 1)) @@ -365,7 +263,9 @@ class VectorMNISTVAE(th.nn.Module): def _decode_features(self, z, label): if label is not None: - assert self.conditional, "decoding with an input label requires a conditional AE" + if not self.conditional: + raise ValueError("decoding with an input label " + "requires a conditional AE") label_onehot = _onehot(label) z = th.cat([z, label_onehot], 1) @@ -378,7 +278,8 @@ class VectorMNISTVAE(th.nn.Module): feats = self._decode_features(z, label) if self.raster: - out = self.raster_decoder(feats).view(bs, 1, self.imsize, self.imsize) + out = self.raster_decoder(feats).view( + bs, 1, self.imsize, self.imsize) return out, {} all_points = self.point_predictor(feats) @@ -389,7 +290,10 @@ class VectorMNISTVAE(th.nn.Module): if False: all_widths = th.ones(bs, self.paths) * 0.5 else: - all_widths = self.width_predictor(feats) * 1.5 + .25 + all_widths = self.width_predictor(feats) + min_width = self.stroke_width[0] + max_width = self.stroke_width[1] + all_widths = (max_width - min_width) * all_widths + min_width if False: all_alphas = th.ones(bs, self.paths) @@ -426,14 +330,16 @@ class VectorMNISTVAE(th.nn.Module): [shapes, shape_groups, (self.imsize, self.imsize)]) # Rasterize - out = render(self.imsize, self.imsize, shapes, shape_groups, samples=self.samples) + out = render(self.imsize, self.imsize, shapes, shape_groups, + samples=self.samples) # Torch format, discard alpha, make gray - out = out.permute(2, 0, 1).view(4, self.imsize, self.imsize)[:3].mean(0, keepdim=True) + out = out.permute(2, 0, 1).view( + 4, self.imsize, self.imsize)[:3].mean(0, keepdim=True) outputs.append(out) - output = th.stack(outputs).to(z.device) + output = th.stack(outputs).to(z.device) auxdata = { "points": all_points, @@ -446,8 +352,6 @@ class VectorMNISTVAE(th.nn.Module): return output, auxdata def forward(self, im, label): - bs = im.shape[0] - if self.variational: mu, logvar = self.encode(im, label) z = self.reparameterize(mu, logvar) @@ -457,9 +361,9 @@ class VectorMNISTVAE(th.nn.Module): logvar = None if self.conditional: - output, aux = self.decode(z, label=label) + output, aux = self.decode(z, label=label) else: - output, aux = self.decode(z) + output, aux = self.decode(z) aux["logvar"] = logvar aux["mu"] = mu @@ -467,250 +371,6 @@ class VectorMNISTVAE(th.nn.Module): return output, aux -class VectorMNISTGenerator(th.nn.Module): - def __init__(self, imsize=28, paths=4, segments=5, samples=2, conditional=False, - zdim=20, fc=False): - super(VectorMNISTGenerator, self).__init__() - if imsize != 28: - raise NotImplementedError() - - self.samples = samples - self.imsize = imsize - self.paths = paths - self.segments = segments - self.conditional = conditional - self.zdim = zdim - self.fc = fc - - ncond = 0 - if self.conditional: # one hot encoded input for conditional model - ncond = 10 - - nc = 1024 - self.trunk = th.nn.Sequential( - th.nn.Linear(zdim + ncond, nc), # noise + one-hot - th.nn.SELU(inplace=True), - - # th.nn.Linear(nc, nc), - # th.nn.SELU(inplace=True), - - th.nn.Linear(nc, nc), - th.nn.SELU(inplace=True), - - # th.nn.Linear(nc, nc), - # th.nn.SELU(inplace=True), - ) - - # 4 points bezier so n_segments -> 3*n_segments + 1 points - self.point_predictor = th.nn.Sequential( - th.nn.Linear(nc, 2*self.paths*(self.segments*3+1)), - # th.nn.Linear(nc, 2*self.paths*(self.segments*1+1)), - th.nn.Tanh() # bound spatial extent - ) - - self.width_predictor = th.nn.Sequential( - th.nn.Linear(nc, self.paths), - th.nn.Tanh() - ) - - self.alpha_predictor = th.nn.Sequential( - th.nn.Linear(nc, self.paths), - th.nn.Tanh() - ) - - # self.postprocessor = th.nn.Sequential( - # th.nn.Conv2d(1, 32, 3, padding=1), - # th.nn.LeakyReLU(inplace=True), - # th.nn.Conv2d(32, 1, 1), - # ) - self._reset_weights() - - def _reset_weights(self): - for n, p in self.trunk.named_parameters(): - if 'bias' in n: - p.data.zero_() - elif 'weight' in n: - th.nn.init.kaiming_normal_(p) - p.data.mul_(0.7) - # th.nn.init.kaiming_normal_(p, nonlinearity="leaky_relu") - - for n, p in self.point_predictor.named_parameters(): - # if 'bias' in n: - # p.data.zero_() - if 'weight' in n: - th.nn.init.orthogonal_(p) - # th.nn.init.kaiming_normal_(p, nonlinearity="tanh") - - for n, p in self.width_predictor.named_parameters(): - if 'bias' in n: - p.data.zero_() - elif 'weight' in n: - # th.nn.init.orthogonal_(p) - th.nn.init.kaiming_normal_(p, nonlinearity="tanh") - - for n, p in self.alpha_predictor.named_parameters(): - if 'bias' in n: - p.data.zero_() - elif 'weight' in n: - th.nn.init.kaiming_normal_(p, nonlinearity="tanh") - # th.nn.init.orthogonal_(p) - - def sample_z(self, bs): - return th.randn(bs, self.zdim) - - def gen_sample(self, z, label=None): - bs = z.shape[0] - if self.conditional: - if label is None: - raise ValueError("GAN is conditional, please provide a label") - - # One-hot encoding of the image label - label_onehot = _onehot(label) - - # get some embedding - in_ = th.cat([z, label_onehot.float()], 1) - else: - in_ = z - - feats = self.trunk(in_) - - all_points = self.point_predictor(feats) - all_points = all_points.view(bs, self.paths, -1, 2) - - if False: - all_alphas = th.ones(bs, self.paths) - else: - all_alphas = self.alpha_predictor(feats) - - # stroke size between 0.5 and 3.5 px - if False: - all_widths = th.ones(bs, self.paths) * 1 - else: - all_widths = self.width_predictor(feats) - all_widths = 1.5*all_widths + 0.5 - - all_points = all_points*(self.imsize//2) + self.imsize//2 - - # Process the batch sequentially - outputs = [] - for k in range(bs): - # Get point parameters from network - shapes = [] - shape_groups = [] - for p in range(self.paths): - points = all_points[k, p].contiguous().cpu() - # num_ctrl_pts = th.zeros(self.segments, dtype=th.int32)+0 - num_ctrl_pts = th.zeros(self.segments, dtype=th.int32)+2 - width = all_widths[k, p].cpu() - alpha = all_alphas[k, p].cpu() - color = th.cat([th.ones(3), alpha.view(1,)]) - path = pydiffvg.Path( - num_control_points=num_ctrl_pts, points=points, - stroke_width=width, is_closed=False) - shapes.append(path) - path_group = pydiffvg.ShapeGroup( - shape_ids=th.tensor([len(shapes) - 1]), - fill_color=None, - stroke_color=color) - shape_groups.append(path_group) - - # Rasterize - out = render(self.imsize, self.imsize, shapes, shape_groups, samples=self.samples) - - # Torch format, discard alpha, make gray - out = out.permute(2, 0, 1).view(4, self.imsize, self.imsize)[:3].mean(0, keepdim=True) - - outputs.append(out) - - output = th.stack(outputs).to(z.device) - aux_data = { - "points": all_points, - "raw_vector": output, - } - - # output = self.postprocessor(output) - - # map to [-1, 1] - output = output*2.0 - 1.0 - - return output, aux_data - - def forward(self, im, label): - bs = label.shape[0] - - # sample a hidden vector (same dim as the raster version) - z = self.sample_z(bs).to(im.device) - if args.conditional: - return self.gen_sample(z, label=label) - else: - return self.gen_sample(z) - - - -class Discriminator(th.nn.Module): - def __init__(self, conditional=False, fc=False): - super(Discriminator, self).__init__() - - self.conditional = conditional - - ncond = 0 - if self.conditional: # one hot encoded input for conditional model - ncond = 10 - - sn = th.nn.utils.spectral_norm - # sn = lambda x: x - - self.fc = fc - - mult = 2 - if self.fc: - self.net = th.nn.Sequential( - Flatten(), - th.nn.Linear(28*28 + ncond, mult*256), - th.nn.LeakyReLU(0.2, inplace=True), - - # th.nn.Linear(mult*256, mult*256, 4), - # th.nn.LeakyReLU(0.2, inplace=True), - # th.nn.Dropout(0.5), - - th.nn.Linear(mult*256, mult*256, 4), - th.nn.LeakyReLU(0.2, inplace=True), - - th.nn.Linear(mult*256*1*1, 1), - ) - else: - self.net = th.nn.Sequential( - th.nn.Conv2d(1 + ncond, mult*64, 4, padding=0, stride=2), - - th.nn.LeakyReLU(0.2, inplace=True), - # 16x16 - - sn(th.nn.Conv2d(mult*64, mult*128, 4, padding=0, stride=2)), - th.nn.LeakyReLU(0.2, inplace=True), - # 8x8 - - sn(th.nn.Conv2d(mult*128, mult*256, 4, padding=0, stride=2)), - th.nn.LeakyReLU(0.2, inplace=True), - # 4x4 - - Flatten(), - - th.nn.Linear(mult*256*1*1, 1), - ) - - self._reset_weights() - - def _reset_weights(self): - for n, p in self.net.named_parameters(): - if 'bias' in n: - p.data.zero_() - if 'weight' in n: - th.nn.init.kaiming_normal_(p, nonlinearity="leaky_relu") - - def forward(self, x): - out = self.net(x) - return out - class Dataset(th.utils.data.Dataset): def __init__(self, data_dir, imsize): super(Dataset, self).__init__() @@ -783,8 +443,8 @@ def train(args): extras, meta = checkpointer.load_latest() if meta is not None and meta != model_params: - LOG.info("Checkpoint's metaparams differ from CLI, aborting: %s and %s", - meta, model_params) + LOG.info(f"Checkpoint's metaparams differ from CLI, " + f"aborting: {meta} and {model_params}") # Hook interface if args.generator in ["vae", "ae"]: @@ -794,17 +454,18 @@ def train(args): else: LOG.info("Using an AE") interface = VAEInterface(model, lr=args.lr, cuda=args.cuda, - variational=variational, w_kld=args.kld_weight) + variational=variational, + w_kld=args.kld_weight) trainer = ttools.Trainer(interface) # Add callbacks - keys = ["loss_g", "loss_d"] + keys = [] if args.generator == "vae": - keys = ["kld", "data_loss", "loss"] + keys = ["kld", "data_loss", "loss", "logvar"] elif args.generator == "ae": keys = ["data_loss", "loss"] - port = 8097 + port = 8080 trainer.add_callback(ttools.callbacks.ProgressBarCallback( keys=keys, val_keys=keys)) trainer.add_callback(ttools.callbacks.VisdomLoggingCallback( @@ -979,8 +640,10 @@ if __name__ == "__main__": parser.add_argument("--cpu", dest="cuda", action="store_false", default=th.cuda.is_available(), help="if true, use CPU instead of GPU.") - parser.add_argument("--conditional", action="store_true", default=False) - parser.add_argument("--fc", action="store_true", default=False) + parser.add_argument("--no-conditional", dest="conditional", + action="store_false", default=True) + parser.add_argument("--no-fc", dest="fc", action="store_false", + default=True) parser.add_argument("--data_dir", default="mnist", help="path to download and store the data.") @@ -992,19 +655,19 @@ if __name__ == "__main__": "autoencoder") parser_train.add_argument("--freq", type=int, default=100, help="number of steps between visualizations") - parser_train.add_argument("--lr", type=float, default=1e-4, + parser_train.add_argument("--lr", type=float, default=5e-5, help="learning rate") parser_train.add_argument("--kld_weight", type=float, default=1.0, help="scalar weight for the KL divergence term.") parser_train.add_argument("--bs", type=int, default=8, help="batch size") - parser_train.add_argument("--num_epochs", type=int, + parser_train.add_argument("--num_epochs", default=50, type=int, help="max number of epochs") # Vector configs parser_train.add_argument("--paths", type=int, default=1, - help="number of unique vector paths to generate.") + help="number of vector paths to generate.") parser_train.add_argument("--segments", type=int, default=3, help="number of segments per vector path") - parser_train.add_argument("--samples", type=int, default=2, + parser_train.add_argument("--samples", type=int, default=4, help="number of samples in the MC rasterizer") parser_train.add_argument("--zdim", type=int, default=20, help="dimension of the latent space")