fixes mnist_vae
This commit is contained in:
committed by
Tzu-Mao Li
parent
2ec688ebb7
commit
afe2e674d6
@@ -3,11 +3,11 @@
|
||||
|
||||
Usage:
|
||||
|
||||
* Train a model:
|
||||
* Train a model:
|
||||
|
||||
`python mnist_vae.py train`
|
||||
|
||||
* Generate samples from a trained model:
|
||||
* Generate samples from a trained model:
|
||||
|
||||
`python mnist_vae.py sample`
|
||||
|
||||
@@ -26,7 +26,6 @@ import torchvision.transforms as transforms
|
||||
|
||||
import ttools
|
||||
import ttools.interfaces
|
||||
from ttools.modules import networks
|
||||
|
||||
from modules import Flatten
|
||||
|
||||
@@ -50,20 +49,21 @@ def _onehot(label):
|
||||
|
||||
def render(canvas_width, canvas_height, shapes, shape_groups, samples=2):
|
||||
_render = pydiffvg.RenderFunction.apply
|
||||
scene_args = pydiffvg.RenderFunction.serialize_scene(\
|
||||
scene_args = pydiffvg.RenderFunction.serialize_scene(
|
||||
canvas_width, canvas_height, shapes, shape_groups)
|
||||
img = _render(canvas_width, # width
|
||||
canvas_height, # height
|
||||
samples, # num_samples_x
|
||||
samples, # num_samples_y
|
||||
0, # seed
|
||||
None, # background
|
||||
*scene_args)
|
||||
img = _render(canvas_width,
|
||||
canvas_height,
|
||||
samples,
|
||||
samples,
|
||||
0,
|
||||
None,
|
||||
*scene_args)
|
||||
return img
|
||||
|
||||
|
||||
class MNISTCallback(ttools.callbacks.ImageDisplayCallback):
|
||||
"""Simple callback that visualize generated images during training."""
|
||||
|
||||
def visualized_image(self, batch, step_data, is_val=False):
|
||||
im = step_data["rendering"].detach().cpu()
|
||||
im = 0.5 + 0.5*im
|
||||
@@ -142,87 +142,17 @@ class VAEInterface(ttools.ModelInterface):
|
||||
ret["data_loss"] = data_loss.item()
|
||||
ret["auxdata"] = auxdata
|
||||
ret["rendering"] = rendering
|
||||
ret["logvar"] = logvar.abs().max().item()
|
||||
|
||||
return ret
|
||||
|
||||
# def init_validation(self):
|
||||
# return {"count": 0, "loss": 0}
|
||||
#
|
||||
# def update_validation(self, batch, fwd, running_data):
|
||||
# with th.no_grad():
|
||||
# ref = batch[1].to(self.device)
|
||||
# loss = th.nn.functional.mse_loss(fwd, ref)
|
||||
# n = ref.shape[0]
|
||||
#
|
||||
# return {
|
||||
# "loss": running_data["loss"] + loss.item()*n,
|
||||
# "count": running_data["count"] + n
|
||||
# }
|
||||
#
|
||||
# def finalize_validation(self, running_data):
|
||||
# return {
|
||||
# "loss": running_data["loss"] / running_data["count"]
|
||||
# }
|
||||
|
||||
|
||||
class MNISTGenerator(th.nn.Module):
|
||||
def __init__(self, imsize=28):
|
||||
super(MNISTGenerator, self).__init__()
|
||||
if imsize != 28:
|
||||
raise NotImplementedError()
|
||||
|
||||
mul = 2
|
||||
self.convnet = th.nn.Sequential(
|
||||
# 4x4
|
||||
th.nn.ConvTranspose2d(16 + 1, mul*32, 4, padding=1, stride=2),
|
||||
th.nn.LeakyReLU(inplace=True),
|
||||
th.nn.Conv2d(mul*32, mul*32, 3, padding=1),
|
||||
th.nn.LeakyReLU(inplace=True),
|
||||
|
||||
# 8x8
|
||||
th.nn.ConvTranspose2d(mul*32, mul*64, 4, padding=1, stride=2),
|
||||
th.nn.LeakyReLU(inplace=True),
|
||||
th.nn.Conv2d(mul*64, mul*64, 3, padding=1),
|
||||
th.nn.LeakyReLU(inplace=True),
|
||||
|
||||
# 16x16
|
||||
th.nn.ConvTranspose2d(mul*64, mul*128, 4, padding=1, stride=2),
|
||||
th.nn.LeakyReLU(inplace=True),
|
||||
th.nn.Conv2d(mul*128, mul*128, 3, padding=1),
|
||||
th.nn.LeakyReLU(inplace=True),
|
||||
|
||||
# 32x32
|
||||
th.nn.Conv2d(mul*128, mul*128, 3, padding=1),
|
||||
th.nn.LeakyReLU(inplace=True),
|
||||
|
||||
th.nn.Conv2d(mul*128, mul*128, 3, padding=1),
|
||||
th.nn.LeakyReLU(inplace=True),
|
||||
|
||||
th.nn.Conv2d(mul*128, 1, 1),
|
||||
# th.nn.Tanh(),
|
||||
)
|
||||
|
||||
def forward(self, im, label):
|
||||
bs = im.shape[0]
|
||||
|
||||
# sample a hidden vector
|
||||
z = th.randn(bs, 16, 4, 4).to(im.device)
|
||||
|
||||
# make the model conditional
|
||||
in_ = th.cat([z, label.float().view(bs, 1, 1, 1).repeat(1, 1, 4, 4)], 1)
|
||||
|
||||
out = self.convnet(in_)
|
||||
return out, None
|
||||
|
||||
|
||||
class VectorMNISTVAE(th.nn.Module):
|
||||
def __init__(self, imsize=28, paths=4, segments=5, samples=2, zdim=128,
|
||||
conditional=False, variational=True, raster=False, fc=False):
|
||||
conditional=False, variational=True, raster=False, fc=False,
|
||||
stroke_width=None):
|
||||
super(VectorMNISTVAE, self).__init__()
|
||||
|
||||
# if imsize != 28:
|
||||
# raise NotImplementedError()
|
||||
|
||||
self.samples = samples
|
||||
self.imsize = imsize
|
||||
self.paths = paths
|
||||
@@ -231,6 +161,12 @@ class VectorMNISTVAE(th.nn.Module):
|
||||
self.conditional = conditional
|
||||
self.variational = variational
|
||||
|
||||
if stroke_width is None:
|
||||
self.stroke_width = (1.0, 3.0)
|
||||
LOG.warning("Setting default stroke with %s", self.stroke_width)
|
||||
else:
|
||||
self.stroke_width = stroke_width
|
||||
|
||||
ncond = 0
|
||||
if self.conditional: # one hot encoded input for conditional model
|
||||
ncond = 10
|
||||
@@ -278,8 +214,8 @@ class VectorMNISTVAE(th.nn.Module):
|
||||
th.nn.SELU(inplace=True),
|
||||
)
|
||||
|
||||
|
||||
self.raster = raster
|
||||
|
||||
if self.raster:
|
||||
self.raster_decoder = th.nn.Sequential(
|
||||
th.nn.Linear(nc, imsize*imsize),
|
||||
@@ -293,59 +229,21 @@ class VectorMNISTVAE(th.nn.Module):
|
||||
|
||||
self.width_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(nc, self.paths),
|
||||
th.nn.Tanh()
|
||||
th.nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.alpha_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(nc, self.paths),
|
||||
th.nn.Tanh()
|
||||
th.nn.Sigmoid()
|
||||
)
|
||||
|
||||
self._reset_weights()
|
||||
|
||||
def _reset_weights(self):
|
||||
for n, p in self.encoder.named_parameters():
|
||||
if 'bias' in n:
|
||||
p.data.zero_()
|
||||
elif 'weight' in n:
|
||||
th.nn.init.kaiming_normal_(p.data, nonlinearity="leaky_relu")
|
||||
|
||||
th.nn.init.kaiming_normal_(self.mu_predictor.weight.data, nonlinearity="linear")
|
||||
if self.variational:
|
||||
th.nn.init.kaiming_normal_(self.logvar_predictor.weight.data, nonlinearity="linear")
|
||||
|
||||
for n, p in self.decoder.named_parameters():
|
||||
if 'bias' in n:
|
||||
p.data.zero_()
|
||||
elif 'weight' in n:
|
||||
th.nn.init.kaiming_normal_(p, nonlinearity="linear")
|
||||
|
||||
if not self.raster:
|
||||
for n, p in self.point_predictor.named_parameters():
|
||||
pass
|
||||
# if 'bias' in n:
|
||||
# p.data.zero_()
|
||||
# if 'weight' in n:
|
||||
# th.nn.init.orthogonal_(p)
|
||||
|
||||
for n, p in self.width_predictor.named_parameters():
|
||||
if 'bias' in n:
|
||||
p.data.zero_()
|
||||
elif 'weight' in n:
|
||||
th.nn.init.orthogonal_(p)
|
||||
|
||||
for n, p in self.alpha_predictor.named_parameters():
|
||||
if 'bias' in n:
|
||||
p.data.zero_()
|
||||
elif 'weight' in n:
|
||||
th.nn.init.orthogonal_(p)
|
||||
|
||||
def encode(self, im, label):
|
||||
bs, _, h, w = im.shape
|
||||
if self.conditional:
|
||||
label_onehot = _onehot(label)
|
||||
if not self.fc:
|
||||
label_onehot = label_onehot.view(bs, 10, 1, 1).repeat(1, 1, h, w)
|
||||
label_onehot = label_onehot.view(
|
||||
bs, 10, 1, 1).repeat(1, 1, h, w)
|
||||
out = self.encoder(th.cat([im, label_onehot], 1))
|
||||
else:
|
||||
out = self.encoder(th.cat([im.view(bs, -1), label_onehot], 1))
|
||||
@@ -365,7 +263,9 @@ class VectorMNISTVAE(th.nn.Module):
|
||||
|
||||
def _decode_features(self, z, label):
|
||||
if label is not None:
|
||||
assert self.conditional, "decoding with an input label requires a conditional AE"
|
||||
if not self.conditional:
|
||||
raise ValueError("decoding with an input label "
|
||||
"requires a conditional AE")
|
||||
label_onehot = _onehot(label)
|
||||
z = th.cat([z, label_onehot], 1)
|
||||
|
||||
@@ -378,7 +278,8 @@ class VectorMNISTVAE(th.nn.Module):
|
||||
feats = self._decode_features(z, label)
|
||||
|
||||
if self.raster:
|
||||
out = self.raster_decoder(feats).view(bs, 1, self.imsize, self.imsize)
|
||||
out = self.raster_decoder(feats).view(
|
||||
bs, 1, self.imsize, self.imsize)
|
||||
return out, {}
|
||||
|
||||
all_points = self.point_predictor(feats)
|
||||
@@ -389,7 +290,10 @@ class VectorMNISTVAE(th.nn.Module):
|
||||
if False:
|
||||
all_widths = th.ones(bs, self.paths) * 0.5
|
||||
else:
|
||||
all_widths = self.width_predictor(feats) * 1.5 + .25
|
||||
all_widths = self.width_predictor(feats)
|
||||
min_width = self.stroke_width[0]
|
||||
max_width = self.stroke_width[1]
|
||||
all_widths = (max_width - min_width) * all_widths + min_width
|
||||
|
||||
if False:
|
||||
all_alphas = th.ones(bs, self.paths)
|
||||
@@ -426,14 +330,16 @@ class VectorMNISTVAE(th.nn.Module):
|
||||
[shapes, shape_groups, (self.imsize, self.imsize)])
|
||||
|
||||
# Rasterize
|
||||
out = render(self.imsize, self.imsize, shapes, shape_groups, samples=self.samples)
|
||||
out = render(self.imsize, self.imsize, shapes, shape_groups,
|
||||
samples=self.samples)
|
||||
|
||||
# Torch format, discard alpha, make gray
|
||||
out = out.permute(2, 0, 1).view(4, self.imsize, self.imsize)[:3].mean(0, keepdim=True)
|
||||
out = out.permute(2, 0, 1).view(
|
||||
4, self.imsize, self.imsize)[:3].mean(0, keepdim=True)
|
||||
|
||||
outputs.append(out)
|
||||
|
||||
output = th.stack(outputs).to(z.device)
|
||||
output = th.stack(outputs).to(z.device)
|
||||
|
||||
auxdata = {
|
||||
"points": all_points,
|
||||
@@ -446,8 +352,6 @@ class VectorMNISTVAE(th.nn.Module):
|
||||
return output, auxdata
|
||||
|
||||
def forward(self, im, label):
|
||||
bs = im.shape[0]
|
||||
|
||||
if self.variational:
|
||||
mu, logvar = self.encode(im, label)
|
||||
z = self.reparameterize(mu, logvar)
|
||||
@@ -457,9 +361,9 @@ class VectorMNISTVAE(th.nn.Module):
|
||||
logvar = None
|
||||
|
||||
if self.conditional:
|
||||
output, aux = self.decode(z, label=label)
|
||||
output, aux = self.decode(z, label=label)
|
||||
else:
|
||||
output, aux = self.decode(z)
|
||||
output, aux = self.decode(z)
|
||||
|
||||
aux["logvar"] = logvar
|
||||
aux["mu"] = mu
|
||||
@@ -467,250 +371,6 @@ class VectorMNISTVAE(th.nn.Module):
|
||||
return output, aux
|
||||
|
||||
|
||||
class VectorMNISTGenerator(th.nn.Module):
|
||||
def __init__(self, imsize=28, paths=4, segments=5, samples=2, conditional=False,
|
||||
zdim=20, fc=False):
|
||||
super(VectorMNISTGenerator, self).__init__()
|
||||
if imsize != 28:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.samples = samples
|
||||
self.imsize = imsize
|
||||
self.paths = paths
|
||||
self.segments = segments
|
||||
self.conditional = conditional
|
||||
self.zdim = zdim
|
||||
self.fc = fc
|
||||
|
||||
ncond = 0
|
||||
if self.conditional: # one hot encoded input for conditional model
|
||||
ncond = 10
|
||||
|
||||
nc = 1024
|
||||
self.trunk = th.nn.Sequential(
|
||||
th.nn.Linear(zdim + ncond, nc), # noise + one-hot
|
||||
th.nn.SELU(inplace=True),
|
||||
|
||||
# th.nn.Linear(nc, nc),
|
||||
# th.nn.SELU(inplace=True),
|
||||
|
||||
th.nn.Linear(nc, nc),
|
||||
th.nn.SELU(inplace=True),
|
||||
|
||||
# th.nn.Linear(nc, nc),
|
||||
# th.nn.SELU(inplace=True),
|
||||
)
|
||||
|
||||
# 4 points bezier so n_segments -> 3*n_segments + 1 points
|
||||
self.point_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(nc, 2*self.paths*(self.segments*3+1)),
|
||||
# th.nn.Linear(nc, 2*self.paths*(self.segments*1+1)),
|
||||
th.nn.Tanh() # bound spatial extent
|
||||
)
|
||||
|
||||
self.width_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(nc, self.paths),
|
||||
th.nn.Tanh()
|
||||
)
|
||||
|
||||
self.alpha_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(nc, self.paths),
|
||||
th.nn.Tanh()
|
||||
)
|
||||
|
||||
# self.postprocessor = th.nn.Sequential(
|
||||
# th.nn.Conv2d(1, 32, 3, padding=1),
|
||||
# th.nn.LeakyReLU(inplace=True),
|
||||
# th.nn.Conv2d(32, 1, 1),
|
||||
# )
|
||||
self._reset_weights()
|
||||
|
||||
def _reset_weights(self):
|
||||
for n, p in self.trunk.named_parameters():
|
||||
if 'bias' in n:
|
||||
p.data.zero_()
|
||||
elif 'weight' in n:
|
||||
th.nn.init.kaiming_normal_(p)
|
||||
p.data.mul_(0.7)
|
||||
# th.nn.init.kaiming_normal_(p, nonlinearity="leaky_relu")
|
||||
|
||||
for n, p in self.point_predictor.named_parameters():
|
||||
# if 'bias' in n:
|
||||
# p.data.zero_()
|
||||
if 'weight' in n:
|
||||
th.nn.init.orthogonal_(p)
|
||||
# th.nn.init.kaiming_normal_(p, nonlinearity="tanh")
|
||||
|
||||
for n, p in self.width_predictor.named_parameters():
|
||||
if 'bias' in n:
|
||||
p.data.zero_()
|
||||
elif 'weight' in n:
|
||||
# th.nn.init.orthogonal_(p)
|
||||
th.nn.init.kaiming_normal_(p, nonlinearity="tanh")
|
||||
|
||||
for n, p in self.alpha_predictor.named_parameters():
|
||||
if 'bias' in n:
|
||||
p.data.zero_()
|
||||
elif 'weight' in n:
|
||||
th.nn.init.kaiming_normal_(p, nonlinearity="tanh")
|
||||
# th.nn.init.orthogonal_(p)
|
||||
|
||||
def sample_z(self, bs):
|
||||
return th.randn(bs, self.zdim)
|
||||
|
||||
def gen_sample(self, z, label=None):
|
||||
bs = z.shape[0]
|
||||
if self.conditional:
|
||||
if label is None:
|
||||
raise ValueError("GAN is conditional, please provide a label")
|
||||
|
||||
# One-hot encoding of the image label
|
||||
label_onehot = _onehot(label)
|
||||
|
||||
# get some embedding
|
||||
in_ = th.cat([z, label_onehot.float()], 1)
|
||||
else:
|
||||
in_ = z
|
||||
|
||||
feats = self.trunk(in_)
|
||||
|
||||
all_points = self.point_predictor(feats)
|
||||
all_points = all_points.view(bs, self.paths, -1, 2)
|
||||
|
||||
if False:
|
||||
all_alphas = th.ones(bs, self.paths)
|
||||
else:
|
||||
all_alphas = self.alpha_predictor(feats)
|
||||
|
||||
# stroke size between 0.5 and 3.5 px
|
||||
if False:
|
||||
all_widths = th.ones(bs, self.paths) * 1
|
||||
else:
|
||||
all_widths = self.width_predictor(feats)
|
||||
all_widths = 1.5*all_widths + 0.5
|
||||
|
||||
all_points = all_points*(self.imsize//2) + self.imsize//2
|
||||
|
||||
# Process the batch sequentially
|
||||
outputs = []
|
||||
for k in range(bs):
|
||||
# Get point parameters from network
|
||||
shapes = []
|
||||
shape_groups = []
|
||||
for p in range(self.paths):
|
||||
points = all_points[k, p].contiguous().cpu()
|
||||
# num_ctrl_pts = th.zeros(self.segments, dtype=th.int32)+0
|
||||
num_ctrl_pts = th.zeros(self.segments, dtype=th.int32)+2
|
||||
width = all_widths[k, p].cpu()
|
||||
alpha = all_alphas[k, p].cpu()
|
||||
color = th.cat([th.ones(3), alpha.view(1,)])
|
||||
path = pydiffvg.Path(
|
||||
num_control_points=num_ctrl_pts, points=points,
|
||||
stroke_width=width, is_closed=False)
|
||||
shapes.append(path)
|
||||
path_group = pydiffvg.ShapeGroup(
|
||||
shape_ids=th.tensor([len(shapes) - 1]),
|
||||
fill_color=None,
|
||||
stroke_color=color)
|
||||
shape_groups.append(path_group)
|
||||
|
||||
# Rasterize
|
||||
out = render(self.imsize, self.imsize, shapes, shape_groups, samples=self.samples)
|
||||
|
||||
# Torch format, discard alpha, make gray
|
||||
out = out.permute(2, 0, 1).view(4, self.imsize, self.imsize)[:3].mean(0, keepdim=True)
|
||||
|
||||
outputs.append(out)
|
||||
|
||||
output = th.stack(outputs).to(z.device)
|
||||
aux_data = {
|
||||
"points": all_points,
|
||||
"raw_vector": output,
|
||||
}
|
||||
|
||||
# output = self.postprocessor(output)
|
||||
|
||||
# map to [-1, 1]
|
||||
output = output*2.0 - 1.0
|
||||
|
||||
return output, aux_data
|
||||
|
||||
def forward(self, im, label):
|
||||
bs = label.shape[0]
|
||||
|
||||
# sample a hidden vector (same dim as the raster version)
|
||||
z = self.sample_z(bs).to(im.device)
|
||||
if args.conditional:
|
||||
return self.gen_sample(z, label=label)
|
||||
else:
|
||||
return self.gen_sample(z)
|
||||
|
||||
|
||||
|
||||
class Discriminator(th.nn.Module):
|
||||
def __init__(self, conditional=False, fc=False):
|
||||
super(Discriminator, self).__init__()
|
||||
|
||||
self.conditional = conditional
|
||||
|
||||
ncond = 0
|
||||
if self.conditional: # one hot encoded input for conditional model
|
||||
ncond = 10
|
||||
|
||||
sn = th.nn.utils.spectral_norm
|
||||
# sn = lambda x: x
|
||||
|
||||
self.fc = fc
|
||||
|
||||
mult = 2
|
||||
if self.fc:
|
||||
self.net = th.nn.Sequential(
|
||||
Flatten(),
|
||||
th.nn.Linear(28*28 + ncond, mult*256),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
|
||||
# th.nn.Linear(mult*256, mult*256, 4),
|
||||
# th.nn.LeakyReLU(0.2, inplace=True),
|
||||
# th.nn.Dropout(0.5),
|
||||
|
||||
th.nn.Linear(mult*256, mult*256, 4),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
|
||||
th.nn.Linear(mult*256*1*1, 1),
|
||||
)
|
||||
else:
|
||||
self.net = th.nn.Sequential(
|
||||
th.nn.Conv2d(1 + ncond, mult*64, 4, padding=0, stride=2),
|
||||
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
# 16x16
|
||||
|
||||
sn(th.nn.Conv2d(mult*64, mult*128, 4, padding=0, stride=2)),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
# 8x8
|
||||
|
||||
sn(th.nn.Conv2d(mult*128, mult*256, 4, padding=0, stride=2)),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
# 4x4
|
||||
|
||||
Flatten(),
|
||||
|
||||
th.nn.Linear(mult*256*1*1, 1),
|
||||
)
|
||||
|
||||
self._reset_weights()
|
||||
|
||||
def _reset_weights(self):
|
||||
for n, p in self.net.named_parameters():
|
||||
if 'bias' in n:
|
||||
p.data.zero_()
|
||||
if 'weight' in n:
|
||||
th.nn.init.kaiming_normal_(p, nonlinearity="leaky_relu")
|
||||
|
||||
def forward(self, x):
|
||||
out = self.net(x)
|
||||
return out
|
||||
|
||||
class Dataset(th.utils.data.Dataset):
|
||||
def __init__(self, data_dir, imsize):
|
||||
super(Dataset, self).__init__()
|
||||
@@ -783,8 +443,8 @@ def train(args):
|
||||
extras, meta = checkpointer.load_latest()
|
||||
|
||||
if meta is not None and meta != model_params:
|
||||
LOG.info("Checkpoint's metaparams differ from CLI, aborting: %s and %s",
|
||||
meta, model_params)
|
||||
LOG.info(f"Checkpoint's metaparams differ from CLI, "
|
||||
f"aborting: {meta} and {model_params}")
|
||||
|
||||
# Hook interface
|
||||
if args.generator in ["vae", "ae"]:
|
||||
@@ -794,17 +454,18 @@ def train(args):
|
||||
else:
|
||||
LOG.info("Using an AE")
|
||||
interface = VAEInterface(model, lr=args.lr, cuda=args.cuda,
|
||||
variational=variational, w_kld=args.kld_weight)
|
||||
variational=variational,
|
||||
w_kld=args.kld_weight)
|
||||
|
||||
trainer = ttools.Trainer(interface)
|
||||
|
||||
# Add callbacks
|
||||
keys = ["loss_g", "loss_d"]
|
||||
keys = []
|
||||
if args.generator == "vae":
|
||||
keys = ["kld", "data_loss", "loss"]
|
||||
keys = ["kld", "data_loss", "loss", "logvar"]
|
||||
elif args.generator == "ae":
|
||||
keys = ["data_loss", "loss"]
|
||||
port = 8097
|
||||
port = 8080
|
||||
trainer.add_callback(ttools.callbacks.ProgressBarCallback(
|
||||
keys=keys, val_keys=keys))
|
||||
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
|
||||
@@ -979,8 +640,10 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--cpu", dest="cuda", action="store_false",
|
||||
default=th.cuda.is_available(),
|
||||
help="if true, use CPU instead of GPU.")
|
||||
parser.add_argument("--conditional", action="store_true", default=False)
|
||||
parser.add_argument("--fc", action="store_true", default=False)
|
||||
parser.add_argument("--no-conditional", dest="conditional",
|
||||
action="store_false", default=True)
|
||||
parser.add_argument("--no-fc", dest="fc", action="store_false",
|
||||
default=True)
|
||||
parser.add_argument("--data_dir", default="mnist",
|
||||
help="path to download and store the data.")
|
||||
|
||||
@@ -992,19 +655,19 @@ if __name__ == "__main__":
|
||||
"autoencoder")
|
||||
parser_train.add_argument("--freq", type=int, default=100,
|
||||
help="number of steps between visualizations")
|
||||
parser_train.add_argument("--lr", type=float, default=1e-4,
|
||||
parser_train.add_argument("--lr", type=float, default=5e-5,
|
||||
help="learning rate")
|
||||
parser_train.add_argument("--kld_weight", type=float, default=1.0,
|
||||
help="scalar weight for the KL divergence term.")
|
||||
parser_train.add_argument("--bs", type=int, default=8, help="batch size")
|
||||
parser_train.add_argument("--num_epochs", type=int,
|
||||
parser_train.add_argument("--num_epochs", default=50, type=int,
|
||||
help="max number of epochs")
|
||||
# Vector configs
|
||||
parser_train.add_argument("--paths", type=int, default=1,
|
||||
help="number of unique vector paths to generate.")
|
||||
help="number of vector paths to generate.")
|
||||
parser_train.add_argument("--segments", type=int, default=3,
|
||||
help="number of segments per vector path")
|
||||
parser_train.add_argument("--samples", type=int, default=2,
|
||||
parser_train.add_argument("--samples", type=int, default=4,
|
||||
help="number of samples in the MC rasterizer")
|
||||
parser_train.add_argument("--zdim", type=int, default=20,
|
||||
help="dimension of the latent space")
|
||||
|
Reference in New Issue
Block a user