485 lines
16 KiB
Python
485 lines
16 KiB
Python
"""Collection of generative models."""
|
|
|
|
import torch as th
|
|
import ttools
|
|
|
|
import rendering
|
|
import modules
|
|
|
|
LOG = ttools.get_logger(__name__)
|
|
|
|
|
|
class BaseModel(th.nn.Module):
|
|
def sample_z(self, bs, device="cpu"):
|
|
return th.randn(bs, self.zdim).to(device)
|
|
|
|
|
|
class BaseVectorModel(BaseModel):
|
|
def get_vector(self, z):
|
|
_, scenes = self._forward(z)
|
|
return scenes
|
|
|
|
def _forward(self, x):
|
|
raise NotImplementedError()
|
|
|
|
def forward(self, z):
|
|
# Only return the raster
|
|
return self._forward(z)[0]
|
|
|
|
|
|
class BezierVectorGenerator(BaseVectorModel):
|
|
NUM_SEGMENTS = 2
|
|
def __init__(self, num_strokes=4,
|
|
zdim=128, width=32, imsize=32,
|
|
color_output=False,
|
|
stroke_width=None):
|
|
super(BezierVectorGenerator, self).__init__()
|
|
|
|
if stroke_width is None:
|
|
self.stroke_width = (0.5, 3.0)
|
|
LOG.warning("Setting default stroke with %s", self.stroke_width)
|
|
else:
|
|
self.stroke_width = stroke_width
|
|
|
|
self.imsize = imsize
|
|
self.num_strokes = num_strokes
|
|
self.zdim = zdim
|
|
|
|
self.trunk = th.nn.Sequential(
|
|
th.nn.Linear(zdim, width),
|
|
th.nn.SELU(inplace=True),
|
|
|
|
th.nn.Linear(width, 2*width),
|
|
th.nn.SELU(inplace=True),
|
|
|
|
th.nn.Linear(2*width, 4*width),
|
|
th.nn.SELU(inplace=True),
|
|
|
|
th.nn.Linear(4*width, 8*width),
|
|
th.nn.SELU(inplace=True),
|
|
)
|
|
|
|
# 4 points bezier with n_segments -> 3*n_segments + 1 points
|
|
self.point_predictor = th.nn.Sequential(
|
|
th.nn.Linear(8*width,
|
|
2*self.num_strokes*(
|
|
BezierVectorGenerator.NUM_SEGMENTS*3 + 1)),
|
|
th.nn.Tanh() # bound spatial extent
|
|
)
|
|
|
|
self.width_predictor = th.nn.Sequential(
|
|
th.nn.Linear(8*width, self.num_strokes),
|
|
th.nn.Sigmoid()
|
|
)
|
|
|
|
self.alpha_predictor = th.nn.Sequential(
|
|
th.nn.Linear(8*width, self.num_strokes),
|
|
th.nn.Sigmoid()
|
|
)
|
|
|
|
self.color_predictor = None
|
|
if color_output:
|
|
self.color_predictor = th.nn.Sequential(
|
|
th.nn.Linear(8*width, 3*self.num_strokes),
|
|
th.nn.Sigmoid()
|
|
)
|
|
|
|
def _forward(self, z):
|
|
bs = z.shape[0]
|
|
|
|
feats = self.trunk(z)
|
|
all_points = self.point_predictor(feats)
|
|
all_alphas = self.alpha_predictor(feats)
|
|
|
|
if self.color_predictor:
|
|
all_colors = self.color_predictor(feats)
|
|
all_colors = all_colors.view(bs, self.num_strokes, 3)
|
|
else:
|
|
all_colors = None
|
|
|
|
all_widths = self.width_predictor(feats)
|
|
min_width = self.stroke_width[0]
|
|
max_width = self.stroke_width[1]
|
|
all_widths = (max_width - min_width) * all_widths + min_width
|
|
|
|
all_points = all_points.view(
|
|
bs, self.num_strokes, BezierVectorGenerator.NUM_SEGMENTS*3+1, 2)
|
|
|
|
output, scenes = rendering.bezier_render(all_points, all_widths, all_alphas,
|
|
colors=all_colors,
|
|
canvas_size=self.imsize)
|
|
|
|
# map to [-1, 1]
|
|
output = output*2.0 - 1.0
|
|
|
|
return output, scenes
|
|
|
|
|
|
class VectorGenerator(BaseVectorModel):
|
|
def __init__(self, num_strokes=4,
|
|
zdim=128, width=32, imsize=32,
|
|
color_output=False,
|
|
stroke_width=None):
|
|
super(VectorGenerator, self).__init__()
|
|
|
|
if stroke_width is None:
|
|
self.stroke_width = (0.5, 3.0)
|
|
LOG.warning("Setting default stroke with %s", self.stroke_width)
|
|
else:
|
|
self.stroke_width = stroke_width
|
|
|
|
self.imsize = imsize
|
|
self.num_strokes = num_strokes
|
|
self.zdim = zdim
|
|
|
|
self.trunk = th.nn.Sequential(
|
|
th.nn.Linear(zdim, width),
|
|
th.nn.SELU(inplace=True),
|
|
|
|
th.nn.Linear(width, 2*width),
|
|
th.nn.SELU(inplace=True),
|
|
|
|
th.nn.Linear(2*width, 4*width),
|
|
th.nn.SELU(inplace=True),
|
|
|
|
th.nn.Linear(4*width, 8*width),
|
|
th.nn.SELU(inplace=True),
|
|
)
|
|
|
|
# straight lines so n_segments -> n_segments - 1 points
|
|
self.point_predictor = th.nn.Sequential(
|
|
th.nn.Linear(8*width, 2*(self.num_strokes*2)),
|
|
th.nn.Tanh() # bound spatial extent
|
|
)
|
|
|
|
self.width_predictor = th.nn.Sequential(
|
|
th.nn.Linear(8*width, self.num_strokes),
|
|
th.nn.Sigmoid()
|
|
)
|
|
|
|
self.alpha_predictor = th.nn.Sequential(
|
|
th.nn.Linear(8*width, self.num_strokes),
|
|
th.nn.Sigmoid()
|
|
)
|
|
|
|
self.color_predictor = None
|
|
if color_output:
|
|
self.color_predictor = th.nn.Sequential(
|
|
th.nn.Linear(8*width, 3*self.num_strokes),
|
|
th.nn.Sigmoid()
|
|
)
|
|
|
|
def _forward(self, z):
|
|
bs = z.shape[0]
|
|
|
|
feats = self.trunk(z)
|
|
|
|
all_points = self.point_predictor(feats)
|
|
|
|
all_alphas = self.alpha_predictor(feats)
|
|
|
|
if self.color_predictor:
|
|
all_colors = self.color_predictor(feats)
|
|
all_colors = all_colors.view(bs, self.num_strokes, 3)
|
|
else:
|
|
all_colors = None
|
|
|
|
all_widths = self.width_predictor(feats)
|
|
min_width = self.stroke_width[0]
|
|
max_width = self.stroke_width[1]
|
|
all_widths = (max_width - min_width) * all_widths + min_width
|
|
|
|
all_points = all_points.view(bs, self.num_strokes, 2, 2)
|
|
output, scenes = rendering.line_render(all_points, all_widths, all_alphas,
|
|
colors=all_colors,
|
|
canvas_size=self.imsize)
|
|
|
|
# map to [-1, 1]
|
|
output = output*2.0 - 1.0
|
|
|
|
return output, scenes
|
|
|
|
|
|
class RNNVectorGenerator(BaseVectorModel):
|
|
def __init__(self, num_strokes=64,
|
|
zdim=128, width=32, imsize=32,
|
|
hidden_size=512, dropout=0.9,
|
|
color_output=False,
|
|
num_layers=3, stroke_width=None):
|
|
super(RNNVectorGenerator, self).__init__()
|
|
|
|
|
|
if stroke_width is None:
|
|
self.stroke_width = (0.5, 3.0)
|
|
LOG.warning("Setting default stroke with %s", self.stroke_width)
|
|
else:
|
|
self.stroke_width = stroke_width
|
|
|
|
self.num_layers = num_layers
|
|
self.imsize = imsize
|
|
self.num_strokes = num_strokes
|
|
self.hidden_size = hidden_size
|
|
self.zdim = zdim
|
|
|
|
self.hidden_cell_predictor = th.nn.Linear(
|
|
zdim, 2*hidden_size*num_layers)
|
|
|
|
self.lstm = th.nn.LSTM(
|
|
zdim, hidden_size,
|
|
num_layers=self.num_layers, dropout=dropout,
|
|
batch_first=True)
|
|
|
|
# straight lines so n_segments -> n_segments - 1 points
|
|
self.point_predictor = th.nn.Sequential(
|
|
th.nn.Linear(hidden_size, 2*2), # 2 points, (x,y)
|
|
th.nn.Tanh() # bound spatial extent
|
|
)
|
|
|
|
self.width_predictor = th.nn.Sequential(
|
|
th.nn.Linear(hidden_size, 1),
|
|
th.nn.Sigmoid()
|
|
)
|
|
|
|
self.alpha_predictor = th.nn.Sequential(
|
|
th.nn.Linear(hidden_size, 1),
|
|
th.nn.Sigmoid()
|
|
)
|
|
|
|
def _forward(self, z, hidden_and_cell=None):
|
|
steps = self.num_strokes
|
|
|
|
# z is passed at each step, duplicate it
|
|
bs = z.shape[0]
|
|
expanded_z = z.unsqueeze(1).repeat(1, steps, 1)
|
|
|
|
# First step in the RNN
|
|
if hidden_and_cell is None:
|
|
# Initialize from latent vector
|
|
hidden_and_cell = self.hidden_cell_predictor(th.tanh(z))
|
|
hidden = hidden_and_cell[:, :self.hidden_size*self.num_layers]
|
|
hidden = hidden.view(-1, self.num_layers, self.hidden_size)
|
|
hidden = hidden.permute(1, 0, 2).contiguous()
|
|
cell = hidden_and_cell[:, self.hidden_size*self.num_layers:]
|
|
cell = cell.view(-1, self.num_layers, self.hidden_size)
|
|
cell = cell.permute(1, 0, 2).contiguous()
|
|
hidden_and_cell = (hidden, cell)
|
|
|
|
feats, hidden_and_cell = self.lstm(expanded_z, hidden_and_cell)
|
|
hidden, cell = hidden_and_cell
|
|
|
|
feats = feats.reshape(bs*steps, self.hidden_size)
|
|
|
|
all_points = self.point_predictor(feats).view(bs, steps, 2, 2)
|
|
all_alphas = self.alpha_predictor(feats).view(bs, steps)
|
|
all_widths = self.width_predictor(feats).view(bs, steps)
|
|
|
|
min_width = self.stroke_width[0]
|
|
max_width = self.stroke_width[1]
|
|
all_widths = (max_width - min_width) * all_widths + min_width
|
|
|
|
output, scenes = rendering.line_render(all_points, all_widths, all_alphas,
|
|
canvas_size=self.imsize)
|
|
|
|
# map to [-1, 1]
|
|
output = output*2.0 - 1.0
|
|
|
|
return output, scenes
|
|
|
|
|
|
class ChainRNNVectorGenerator(BaseVectorModel):
|
|
"""Strokes form a single long chain."""
|
|
def __init__(self, num_strokes=64,
|
|
zdim=128, width=32, imsize=32,
|
|
hidden_size=512, dropout=0.9,
|
|
color_output=False,
|
|
num_layers=3, stroke_width=None):
|
|
super(ChainRNNVectorGenerator, self).__init__()
|
|
|
|
if stroke_width is None:
|
|
self.stroke_width = (0.5, 3.0)
|
|
LOG.warning("Setting default stroke with %s", self.stroke_width)
|
|
else:
|
|
self.stroke_width = stroke_width
|
|
|
|
self.num_layers = num_layers
|
|
self.imsize = imsize
|
|
self.num_strokes = num_strokes
|
|
self.hidden_size = hidden_size
|
|
self.zdim = zdim
|
|
|
|
self.hidden_cell_predictor = th.nn.Linear(
|
|
zdim, 2*hidden_size*num_layers)
|
|
|
|
self.lstm = th.nn.LSTM(
|
|
zdim, hidden_size,
|
|
num_layers=self.num_layers, dropout=dropout,
|
|
batch_first=True)
|
|
|
|
# straight lines so n_segments -> n_segments - 1 points
|
|
self.point_predictor = th.nn.Sequential(
|
|
th.nn.Linear(hidden_size, 2), # 1 point, (x,y)
|
|
th.nn.Tanh() # bound spatial extent
|
|
)
|
|
|
|
self.width_predictor = th.nn.Sequential(
|
|
th.nn.Linear(hidden_size, 1),
|
|
th.nn.Sigmoid()
|
|
)
|
|
|
|
self.alpha_predictor = th.nn.Sequential(
|
|
th.nn.Linear(hidden_size, 1),
|
|
th.nn.Sigmoid()
|
|
)
|
|
|
|
def _forward(self, z, hidden_and_cell=None):
|
|
steps = self.num_strokes
|
|
|
|
# z is passed at each step, duplicate it
|
|
bs = z.shape[0]
|
|
expanded_z = z.unsqueeze(1).repeat(1, steps, 1)
|
|
|
|
# First step in the RNN
|
|
if hidden_and_cell is None:
|
|
# Initialize from latent vector
|
|
hidden_and_cell = self.hidden_cell_predictor(th.tanh(z))
|
|
hidden = hidden_and_cell[:, :self.hidden_size*self.num_layers]
|
|
hidden = hidden.view(-1, self.num_layers, self.hidden_size)
|
|
hidden = hidden.permute(1, 0, 2).contiguous()
|
|
cell = hidden_and_cell[:, self.hidden_size*self.num_layers:]
|
|
cell = cell.view(-1, self.num_layers, self.hidden_size)
|
|
cell = cell.permute(1, 0, 2).contiguous()
|
|
hidden_and_cell = (hidden, cell)
|
|
|
|
feats, hidden_and_cell = self.lstm(expanded_z, hidden_and_cell)
|
|
hidden, cell = hidden_and_cell
|
|
|
|
feats = feats.reshape(bs*steps, self.hidden_size)
|
|
|
|
# Construct the chain
|
|
end_points = self.point_predictor(feats).view(bs, steps, 1, 2)
|
|
start_points = th.cat([
|
|
# first point is canvas center
|
|
th.zeros(bs, 1, 1, 2, device=feats.device),
|
|
end_points[:, 1:, :, :]], 1)
|
|
all_points = th.cat([start_points, end_points], 2)
|
|
|
|
all_alphas = self.alpha_predictor(feats).view(bs, steps)
|
|
all_widths = self.width_predictor(feats).view(bs, steps)
|
|
|
|
min_width = self.stroke_width[0]
|
|
max_width = self.stroke_width[1]
|
|
all_widths = (max_width - min_width) * all_widths + min_width
|
|
|
|
output, scenes = rendering.line_render(all_points, all_widths, all_alphas,
|
|
canvas_size=self.imsize)
|
|
|
|
# map to [-1, 1]
|
|
output = output*2.0 - 1.0
|
|
|
|
return output, scenes
|
|
|
|
|
|
class Generator(BaseModel):
|
|
def __init__(self, width=64, imsize=32, zdim=128,
|
|
stroke_width=None,
|
|
color_output=False,
|
|
num_strokes=4):
|
|
super(Generator, self).__init__()
|
|
assert imsize == 32
|
|
|
|
self.imsize = imsize
|
|
self.zdim = zdim
|
|
|
|
num_in_chans = self.zdim // (2*2)
|
|
num_out_chans = 3 if color_output else 1
|
|
|
|
self.net = th.nn.Sequential(
|
|
th.nn.ConvTranspose2d(num_in_chans, width*8, 4, padding=1,
|
|
stride=2),
|
|
th.nn.LeakyReLU(0.2, inplace=True),
|
|
th.nn.Conv2d(width*8, width*8, 3, padding=1),
|
|
th.nn.BatchNorm2d(width*8),
|
|
th.nn.LeakyReLU(0.2, inplace=True),
|
|
# 4x4
|
|
|
|
th.nn.ConvTranspose2d(8*width, 4*width, 4, padding=1, stride=2),
|
|
th.nn.LeakyReLU(0.2, inplace=True),
|
|
th.nn.Conv2d(4*width, 4*width, 3, padding=1),
|
|
th.nn.BatchNorm2d(width*4),
|
|
th.nn.LeakyReLU(0.2, inplace=True),
|
|
# 8x8
|
|
|
|
th.nn.ConvTranspose2d(4*width, 2*width, 4, padding=1, stride=2),
|
|
th.nn.LeakyReLU(0.2, inplace=True),
|
|
th.nn.Conv2d(2*width, 2*width, 3, padding=1),
|
|
th.nn.BatchNorm2d(width*2),
|
|
th.nn.LeakyReLU(0.2, inplace=True),
|
|
# 16x16
|
|
|
|
th.nn.ConvTranspose2d(2*width, width, 4, padding=1, stride=2),
|
|
th.nn.LeakyReLU(0.2, inplace=True),
|
|
th.nn.Conv2d(width, width, 3, padding=1),
|
|
th.nn.BatchNorm2d(width),
|
|
th.nn.LeakyReLU(0.2, inplace=True),
|
|
# 32x32
|
|
|
|
th.nn.Conv2d(width, width, 3, padding=1),
|
|
th.nn.BatchNorm2d(width),
|
|
th.nn.LeakyReLU(0.2, inplace=True),
|
|
th.nn.Conv2d(width, width, 3, padding=1),
|
|
th.nn.LeakyReLU(0.2, inplace=True),
|
|
th.nn.Conv2d(width, num_out_chans, 1),
|
|
|
|
th.nn.Tanh(),
|
|
)
|
|
|
|
def forward(self, z):
|
|
bs = z.shape[0]
|
|
num_in_chans = self.zdim // (2*2)
|
|
raster = self.net(z.view(bs, num_in_chans, 2, 2))
|
|
return raster
|
|
|
|
|
|
class Discriminator(th.nn.Module):
|
|
def __init__(self, conditional=False, width=64, color_output=False):
|
|
super(Discriminator, self).__init__()
|
|
|
|
self.conditional = conditional
|
|
|
|
sn = th.nn.utils.spectral_norm
|
|
|
|
num_chan_in = 3 if color_output else 1
|
|
|
|
self.net = th.nn.Sequential(
|
|
th.nn.Conv2d(num_chan_in, width, 3, padding=1),
|
|
th.nn.LeakyReLU(0.2, inplace=True),
|
|
th.nn.Conv2d(width, 2*width, 4, padding=1, stride=2),
|
|
th.nn.LeakyReLU(0.2, inplace=True),
|
|
# 16x16
|
|
|
|
sn(th.nn.Conv2d(2*width, 2*width, 3, padding=1)),
|
|
th.nn.LeakyReLU(0.2, inplace=True),
|
|
sn(th.nn.Conv2d(2*width, 4*width, 4, padding=1, stride=2)),
|
|
th.nn.LeakyReLU(0.2, inplace=True),
|
|
# 8x8
|
|
|
|
sn(th.nn.Conv2d(4*width, 4*width, 3, padding=1)),
|
|
th.nn.LeakyReLU(0.2, inplace=True),
|
|
sn(th.nn.Conv2d(4*width, width*4, 4, padding=1, stride=2)),
|
|
th.nn.LeakyReLU(0.2, inplace=True),
|
|
# 4x4
|
|
|
|
sn(th.nn.Conv2d(4*width, 4*width, 3, padding=1)),
|
|
th.nn.LeakyReLU(0.2, inplace=True),
|
|
sn(th.nn.Conv2d(4*width, width*4, 4, padding=1, stride=2)),
|
|
th.nn.LeakyReLU(0.2, inplace=True),
|
|
# 2x2
|
|
|
|
modules.Flatten(),
|
|
th.nn.Linear(width*4*2*2, 1),
|
|
)
|
|
|
|
def forward(self, x):
|
|
out = self.net(x)
|
|
return out
|