initial commit
This commit is contained in:
229
apps/generative_models/data.py
Normal file
229
apps/generative_models/data.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import os
|
||||
import time
|
||||
import torch as th
|
||||
import numpy as np
|
||||
import torchvision.datasets as dset
|
||||
import torchvision.transforms as transforms
|
||||
import imageio
|
||||
|
||||
import ttools
|
||||
import rendering
|
||||
|
||||
BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir)
|
||||
DATA = os.path.join(BASE_DIR, "data")
|
||||
|
||||
LOG = ttools.get_logger(__name__)
|
||||
|
||||
|
||||
class QuickDrawImageDataset(th.utils.data.Dataset):
|
||||
BASE_DATA_URL = \
|
||||
"https://console.cloud.google.com/storage/browser/_details/quickdraw_dataset/full/numpy_bitmap/cat.npy"
|
||||
"""
|
||||
Args:
|
||||
spatial_limit(int): maximum spatial extent in pixels.
|
||||
"""
|
||||
def __init__(self, imsize, train=True):
|
||||
super(QuickDrawImageDataset, self).__init__()
|
||||
file = os.path.join(DATA, "cat.npy")
|
||||
|
||||
self.imsize = imsize
|
||||
|
||||
if not os.path.exists(file):
|
||||
msg = "Dataset file %s does not exist, please download"
|
||||
" it from %s" % (file, QuickDrawImageDataset.BASE_DATA_URL)
|
||||
LOG.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
self.data = np.load(file, allow_pickle=True, encoding="latin1")
|
||||
|
||||
def __len__(self):
|
||||
return self.data.shape[0]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
im = np.reshape(self.data[idx], (1, 1, 28, 28))
|
||||
im = th.from_numpy(im).float() / 255.0
|
||||
im = th.nn.functional.interpolate(im, size=(self.imsize, self.imsize))
|
||||
|
||||
# Bring it to [-1, 1]
|
||||
im = th.clamp(im, 0, 1)
|
||||
im -= 0.5
|
||||
im /= 0.5
|
||||
|
||||
return im.squeeze(0)
|
||||
|
||||
|
||||
class QuickDrawDataset(th.utils.data.Dataset):
|
||||
BASE_DATA_URL = \
|
||||
"https://storage.cloud.google.com/quickdraw_dataset/sketchrnn"
|
||||
|
||||
"""
|
||||
Args:
|
||||
spatial_limit(int): maximum spatial extent in pixels.
|
||||
"""
|
||||
def __init__(self, dataset, mode="train",
|
||||
max_seq_length=250,
|
||||
spatial_limit=1000):
|
||||
super(QuickDrawDataset, self).__init__()
|
||||
file = os.path.join(DATA, "sketchrnn_"+dataset)
|
||||
remote = os.path.join(QuickDrawDataset.BASE_DATA_URL, dataset)
|
||||
|
||||
self.max_seq_length = max_seq_length
|
||||
self.spatial_limit = spatial_limit
|
||||
|
||||
if mode not in ["train", "test", "valid"]:
|
||||
return ValueError("Only allowed data mode are 'train' and 'test',"
|
||||
" 'valid'.")
|
||||
|
||||
if not os.path.exists(file):
|
||||
msg = "Dataset file %s does not exist, please download"
|
||||
" it from %s" % (file, remote)
|
||||
LOG.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
data = np.load(file, allow_pickle=True, encoding="latin1")[mode]
|
||||
data = self.purify(data)
|
||||
data = self.normalize(data)
|
||||
|
||||
# Length of longest sequence in the dataset
|
||||
self.nmax = max([len(seq) for seq in data])
|
||||
self.sketches = data
|
||||
|
||||
def __repr__(self):
|
||||
return "Dataset with %d sequences of max length %d" % \
|
||||
(len(self.sketches), self.nmax)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sketches)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Return the idx-th stroke in 5-D format, padded to length (Nmax+2).
|
||||
|
||||
The first and last element of the sequence are fixed to "start-" and
|
||||
"end-of-sequence" token.
|
||||
|
||||
dx, dy, + 3 numbers for one-hot encoding of state:
|
||||
1 0 0: pen touching paper till next point
|
||||
0 1 0: pen lifted from paper after current point
|
||||
0 0 1: drawing has ended, next points (including current will not be
|
||||
drawn)
|
||||
"""
|
||||
sample_data = self.sketches[idx]
|
||||
|
||||
# Allow two extra slots for start/end of sequence tokens
|
||||
sample = np.zeros((self.nmax+2, 5), dtype=np.float32)
|
||||
|
||||
n = sample_data.shape[0]
|
||||
|
||||
# normalize dx, dy
|
||||
deltas = sample_data[:, :2]
|
||||
# Absolute coordinates
|
||||
positions = deltas[..., :2].cumsum(0)
|
||||
maxi = np.abs(positions).max() + 1e-8
|
||||
deltas = deltas / (1.1 * maxi) # leave some margin on edges
|
||||
|
||||
# fill in dx, dy coordinates
|
||||
sample[1:n+1, :2] = deltas
|
||||
|
||||
# on paper indicator: 0 means touching paper in the 3d format, flip it
|
||||
sample[1:n+1, 2] = 1 - sample_data[:, 2]
|
||||
|
||||
# off-paper indicator, complement of previous flag
|
||||
sample[1:n+1, 3] = 1 - sample[1:n+1, 2]
|
||||
|
||||
# fill with end of sequence tokens for the remainder
|
||||
sample[n+1:, 4] = 1
|
||||
|
||||
# Start of sequence token
|
||||
sample[0] = [0, 0, 1, 0, 0]
|
||||
|
||||
return sample
|
||||
|
||||
def purify(self, strokes):
|
||||
"""removes to small or too long sequences + removes large gaps"""
|
||||
data = []
|
||||
for seq in strokes:
|
||||
if seq.shape[0] <= self.max_seq_length:
|
||||
# and seq.shape[0] > 10:
|
||||
|
||||
# Limit large spatial gaps
|
||||
seq = np.minimum(seq, self.spatial_limit)
|
||||
seq = np.maximum(seq, -self.spatial_limit)
|
||||
seq = np.array(seq, dtype=np.float32)
|
||||
data.append(seq)
|
||||
return data
|
||||
|
||||
def calculate_normalizing_scale_factor(self, strokes):
|
||||
"""Calculate the normalizing factor explained in appendix of
|
||||
sketch-rnn."""
|
||||
data = []
|
||||
for i, stroke_i in enumerate(strokes):
|
||||
for j, pt in enumerate(strokes[i]):
|
||||
data.append(pt[0])
|
||||
data.append(pt[1])
|
||||
data = np.array(data)
|
||||
return np.std(data)
|
||||
|
||||
def normalize(self, strokes):
|
||||
"""Normalize entire dataset (delta_x, delta_y) by the scaling
|
||||
factor."""
|
||||
data = []
|
||||
scale_factor = self.calculate_normalizing_scale_factor(strokes)
|
||||
for seq in strokes:
|
||||
seq[:, 0:2] /= scale_factor
|
||||
data.append(seq)
|
||||
return data
|
||||
|
||||
|
||||
class FixedLengthQuickDrawDataset(QuickDrawDataset):
|
||||
"""A variant of the QuickDraw dataset where the strokes are represented as
|
||||
a fixed-length sequence of triplets (dx, dy, opacity), where opacity = 0, 1.
|
||||
"""
|
||||
def __init__(self, *args, canvas_size=64, **kwargs):
|
||||
super(FixedLengthQuickDrawDataset, self).__init__(*args, **kwargs)
|
||||
self.canvas_size = canvas_size
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = super(FixedLengthQuickDrawDataset, self).__getitem__(idx)
|
||||
|
||||
# We construct a stroke opacity variable from the pen down state, dx, dy remain unchanged
|
||||
strokes = sample[:, :3]
|
||||
|
||||
im = np.zeros((1, 1))
|
||||
|
||||
# render image
|
||||
# start = time.time()
|
||||
im = rendering.opacityStroke2diffvg(
|
||||
th.from_numpy(strokes).unsqueeze(0), canvas_size=self.canvas_size,
|
||||
relative=True, debug=False)
|
||||
im = im.squeeze(0).numpy()
|
||||
# elapsed = (time.time() - start)*1000
|
||||
# print("item %d pipeline gt rendering took %.2fms" % (idx, elapsed))
|
||||
|
||||
return strokes, im
|
||||
|
||||
|
||||
class MNISTDataset(th.utils.data.Dataset):
|
||||
def __init__(self, imsize, train=True):
|
||||
super(MNISTDataset, self).__init__()
|
||||
self.mnist = dset.MNIST(root=os.path.join(DATA, "mnist"),
|
||||
train=train,
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.Resize((imsize, imsize)),
|
||||
transforms.ToTensor(),
|
||||
]))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.mnist)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
im, label = self.mnist[idx]
|
||||
|
||||
# make sure data uses [0, 1] range
|
||||
im -= im.min()
|
||||
im /= im.max() + 1e-8
|
||||
|
||||
# Bring it to [-1, 1]
|
||||
im -= 0.5
|
||||
im /= 0.5
|
||||
return im
|
Reference in New Issue
Block a user