initial commit
This commit is contained in:
484
apps/generative_models/models.py
Normal file
484
apps/generative_models/models.py
Normal file
@@ -0,0 +1,484 @@
|
||||
"""Collection of generative models."""
|
||||
|
||||
import torch as th
|
||||
import ttools
|
||||
|
||||
import rendering
|
||||
import modules
|
||||
|
||||
LOG = ttools.get_logger(__name__)
|
||||
|
||||
|
||||
class BaseModel(th.nn.Module):
|
||||
def sample_z(self, bs, device="cpu"):
|
||||
return th.randn(bs, self.zdim).to(device)
|
||||
|
||||
|
||||
class BaseVectorModel(BaseModel):
|
||||
def get_vector(self, z):
|
||||
_, scenes = self._forward(z)
|
||||
return scenes
|
||||
|
||||
def _forward(self, x):
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, z):
|
||||
# Only return the raster
|
||||
return self._forward(z)[0]
|
||||
|
||||
|
||||
class BezierVectorGenerator(BaseVectorModel):
|
||||
NUM_SEGMENTS = 2
|
||||
def __init__(self, num_strokes=4,
|
||||
zdim=128, width=32, imsize=32,
|
||||
color_output=False,
|
||||
stroke_width=None):
|
||||
super(BezierVectorGenerator, self).__init__()
|
||||
|
||||
if stroke_width is None:
|
||||
self.stroke_width = (0.5, 3.0)
|
||||
LOG.warning("Setting default stroke with %s", self.stroke_width)
|
||||
else:
|
||||
self.stroke_width = stroke_width
|
||||
|
||||
self.imsize = imsize
|
||||
self.num_strokes = num_strokes
|
||||
self.zdim = zdim
|
||||
|
||||
self.trunk = th.nn.Sequential(
|
||||
th.nn.Linear(zdim, width),
|
||||
th.nn.SELU(inplace=True),
|
||||
|
||||
th.nn.Linear(width, 2*width),
|
||||
th.nn.SELU(inplace=True),
|
||||
|
||||
th.nn.Linear(2*width, 4*width),
|
||||
th.nn.SELU(inplace=True),
|
||||
|
||||
th.nn.Linear(4*width, 8*width),
|
||||
th.nn.SELU(inplace=True),
|
||||
)
|
||||
|
||||
# 4 points bezier with n_segments -> 3*n_segments + 1 points
|
||||
self.point_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(8*width,
|
||||
2*self.num_strokes*(
|
||||
BezierVectorGenerator.NUM_SEGMENTS*3 + 1)),
|
||||
th.nn.Tanh() # bound spatial extent
|
||||
)
|
||||
|
||||
self.width_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(8*width, self.num_strokes),
|
||||
th.nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.alpha_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(8*width, self.num_strokes),
|
||||
th.nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.color_predictor = None
|
||||
if color_output:
|
||||
self.color_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(8*width, 3*self.num_strokes),
|
||||
th.nn.Sigmoid()
|
||||
)
|
||||
|
||||
def _forward(self, z):
|
||||
bs = z.shape[0]
|
||||
|
||||
feats = self.trunk(z)
|
||||
all_points = self.point_predictor(feats)
|
||||
all_alphas = self.alpha_predictor(feats)
|
||||
|
||||
if self.color_predictor:
|
||||
all_colors = self.color_predictor(feats)
|
||||
all_colors = all_colors.view(bs, self.num_strokes, 3)
|
||||
else:
|
||||
all_colors = None
|
||||
|
||||
all_widths = self.width_predictor(feats)
|
||||
min_width = self.stroke_width[0]
|
||||
max_width = self.stroke_width[1]
|
||||
all_widths = (max_width - min_width) * all_widths + min_width
|
||||
|
||||
all_points = all_points.view(
|
||||
bs, self.num_strokes, BezierVectorGenerator.NUM_SEGMENTS*3+1, 2)
|
||||
|
||||
output, scenes = rendering.bezier_render(all_points, all_widths, all_alphas,
|
||||
colors=all_colors,
|
||||
canvas_size=self.imsize)
|
||||
|
||||
# map to [-1, 1]
|
||||
output = output*2.0 - 1.0
|
||||
|
||||
return output, scenes
|
||||
|
||||
|
||||
class VectorGenerator(BaseVectorModel):
|
||||
def __init__(self, num_strokes=4,
|
||||
zdim=128, width=32, imsize=32,
|
||||
color_output=False,
|
||||
stroke_width=None):
|
||||
super(VectorGenerator, self).__init__()
|
||||
|
||||
if stroke_width is None:
|
||||
self.stroke_width = (0.5, 3.0)
|
||||
LOG.warning("Setting default stroke with %s", self.stroke_width)
|
||||
else:
|
||||
self.stroke_width = stroke_width
|
||||
|
||||
self.imsize = imsize
|
||||
self.num_strokes = num_strokes
|
||||
self.zdim = zdim
|
||||
|
||||
self.trunk = th.nn.Sequential(
|
||||
th.nn.Linear(zdim, width),
|
||||
th.nn.SELU(inplace=True),
|
||||
|
||||
th.nn.Linear(width, 2*width),
|
||||
th.nn.SELU(inplace=True),
|
||||
|
||||
th.nn.Linear(2*width, 4*width),
|
||||
th.nn.SELU(inplace=True),
|
||||
|
||||
th.nn.Linear(4*width, 8*width),
|
||||
th.nn.SELU(inplace=True),
|
||||
)
|
||||
|
||||
# straight lines so n_segments -> n_segments - 1 points
|
||||
self.point_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(8*width, 2*(self.num_strokes*2)),
|
||||
th.nn.Tanh() # bound spatial extent
|
||||
)
|
||||
|
||||
self.width_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(8*width, self.num_strokes),
|
||||
th.nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.alpha_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(8*width, self.num_strokes),
|
||||
th.nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.color_predictor = None
|
||||
if color_output:
|
||||
self.color_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(8*width, 3*self.num_strokes),
|
||||
th.nn.Sigmoid()
|
||||
)
|
||||
|
||||
def _forward(self, z):
|
||||
bs = z.shape[0]
|
||||
|
||||
feats = self.trunk(z)
|
||||
|
||||
all_points = self.point_predictor(feats)
|
||||
|
||||
all_alphas = self.alpha_predictor(feats)
|
||||
|
||||
if self.color_predictor:
|
||||
all_colors = self.color_predictor(feats)
|
||||
all_colors = all_colors.view(bs, self.num_strokes, 3)
|
||||
else:
|
||||
all_colors = None
|
||||
|
||||
all_widths = self.width_predictor(feats)
|
||||
min_width = self.stroke_width[0]
|
||||
max_width = self.stroke_width[1]
|
||||
all_widths = (max_width - min_width) * all_widths + min_width
|
||||
|
||||
all_points = all_points.view(bs, self.num_strokes, 2, 2)
|
||||
output, scenes = rendering.line_render(all_points, all_widths, all_alphas,
|
||||
colors=all_colors,
|
||||
canvas_size=self.imsize)
|
||||
|
||||
# map to [-1, 1]
|
||||
output = output*2.0 - 1.0
|
||||
|
||||
return output, scenes
|
||||
|
||||
|
||||
class RNNVectorGenerator(BaseVectorModel):
|
||||
def __init__(self, num_strokes=64,
|
||||
zdim=128, width=32, imsize=32,
|
||||
hidden_size=512, dropout=0.9,
|
||||
color_output=False,
|
||||
num_layers=3, stroke_width=None):
|
||||
super(RNNVectorGenerator, self).__init__()
|
||||
|
||||
|
||||
if stroke_width is None:
|
||||
self.stroke_width = (0.5, 3.0)
|
||||
LOG.warning("Setting default stroke with %s", self.stroke_width)
|
||||
else:
|
||||
self.stroke_width = stroke_width
|
||||
|
||||
self.num_layers = num_layers
|
||||
self.imsize = imsize
|
||||
self.num_strokes = num_strokes
|
||||
self.hidden_size = hidden_size
|
||||
self.zdim = zdim
|
||||
|
||||
self.hidden_cell_predictor = th.nn.Linear(
|
||||
zdim, 2*hidden_size*num_layers)
|
||||
|
||||
self.lstm = th.nn.LSTM(
|
||||
zdim, hidden_size,
|
||||
num_layers=self.num_layers, dropout=dropout,
|
||||
batch_first=True)
|
||||
|
||||
# straight lines so n_segments -> n_segments - 1 points
|
||||
self.point_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(hidden_size, 2*2), # 2 points, (x,y)
|
||||
th.nn.Tanh() # bound spatial extent
|
||||
)
|
||||
|
||||
self.width_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(hidden_size, 1),
|
||||
th.nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.alpha_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(hidden_size, 1),
|
||||
th.nn.Sigmoid()
|
||||
)
|
||||
|
||||
def _forward(self, z, hidden_and_cell=None):
|
||||
steps = self.num_strokes
|
||||
|
||||
# z is passed at each step, duplicate it
|
||||
bs = z.shape[0]
|
||||
expanded_z = z.unsqueeze(1).repeat(1, steps, 1)
|
||||
|
||||
# First step in the RNN
|
||||
if hidden_and_cell is None:
|
||||
# Initialize from latent vector
|
||||
hidden_and_cell = self.hidden_cell_predictor(th.tanh(z))
|
||||
hidden = hidden_and_cell[:, :self.hidden_size*self.num_layers]
|
||||
hidden = hidden.view(-1, self.num_layers, self.hidden_size)
|
||||
hidden = hidden.permute(1, 0, 2).contiguous()
|
||||
cell = hidden_and_cell[:, self.hidden_size*self.num_layers:]
|
||||
cell = cell.view(-1, self.num_layers, self.hidden_size)
|
||||
cell = cell.permute(1, 0, 2).contiguous()
|
||||
hidden_and_cell = (hidden, cell)
|
||||
|
||||
feats, hidden_and_cell = self.lstm(expanded_z, hidden_and_cell)
|
||||
hidden, cell = hidden_and_cell
|
||||
|
||||
feats = feats.reshape(bs*steps, self.hidden_size)
|
||||
|
||||
all_points = self.point_predictor(feats).view(bs, steps, 2, 2)
|
||||
all_alphas = self.alpha_predictor(feats).view(bs, steps)
|
||||
all_widths = self.width_predictor(feats).view(bs, steps)
|
||||
|
||||
min_width = self.stroke_width[0]
|
||||
max_width = self.stroke_width[1]
|
||||
all_widths = (max_width - min_width) * all_widths + min_width
|
||||
|
||||
output, scenes = rendering.line_render(all_points, all_widths, all_alphas,
|
||||
canvas_size=self.imsize)
|
||||
|
||||
# map to [-1, 1]
|
||||
output = output*2.0 - 1.0
|
||||
|
||||
return output, scenes
|
||||
|
||||
|
||||
class ChainRNNVectorGenerator(BaseVectorModel):
|
||||
"""Strokes form a single long chain."""
|
||||
def __init__(self, num_strokes=64,
|
||||
zdim=128, width=32, imsize=32,
|
||||
hidden_size=512, dropout=0.9,
|
||||
color_output=False,
|
||||
num_layers=3, stroke_width=None):
|
||||
super(ChainRNNVectorGenerator, self).__init__()
|
||||
|
||||
if stroke_width is None:
|
||||
self.stroke_width = (0.5, 3.0)
|
||||
LOG.warning("Setting default stroke with %s", self.stroke_width)
|
||||
else:
|
||||
self.stroke_width = stroke_width
|
||||
|
||||
self.num_layers = num_layers
|
||||
self.imsize = imsize
|
||||
self.num_strokes = num_strokes
|
||||
self.hidden_size = hidden_size
|
||||
self.zdim = zdim
|
||||
|
||||
self.hidden_cell_predictor = th.nn.Linear(
|
||||
zdim, 2*hidden_size*num_layers)
|
||||
|
||||
self.lstm = th.nn.LSTM(
|
||||
zdim, hidden_size,
|
||||
num_layers=self.num_layers, dropout=dropout,
|
||||
batch_first=True)
|
||||
|
||||
# straight lines so n_segments -> n_segments - 1 points
|
||||
self.point_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(hidden_size, 2), # 1 point, (x,y)
|
||||
th.nn.Tanh() # bound spatial extent
|
||||
)
|
||||
|
||||
self.width_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(hidden_size, 1),
|
||||
th.nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.alpha_predictor = th.nn.Sequential(
|
||||
th.nn.Linear(hidden_size, 1),
|
||||
th.nn.Sigmoid()
|
||||
)
|
||||
|
||||
def _forward(self, z, hidden_and_cell=None):
|
||||
steps = self.num_strokes
|
||||
|
||||
# z is passed at each step, duplicate it
|
||||
bs = z.shape[0]
|
||||
expanded_z = z.unsqueeze(1).repeat(1, steps, 1)
|
||||
|
||||
# First step in the RNN
|
||||
if hidden_and_cell is None:
|
||||
# Initialize from latent vector
|
||||
hidden_and_cell = self.hidden_cell_predictor(th.tanh(z))
|
||||
hidden = hidden_and_cell[:, :self.hidden_size*self.num_layers]
|
||||
hidden = hidden.view(-1, self.num_layers, self.hidden_size)
|
||||
hidden = hidden.permute(1, 0, 2).contiguous()
|
||||
cell = hidden_and_cell[:, self.hidden_size*self.num_layers:]
|
||||
cell = cell.view(-1, self.num_layers, self.hidden_size)
|
||||
cell = cell.permute(1, 0, 2).contiguous()
|
||||
hidden_and_cell = (hidden, cell)
|
||||
|
||||
feats, hidden_and_cell = self.lstm(expanded_z, hidden_and_cell)
|
||||
hidden, cell = hidden_and_cell
|
||||
|
||||
feats = feats.reshape(bs*steps, self.hidden_size)
|
||||
|
||||
# Construct the chain
|
||||
end_points = self.point_predictor(feats).view(bs, steps, 1, 2)
|
||||
start_points = th.cat([
|
||||
# first point is canvas center
|
||||
th.zeros(bs, 1, 1, 2, device=feats.device),
|
||||
end_points[:, 1:, :, :]], 1)
|
||||
all_points = th.cat([start_points, end_points], 2)
|
||||
|
||||
all_alphas = self.alpha_predictor(feats).view(bs, steps)
|
||||
all_widths = self.width_predictor(feats).view(bs, steps)
|
||||
|
||||
min_width = self.stroke_width[0]
|
||||
max_width = self.stroke_width[1]
|
||||
all_widths = (max_width - min_width) * all_widths + min_width
|
||||
|
||||
output, scenes = rendering.line_render(all_points, all_widths, all_alphas,
|
||||
canvas_size=self.imsize)
|
||||
|
||||
# map to [-1, 1]
|
||||
output = output*2.0 - 1.0
|
||||
|
||||
return output, scenes
|
||||
|
||||
|
||||
class Generator(BaseModel):
|
||||
def __init__(self, width=64, imsize=32, zdim=128,
|
||||
stroke_width=None,
|
||||
color_output=False,
|
||||
num_strokes=4):
|
||||
super(Generator, self).__init__()
|
||||
assert imsize == 32
|
||||
|
||||
self.imsize = imsize
|
||||
self.zdim = zdim
|
||||
|
||||
num_in_chans = self.zdim // (2*2)
|
||||
num_out_chans = 3 if color_output else 1
|
||||
|
||||
self.net = th.nn.Sequential(
|
||||
th.nn.ConvTranspose2d(num_in_chans, width*8, 4, padding=1,
|
||||
stride=2),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
th.nn.Conv2d(width*8, width*8, 3, padding=1),
|
||||
th.nn.BatchNorm2d(width*8),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
# 4x4
|
||||
|
||||
th.nn.ConvTranspose2d(8*width, 4*width, 4, padding=1, stride=2),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
th.nn.Conv2d(4*width, 4*width, 3, padding=1),
|
||||
th.nn.BatchNorm2d(width*4),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
# 8x8
|
||||
|
||||
th.nn.ConvTranspose2d(4*width, 2*width, 4, padding=1, stride=2),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
th.nn.Conv2d(2*width, 2*width, 3, padding=1),
|
||||
th.nn.BatchNorm2d(width*2),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
# 16x16
|
||||
|
||||
th.nn.ConvTranspose2d(2*width, width, 4, padding=1, stride=2),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
th.nn.Conv2d(width, width, 3, padding=1),
|
||||
th.nn.BatchNorm2d(width),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
# 32x32
|
||||
|
||||
th.nn.Conv2d(width, width, 3, padding=1),
|
||||
th.nn.BatchNorm2d(width),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
th.nn.Conv2d(width, width, 3, padding=1),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
th.nn.Conv2d(width, num_out_chans, 1),
|
||||
|
||||
th.nn.Tanh(),
|
||||
)
|
||||
|
||||
def forward(self, z):
|
||||
bs = z.shape[0]
|
||||
num_in_chans = self.zdim // (2*2)
|
||||
raster = self.net(z.view(bs, num_in_chans, 2, 2))
|
||||
return raster
|
||||
|
||||
|
||||
class Discriminator(th.nn.Module):
|
||||
def __init__(self, conditional=False, width=64, color_output=False):
|
||||
super(Discriminator, self).__init__()
|
||||
|
||||
self.conditional = conditional
|
||||
|
||||
sn = th.nn.utils.spectral_norm
|
||||
|
||||
num_chan_in = 3 if color_output else 1
|
||||
|
||||
self.net = th.nn.Sequential(
|
||||
th.nn.Conv2d(num_chan_in, width, 3, padding=1),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
th.nn.Conv2d(width, 2*width, 4, padding=1, stride=2),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
# 16x16
|
||||
|
||||
sn(th.nn.Conv2d(2*width, 2*width, 3, padding=1)),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
sn(th.nn.Conv2d(2*width, 4*width, 4, padding=1, stride=2)),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
# 8x8
|
||||
|
||||
sn(th.nn.Conv2d(4*width, 4*width, 3, padding=1)),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
sn(th.nn.Conv2d(4*width, width*4, 4, padding=1, stride=2)),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
# 4x4
|
||||
|
||||
sn(th.nn.Conv2d(4*width, 4*width, 3, padding=1)),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
sn(th.nn.Conv2d(4*width, width*4, 4, padding=1, stride=2)),
|
||||
th.nn.LeakyReLU(0.2, inplace=True),
|
||||
# 2x2
|
||||
|
||||
modules.Flatten(),
|
||||
th.nn.Linear(width*4*2*2, 1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.net(x)
|
||||
return out
|
Reference in New Issue
Block a user