fixes mnist_vae
This commit is contained in:
committed by
Tzu-Mao Li
parent
2ec688ebb7
commit
afe2e674d6
@@ -3,11 +3,11 @@
|
|||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
* Train a model:
|
* Train a model:
|
||||||
|
|
||||||
`python mnist_vae.py train`
|
`python mnist_vae.py train`
|
||||||
|
|
||||||
* Generate samples from a trained model:
|
* Generate samples from a trained model:
|
||||||
|
|
||||||
`python mnist_vae.py sample`
|
`python mnist_vae.py sample`
|
||||||
|
|
||||||
@@ -26,7 +26,6 @@ import torchvision.transforms as transforms
|
|||||||
|
|
||||||
import ttools
|
import ttools
|
||||||
import ttools.interfaces
|
import ttools.interfaces
|
||||||
from ttools.modules import networks
|
|
||||||
|
|
||||||
from modules import Flatten
|
from modules import Flatten
|
||||||
|
|
||||||
@@ -50,20 +49,21 @@ def _onehot(label):
|
|||||||
|
|
||||||
def render(canvas_width, canvas_height, shapes, shape_groups, samples=2):
|
def render(canvas_width, canvas_height, shapes, shape_groups, samples=2):
|
||||||
_render = pydiffvg.RenderFunction.apply
|
_render = pydiffvg.RenderFunction.apply
|
||||||
scene_args = pydiffvg.RenderFunction.serialize_scene(\
|
scene_args = pydiffvg.RenderFunction.serialize_scene(
|
||||||
canvas_width, canvas_height, shapes, shape_groups)
|
canvas_width, canvas_height, shapes, shape_groups)
|
||||||
img = _render(canvas_width, # width
|
img = _render(canvas_width,
|
||||||
canvas_height, # height
|
canvas_height,
|
||||||
samples, # num_samples_x
|
samples,
|
||||||
samples, # num_samples_y
|
samples,
|
||||||
0, # seed
|
0,
|
||||||
None, # background
|
None,
|
||||||
*scene_args)
|
*scene_args)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
class MNISTCallback(ttools.callbacks.ImageDisplayCallback):
|
class MNISTCallback(ttools.callbacks.ImageDisplayCallback):
|
||||||
"""Simple callback that visualize generated images during training."""
|
"""Simple callback that visualize generated images during training."""
|
||||||
|
|
||||||
def visualized_image(self, batch, step_data, is_val=False):
|
def visualized_image(self, batch, step_data, is_val=False):
|
||||||
im = step_data["rendering"].detach().cpu()
|
im = step_data["rendering"].detach().cpu()
|
||||||
im = 0.5 + 0.5*im
|
im = 0.5 + 0.5*im
|
||||||
@@ -142,87 +142,17 @@ class VAEInterface(ttools.ModelInterface):
|
|||||||
ret["data_loss"] = data_loss.item()
|
ret["data_loss"] = data_loss.item()
|
||||||
ret["auxdata"] = auxdata
|
ret["auxdata"] = auxdata
|
||||||
ret["rendering"] = rendering
|
ret["rendering"] = rendering
|
||||||
|
ret["logvar"] = logvar.abs().max().item()
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
# def init_validation(self):
|
|
||||||
# return {"count": 0, "loss": 0}
|
|
||||||
#
|
|
||||||
# def update_validation(self, batch, fwd, running_data):
|
|
||||||
# with th.no_grad():
|
|
||||||
# ref = batch[1].to(self.device)
|
|
||||||
# loss = th.nn.functional.mse_loss(fwd, ref)
|
|
||||||
# n = ref.shape[0]
|
|
||||||
#
|
|
||||||
# return {
|
|
||||||
# "loss": running_data["loss"] + loss.item()*n,
|
|
||||||
# "count": running_data["count"] + n
|
|
||||||
# }
|
|
||||||
#
|
|
||||||
# def finalize_validation(self, running_data):
|
|
||||||
# return {
|
|
||||||
# "loss": running_data["loss"] / running_data["count"]
|
|
||||||
# }
|
|
||||||
|
|
||||||
|
|
||||||
class MNISTGenerator(th.nn.Module):
|
|
||||||
def __init__(self, imsize=28):
|
|
||||||
super(MNISTGenerator, self).__init__()
|
|
||||||
if imsize != 28:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
mul = 2
|
|
||||||
self.convnet = th.nn.Sequential(
|
|
||||||
# 4x4
|
|
||||||
th.nn.ConvTranspose2d(16 + 1, mul*32, 4, padding=1, stride=2),
|
|
||||||
th.nn.LeakyReLU(inplace=True),
|
|
||||||
th.nn.Conv2d(mul*32, mul*32, 3, padding=1),
|
|
||||||
th.nn.LeakyReLU(inplace=True),
|
|
||||||
|
|
||||||
# 8x8
|
|
||||||
th.nn.ConvTranspose2d(mul*32, mul*64, 4, padding=1, stride=2),
|
|
||||||
th.nn.LeakyReLU(inplace=True),
|
|
||||||
th.nn.Conv2d(mul*64, mul*64, 3, padding=1),
|
|
||||||
th.nn.LeakyReLU(inplace=True),
|
|
||||||
|
|
||||||
# 16x16
|
|
||||||
th.nn.ConvTranspose2d(mul*64, mul*128, 4, padding=1, stride=2),
|
|
||||||
th.nn.LeakyReLU(inplace=True),
|
|
||||||
th.nn.Conv2d(mul*128, mul*128, 3, padding=1),
|
|
||||||
th.nn.LeakyReLU(inplace=True),
|
|
||||||
|
|
||||||
# 32x32
|
|
||||||
th.nn.Conv2d(mul*128, mul*128, 3, padding=1),
|
|
||||||
th.nn.LeakyReLU(inplace=True),
|
|
||||||
|
|
||||||
th.nn.Conv2d(mul*128, mul*128, 3, padding=1),
|
|
||||||
th.nn.LeakyReLU(inplace=True),
|
|
||||||
|
|
||||||
th.nn.Conv2d(mul*128, 1, 1),
|
|
||||||
# th.nn.Tanh(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, im, label):
|
|
||||||
bs = im.shape[0]
|
|
||||||
|
|
||||||
# sample a hidden vector
|
|
||||||
z = th.randn(bs, 16, 4, 4).to(im.device)
|
|
||||||
|
|
||||||
# make the model conditional
|
|
||||||
in_ = th.cat([z, label.float().view(bs, 1, 1, 1).repeat(1, 1, 4, 4)], 1)
|
|
||||||
|
|
||||||
out = self.convnet(in_)
|
|
||||||
return out, None
|
|
||||||
|
|
||||||
|
|
||||||
class VectorMNISTVAE(th.nn.Module):
|
class VectorMNISTVAE(th.nn.Module):
|
||||||
def __init__(self, imsize=28, paths=4, segments=5, samples=2, zdim=128,
|
def __init__(self, imsize=28, paths=4, segments=5, samples=2, zdim=128,
|
||||||
conditional=False, variational=True, raster=False, fc=False):
|
conditional=False, variational=True, raster=False, fc=False,
|
||||||
|
stroke_width=None):
|
||||||
super(VectorMNISTVAE, self).__init__()
|
super(VectorMNISTVAE, self).__init__()
|
||||||
|
|
||||||
# if imsize != 28:
|
|
||||||
# raise NotImplementedError()
|
|
||||||
|
|
||||||
self.samples = samples
|
self.samples = samples
|
||||||
self.imsize = imsize
|
self.imsize = imsize
|
||||||
self.paths = paths
|
self.paths = paths
|
||||||
@@ -231,6 +161,12 @@ class VectorMNISTVAE(th.nn.Module):
|
|||||||
self.conditional = conditional
|
self.conditional = conditional
|
||||||
self.variational = variational
|
self.variational = variational
|
||||||
|
|
||||||
|
if stroke_width is None:
|
||||||
|
self.stroke_width = (1.0, 3.0)
|
||||||
|
LOG.warning("Setting default stroke with %s", self.stroke_width)
|
||||||
|
else:
|
||||||
|
self.stroke_width = stroke_width
|
||||||
|
|
||||||
ncond = 0
|
ncond = 0
|
||||||
if self.conditional: # one hot encoded input for conditional model
|
if self.conditional: # one hot encoded input for conditional model
|
||||||
ncond = 10
|
ncond = 10
|
||||||
@@ -278,8 +214,8 @@ class VectorMNISTVAE(th.nn.Module):
|
|||||||
th.nn.SELU(inplace=True),
|
th.nn.SELU(inplace=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
self.raster = raster
|
self.raster = raster
|
||||||
|
|
||||||
if self.raster:
|
if self.raster:
|
||||||
self.raster_decoder = th.nn.Sequential(
|
self.raster_decoder = th.nn.Sequential(
|
||||||
th.nn.Linear(nc, imsize*imsize),
|
th.nn.Linear(nc, imsize*imsize),
|
||||||
@@ -293,59 +229,21 @@ class VectorMNISTVAE(th.nn.Module):
|
|||||||
|
|
||||||
self.width_predictor = th.nn.Sequential(
|
self.width_predictor = th.nn.Sequential(
|
||||||
th.nn.Linear(nc, self.paths),
|
th.nn.Linear(nc, self.paths),
|
||||||
th.nn.Tanh()
|
th.nn.Sigmoid()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.alpha_predictor = th.nn.Sequential(
|
self.alpha_predictor = th.nn.Sequential(
|
||||||
th.nn.Linear(nc, self.paths),
|
th.nn.Linear(nc, self.paths),
|
||||||
th.nn.Tanh()
|
th.nn.Sigmoid()
|
||||||
)
|
)
|
||||||
|
|
||||||
self._reset_weights()
|
|
||||||
|
|
||||||
def _reset_weights(self):
|
|
||||||
for n, p in self.encoder.named_parameters():
|
|
||||||
if 'bias' in n:
|
|
||||||
p.data.zero_()
|
|
||||||
elif 'weight' in n:
|
|
||||||
th.nn.init.kaiming_normal_(p.data, nonlinearity="leaky_relu")
|
|
||||||
|
|
||||||
th.nn.init.kaiming_normal_(self.mu_predictor.weight.data, nonlinearity="linear")
|
|
||||||
if self.variational:
|
|
||||||
th.nn.init.kaiming_normal_(self.logvar_predictor.weight.data, nonlinearity="linear")
|
|
||||||
|
|
||||||
for n, p in self.decoder.named_parameters():
|
|
||||||
if 'bias' in n:
|
|
||||||
p.data.zero_()
|
|
||||||
elif 'weight' in n:
|
|
||||||
th.nn.init.kaiming_normal_(p, nonlinearity="linear")
|
|
||||||
|
|
||||||
if not self.raster:
|
|
||||||
for n, p in self.point_predictor.named_parameters():
|
|
||||||
pass
|
|
||||||
# if 'bias' in n:
|
|
||||||
# p.data.zero_()
|
|
||||||
# if 'weight' in n:
|
|
||||||
# th.nn.init.orthogonal_(p)
|
|
||||||
|
|
||||||
for n, p in self.width_predictor.named_parameters():
|
|
||||||
if 'bias' in n:
|
|
||||||
p.data.zero_()
|
|
||||||
elif 'weight' in n:
|
|
||||||
th.nn.init.orthogonal_(p)
|
|
||||||
|
|
||||||
for n, p in self.alpha_predictor.named_parameters():
|
|
||||||
if 'bias' in n:
|
|
||||||
p.data.zero_()
|
|
||||||
elif 'weight' in n:
|
|
||||||
th.nn.init.orthogonal_(p)
|
|
||||||
|
|
||||||
def encode(self, im, label):
|
def encode(self, im, label):
|
||||||
bs, _, h, w = im.shape
|
bs, _, h, w = im.shape
|
||||||
if self.conditional:
|
if self.conditional:
|
||||||
label_onehot = _onehot(label)
|
label_onehot = _onehot(label)
|
||||||
if not self.fc:
|
if not self.fc:
|
||||||
label_onehot = label_onehot.view(bs, 10, 1, 1).repeat(1, 1, h, w)
|
label_onehot = label_onehot.view(
|
||||||
|
bs, 10, 1, 1).repeat(1, 1, h, w)
|
||||||
out = self.encoder(th.cat([im, label_onehot], 1))
|
out = self.encoder(th.cat([im, label_onehot], 1))
|
||||||
else:
|
else:
|
||||||
out = self.encoder(th.cat([im.view(bs, -1), label_onehot], 1))
|
out = self.encoder(th.cat([im.view(bs, -1), label_onehot], 1))
|
||||||
@@ -365,7 +263,9 @@ class VectorMNISTVAE(th.nn.Module):
|
|||||||
|
|
||||||
def _decode_features(self, z, label):
|
def _decode_features(self, z, label):
|
||||||
if label is not None:
|
if label is not None:
|
||||||
assert self.conditional, "decoding with an input label requires a conditional AE"
|
if not self.conditional:
|
||||||
|
raise ValueError("decoding with an input label "
|
||||||
|
"requires a conditional AE")
|
||||||
label_onehot = _onehot(label)
|
label_onehot = _onehot(label)
|
||||||
z = th.cat([z, label_onehot], 1)
|
z = th.cat([z, label_onehot], 1)
|
||||||
|
|
||||||
@@ -378,7 +278,8 @@ class VectorMNISTVAE(th.nn.Module):
|
|||||||
feats = self._decode_features(z, label)
|
feats = self._decode_features(z, label)
|
||||||
|
|
||||||
if self.raster:
|
if self.raster:
|
||||||
out = self.raster_decoder(feats).view(bs, 1, self.imsize, self.imsize)
|
out = self.raster_decoder(feats).view(
|
||||||
|
bs, 1, self.imsize, self.imsize)
|
||||||
return out, {}
|
return out, {}
|
||||||
|
|
||||||
all_points = self.point_predictor(feats)
|
all_points = self.point_predictor(feats)
|
||||||
@@ -389,7 +290,10 @@ class VectorMNISTVAE(th.nn.Module):
|
|||||||
if False:
|
if False:
|
||||||
all_widths = th.ones(bs, self.paths) * 0.5
|
all_widths = th.ones(bs, self.paths) * 0.5
|
||||||
else:
|
else:
|
||||||
all_widths = self.width_predictor(feats) * 1.5 + .25
|
all_widths = self.width_predictor(feats)
|
||||||
|
min_width = self.stroke_width[0]
|
||||||
|
max_width = self.stroke_width[1]
|
||||||
|
all_widths = (max_width - min_width) * all_widths + min_width
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
all_alphas = th.ones(bs, self.paths)
|
all_alphas = th.ones(bs, self.paths)
|
||||||
@@ -426,14 +330,16 @@ class VectorMNISTVAE(th.nn.Module):
|
|||||||
[shapes, shape_groups, (self.imsize, self.imsize)])
|
[shapes, shape_groups, (self.imsize, self.imsize)])
|
||||||
|
|
||||||
# Rasterize
|
# Rasterize
|
||||||
out = render(self.imsize, self.imsize, shapes, shape_groups, samples=self.samples)
|
out = render(self.imsize, self.imsize, shapes, shape_groups,
|
||||||
|
samples=self.samples)
|
||||||
|
|
||||||
# Torch format, discard alpha, make gray
|
# Torch format, discard alpha, make gray
|
||||||
out = out.permute(2, 0, 1).view(4, self.imsize, self.imsize)[:3].mean(0, keepdim=True)
|
out = out.permute(2, 0, 1).view(
|
||||||
|
4, self.imsize, self.imsize)[:3].mean(0, keepdim=True)
|
||||||
|
|
||||||
outputs.append(out)
|
outputs.append(out)
|
||||||
|
|
||||||
output = th.stack(outputs).to(z.device)
|
output = th.stack(outputs).to(z.device)
|
||||||
|
|
||||||
auxdata = {
|
auxdata = {
|
||||||
"points": all_points,
|
"points": all_points,
|
||||||
@@ -446,8 +352,6 @@ class VectorMNISTVAE(th.nn.Module):
|
|||||||
return output, auxdata
|
return output, auxdata
|
||||||
|
|
||||||
def forward(self, im, label):
|
def forward(self, im, label):
|
||||||
bs = im.shape[0]
|
|
||||||
|
|
||||||
if self.variational:
|
if self.variational:
|
||||||
mu, logvar = self.encode(im, label)
|
mu, logvar = self.encode(im, label)
|
||||||
z = self.reparameterize(mu, logvar)
|
z = self.reparameterize(mu, logvar)
|
||||||
@@ -457,9 +361,9 @@ class VectorMNISTVAE(th.nn.Module):
|
|||||||
logvar = None
|
logvar = None
|
||||||
|
|
||||||
if self.conditional:
|
if self.conditional:
|
||||||
output, aux = self.decode(z, label=label)
|
output, aux = self.decode(z, label=label)
|
||||||
else:
|
else:
|
||||||
output, aux = self.decode(z)
|
output, aux = self.decode(z)
|
||||||
|
|
||||||
aux["logvar"] = logvar
|
aux["logvar"] = logvar
|
||||||
aux["mu"] = mu
|
aux["mu"] = mu
|
||||||
@@ -467,250 +371,6 @@ class VectorMNISTVAE(th.nn.Module):
|
|||||||
return output, aux
|
return output, aux
|
||||||
|
|
||||||
|
|
||||||
class VectorMNISTGenerator(th.nn.Module):
|
|
||||||
def __init__(self, imsize=28, paths=4, segments=5, samples=2, conditional=False,
|
|
||||||
zdim=20, fc=False):
|
|
||||||
super(VectorMNISTGenerator, self).__init__()
|
|
||||||
if imsize != 28:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
self.samples = samples
|
|
||||||
self.imsize = imsize
|
|
||||||
self.paths = paths
|
|
||||||
self.segments = segments
|
|
||||||
self.conditional = conditional
|
|
||||||
self.zdim = zdim
|
|
||||||
self.fc = fc
|
|
||||||
|
|
||||||
ncond = 0
|
|
||||||
if self.conditional: # one hot encoded input for conditional model
|
|
||||||
ncond = 10
|
|
||||||
|
|
||||||
nc = 1024
|
|
||||||
self.trunk = th.nn.Sequential(
|
|
||||||
th.nn.Linear(zdim + ncond, nc), # noise + one-hot
|
|
||||||
th.nn.SELU(inplace=True),
|
|
||||||
|
|
||||||
# th.nn.Linear(nc, nc),
|
|
||||||
# th.nn.SELU(inplace=True),
|
|
||||||
|
|
||||||
th.nn.Linear(nc, nc),
|
|
||||||
th.nn.SELU(inplace=True),
|
|
||||||
|
|
||||||
# th.nn.Linear(nc, nc),
|
|
||||||
# th.nn.SELU(inplace=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4 points bezier so n_segments -> 3*n_segments + 1 points
|
|
||||||
self.point_predictor = th.nn.Sequential(
|
|
||||||
th.nn.Linear(nc, 2*self.paths*(self.segments*3+1)),
|
|
||||||
# th.nn.Linear(nc, 2*self.paths*(self.segments*1+1)),
|
|
||||||
th.nn.Tanh() # bound spatial extent
|
|
||||||
)
|
|
||||||
|
|
||||||
self.width_predictor = th.nn.Sequential(
|
|
||||||
th.nn.Linear(nc, self.paths),
|
|
||||||
th.nn.Tanh()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.alpha_predictor = th.nn.Sequential(
|
|
||||||
th.nn.Linear(nc, self.paths),
|
|
||||||
th.nn.Tanh()
|
|
||||||
)
|
|
||||||
|
|
||||||
# self.postprocessor = th.nn.Sequential(
|
|
||||||
# th.nn.Conv2d(1, 32, 3, padding=1),
|
|
||||||
# th.nn.LeakyReLU(inplace=True),
|
|
||||||
# th.nn.Conv2d(32, 1, 1),
|
|
||||||
# )
|
|
||||||
self._reset_weights()
|
|
||||||
|
|
||||||
def _reset_weights(self):
|
|
||||||
for n, p in self.trunk.named_parameters():
|
|
||||||
if 'bias' in n:
|
|
||||||
p.data.zero_()
|
|
||||||
elif 'weight' in n:
|
|
||||||
th.nn.init.kaiming_normal_(p)
|
|
||||||
p.data.mul_(0.7)
|
|
||||||
# th.nn.init.kaiming_normal_(p, nonlinearity="leaky_relu")
|
|
||||||
|
|
||||||
for n, p in self.point_predictor.named_parameters():
|
|
||||||
# if 'bias' in n:
|
|
||||||
# p.data.zero_()
|
|
||||||
if 'weight' in n:
|
|
||||||
th.nn.init.orthogonal_(p)
|
|
||||||
# th.nn.init.kaiming_normal_(p, nonlinearity="tanh")
|
|
||||||
|
|
||||||
for n, p in self.width_predictor.named_parameters():
|
|
||||||
if 'bias' in n:
|
|
||||||
p.data.zero_()
|
|
||||||
elif 'weight' in n:
|
|
||||||
# th.nn.init.orthogonal_(p)
|
|
||||||
th.nn.init.kaiming_normal_(p, nonlinearity="tanh")
|
|
||||||
|
|
||||||
for n, p in self.alpha_predictor.named_parameters():
|
|
||||||
if 'bias' in n:
|
|
||||||
p.data.zero_()
|
|
||||||
elif 'weight' in n:
|
|
||||||
th.nn.init.kaiming_normal_(p, nonlinearity="tanh")
|
|
||||||
# th.nn.init.orthogonal_(p)
|
|
||||||
|
|
||||||
def sample_z(self, bs):
|
|
||||||
return th.randn(bs, self.zdim)
|
|
||||||
|
|
||||||
def gen_sample(self, z, label=None):
|
|
||||||
bs = z.shape[0]
|
|
||||||
if self.conditional:
|
|
||||||
if label is None:
|
|
||||||
raise ValueError("GAN is conditional, please provide a label")
|
|
||||||
|
|
||||||
# One-hot encoding of the image label
|
|
||||||
label_onehot = _onehot(label)
|
|
||||||
|
|
||||||
# get some embedding
|
|
||||||
in_ = th.cat([z, label_onehot.float()], 1)
|
|
||||||
else:
|
|
||||||
in_ = z
|
|
||||||
|
|
||||||
feats = self.trunk(in_)
|
|
||||||
|
|
||||||
all_points = self.point_predictor(feats)
|
|
||||||
all_points = all_points.view(bs, self.paths, -1, 2)
|
|
||||||
|
|
||||||
if False:
|
|
||||||
all_alphas = th.ones(bs, self.paths)
|
|
||||||
else:
|
|
||||||
all_alphas = self.alpha_predictor(feats)
|
|
||||||
|
|
||||||
# stroke size between 0.5 and 3.5 px
|
|
||||||
if False:
|
|
||||||
all_widths = th.ones(bs, self.paths) * 1
|
|
||||||
else:
|
|
||||||
all_widths = self.width_predictor(feats)
|
|
||||||
all_widths = 1.5*all_widths + 0.5
|
|
||||||
|
|
||||||
all_points = all_points*(self.imsize//2) + self.imsize//2
|
|
||||||
|
|
||||||
# Process the batch sequentially
|
|
||||||
outputs = []
|
|
||||||
for k in range(bs):
|
|
||||||
# Get point parameters from network
|
|
||||||
shapes = []
|
|
||||||
shape_groups = []
|
|
||||||
for p in range(self.paths):
|
|
||||||
points = all_points[k, p].contiguous().cpu()
|
|
||||||
# num_ctrl_pts = th.zeros(self.segments, dtype=th.int32)+0
|
|
||||||
num_ctrl_pts = th.zeros(self.segments, dtype=th.int32)+2
|
|
||||||
width = all_widths[k, p].cpu()
|
|
||||||
alpha = all_alphas[k, p].cpu()
|
|
||||||
color = th.cat([th.ones(3), alpha.view(1,)])
|
|
||||||
path = pydiffvg.Path(
|
|
||||||
num_control_points=num_ctrl_pts, points=points,
|
|
||||||
stroke_width=width, is_closed=False)
|
|
||||||
shapes.append(path)
|
|
||||||
path_group = pydiffvg.ShapeGroup(
|
|
||||||
shape_ids=th.tensor([len(shapes) - 1]),
|
|
||||||
fill_color=None,
|
|
||||||
stroke_color=color)
|
|
||||||
shape_groups.append(path_group)
|
|
||||||
|
|
||||||
# Rasterize
|
|
||||||
out = render(self.imsize, self.imsize, shapes, shape_groups, samples=self.samples)
|
|
||||||
|
|
||||||
# Torch format, discard alpha, make gray
|
|
||||||
out = out.permute(2, 0, 1).view(4, self.imsize, self.imsize)[:3].mean(0, keepdim=True)
|
|
||||||
|
|
||||||
outputs.append(out)
|
|
||||||
|
|
||||||
output = th.stack(outputs).to(z.device)
|
|
||||||
aux_data = {
|
|
||||||
"points": all_points,
|
|
||||||
"raw_vector": output,
|
|
||||||
}
|
|
||||||
|
|
||||||
# output = self.postprocessor(output)
|
|
||||||
|
|
||||||
# map to [-1, 1]
|
|
||||||
output = output*2.0 - 1.0
|
|
||||||
|
|
||||||
return output, aux_data
|
|
||||||
|
|
||||||
def forward(self, im, label):
|
|
||||||
bs = label.shape[0]
|
|
||||||
|
|
||||||
# sample a hidden vector (same dim as the raster version)
|
|
||||||
z = self.sample_z(bs).to(im.device)
|
|
||||||
if args.conditional:
|
|
||||||
return self.gen_sample(z, label=label)
|
|
||||||
else:
|
|
||||||
return self.gen_sample(z)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Discriminator(th.nn.Module):
|
|
||||||
def __init__(self, conditional=False, fc=False):
|
|
||||||
super(Discriminator, self).__init__()
|
|
||||||
|
|
||||||
self.conditional = conditional
|
|
||||||
|
|
||||||
ncond = 0
|
|
||||||
if self.conditional: # one hot encoded input for conditional model
|
|
||||||
ncond = 10
|
|
||||||
|
|
||||||
sn = th.nn.utils.spectral_norm
|
|
||||||
# sn = lambda x: x
|
|
||||||
|
|
||||||
self.fc = fc
|
|
||||||
|
|
||||||
mult = 2
|
|
||||||
if self.fc:
|
|
||||||
self.net = th.nn.Sequential(
|
|
||||||
Flatten(),
|
|
||||||
th.nn.Linear(28*28 + ncond, mult*256),
|
|
||||||
th.nn.LeakyReLU(0.2, inplace=True),
|
|
||||||
|
|
||||||
# th.nn.Linear(mult*256, mult*256, 4),
|
|
||||||
# th.nn.LeakyReLU(0.2, inplace=True),
|
|
||||||
# th.nn.Dropout(0.5),
|
|
||||||
|
|
||||||
th.nn.Linear(mult*256, mult*256, 4),
|
|
||||||
th.nn.LeakyReLU(0.2, inplace=True),
|
|
||||||
|
|
||||||
th.nn.Linear(mult*256*1*1, 1),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.net = th.nn.Sequential(
|
|
||||||
th.nn.Conv2d(1 + ncond, mult*64, 4, padding=0, stride=2),
|
|
||||||
|
|
||||||
th.nn.LeakyReLU(0.2, inplace=True),
|
|
||||||
# 16x16
|
|
||||||
|
|
||||||
sn(th.nn.Conv2d(mult*64, mult*128, 4, padding=0, stride=2)),
|
|
||||||
th.nn.LeakyReLU(0.2, inplace=True),
|
|
||||||
# 8x8
|
|
||||||
|
|
||||||
sn(th.nn.Conv2d(mult*128, mult*256, 4, padding=0, stride=2)),
|
|
||||||
th.nn.LeakyReLU(0.2, inplace=True),
|
|
||||||
# 4x4
|
|
||||||
|
|
||||||
Flatten(),
|
|
||||||
|
|
||||||
th.nn.Linear(mult*256*1*1, 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
self._reset_weights()
|
|
||||||
|
|
||||||
def _reset_weights(self):
|
|
||||||
for n, p in self.net.named_parameters():
|
|
||||||
if 'bias' in n:
|
|
||||||
p.data.zero_()
|
|
||||||
if 'weight' in n:
|
|
||||||
th.nn.init.kaiming_normal_(p, nonlinearity="leaky_relu")
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = self.net(x)
|
|
||||||
return out
|
|
||||||
|
|
||||||
class Dataset(th.utils.data.Dataset):
|
class Dataset(th.utils.data.Dataset):
|
||||||
def __init__(self, data_dir, imsize):
|
def __init__(self, data_dir, imsize):
|
||||||
super(Dataset, self).__init__()
|
super(Dataset, self).__init__()
|
||||||
@@ -783,8 +443,8 @@ def train(args):
|
|||||||
extras, meta = checkpointer.load_latest()
|
extras, meta = checkpointer.load_latest()
|
||||||
|
|
||||||
if meta is not None and meta != model_params:
|
if meta is not None and meta != model_params:
|
||||||
LOG.info("Checkpoint's metaparams differ from CLI, aborting: %s and %s",
|
LOG.info(f"Checkpoint's metaparams differ from CLI, "
|
||||||
meta, model_params)
|
f"aborting: {meta} and {model_params}")
|
||||||
|
|
||||||
# Hook interface
|
# Hook interface
|
||||||
if args.generator in ["vae", "ae"]:
|
if args.generator in ["vae", "ae"]:
|
||||||
@@ -794,17 +454,18 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
LOG.info("Using an AE")
|
LOG.info("Using an AE")
|
||||||
interface = VAEInterface(model, lr=args.lr, cuda=args.cuda,
|
interface = VAEInterface(model, lr=args.lr, cuda=args.cuda,
|
||||||
variational=variational, w_kld=args.kld_weight)
|
variational=variational,
|
||||||
|
w_kld=args.kld_weight)
|
||||||
|
|
||||||
trainer = ttools.Trainer(interface)
|
trainer = ttools.Trainer(interface)
|
||||||
|
|
||||||
# Add callbacks
|
# Add callbacks
|
||||||
keys = ["loss_g", "loss_d"]
|
keys = []
|
||||||
if args.generator == "vae":
|
if args.generator == "vae":
|
||||||
keys = ["kld", "data_loss", "loss"]
|
keys = ["kld", "data_loss", "loss", "logvar"]
|
||||||
elif args.generator == "ae":
|
elif args.generator == "ae":
|
||||||
keys = ["data_loss", "loss"]
|
keys = ["data_loss", "loss"]
|
||||||
port = 8097
|
port = 8080
|
||||||
trainer.add_callback(ttools.callbacks.ProgressBarCallback(
|
trainer.add_callback(ttools.callbacks.ProgressBarCallback(
|
||||||
keys=keys, val_keys=keys))
|
keys=keys, val_keys=keys))
|
||||||
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
|
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
|
||||||
@@ -979,8 +640,10 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--cpu", dest="cuda", action="store_false",
|
parser.add_argument("--cpu", dest="cuda", action="store_false",
|
||||||
default=th.cuda.is_available(),
|
default=th.cuda.is_available(),
|
||||||
help="if true, use CPU instead of GPU.")
|
help="if true, use CPU instead of GPU.")
|
||||||
parser.add_argument("--conditional", action="store_true", default=False)
|
parser.add_argument("--no-conditional", dest="conditional",
|
||||||
parser.add_argument("--fc", action="store_true", default=False)
|
action="store_false", default=True)
|
||||||
|
parser.add_argument("--no-fc", dest="fc", action="store_false",
|
||||||
|
default=True)
|
||||||
parser.add_argument("--data_dir", default="mnist",
|
parser.add_argument("--data_dir", default="mnist",
|
||||||
help="path to download and store the data.")
|
help="path to download and store the data.")
|
||||||
|
|
||||||
@@ -992,19 +655,19 @@ if __name__ == "__main__":
|
|||||||
"autoencoder")
|
"autoencoder")
|
||||||
parser_train.add_argument("--freq", type=int, default=100,
|
parser_train.add_argument("--freq", type=int, default=100,
|
||||||
help="number of steps between visualizations")
|
help="number of steps between visualizations")
|
||||||
parser_train.add_argument("--lr", type=float, default=1e-4,
|
parser_train.add_argument("--lr", type=float, default=5e-5,
|
||||||
help="learning rate")
|
help="learning rate")
|
||||||
parser_train.add_argument("--kld_weight", type=float, default=1.0,
|
parser_train.add_argument("--kld_weight", type=float, default=1.0,
|
||||||
help="scalar weight for the KL divergence term.")
|
help="scalar weight for the KL divergence term.")
|
||||||
parser_train.add_argument("--bs", type=int, default=8, help="batch size")
|
parser_train.add_argument("--bs", type=int, default=8, help="batch size")
|
||||||
parser_train.add_argument("--num_epochs", type=int,
|
parser_train.add_argument("--num_epochs", default=50, type=int,
|
||||||
help="max number of epochs")
|
help="max number of epochs")
|
||||||
# Vector configs
|
# Vector configs
|
||||||
parser_train.add_argument("--paths", type=int, default=1,
|
parser_train.add_argument("--paths", type=int, default=1,
|
||||||
help="number of unique vector paths to generate.")
|
help="number of vector paths to generate.")
|
||||||
parser_train.add_argument("--segments", type=int, default=3,
|
parser_train.add_argument("--segments", type=int, default=3,
|
||||||
help="number of segments per vector path")
|
help="number of segments per vector path")
|
||||||
parser_train.add_argument("--samples", type=int, default=2,
|
parser_train.add_argument("--samples", type=int, default=4,
|
||||||
help="number of samples in the MC rasterizer")
|
help="number of samples in the MC rasterizer")
|
||||||
parser_train.add_argument("--zdim", type=int, default=20,
|
parser_train.add_argument("--zdim", type=int, default=20,
|
||||||
help="dimension of the latent space")
|
help="dimension of the latent space")
|
||||||
|
Reference in New Issue
Block a user