initial commit
This commit is contained in:
99
apps/generative_models/losses.py
Normal file
99
apps/generative_models/losses.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Losses for the generative models and baselines."""
|
||||
import torch as th
|
||||
import numpy as np
|
||||
|
||||
import ttools.modules.image_operators as imops
|
||||
|
||||
|
||||
class KLDivergence(th.nn.Module):
|
||||
"""
|
||||
Args:
|
||||
min_value(float): the loss is clipped so that value below this
|
||||
number don't affect the optimization.
|
||||
"""
|
||||
def __init__(self, min_value=0.2):
|
||||
super(KLDivergence, self).__init__()
|
||||
self.min_value = min_value
|
||||
|
||||
def forward(self, mu, log_sigma):
|
||||
loss = -0.5 * (1.0 + log_sigma - mu.pow(2) - log_sigma.exp())
|
||||
loss = loss.mean()
|
||||
loss = th.max(loss, self.min_value*th.ones_like(loss))
|
||||
return loss
|
||||
|
||||
|
||||
class MultiscaleMSELoss(th.nn.Module):
|
||||
def __init__(self, channels=3):
|
||||
super(MultiscaleMSELoss, self).__init__()
|
||||
self.blur = imops.GaussianBlur(1, channels=channels)
|
||||
|
||||
def forward(self, im, target):
|
||||
bs, c, h, w = im.shape
|
||||
num_levels = max(int(np.ceil(np.log2(h))) - 2, 1)
|
||||
|
||||
losses = []
|
||||
for lvl in range(num_levels):
|
||||
loss = th.nn.functional.mse_loss(im, target)
|
||||
losses.append(loss)
|
||||
im = th.nn.functional.interpolate(self.blur(im),
|
||||
scale_factor=0.5,
|
||||
mode="nearest")
|
||||
target = th.nn.functional.interpolate(self.blur(target),
|
||||
scale_factor=0.5,
|
||||
mode="nearest")
|
||||
|
||||
losses = th.stack(losses)
|
||||
return losses.sum()
|
||||
|
||||
|
||||
def gaussian_pdfs(dx, dy, params):
|
||||
"""Returns the pdf at (dx, dy) for each Gaussian in the mixture.
|
||||
"""
|
||||
dx = dx.unsqueeze(-1) # replicate dx, dy to evaluate all pdfs at once
|
||||
dy = dy.unsqueeze(-1)
|
||||
|
||||
mu_x = params[..., 0]
|
||||
mu_y = params[..., 1]
|
||||
sigma_x = params[..., 2].exp()
|
||||
sigma_y = params[..., 3].exp()
|
||||
rho_xy = th.tanh(params[..., 4])
|
||||
|
||||
x = ((dx-mu_x) / sigma_x).pow(2)
|
||||
y = ((dy-mu_y) / sigma_y).pow(2)
|
||||
|
||||
xy = (dx-mu_x)*(dy-mu_y) / (sigma_x * sigma_y)
|
||||
arg = x + y - 2.0*rho_xy*xy
|
||||
pdf = th.exp(-arg / (2*(1.0 - rho_xy.pow(2))))
|
||||
norm = 2.0 * np.pi * sigma_x * sigma_y * (1.0 - rho_xy.pow(2)).sqrt()
|
||||
|
||||
return pdf / norm
|
||||
|
||||
|
||||
class GaussianMixtureReconstructionLoss(th.nn.Module):
|
||||
"""
|
||||
Args:
|
||||
"""
|
||||
def __init__(self, eps=1e-5):
|
||||
super(GaussianMixtureReconstructionLoss, self).__init__()
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, pen_logits, mixture_logits, gaussian_params, targets):
|
||||
dx = targets[..., 0]
|
||||
dy = targets[..., 1]
|
||||
pen_state = targets[..., 2:].argmax(-1) # target index
|
||||
|
||||
# Likelihood loss on the stroke position
|
||||
# No need to predict accurate pen position for end-of-sequence tokens
|
||||
valid_stroke = (targets[..., -1] != 1.0).float()
|
||||
mixture_weights = th.nn.functional.softmax(mixture_logits, -1)
|
||||
pdfs = gaussian_pdfs(dx, dy, gaussian_params)
|
||||
position_loss = - th.log(self.eps + (pdfs * mixture_weights).sum(-1))
|
||||
|
||||
# by actual non-empty count
|
||||
position_loss = (position_loss*valid_stroke).sum() / valid_stroke.sum()
|
||||
|
||||
# Classification loss for the stroke mode
|
||||
pen_loss = th.nn.functional.cross_entropy(pen_logits.view(-1, 3),
|
||||
pen_state.view(-1))
|
||||
|
||||
return position_loss + pen_loss
|
Reference in New Issue
Block a user