230 lines
7.4 KiB
Python
230 lines
7.4 KiB
Python
import os
|
|
import time
|
|
import torch as th
|
|
import numpy as np
|
|
import torchvision.datasets as dset
|
|
import torchvision.transforms as transforms
|
|
import imageio
|
|
|
|
import ttools
|
|
import rendering
|
|
|
|
BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir)
|
|
DATA = os.path.join(BASE_DIR, "data")
|
|
|
|
LOG = ttools.get_logger(__name__)
|
|
|
|
|
|
class QuickDrawImageDataset(th.utils.data.Dataset):
|
|
BASE_DATA_URL = \
|
|
"https://console.cloud.google.com/storage/browser/_details/quickdraw_dataset/full/numpy_bitmap/cat.npy"
|
|
"""
|
|
Args:
|
|
spatial_limit(int): maximum spatial extent in pixels.
|
|
"""
|
|
def __init__(self, imsize, train=True):
|
|
super(QuickDrawImageDataset, self).__init__()
|
|
file = os.path.join(DATA, "cat.npy")
|
|
|
|
self.imsize = imsize
|
|
|
|
if not os.path.exists(file):
|
|
msg = "Dataset file %s does not exist, please download"
|
|
" it from %s" % (file, QuickDrawImageDataset.BASE_DATA_URL)
|
|
LOG.error(msg)
|
|
raise RuntimeError(msg)
|
|
|
|
self.data = np.load(file, allow_pickle=True, encoding="latin1")
|
|
|
|
def __len__(self):
|
|
return self.data.shape[0]
|
|
|
|
def __getitem__(self, idx):
|
|
im = np.reshape(self.data[idx], (1, 1, 28, 28))
|
|
im = th.from_numpy(im).float() / 255.0
|
|
im = th.nn.functional.interpolate(im, size=(self.imsize, self.imsize))
|
|
|
|
# Bring it to [-1, 1]
|
|
im = th.clamp(im, 0, 1)
|
|
im -= 0.5
|
|
im /= 0.5
|
|
|
|
return im.squeeze(0)
|
|
|
|
|
|
class QuickDrawDataset(th.utils.data.Dataset):
|
|
BASE_DATA_URL = \
|
|
"https://storage.cloud.google.com/quickdraw_dataset/sketchrnn"
|
|
|
|
"""
|
|
Args:
|
|
spatial_limit(int): maximum spatial extent in pixels.
|
|
"""
|
|
def __init__(self, dataset, mode="train",
|
|
max_seq_length=250,
|
|
spatial_limit=1000):
|
|
super(QuickDrawDataset, self).__init__()
|
|
file = os.path.join(DATA, "sketchrnn_"+dataset)
|
|
remote = os.path.join(QuickDrawDataset.BASE_DATA_URL, dataset)
|
|
|
|
self.max_seq_length = max_seq_length
|
|
self.spatial_limit = spatial_limit
|
|
|
|
if mode not in ["train", "test", "valid"]:
|
|
return ValueError("Only allowed data mode are 'train' and 'test',"
|
|
" 'valid'.")
|
|
|
|
if not os.path.exists(file):
|
|
msg = "Dataset file %s does not exist, please download"
|
|
" it from %s" % (file, remote)
|
|
LOG.error(msg)
|
|
raise RuntimeError(msg)
|
|
|
|
data = np.load(file, allow_pickle=True, encoding="latin1")[mode]
|
|
data = self.purify(data)
|
|
data = self.normalize(data)
|
|
|
|
# Length of longest sequence in the dataset
|
|
self.nmax = max([len(seq) for seq in data])
|
|
self.sketches = data
|
|
|
|
def __repr__(self):
|
|
return "Dataset with %d sequences of max length %d" % \
|
|
(len(self.sketches), self.nmax)
|
|
|
|
def __len__(self):
|
|
return len(self.sketches)
|
|
|
|
def __getitem__(self, idx):
|
|
"""Return the idx-th stroke in 5-D format, padded to length (Nmax+2).
|
|
|
|
The first and last element of the sequence are fixed to "start-" and
|
|
"end-of-sequence" token.
|
|
|
|
dx, dy, + 3 numbers for one-hot encoding of state:
|
|
1 0 0: pen touching paper till next point
|
|
0 1 0: pen lifted from paper after current point
|
|
0 0 1: drawing has ended, next points (including current will not be
|
|
drawn)
|
|
"""
|
|
sample_data = self.sketches[idx]
|
|
|
|
# Allow two extra slots for start/end of sequence tokens
|
|
sample = np.zeros((self.nmax+2, 5), dtype=np.float32)
|
|
|
|
n = sample_data.shape[0]
|
|
|
|
# normalize dx, dy
|
|
deltas = sample_data[:, :2]
|
|
# Absolute coordinates
|
|
positions = deltas[..., :2].cumsum(0)
|
|
maxi = np.abs(positions).max() + 1e-8
|
|
deltas = deltas / (1.1 * maxi) # leave some margin on edges
|
|
|
|
# fill in dx, dy coordinates
|
|
sample[1:n+1, :2] = deltas
|
|
|
|
# on paper indicator: 0 means touching paper in the 3d format, flip it
|
|
sample[1:n+1, 2] = 1 - sample_data[:, 2]
|
|
|
|
# off-paper indicator, complement of previous flag
|
|
sample[1:n+1, 3] = 1 - sample[1:n+1, 2]
|
|
|
|
# fill with end of sequence tokens for the remainder
|
|
sample[n+1:, 4] = 1
|
|
|
|
# Start of sequence token
|
|
sample[0] = [0, 0, 1, 0, 0]
|
|
|
|
return sample
|
|
|
|
def purify(self, strokes):
|
|
"""removes to small or too long sequences + removes large gaps"""
|
|
data = []
|
|
for seq in strokes:
|
|
if seq.shape[0] <= self.max_seq_length:
|
|
# and seq.shape[0] > 10:
|
|
|
|
# Limit large spatial gaps
|
|
seq = np.minimum(seq, self.spatial_limit)
|
|
seq = np.maximum(seq, -self.spatial_limit)
|
|
seq = np.array(seq, dtype=np.float32)
|
|
data.append(seq)
|
|
return data
|
|
|
|
def calculate_normalizing_scale_factor(self, strokes):
|
|
"""Calculate the normalizing factor explained in appendix of
|
|
sketch-rnn."""
|
|
data = []
|
|
for i, stroke_i in enumerate(strokes):
|
|
for j, pt in enumerate(strokes[i]):
|
|
data.append(pt[0])
|
|
data.append(pt[1])
|
|
data = np.array(data)
|
|
return np.std(data)
|
|
|
|
def normalize(self, strokes):
|
|
"""Normalize entire dataset (delta_x, delta_y) by the scaling
|
|
factor."""
|
|
data = []
|
|
scale_factor = self.calculate_normalizing_scale_factor(strokes)
|
|
for seq in strokes:
|
|
seq[:, 0:2] /= scale_factor
|
|
data.append(seq)
|
|
return data
|
|
|
|
|
|
class FixedLengthQuickDrawDataset(QuickDrawDataset):
|
|
"""A variant of the QuickDraw dataset where the strokes are represented as
|
|
a fixed-length sequence of triplets (dx, dy, opacity), where opacity = 0, 1.
|
|
"""
|
|
def __init__(self, *args, canvas_size=64, **kwargs):
|
|
super(FixedLengthQuickDrawDataset, self).__init__(*args, **kwargs)
|
|
self.canvas_size = canvas_size
|
|
|
|
def __getitem__(self, idx):
|
|
sample = super(FixedLengthQuickDrawDataset, self).__getitem__(idx)
|
|
|
|
# We construct a stroke opacity variable from the pen down state, dx, dy remain unchanged
|
|
strokes = sample[:, :3]
|
|
|
|
im = np.zeros((1, 1))
|
|
|
|
# render image
|
|
# start = time.time()
|
|
im = rendering.opacityStroke2diffvg(
|
|
th.from_numpy(strokes).unsqueeze(0), canvas_size=self.canvas_size,
|
|
relative=True, debug=False)
|
|
im = im.squeeze(0).numpy()
|
|
# elapsed = (time.time() - start)*1000
|
|
# print("item %d pipeline gt rendering took %.2fms" % (idx, elapsed))
|
|
|
|
return strokes, im
|
|
|
|
|
|
class MNISTDataset(th.utils.data.Dataset):
|
|
def __init__(self, imsize, train=True):
|
|
super(MNISTDataset, self).__init__()
|
|
self.mnist = dset.MNIST(root=os.path.join(DATA, "mnist"),
|
|
train=train,
|
|
download=True,
|
|
transform=transforms.Compose([
|
|
transforms.Resize((imsize, imsize)),
|
|
transforms.ToTensor(),
|
|
]))
|
|
|
|
def __len__(self):
|
|
return len(self.mnist)
|
|
|
|
def __getitem__(self, idx):
|
|
im, label = self.mnist[idx]
|
|
|
|
# make sure data uses [0, 1] range
|
|
im -= im.min()
|
|
im /= im.max() + 1e-8
|
|
|
|
# Bring it to [-1, 1]
|
|
im -= 0.5
|
|
im /= 0.5
|
|
return im
|