100 lines
3.3 KiB
Python
100 lines
3.3 KiB
Python
"""Losses for the generative models and baselines."""
|
|
import torch as th
|
|
import numpy as np
|
|
|
|
import ttools.modules.image_operators as imops
|
|
|
|
|
|
class KLDivergence(th.nn.Module):
|
|
"""
|
|
Args:
|
|
min_value(float): the loss is clipped so that value below this
|
|
number don't affect the optimization.
|
|
"""
|
|
def __init__(self, min_value=0.2):
|
|
super(KLDivergence, self).__init__()
|
|
self.min_value = min_value
|
|
|
|
def forward(self, mu, log_sigma):
|
|
loss = -0.5 * (1.0 + log_sigma - mu.pow(2) - log_sigma.exp())
|
|
loss = loss.mean()
|
|
loss = th.max(loss, self.min_value*th.ones_like(loss))
|
|
return loss
|
|
|
|
|
|
class MultiscaleMSELoss(th.nn.Module):
|
|
def __init__(self, channels=3):
|
|
super(MultiscaleMSELoss, self).__init__()
|
|
self.blur = imops.GaussianBlur(1, channels=channels)
|
|
|
|
def forward(self, im, target):
|
|
bs, c, h, w = im.shape
|
|
num_levels = max(int(np.ceil(np.log2(h))) - 2, 1)
|
|
|
|
losses = []
|
|
for lvl in range(num_levels):
|
|
loss = th.nn.functional.mse_loss(im, target)
|
|
losses.append(loss)
|
|
im = th.nn.functional.interpolate(self.blur(im),
|
|
scale_factor=0.5,
|
|
mode="nearest")
|
|
target = th.nn.functional.interpolate(self.blur(target),
|
|
scale_factor=0.5,
|
|
mode="nearest")
|
|
|
|
losses = th.stack(losses)
|
|
return losses.sum()
|
|
|
|
|
|
def gaussian_pdfs(dx, dy, params):
|
|
"""Returns the pdf at (dx, dy) for each Gaussian in the mixture.
|
|
"""
|
|
dx = dx.unsqueeze(-1) # replicate dx, dy to evaluate all pdfs at once
|
|
dy = dy.unsqueeze(-1)
|
|
|
|
mu_x = params[..., 0]
|
|
mu_y = params[..., 1]
|
|
sigma_x = params[..., 2].exp()
|
|
sigma_y = params[..., 3].exp()
|
|
rho_xy = th.tanh(params[..., 4])
|
|
|
|
x = ((dx-mu_x) / sigma_x).pow(2)
|
|
y = ((dy-mu_y) / sigma_y).pow(2)
|
|
|
|
xy = (dx-mu_x)*(dy-mu_y) / (sigma_x * sigma_y)
|
|
arg = x + y - 2.0*rho_xy*xy
|
|
pdf = th.exp(-arg / (2*(1.0 - rho_xy.pow(2))))
|
|
norm = 2.0 * np.pi * sigma_x * sigma_y * (1.0 - rho_xy.pow(2)).sqrt()
|
|
|
|
return pdf / norm
|
|
|
|
|
|
class GaussianMixtureReconstructionLoss(th.nn.Module):
|
|
"""
|
|
Args:
|
|
"""
|
|
def __init__(self, eps=1e-5):
|
|
super(GaussianMixtureReconstructionLoss, self).__init__()
|
|
self.eps = eps
|
|
|
|
def forward(self, pen_logits, mixture_logits, gaussian_params, targets):
|
|
dx = targets[..., 0]
|
|
dy = targets[..., 1]
|
|
pen_state = targets[..., 2:].argmax(-1) # target index
|
|
|
|
# Likelihood loss on the stroke position
|
|
# No need to predict accurate pen position for end-of-sequence tokens
|
|
valid_stroke = (targets[..., -1] != 1.0).float()
|
|
mixture_weights = th.nn.functional.softmax(mixture_logits, -1)
|
|
pdfs = gaussian_pdfs(dx, dy, gaussian_params)
|
|
position_loss = - th.log(self.eps + (pdfs * mixture_weights).sum(-1))
|
|
|
|
# by actual non-empty count
|
|
position_loss = (position_loss*valid_stroke).sum() / valid_stroke.sum()
|
|
|
|
# Classification loss for the stroke mode
|
|
pen_loss = th.nn.functional.cross_entropy(pen_logits.view(-1, 3),
|
|
pen_state.view(-1))
|
|
|
|
return position_loss + pen_loss
|