initial commit

This commit is contained in:
Tzu-Mao Li
2020-09-03 22:30:30 -04:00
commit 413a3e5cee
148 changed files with 138536 additions and 0 deletions

9
pydiffvg/__init__.py Normal file
View File

@@ -0,0 +1,9 @@
from .device import *
from .shape import *
from .pixel_filter import *
from .render_pytorch import *
from .image import *
from .parse_svg import *
from .color import *
from .optimize_svg import *
from .save_svg import *

24
pydiffvg/color.py Normal file
View File

@@ -0,0 +1,24 @@
import pydiffvg
import torch
class LinearGradient:
def __init__(self,
begin = torch.tensor([0.0, 0.0]),
end = torch.tensor([0.0, 0.0]),
offsets = torch.tensor([0.0]),
stop_colors = torch.tensor([0.0, 0.0, 0.0, 0.0])):
self.begin = begin
self.end = end
self.offsets = offsets
self.stop_colors = stop_colors
class RadialGradient:
def __init__(self,
center = torch.tensor([0.0, 0.0]),
radius = torch.tensor([0.0, 0.0]),
offsets = torch.tensor([0.0]),
stop_colors = torch.tensor([0.0, 0.0, 0.0, 0.0])):
self.center = center
self.radius = radius
self.offsets = offsets
self.stop_colors = stop_colors

25
pydiffvg/device.py Normal file
View File

@@ -0,0 +1,25 @@
import torch
use_gpu = torch.cuda.is_available()
device = torch.device('cuda') if use_gpu else torch.device('cpu')
def set_use_gpu(v):
global use_gpu
global device
use_gpu = v
if not use_gpu:
device = torch.device('cpu')
def get_use_gpu():
global use_gpu
return use_gpu
def set_device(d):
global device
global use_gpu
device = d
use_gpu = device.type == 'cuda'
def get_device():
global device
return device

22
pydiffvg/image.py Normal file
View File

@@ -0,0 +1,22 @@
import numpy as np
import skimage
import skimage.io
import os
def imwrite(img, filename, gamma = 2.2, normalize = False):
directory = os.path.dirname(filename)
if directory != '' and not os.path.exists(directory):
os.makedirs(directory)
if not isinstance(img, np.ndarray):
img = img.data.numpy()
if normalize:
img_rng = np.max(img) - np.min(img)
if img_rng > 0:
img = (img - np.min(img)) / img_rng
img = np.clip(img, 0.0, 1.0)
if img.ndim==2:
#repeat along the third dimension
img=np.expand_dims(img,2)
img[:, :, :3] = np.power(img[:, :, :3], 1.0/gamma)
skimage.io.imsave(filename, (img * 255).astype(np.uint8))

1606
pydiffvg/optimize_svg.py Normal file

File diff suppressed because it is too large Load Diff

578
pydiffvg/parse_svg.py Normal file
View File

@@ -0,0 +1,578 @@
import torch
import xml.etree.ElementTree as etree
import numpy as np
import diffvg
import os
import pydiffvg
import svgpathtools
import svgpathtools.parser
import re
import warnings
import cssutils
import logging
cssutils.log.setLevel(logging.ERROR)
def remove_namespaces(s):
"""
{...} ... -> ...
"""
return re.sub('{.*}', '', s)
def parse_style(s, defs):
style_dict = {}
for e in s.split(';'):
key_value = e.split(':')
if len(key_value) == 2:
key = key_value[0].strip()
value = key_value[1].strip()
if key == 'fill' or key == 'stroke':
# Special case: convert colors into tensor in definitions so
# that different shapes can share the same color
value = parse_color(value, defs)
style_dict[key] = value
return style_dict
def parse_hex(s):
"""
Hex to tuple
"""
s = s.lstrip('#')
if len(s) == 3:
s = s[0] + s[0] + s[1] + s[1] + s[2] + s[2]
rgb = tuple(int(s[i:i+2], 16) for i in (0, 2, 4))
# sRGB to RGB
# return torch.pow(torch.tensor([rgb[0] / 255.0, rgb[1] / 255.0, rgb[2] / 255.0]), 2.2)
return torch.pow(torch.tensor([rgb[0] / 255.0, rgb[1] / 255.0, rgb[2] / 255.0]), 1.0)
def parse_int(s):
"""
trim alphabets
"""
return int(float(''.join(i for i in s if (not i.isalpha()))))
def parse_color(s, defs):
if s is None:
return None
if isinstance(s, torch.Tensor):
return s
s = s.lstrip(' ')
color = torch.tensor([0.0, 0.0, 0.0, 1.0])
if s[0] == '#':
color[:3] = parse_hex(s)
elif s[:3] == 'url':
# url(#id)
color = defs[s[4:-1].lstrip('#')]
elif s == 'none':
color = None
elif s[:4] == 'rgb(':
rgb = s[4:-1].split(',')
color = torch.tensor([int(rgb[0]) / 255.0, int(rgb[1]) / 255.0, int(rgb[2]) / 255.0, 1.0])
elif s == 'none':
return None
else:
warnings.warn('Unknown color command ' + s)
return color
# https://github.com/mathandy/svgpathtools/blob/7ebc56a831357379ff22216bec07e2c12e8c5bc6/svgpathtools/parser.py
def _parse_transform_substr(transform_substr):
type_str, value_str = transform_substr.split('(')
value_str = value_str.replace(',', ' ')
values = list(map(float, filter(None, value_str.split(' '))))
transform = np.identity(3)
if 'matrix' in type_str:
transform[0:2, 0:3] = np.array([values[0:6:2], values[1:6:2]])
elif 'translate' in transform_substr:
transform[0, 2] = values[0]
if len(values) > 1:
transform[1, 2] = values[1]
elif 'scale' in transform_substr:
x_scale = values[0]
y_scale = values[1] if (len(values) > 1) else x_scale
transform[0, 0] = x_scale
transform[1, 1] = y_scale
elif 'rotate' in transform_substr:
angle = values[0] * np.pi / 180.0
if len(values) == 3:
offset = values[1:3]
else:
offset = (0, 0)
tf_offset = np.identity(3)
tf_offset[0:2, 2:3] = np.array([[offset[0]], [offset[1]]])
tf_rotate = np.identity(3)
tf_rotate[0:2, 0:2] = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]])
tf_offset_neg = np.identity(3)
tf_offset_neg[0:2, 2:3] = np.array([[-offset[0]], [-offset[1]]])
transform = tf_offset.dot(tf_rotate).dot(tf_offset_neg)
elif 'skewX' in transform_substr:
transform[0, 1] = np.tan(values[0] * np.pi / 180.0)
elif 'skewY' in transform_substr:
transform[1, 0] = np.tan(values[0] * np.pi / 180.0)
else:
# Return an identity matrix if the type of transform is unknown, and warn the user
warnings.warn('Unknown SVG transform type: {0}'.format(type_str))
return transform
def parse_transform(transform_str):
"""
Converts a valid SVG transformation string into a 3x3 matrix.
If the string is empty or null, this returns a 3x3 identity matrix
"""
if not transform_str:
return np.identity(3)
elif not isinstance(transform_str, str):
raise TypeError('Must provide a string to parse')
total_transform = np.identity(3)
transform_substrs = transform_str.split(')')[:-1] # Skip the last element, because it should be empty
for substr in transform_substrs:
total_transform = total_transform.dot(_parse_transform_substr(substr))
return torch.from_numpy(total_transform).type(torch.float32)
def parse_linear_gradient(node, transform, defs):
begin = torch.tensor([0.0, 0.0])
end = torch.tensor([0.0, 0.0])
offsets = []
stop_colors = []
# Inherit from parent
for key in node.attrib:
if remove_namespaces(key) == 'href':
value = node.attrib[key]
parent = defs[value.lstrip('#')]
begin = parent.begin
end = parent.end
offsets = parent.offsets
stop_colors = parent.stop_colors
for attrib in node.attrib:
attrib = remove_namespaces(attrib)
if attrib == 'x1':
begin[0] = float(node.attrib['x1'])
elif attrib == 'y1':
begin[1] = float(node.attrib['y1'])
elif attrib == 'x2':
end[0] = float(node.attrib['x2'])
elif attrib == 'y2':
end[1] = float(node.attrib['y2'])
elif attrib == 'gradientTransform':
transform = transform @ parse_transform(node.attrib['gradientTransform'])
begin = transform @ torch.cat((begin, torch.ones([1])))
begin = begin / begin[2]
begin = begin[:2]
end = transform @ torch.cat((end, torch.ones([1])))
end = end / end[2]
end = end[:2]
for child in node:
tag = remove_namespaces(child.tag)
if tag == 'stop':
offset = float(child.attrib['offset'])
color = [0.0, 0.0, 0.0, 1.0]
if 'stop-color' in child.attrib:
c = parse_color(child.attrib['stop-color'], defs)
color[:3] = [c[0], c[1], c[2]]
if 'stop-opacity' in child.attrib:
color[3] = float(child.attrib['stop-opacity'])
if 'style' in child.attrib:
style = parse_style(child.attrib['style'], defs)
if 'stop-color' in style:
c = parse_color(style['stop-color'], defs)
color[:3] = [c[0], c[1], c[2]]
if 'stop-opacity' in style:
color[3] = float(style['stop-opacity'])
offsets.append(offset)
stop_colors.append(color)
if isinstance(offsets, list):
offsets = torch.tensor(offsets)
if isinstance(stop_colors, list):
stop_colors = torch.tensor(stop_colors)
return pydiffvg.LinearGradient(begin, end, offsets, stop_colors)
def parse_radial_gradient(node, transform, defs):
begin = torch.tensor([0.0, 0.0])
end = torch.tensor([0.0, 0.0])
center = torch.tensor([0.0, 0.0])
radius = torch.tensor([0.0, 0.0])
offsets = []
stop_colors = []
# Inherit from parent
for key in node.attrib:
if remove_namespaces(key) == 'href':
value = node.attrib[key]
parent = defs[value.lstrip('#')]
begin = parent.begin
end = parent.end
offsets = parent.offsets
stop_colors = parent.stop_colors
for attrib in node.attrib:
attrib = remove_namespaces(attrib)
if attrib == 'cx':
center[0] = float(node.attrib['cx'])
elif attrib == 'cy':
center[1] = float(node.attrib['cy'])
elif attrib == 'fx':
radius[0] = float(node.attrib['fx'])
elif attrib == 'fy':
radius[1] = float(node.attrib['fy'])
elif attrib == 'fr':
radius[0] = float(node.attrib['fr'])
radius[1] = float(node.attrib['fr'])
elif attrib == 'gradientTransform':
transform = transform @ parse_transform(node.attrib['gradientTransform'])
# TODO: this is incorrect
center = transform @ torch.cat((center, torch.ones([1])))
center = center / center[2]
center = center[:2]
for child in node:
tag = remove_namespaces(child.tag)
if tag == 'stop':
offset = float(child.attrib['offset'])
color = [0.0, 0.0, 0.0, 1.0]
if 'stop-color' in child.attrib:
c = parse_color(child.attrib['stop-color'], defs)
color[:3] = [c[0], c[1], c[2]]
if 'stop-opacity' in child.attrib:
color[3] = float(child.attrib['stop-opacity'])
if 'style' in child.attrib:
style = parse_style(child.attrib['style'], defs)
if 'stop-color' in style:
c = parse_color(style['stop-color'], defs)
color[:3] = [c[0], c[1], c[2]]
if 'stop-opacity' in style:
color[3] = float(style['stop-opacity'])
offsets.append(offset)
stop_colors.append(color)
if isinstance(offsets, list):
offsets = torch.tensor(offsets)
if isinstance(stop_colors, list):
stop_colors = torch.tensor(stop_colors)
return pydiffvg.RadialGradient(begin, end, offsets, stop_colors)
def parse_stylesheet(node, transform, defs):
# collect CSS classes
sheet = cssutils.parseString(node.text)
for rule in sheet:
if hasattr(rule, 'selectorText') and hasattr(rule, 'style'):
name = rule.selectorText
if len(name) >= 2 and name[0] == '.':
defs[name[1:]] = parse_style(rule.style.getCssText(), defs)
return defs
def parse_defs(node, transform, defs):
for child in node:
tag = remove_namespaces(child.tag)
if tag == 'linearGradient':
if 'id' in child.attrib:
defs[child.attrib['id']] = parse_linear_gradient(child, transform, defs)
elif tag == 'radialGradient':
if 'id' in child.attrib:
defs[child.attrib['id']] = parse_radial_gradient(child, transform, defs)
elif tag == 'style':
defs = parse_stylesheet(child, transform, defs)
return defs
def parse_common_attrib(node, transform, fill_color, defs):
attribs = {}
if 'class' in node.attrib:
attribs.update(defs[node.attrib['class']])
attribs.update(node.attrib)
name = ''
if 'id' in node.attrib:
name = node.attrib['id']
stroke_color = None
stroke_width = torch.tensor(0.5)
use_even_odd_rule = False
new_transform = transform
if 'transform' in attribs:
new_transform = transform @ parse_transform(attribs['transform'])
if 'fill' in attribs:
fill_color = parse_color(attribs['fill'], defs)
fill_opacity = 1.0
if 'fill-opacity' in attribs:
fill_opacity *= float(attribs['fill-opacity'])
if 'opacity' in attribs:
fill_opacity *= float(attribs['opacity'])
# Ignore opacity if the color is a gradient
if isinstance(fill_color, torch.Tensor):
fill_color[3] = fill_opacity
if 'fill-rule' in attribs:
if attribs['fill-rule'] == "evenodd":
use_even_odd_rule = True
elif attribs['fill-rule'] == "nonzero":
use_even_odd_rule = False
else:
warnings.warn('Unknown fill-rule: {}'.format(attribs['fill-rule']))
if 'stroke' in attribs:
stroke_color = parse_color(attribs['stroke'], defs)
if 'stroke-width' in attribs:
stroke_width = attribs['stroke-width']
if stroke_width[-2:] == 'px':
stroke_width = stroke_width[:-2]
stroke_width = torch.tensor(float(stroke_width) / 2.0)
if 'style' in attribs:
style = parse_style(attribs['style'], defs)
if 'fill' in style:
fill_color = parse_color(style['fill'], defs)
fill_opacity = 1.0
if 'fill-opacity' in style:
fill_opacity *= float(style['fill-opacity'])
if 'opacity' in style:
fill_opacity *= float(style['opacity'])
if 'fill-rule' in style:
if style['fill-rule'] == "evenodd":
use_even_odd_rule = True
elif style['fill-rule'] == "nonzero":
use_even_odd_rule = False
else:
warnings.warn('Unknown fill-rule: {}'.format(style['fill-rule']))
# Ignore opacity if the color is a gradient
if isinstance(fill_color, torch.Tensor):
fill_color[3] = fill_opacity
if 'stroke' in style:
if style['stroke'] != 'none':
stroke_color = parse_color(style['stroke'], defs)
# Ignore opacity if the color is a gradient
if isinstance(stroke_color, torch.Tensor):
if 'stroke-opacity' in style:
stroke_color[3] = float(style['stroke-opacity'])
if 'opacity' in style:
stroke_color[3] *= float(style['opacity'])
if 'stroke-width' in style:
stroke_width = style['stroke-width']
if stroke_width[-2:] == 'px':
stroke_width = stroke_width[:-2]
stroke_width = torch.tensor(float(stroke_width) / 2.0)
if isinstance(fill_color, pydiffvg.LinearGradient):
fill_color.begin = new_transform @ torch.cat((fill_color.begin, torch.ones([1])))
fill_color.begin = fill_color.begin / fill_color.begin[2]
fill_color.begin = fill_color.begin[:2]
fill_color.end = new_transform @ torch.cat((fill_color.end, torch.ones([1])))
fill_color.end = fill_color.end / fill_color.end[2]
fill_color.end = fill_color.end[:2]
if isinstance(stroke_color, pydiffvg.LinearGradient):
stroke_color.begin = new_transform @ torch.cat((stroke_color.begin, torch.ones([1])))
stroke_color.begin = stroke_color.begin / stroke_color.begin[2]
stroke_color.begin = stroke_color.begin[:2]
stroke_color.end = new_transform @ torch.cat((stroke_color.end, torch.ones([1])))
stroke_color.end = stroke_color.end / stroke_color.end[2]
stroke_color.end = stroke_color.end[:2]
if 'filter' in style:
print('*** WARNING ***: Ignoring filter for path with id "{}"'.format(name))
return new_transform, fill_color, stroke_color, stroke_width, use_even_odd_rule
def is_shape(tag):
return tag == 'path' or tag == 'polygon' or tag == 'line' or tag == 'circle' or tag == 'rect'
def parse_shape(node, transform, fill_color, shapes, shape_groups, defs):
tag = remove_namespaces(node.tag)
new_transform, new_fill_color, stroke_color, stroke_width, use_even_odd_rule = \
parse_common_attrib(node, transform, fill_color, defs)
if tag == 'path':
d = node.attrib['d']
name = ''
if 'id' in node.attrib:
name = node.attrib['id']
force_closing = new_fill_color is not None
paths = pydiffvg.from_svg_path(d, new_transform, force_closing)
for idx, path in enumerate(paths):
assert(path.points.shape[1] == 2)
path.stroke_width = stroke_width
path.source_id = name
path.id = "{}-{}".format(name,idx) if len(paths)>1 else name
prev_shapes_size = len(shapes)
shapes = shapes + paths
shape_ids = torch.tensor(list(range(prev_shapes_size, len(shapes))))
shape_groups.append(pydiffvg.ShapeGroup(\
shape_ids = shape_ids,
fill_color = new_fill_color,
stroke_color = stroke_color,
use_even_odd_rule = use_even_odd_rule,
id = name))
elif tag == 'polygon':
name = ''
if 'id' in node.attrib:
name = node.attrib['id']
force_closing = new_fill_color is not None
pts = node.attrib['points'].strip()
pts = pts.split(' ')
# import ipdb; ipdb.set_trace()
pts = [[float(y) for y in re.split(',| ', x)] for x in pts if x]
pts = torch.tensor(pts, dtype=torch.float32).view(-1, 2)
polygon = pydiffvg.Polygon(pts, force_closing)
polygon.stroke_width = stroke_width
shape_ids = torch.tensor([len(shapes)])
shapes.append(polygon)
shape_groups.append(pydiffvg.ShapeGroup(\
shape_ids = shape_ids,
fill_color = new_fill_color,
stroke_color = stroke_color,
use_even_odd_rule = use_even_odd_rule,
shape_to_canvas = new_transform,
id = name))
elif tag == 'line':
x1 = float(node.attrib['x1'])
y1 = float(node.attrib['y1'])
x2 = float(node.attrib['x2'])
y2 = float(node.attrib['y2'])
p1 = torch.tensor([x1, y1])
p2 = torch.tensor([x2, y2])
points = torch.stack((p1, p2))
line = pydiffvg.Polygon(points, False)
line.stroke_width = stroke_width
shape_ids = torch.tensor([len(shapes)])
shapes.append(line)
shape_groups.append(pydiffvg.ShapeGroup(\
shape_ids = shape_ids,
fill_color = new_fill_color,
stroke_color = stroke_color,
use_even_odd_rule = use_even_odd_rule,
shape_to_canvas = new_transform))
elif tag == 'circle':
radius = float(node.attrib['r'])
cx = float(node.attrib['cx'])
cy = float(node.attrib['cy'])
name = ''
if 'id' in node.attrib:
name = node.attrib['id']
center = torch.tensor([cx, cy])
circle = pydiffvg.Circle(radius = torch.tensor(radius),
center = center)
circle.stroke_width = stroke_width
shape_ids = torch.tensor([len(shapes)])
shapes.append(circle)
shape_groups.append(pydiffvg.ShapeGroup(\
shape_ids = shape_ids,
fill_color = new_fill_color,
stroke_color = stroke_color,
use_even_odd_rule = use_even_odd_rule,
shape_to_canvas = new_transform))
elif tag == 'ellipse':
rx = float(node.attrib['rx'])
ry = float(node.attrib['ry'])
cx = float(node.attrib['cx'])
cy = float(node.attrib['cy'])
name = ''
if 'id' in node.attrib:
name = node.attrib['id']
center = torch.tensor([cx, cy])
circle = pydiffvg.Circle(radius = torch.tensor(radius),
center = center)
circle.stroke_width = stroke_width
shape_ids = torch.tensor([len(shapes)])
shapes.append(circle)
shape_groups.append(pydiffvg.ShapeGroup(\
shape_ids = shape_ids,
fill_color = new_fill_color,
stroke_color = stroke_color,
use_even_odd_rule = use_even_odd_rule,
shape_to_canvas = new_transform))
elif tag == 'rect':
x = 0.0
y = 0.0
if x in node.attrib:
x = float(node.attrib['x'])
if y in node.attrib:
y = float(node.attrib['y'])
w = float(node.attrib['width'])
h = float(node.attrib['height'])
p_min = torch.tensor([x, y])
p_max = torch.tensor([x + w, x + h])
rect = pydiffvg.Rect(p_min = p_min, p_max = p_max)
rect.stroke_width = stroke_width
shape_ids = torch.tensor([len(shapes)])
shapes.append(rect)
shape_groups.append(pydiffvg.ShapeGroup(\
shape_ids = shape_ids,
fill_color = new_fill_color,
stroke_color = stroke_color,
use_even_odd_rule = use_even_odd_rule,
shape_to_canvas = new_transform))
return shapes, shape_groups
def parse_group(node, transform, fill_color, shapes, shape_groups, defs):
if 'transform' in node.attrib:
transform = transform @ parse_transform(node.attrib['transform'])
if 'fill' in node.attrib:
fill_color = parse_color(node.attrib['fill'], defs)
for child in node:
tag = remove_namespaces(child.tag)
if is_shape(tag):
shapes, shape_groups = parse_shape(\
child, transform, fill_color, shapes, shape_groups, defs)
elif tag == 'g':
shapes, shape_groups = parse_group(\
child, transform, fill_color, shapes, shape_groups, defs)
return shapes, shape_groups
def parse_scene(node):
canvas_width = -1
canvas_height = -1
defs = {}
shapes = []
shape_groups = []
fill_color = torch.tensor([0.0, 0.0, 0.0, 1.0])
transform = torch.eye(3)
if 'viewBox' in node.attrib:
view_box_array = node.attrib['viewBox'].split()
canvas_width = parse_int(view_box_array[2])
canvas_height = parse_int(view_box_array[3])
else:
if 'width' in node.attrib:
canvas_width = parse_int(node.attrib['width'])
else:
print('Warning: Can\'t find canvas width.')
if 'height' in node.attrib:
canvas_height = parse_int(node.attrib['height'])
else:
print('Warning: Can\'t find canvas height.')
for child in node:
tag = remove_namespaces(child.tag)
if tag == 'defs':
defs = parse_defs(child, transform, defs)
elif tag == 'style':
defs = parse_stylesheet(child, transform, defs)
elif tag == 'linearGradient':
if 'id' in child.attrib:
defs[child.attrib['id']] = parse_linear_gradient(child, transform, defs)
elif tag == 'radialGradient':
if 'id' in child.attrib:
defs[child.attrib['id']] = parse_radial_gradient(child, transform, defs)
elif is_shape(tag):
shapes, shape_groups = parse_shape(\
child, transform, fill_color, shapes, shape_groups, defs)
elif tag == 'g':
shapes, shape_groups = parse_group(\
child, transform, fill_color, shapes, shape_groups, defs)
return canvas_width, canvas_height, shapes, shape_groups
def svg_to_scene(filename):
"""
Load from a SVG file and convert to PyTorch tensors.
"""
tree = etree.parse(filename)
root = tree.getroot()
cwd = os.getcwd()
if (os.path.dirname(filename) != ''):
os.chdir(os.path.dirname(filename))
ret = parse_scene(root)
os.chdir(cwd)
return ret

9
pydiffvg/pixel_filter.py Normal file
View File

@@ -0,0 +1,9 @@
import torch
import pydiffvg
class PixelFilter:
def __init__(self,
type,
radius = torch.tensor(0.5)):
self.type = type
self.radius = radius

870
pydiffvg/render_pytorch.py Normal file
View File

@@ -0,0 +1,870 @@
import torch
import diffvg
import pydiffvg
import time
from enum import IntEnum
import warnings
print_timing = False
def set_print_timing(val):
global print_timing
print_timing=val
class OutputType(IntEnum):
color = 1
sdf = 2
class RenderFunction(torch.autograd.Function):
"""
The PyTorch interface of diffvg.
"""
@staticmethod
def serialize_scene(canvas_width,
canvas_height,
shapes,
shape_groups,
filter = pydiffvg.PixelFilter(type = diffvg.FilterType.box,
radius = torch.tensor(0.5)),
output_type = OutputType.color,
use_prefiltering = False,
eval_positions = torch.tensor([])):
"""
Given a list of shapes, convert them to a linear list of argument,
so that we can use it in PyTorch.
"""
num_shapes = len(shapes)
num_shape_groups = len(shape_groups)
args = []
args.append(canvas_width)
args.append(canvas_height)
args.append(num_shapes)
args.append(num_shape_groups)
args.append(output_type)
args.append(use_prefiltering)
args.append(eval_positions.to(pydiffvg.get_device()))
for shape in shapes:
use_thickness = False
if isinstance(shape, pydiffvg.Circle):
assert(shape.center.is_contiguous())
args.append(diffvg.ShapeType.circle)
args.append(shape.radius.cpu())
args.append(shape.center.cpu())
elif isinstance(shape, pydiffvg.Ellipse):
assert(shape.radius.is_contiguous())
assert(shape.center.is_contiguous())
args.append(diffvg.ShapeType.ellipse)
args.append(shape.radius.cpu())
args.append(shape.center.cpu())
elif isinstance(shape, pydiffvg.Path):
assert(shape.num_control_points.is_contiguous())
assert(shape.points.is_contiguous())
assert(shape.points.shape[1] == 2)
assert(torch.isfinite(shape.points).all())
args.append(diffvg.ShapeType.path)
args.append(shape.num_control_points.to(torch.int32).cpu())
args.append(shape.points.cpu())
if len(shape.stroke_width.shape) > 0 and shape.stroke_width.shape[0] > 1:
assert(torch.isfinite(shape.stroke_width).all())
use_thickness = True
args.append(shape.stroke_width.cpu())
else:
args.append(None)
args.append(shape.is_closed)
args.append(shape.use_distance_approx)
elif isinstance(shape, pydiffvg.Polygon):
assert(shape.points.is_contiguous())
assert(shape.points.shape[1] == 2)
args.append(diffvg.ShapeType.path)
if shape.is_closed:
args.append(torch.zeros(shape.points.shape[0], dtype = torch.int32))
else:
args.append(torch.zeros(shape.points.shape[0] - 1, dtype = torch.int32))
args.append(shape.points.cpu())
args.append(None)
args.append(shape.is_closed)
args.append(False) # use_distance_approx
elif isinstance(shape, pydiffvg.Rect):
assert(shape.p_min.is_contiguous())
assert(shape.p_max.is_contiguous())
args.append(diffvg.ShapeType.rect)
args.append(shape.p_min.cpu())
args.append(shape.p_max.cpu())
else:
assert(False)
if use_thickness:
args.append(torch.tensor(0.0))
else:
args.append(shape.stroke_width.cpu())
for shape_group in shape_groups:
assert(shape_group.shape_ids.is_contiguous())
args.append(shape_group.shape_ids.to(torch.int32).cpu())
# Fill color
if shape_group.fill_color is None:
args.append(None)
elif isinstance(shape_group.fill_color, torch.Tensor):
assert(shape_group.fill_color.is_contiguous())
args.append(diffvg.ColorType.constant)
args.append(shape_group.fill_color.cpu())
elif isinstance(shape_group.fill_color, pydiffvg.LinearGradient):
assert(shape_group.fill_color.begin.is_contiguous())
assert(shape_group.fill_color.end.is_contiguous())
assert(shape_group.fill_color.offsets.is_contiguous())
assert(shape_group.fill_color.stop_colors.is_contiguous())
args.append(diffvg.ColorType.linear_gradient)
args.append(shape_group.fill_color.begin.cpu())
args.append(shape_group.fill_color.end.cpu())
args.append(shape_group.fill_color.offsets.cpu())
args.append(shape_group.fill_color.stop_colors.cpu())
elif isinstance(shape_group.fill_color, pydiffvg.RadialGradient):
assert(shape_group.fill_color.center.is_contiguous())
assert(shape_group.fill_color.radius.is_contiguous())
assert(shape_group.fill_color.offsets.is_contiguous())
assert(shape_group.fill_color.stop_colors.is_contiguous())
args.append(diffvg.ColorType.radial_gradient)
args.append(shape_group.fill_color.center.cpu())
args.append(shape_group.fill_color.radius.cpu())
args.append(shape_group.fill_color.offsets.cpu())
args.append(shape_group.fill_color.stop_colors.cpu())
if shape_group.fill_color is not None:
# go through the underlying shapes and check if they are all closed
for shape_id in shape_group.shape_ids:
if isinstance(shapes[shape_id], pydiffvg.Path):
if not shapes[shape_id].is_closed:
warnings.warn("Detected non-closed paths with fill color. This might causes unexpected results.", Warning)
# Stroke color
if shape_group.stroke_color is None:
args.append(None)
elif isinstance(shape_group.stroke_color, torch.Tensor):
assert(shape_group.stroke_color.is_contiguous())
args.append(diffvg.ColorType.constant)
args.append(shape_group.stroke_color.cpu())
elif isinstance(shape_group.stroke_color, pydiffvg.LinearGradient):
assert(shape_group.stroke_color.begin.is_contiguous())
assert(shape_group.stroke_color.end.is_contiguous())
assert(shape_group.stroke_color.offsets.is_contiguous())
assert(shape_group.stroke_color.stop_colors.is_contiguous())
assert(torch.isfinite(shape_group.stroke_color.stop_colors).all())
args.append(diffvg.ColorType.linear_gradient)
args.append(shape_group.stroke_color.begin.cpu())
args.append(shape_group.stroke_color.end.cpu())
args.append(shape_group.stroke_color.offsets.cpu())
args.append(shape_group.stroke_color.stop_colors.cpu())
elif isinstance(shape_group.stroke_color, pydiffvg.RadialGradient):
assert(shape_group.stroke_color.center.is_contiguous())
assert(shape_group.stroke_color.radius.is_contiguous())
assert(shape_group.stroke_color.offsets.is_contiguous())
assert(shape_group.stroke_color.stop_colors.is_contiguous())
assert(torch.isfinite(shape_group.stroke_color.stop_colors).all())
args.append(diffvg.ColorType.radial_gradient)
args.append(shape_group.stroke_color.center.cpu())
args.append(shape_group.stroke_color.radius.cpu())
args.append(shape_group.stroke_color.offsets.cpu())
args.append(shape_group.stroke_color.stop_colors.cpu())
args.append(shape_group.use_even_odd_rule)
# Transformation
args.append(shape_group.shape_to_canvas.contiguous().cpu())
args.append(filter.type)
args.append(filter.radius.cpu())
return args
@staticmethod
def forward(ctx,
width,
height,
num_samples_x,
num_samples_y,
seed,
background_image,
*args):
"""
Forward rendering pass.
"""
# Unpack arguments
current_index = 0
canvas_width = args[current_index]
current_index += 1
canvas_height = args[current_index]
current_index += 1
num_shapes = args[current_index]
current_index += 1
num_shape_groups = args[current_index]
current_index += 1
output_type = args[current_index]
current_index += 1
use_prefiltering = args[current_index]
current_index += 1
eval_positions = args[current_index]
current_index += 1
shapes = []
shape_groups = []
shape_contents = [] # Important to avoid GC deleting the shapes
color_contents = [] # Same as above
for shape_id in range(num_shapes):
shape_type = args[current_index]
current_index += 1
if shape_type == diffvg.ShapeType.circle:
radius = args[current_index]
current_index += 1
center = args[current_index]
current_index += 1
shape = diffvg.Circle(radius, diffvg.Vector2f(center[0], center[1]))
elif shape_type == diffvg.ShapeType.ellipse:
radius = args[current_index]
current_index += 1
center = args[current_index]
current_index += 1
shape = diffvg.Ellipse(diffvg.Vector2f(radius[0], radius[1]),
diffvg.Vector2f(center[0], center[1]))
elif shape_type == diffvg.ShapeType.path:
num_control_points = args[current_index]
current_index += 1
points = args[current_index]
current_index += 1
thickness = args[current_index]
current_index += 1
is_closed = args[current_index]
current_index += 1
use_distance_approx = args[current_index]
current_index += 1
shape = diffvg.Path(diffvg.int_ptr(num_control_points.data_ptr()),
diffvg.float_ptr(points.data_ptr()),
diffvg.float_ptr(thickness.data_ptr() if thickness is not None else 0),
num_control_points.shape[0],
points.shape[0],
is_closed,
use_distance_approx)
elif shape_type == diffvg.ShapeType.rect:
p_min = args[current_index]
current_index += 1
p_max = args[current_index]
current_index += 1
shape = diffvg.Rect(diffvg.Vector2f(p_min[0], p_min[1]),
diffvg.Vector2f(p_max[0], p_max[1]))
else:
assert(False)
stroke_width = args[current_index]
current_index += 1
shapes.append(diffvg.Shape(\
shape_type, shape.get_ptr(), stroke_width.item()))
shape_contents.append(shape)
for shape_group_id in range(num_shape_groups):
shape_ids = args[current_index]
current_index += 1
fill_color_type = args[current_index]
current_index += 1
if fill_color_type == diffvg.ColorType.constant:
color = args[current_index]
current_index += 1
fill_color = diffvg.Constant(\
diffvg.Vector4f(color[0], color[1], color[2], color[3]))
elif fill_color_type == diffvg.ColorType.linear_gradient:
beg = args[current_index]
current_index += 1
end = args[current_index]
current_index += 1
offsets = args[current_index]
current_index += 1
stop_colors = args[current_index]
current_index += 1
assert(offsets.shape[0] == stop_colors.shape[0])
fill_color = diffvg.LinearGradient(diffvg.Vector2f(beg[0], beg[1]),
diffvg.Vector2f(end[0], end[1]),
offsets.shape[0],
diffvg.float_ptr(offsets.data_ptr()),
diffvg.float_ptr(stop_colors.data_ptr()))
elif fill_color_type == diffvg.ColorType.radial_gradient:
center = args[current_index]
current_index += 1
radius = args[current_index]
current_index += 1
offsets = args[current_index]
current_index += 1
stop_colors = args[current_index]
current_index += 1
assert(offsets.shape[0] == stop_colors.shape[0])
fill_color = diffvg.RadialGradient(diffvg.Vector2f(center[0], center[1]),
diffvg.Vector2f(radius[0], radius[1]),
offsets.shape[0],
diffvg.float_ptr(offsets.data_ptr()),
diffvg.float_ptr(stop_colors.data_ptr()))
elif fill_color_type is None:
fill_color = None
else:
assert(False)
stroke_color_type = args[current_index]
current_index += 1
if stroke_color_type == diffvg.ColorType.constant:
color = args[current_index]
current_index += 1
stroke_color = diffvg.Constant(\
diffvg.Vector4f(color[0], color[1], color[2], color[3]))
elif stroke_color_type == diffvg.ColorType.linear_gradient:
beg = args[current_index]
current_index += 1
end = args[current_index]
current_index += 1
offsets = args[current_index]
current_index += 1
stop_colors = args[current_index]
current_index += 1
assert(offsets.shape[0] == stop_colors.shape[0])
stroke_color = diffvg.LinearGradient(diffvg.Vector2f(beg[0], beg[1]),
diffvg.Vector2f(end[0], end[1]),
offsets.shape[0],
diffvg.float_ptr(offsets.data_ptr()),
diffvg.float_ptr(stop_colors.data_ptr()))
elif stroke_color_type == diffvg.ColorType.radial_gradient:
center = args[current_index]
current_index += 1
radius = args[current_index]
current_index += 1
offsets = args[current_index]
current_index += 1
stop_colors = args[current_index]
current_index += 1
assert(offsets.shape[0] == stop_colors.shape[0])
stroke_color = diffvg.RadialGradient(diffvg.Vector2f(center[0], center[1]),
diffvg.Vector2f(radius[0], radius[1]),
offsets.shape[0],
diffvg.float_ptr(offsets.data_ptr()),
diffvg.float_ptr(stop_colors.data_ptr()))
elif stroke_color_type is None:
stroke_color = None
else:
assert(False)
use_even_odd_rule = args[current_index]
current_index += 1
shape_to_canvas = args[current_index]
current_index += 1
if fill_color is not None:
color_contents.append(fill_color)
if stroke_color is not None:
color_contents.append(stroke_color)
shape_groups.append(diffvg.ShapeGroup(\
diffvg.int_ptr(shape_ids.data_ptr()),
shape_ids.shape[0],
diffvg.ColorType.constant if fill_color_type is None else fill_color_type,
diffvg.void_ptr(0) if fill_color is None else fill_color.get_ptr(),
diffvg.ColorType.constant if stroke_color_type is None else stroke_color_type,
diffvg.void_ptr(0) if stroke_color is None else stroke_color.get_ptr(),
use_even_odd_rule,
diffvg.float_ptr(shape_to_canvas.data_ptr())))
filter_type = args[current_index]
current_index += 1
filter_radius = args[current_index]
current_index += 1
filt = diffvg.Filter(filter_type, filter_radius)
start = time.time()
scene = diffvg.Scene(canvas_width, canvas_height,
shapes, shape_groups, filt, pydiffvg.get_use_gpu(),
pydiffvg.get_device().index if pydiffvg.get_device().index is not None else -1)
time_elapsed = time.time() - start
global print_timing
if print_timing:
print('Scene construction, time: %.5f s' % time_elapsed)
if output_type == OutputType.color:
assert(eval_positions.shape[0] == 0)
rendered_image = torch.zeros(height, width, 4, device = pydiffvg.get_device())
else:
assert(output_type == OutputType.sdf)
if eval_positions.shape[0] == 0:
rendered_image = torch.zeros(height, width, 1, device = pydiffvg.get_device())
else:
rendered_image = torch.zeros(eval_positions.shape[0], 1, device = pydiffvg.get_device())
if background_image is not None:
background_image = background_image.to(pydiffvg.get_device())
if background_image.shape[2] == 3:
background_image = torch.cat((\
background_image, torch.ones(background_image.shape[0], background_image.shape[1], 1,
device = background_image.device)), dim = 2)
background_image = background_image.contiguous()
assert(background_image.shape[0] == rendered_image.shape[0])
assert(background_image.shape[1] == rendered_image.shape[1])
assert(background_image.shape[2] == 4)
start = time.time()
diffvg.render(scene,
diffvg.float_ptr(background_image.data_ptr() if background_image is not None else 0),
diffvg.float_ptr(rendered_image.data_ptr() if output_type == OutputType.color else 0),
diffvg.float_ptr(rendered_image.data_ptr() if output_type == OutputType.sdf else 0),
width,
height,
num_samples_x,
num_samples_y,
seed,
diffvg.float_ptr(0), # d_background_image
diffvg.float_ptr(0), # d_render_image
diffvg.float_ptr(0), # d_render_sdf
diffvg.float_ptr(0), # d_translation
use_prefiltering,
diffvg.float_ptr(eval_positions.data_ptr()),
eval_positions.shape[0])
assert(torch.isfinite(rendered_image).all())
time_elapsed = time.time() - start
if print_timing:
print('Forward pass, time: %.5f s' % time_elapsed)
ctx.scene = scene
ctx.background_image = background_image
ctx.shape_contents = shape_contents
ctx.color_contents = color_contents
ctx.filter = filt
ctx.width = width
ctx.height = height
ctx.num_samples_x = num_samples_x
ctx.num_samples_y = num_samples_y
ctx.seed = seed
ctx.output_type = output_type
ctx.use_prefiltering = use_prefiltering
ctx.eval_positions = eval_positions
return rendered_image
@staticmethod
def render_grad(grad_img,
width,
height,
num_samples_x,
num_samples_y,
seed,
background_image,
*args):
if not grad_img.is_contiguous():
grad_img = grad_img.contiguous()
assert(torch.isfinite(grad_img).all())
# Unpack arguments
current_index = 0
canvas_width = args[current_index]
current_index += 1
canvas_height = args[current_index]
current_index += 1
num_shapes = args[current_index]
current_index += 1
num_shape_groups = args[current_index]
current_index += 1
output_type = args[current_index]
current_index += 1
use_prefiltering = args[current_index]
current_index += 1
eval_positions = args[current_index]
current_index += 1
shapes = []
shape_groups = []
shape_contents = [] # Important to avoid GC deleting the shapes
color_contents = [] # Same as above
for shape_id in range(num_shapes):
shape_type = args[current_index]
current_index += 1
if shape_type == diffvg.ShapeType.circle:
radius = args[current_index]
current_index += 1
center = args[current_index]
current_index += 1
shape = diffvg.Circle(radius, diffvg.Vector2f(center[0], center[1]))
elif shape_type == diffvg.ShapeType.ellipse:
radius = args[current_index]
current_index += 1
center = args[current_index]
current_index += 1
shape = diffvg.Ellipse(diffvg.Vector2f(radius[0], radius[1]),
diffvg.Vector2f(center[0], center[1]))
elif shape_type == diffvg.ShapeType.path:
num_control_points = args[current_index]
current_index += 1
points = args[current_index]
current_index += 1
thickness = args[current_index]
current_index += 1
is_closed = args[current_index]
current_index += 1
use_distance_approx = args[current_index]
current_index += 1
shape = diffvg.Path(diffvg.int_ptr(num_control_points.data_ptr()),
diffvg.float_ptr(points.data_ptr()),
diffvg.float_ptr(thickness.data_ptr() if thickness is not None else 0),
num_control_points.shape[0],
points.shape[0],
is_closed,
use_distance_approx)
elif shape_type == diffvg.ShapeType.rect:
p_min = args[current_index]
current_index += 1
p_max = args[current_index]
current_index += 1
shape = diffvg.Rect(diffvg.Vector2f(p_min[0], p_min[1]),
diffvg.Vector2f(p_max[0], p_max[1]))
else:
assert(False)
stroke_width = args[current_index]
current_index += 1
shapes.append(diffvg.Shape(\
shape_type, shape.get_ptr(), stroke_width.item()))
shape_contents.append(shape)
for shape_group_id in range(num_shape_groups):
shape_ids = args[current_index]
current_index += 1
fill_color_type = args[current_index]
current_index += 1
if fill_color_type == diffvg.ColorType.constant:
color = args[current_index]
current_index += 1
fill_color = diffvg.Constant(\
diffvg.Vector4f(color[0], color[1], color[2], color[3]))
elif fill_color_type == diffvg.ColorType.linear_gradient:
beg = args[current_index]
current_index += 1
end = args[current_index]
current_index += 1
offsets = args[current_index]
current_index += 1
stop_colors = args[current_index]
current_index += 1
assert(offsets.shape[0] == stop_colors.shape[0])
fill_color = diffvg.LinearGradient(diffvg.Vector2f(beg[0], beg[1]),
diffvg.Vector2f(end[0], end[1]),
offsets.shape[0],
diffvg.float_ptr(offsets.data_ptr()),
diffvg.float_ptr(stop_colors.data_ptr()))
elif fill_color_type == diffvg.ColorType.radial_gradient:
center = args[current_index]
current_index += 1
radius = args[current_index]
current_index += 1
offsets = args[current_index]
current_index += 1
stop_colors = args[current_index]
current_index += 1
assert(offsets.shape[0] == stop_colors.shape[0])
fill_color = diffvg.RadialGradient(diffvg.Vector2f(center[0], center[1]),
diffvg.Vector2f(radius[0], radius[1]),
offsets.shape[0],
diffvg.float_ptr(offsets.data_ptr()),
diffvg.float_ptr(stop_colors.data_ptr()))
elif fill_color_type is None:
fill_color = None
else:
assert(False)
stroke_color_type = args[current_index]
current_index += 1
if stroke_color_type == diffvg.ColorType.constant:
color = args[current_index]
current_index += 1
stroke_color = diffvg.Constant(\
diffvg.Vector4f(color[0], color[1], color[2], color[3]))
elif stroke_color_type == diffvg.ColorType.linear_gradient:
beg = args[current_index]
current_index += 1
end = args[current_index]
current_index += 1
offsets = args[current_index]
current_index += 1
stop_colors = args[current_index]
current_index += 1
assert(offsets.shape[0] == stop_colors.shape[0])
stroke_color = diffvg.LinearGradient(diffvg.Vector2f(beg[0], beg[1]),
diffvg.Vector2f(end[0], end[1]),
offsets.shape[0],
diffvg.float_ptr(offsets.data_ptr()),
diffvg.float_ptr(stop_colors.data_ptr()))
elif stroke_color_type == diffvg.ColorType.radial_gradient:
center = args[current_index]
current_index += 1
radius = args[current_index]
current_index += 1
offsets = args[current_index]
current_index += 1
stop_colors = args[current_index]
current_index += 1
assert(offsets.shape[0] == stop_colors.shape[0])
stroke_color = diffvg.RadialGradient(diffvg.Vector2f(center[0], center[1]),
diffvg.Vector2f(radius[0], radius[1]),
offsets.shape[0],
diffvg.float_ptr(offsets.data_ptr()),
diffvg.float_ptr(stop_colors.data_ptr()))
elif stroke_color_type is None:
stroke_color = None
else:
assert(False)
use_even_odd_rule = args[current_index]
current_index += 1
shape_to_canvas = args[current_index]
current_index += 1
if fill_color is not None:
color_contents.append(fill_color)
if stroke_color is not None:
color_contents.append(stroke_color)
shape_groups.append(diffvg.ShapeGroup(\
diffvg.int_ptr(shape_ids.data_ptr()),
shape_ids.shape[0],
diffvg.ColorType.constant if fill_color_type is None else fill_color_type,
diffvg.void_ptr(0) if fill_color is None else fill_color.get_ptr(),
diffvg.ColorType.constant if stroke_color_type is None else stroke_color_type,
diffvg.void_ptr(0) if stroke_color is None else stroke_color.get_ptr(),
use_even_odd_rule,
diffvg.float_ptr(shape_to_canvas.data_ptr())))
filter_type = args[current_index]
current_index += 1
filter_radius = args[current_index]
current_index += 1
filt = diffvg.Filter(filter_type, filter_radius)
scene = diffvg.Scene(canvas_width, canvas_height,
shapes, shape_groups, filt, pydiffvg.get_use_gpu(),
pydiffvg.get_device().index if pydiffvg.get_device().index is not None else -1)
if output_type == OutputType.color:
assert(grad_img.shape[2] == 4)
else:
assert(grad_img.shape[2] == 1)
if background_image is not None:
background_image = background_image.to(pydiffvg.get_device())
if background_image.shape[2] == 3:
background_image = torch.cat((\
background_image, torch.ones(background_image.shape[0], background_image.shape[1], 1,
device = background_image.device)), dim = 2)
background_image = background_image.contiguous()
assert(background_image.shape[0] == rendered_image.shape[0])
assert(background_image.shape[1] == rendered_image.shape[1])
assert(background_image.shape[2] == 4)
translation_grad_image = \
torch.zeros(height, width, 2, device = pydiffvg.get_device())
start = time.time()
diffvg.render(scene,
diffvg.float_ptr(background_image.data_ptr() if background_image is not None else 0),
diffvg.float_ptr(0), # render_image
diffvg.float_ptr(0), # render_sdf
width,
height,
num_samples_x,
num_samples_y,
seed,
diffvg.float_ptr(0), # d_background_image
diffvg.float_ptr(grad_img.data_ptr() if output_type == OutputType.color else 0),
diffvg.float_ptr(grad_img.data_ptr() if output_type == OutputType.sdf else 0),
diffvg.float_ptr(translation_grad_image.data_ptr()),
use_prefiltering,
diffvg.float_ptr(eval_positions.data_ptr()),
eval_positions.shape[0])
time_elapsed = time.time() - start
if print_timing:
print('Gradient pass, time: %.5f s' % time_elapsed)
assert(torch.isfinite(translation_grad_image).all())
return translation_grad_image
@staticmethod
def backward(ctx,
grad_img):
if not grad_img.is_contiguous():
grad_img = grad_img.contiguous()
assert(torch.isfinite(grad_img).all())
scene = ctx.scene
width = ctx.width
height = ctx.height
num_samples_x = ctx.num_samples_x
num_samples_y = ctx.num_samples_y
seed = ctx.seed
output_type = ctx.output_type
use_prefiltering = ctx.use_prefiltering
eval_positions = ctx.eval_positions
background_image = ctx.background_image
if background_image is not None:
d_background_image = torch.zeros_like(background_image)
else:
d_background_image = None
start = time.time()
diffvg.render(scene,
diffvg.float_ptr(background_image.data_ptr() if background_image is not None else 0),
diffvg.float_ptr(0), # render_image
diffvg.float_ptr(0), # render_sdf
width,
height,
num_samples_x,
num_samples_y,
seed,
diffvg.float_ptr(d_background_image.data_ptr() if background_image is not None else 0),
diffvg.float_ptr(grad_img.data_ptr() if output_type == OutputType.color else 0),
diffvg.float_ptr(grad_img.data_ptr() if output_type == OutputType.sdf else 0),
diffvg.float_ptr(0), # d_translation
use_prefiltering,
diffvg.float_ptr(eval_positions.data_ptr()),
eval_positions.shape[0])
time_elapsed = time.time() - start
global print_timing
if print_timing:
print('Backward pass, time: %.5f s' % time_elapsed)
d_args = []
d_args.append(None) # width
d_args.append(None) # height
d_args.append(None) # num_samples_x
d_args.append(None) # num_samples_y
d_args.append(None) # seed
d_args.append(d_background_image)
d_args.append(None) # canvas_width
d_args.append(None) # canvas_height
d_args.append(None) # num_shapes
d_args.append(None) # num_shape_groups
d_args.append(None) # output_type
d_args.append(None) # use_prefiltering
d_args.append(None) # eval_positions
for shape_id in range(scene.num_shapes):
d_args.append(None) # type
d_shape = scene.get_d_shape(shape_id)
use_thickness = False
if d_shape.type == diffvg.ShapeType.circle:
d_circle = d_shape.as_circle()
radius = torch.tensor(d_circle.radius)
assert(torch.isfinite(radius).all())
d_args.append(radius)
c = d_circle.center
c = torch.tensor((c.x, c.y))
assert(torch.isfinite(c).all())
d_args.append(c)
elif d_shape.type == diffvg.ShapeType.ellipse:
d_ellipse = d_shape.as_ellipse()
r = d_ellipse.radius
r = torch.tensor((d_ellipse.radius.x, d_ellipse.radius.y))
assert(torch.isfinite(r).all())
d_args.append(r)
c = d_ellipse.center
c = torch.tensor((c.x, c.y))
assert(torch.isfinite(c).all())
d_args.append(c)
elif d_shape.type == diffvg.ShapeType.path:
d_path = d_shape.as_path()
points = torch.zeros((d_path.num_points, 2))
thickness = None
if d_path.has_thickness():
use_thickness = True
thickness = torch.zeros(d_path.num_points)
d_path.copy_to(diffvg.float_ptr(points.data_ptr()), diffvg.float_ptr(thickness.data_ptr()))
else:
d_path.copy_to(diffvg.float_ptr(points.data_ptr()), diffvg.float_ptr(0))
assert(torch.isfinite(points).all())
if thickness is not None:
assert(torch.isfinite(thickness).all())
d_args.append(None) # num_control_points
d_args.append(points)
d_args.append(thickness)
d_args.append(None) # is_closed
d_args.append(None) # use_distance_approx
elif d_shape.type == diffvg.ShapeType.rect:
d_rect = d_shape.as_rect()
p_min = torch.tensor((d_rect.p_min.x, d_rect.p_min.y))
p_max = torch.tensor((d_rect.p_max.x, d_rect.p_max.y))
assert(torch.isfinite(p_min).all())
assert(torch.isfinite(p_max).all())
d_args.append(p_min)
d_args.append(p_max)
else:
assert(False)
if use_thickness:
d_args.append(None)
else:
w = torch.tensor((d_shape.stroke_width))
assert(torch.isfinite(w).all())
d_args.append(w)
for group_id in range(scene.num_shape_groups):
d_shape_group = scene.get_d_shape_group(group_id)
d_args.append(None) # shape_ids
d_args.append(None) # fill_color_type
if d_shape_group.has_fill_color():
if d_shape_group.fill_color_type == diffvg.ColorType.constant:
d_constant = d_shape_group.fill_color_as_constant()
c = d_constant.color
d_args.append(torch.tensor((c.x, c.y, c.z, c.w)))
elif d_shape_group.fill_color_type == diffvg.ColorType.linear_gradient:
d_linear_gradient = d_shape_group.fill_color_as_linear_gradient()
beg = d_linear_gradient.begin
d_args.append(torch.tensor((beg.x, beg.y)))
end = d_linear_gradient.end
d_args.append(torch.tensor((end.x, end.y)))
offsets = torch.zeros((d_linear_gradient.num_stops))
stop_colors = torch.zeros((d_linear_gradient.num_stops, 4))
d_linear_gradient.copy_to(\
diffvg.float_ptr(offsets.data_ptr()),
diffvg.float_ptr(stop_colors.data_ptr()))
assert(torch.isfinite(stop_colors).all())
d_args.append(offsets)
d_args.append(stop_colors)
elif d_shape_group.fill_color_type == diffvg.ColorType.radial_gradient:
d_radial_gradient = d_shape_group.fill_color_as_radial_gradient()
center = d_radial_gradient.center
d_args.append(torch.tensor((center.x, center.y)))
radius = d_radial_gradient.radius
d_args.append(torch.tensor((radius.x, radius.y)))
offsets = torch.zeros((d_radial_gradient.num_stops))
stop_colors = torch.zeros((d_radial_gradient.num_stops, 4))
d_radial_gradient.copy_to(\
diffvg.float_ptr(offsets.data_ptr()),
diffvg.float_ptr(stop_colors.data_ptr()))
assert(torch.isfinite(stop_colors).all())
d_args.append(offsets)
d_args.append(stop_colors)
else:
assert(False)
d_args.append(None) # stroke_color_type
if d_shape_group.has_stroke_color():
if d_shape_group.stroke_color_type == diffvg.ColorType.constant:
d_constant = d_shape_group.stroke_color_as_constant()
c = d_constant.color
d_args.append(torch.tensor((c.x, c.y, c.z, c.w)))
elif d_shape_group.stroke_color_type == diffvg.ColorType.linear_gradient:
d_linear_gradient = d_shape_group.stroke_color_as_linear_gradient()
beg = d_linear_gradient.begin
d_args.append(torch.tensor((beg.x, beg.y)))
end = d_linear_gradient.end
d_args.append(torch.tensor((end.x, end.y)))
offsets = torch.zeros((d_linear_gradient.num_stops))
stop_colors = torch.zeros((d_linear_gradient.num_stops, 4))
d_linear_gradient.copy_to(\
diffvg.float_ptr(offsets.data_ptr()),
diffvg.float_ptr(stop_colors.data_ptr()))
assert(torch.isfinite(stop_colors).all())
d_args.append(offsets)
d_args.append(stop_colors)
elif d_shape_group.fill_color_type == diffvg.ColorType.radial_gradient:
d_radial_gradient = d_shape_group.stroke_color_as_radial_gradient()
center = d_radial_gradient.center
d_args.append(torch.tensor((center.x, center.y)))
radius = d_radial_gradient.radius
d_args.append(torch.tensor((radius.x, radius.y)))
offsets = torch.zeros((d_radial_gradient.num_stops))
stop_colors = torch.zeros((d_radial_gradient.num_stops, 4))
d_radial_gradient.copy_to(\
diffvg.float_ptr(offsets.data_ptr()),
diffvg.float_ptr(stop_colors.data_ptr()))
assert(torch.isfinite(stop_colors).all())
d_args.append(offsets)
d_args.append(stop_colors)
else:
assert(False)
d_args.append(None) # use_even_odd_rule
d_shape_to_canvas = torch.zeros((3, 3))
d_shape_group.copy_to(diffvg.float_ptr(d_shape_to_canvas.data_ptr()))
assert(torch.isfinite(d_shape_to_canvas).all())
d_args.append(d_shape_to_canvas)
d_args.append(None) # filter_type
d_args.append(torch.tensor(scene.get_d_filter_radius()))
return tuple(d_args)

150
pydiffvg/save_svg.py Normal file
View File

@@ -0,0 +1,150 @@
import torch
import pydiffvg
import xml.etree.ElementTree as etree
from xml.dom import minidom
def prettify(elem):
"""Return a pretty-printed XML string for the Element.
"""
rough_string = etree.tostring(elem, 'utf-8')
reparsed = minidom.parseString(rough_string)
return reparsed.toprettyxml(indent=" ")
def save_svg(filename, width, height, shapes, shape_groups, use_gamma = False):
root = etree.Element('svg')
root.set('version', '1.1')
root.set('xmlns', 'http://www.w3.org/2000/svg')
root.set('width', str(width))
root.set('height', str(height))
defs = etree.SubElement(root, 'defs')
g = etree.SubElement(root, 'g')
if use_gamma:
f = etree.SubElement(defs, 'filter')
f.set('id', 'gamma')
f.set('x', '0')
f.set('y', '0')
f.set('width', '100%')
f.set('height', '100%')
gamma = etree.SubElement(f, 'feComponentTransfer')
gamma.set('color-interpolation-filters', 'sRGB')
feFuncR = etree.SubElement(gamma, 'feFuncR')
feFuncR.set('type', 'gamma')
feFuncR.set('amplitude', str(1))
feFuncR.set('exponent', str(1/2.2))
feFuncG = etree.SubElement(gamma, 'feFuncG')
feFuncG.set('type', 'gamma')
feFuncG.set('amplitude', str(1))
feFuncG.set('exponent', str(1/2.2))
feFuncB = etree.SubElement(gamma, 'feFuncB')
feFuncB.set('type', 'gamma')
feFuncB.set('amplitude', str(1))
feFuncB.set('exponent', str(1/2.2))
feFuncA = etree.SubElement(gamma, 'feFuncA')
feFuncA.set('type', 'gamma')
feFuncA.set('amplitude', str(1))
feFuncA.set('exponent', str(1/2.2))
g.set('style', 'filter:url(#gamma)')
# Store color
for i, shape_group in enumerate(shape_groups):
def add_color(shape_color, name):
if isinstance(shape_color, pydiffvg.LinearGradient):
lg = shape_color
color = etree.SubElement(defs, 'linearGradient')
color.set('id', name)
color.set('x1', str(lg.begin[0].item()))
color.set('y1', str(lg.begin[1].item()))
color.set('x2', str(lg.end[0].item()))
color.set('y2', str(lg.end[1].item()))
offsets = lg.offsets.data.cpu().numpy()
stop_colors = lg.stop_colors.data.cpu().numpy()
for j in range(offsets.shape[0]):
stop = etree.SubElement(color, 'stop')
stop.set('offset', offsets[j])
c = lg.stop_colors[j, :]
stop.set('stop-color', 'rgb({}, {}, {})'.format(\
int(255 * c[0]), int(255 * c[1]), int(255 * c[2])))
stop.set('stop-opacity', '{}'.format(c[3]))
if shape_group.fill_color is not None:
add_color(shape_group.fill_color, 'shape_{}_fill'.format(i))
if shape_group.stroke_color is not None:
add_color(shape_group.stroke_color, 'shape_{}_stroke'.format(i))
for i, shape_group in enumerate(shape_groups):
shape = shapes[shape_group.shape_ids[0]]
if isinstance(shape, pydiffvg.Circle):
shape_node = etree.SubElement(g, 'circle')
shape_node.set('r', shape.radius.item())
shape_node.set('cx', shape.center[0].item())
shape_node.set('cy', shape.center[1].item())
elif isinstance(shape, pydiffvg.Polygon):
shape_node = etree.SubElement(g, 'polygon')
points = shape.points.data.cpu().numpy()
path_str = ''
for j in range(0, shape.points.shape[0]):
path_str += '{} {}'.format(points[j, 0], points[j, 1])
if j != shape.points.shape[0] - 1:
path_str += ' '
shape_node.set('points', path_str)
elif isinstance(shape, pydiffvg.Path):
shape_node = etree.SubElement(g, 'path')
num_segments = shape.num_control_points.shape[0]
num_control_points = shape.num_control_points.data.cpu().numpy()
points = shape.points.data.cpu().numpy()
num_points = shape.points.shape[0]
path_str = 'M {} {}'.format(points[0, 0], points[0, 1])
point_id = 1
for j in range(0, num_segments):
if num_control_points[j] == 0:
p = point_id % num_points
path_str += ' L {} {}'.format(\
points[p, 0], points[p, 1])
point_id += 1
elif num_control_points[j] == 1:
p1 = (point_id + 1) % num_points
path_str += ' Q {} {} {} {}'.format(\
points[point_id, 0], points[point_id, 1],
points[p1, 0], points[p1, 1])
point_id += 2
elif num_control_points[j] == 2:
p2 = (point_id + 2) % num_points
path_str += ' C {} {} {} {} {} {}'.format(\
points[point_id, 0], points[point_id, 1],
points[point_id + 1, 0], points[point_id + 1, 1],
points[p2, 0], points[p2, 1])
point_id += 3
shape_node.set('d', path_str)
elif isinstance(shape, pydiffvg.Rect):
shape_node = etree.SubElement(g, 'rect')
shape_node.set('x', shape.p_min[0].item())
shape_node.set('y', shape.p_min[1].item())
shape_node.set('width', shape.p_max[0].item() - shape.p_min[0].item())
shape_node.set('height', shape.p_max[1].item() - shape.p_min[1].item())
else:
assert(False)
shape_node.set('stroke-width', str(2 * shape.stroke_width.data.cpu().item()))
if shape_group.fill_color is not None:
if isinstance(shape_group.fill_color, pydiffvg.LinearGradient):
shape_node.set('fill', 'url(#shape_{}_fill)'.format(i))
else:
c = shape_group.fill_color.data.cpu().numpy()
shape_node.set('fill', 'rgb({}, {}, {})'.format(\
int(255 * c[0]), int(255 * c[1]), int(255 * c[2])))
shape_node.set('opacity', str(c[3]))
else:
shape_node.set('fill', 'none')
if shape_group.stroke_color is not None:
if isinstance(shape_group.stroke_color, pydiffvg.LinearGradient):
shape_node.set('stroke', 'url(#shape_{}_stroke)'.format(i))
else:
c = shape_group.stroke_color.data.cpu().numpy()
shape_node.set('stroke', 'rgb({}, {}, {})'.format(\
int(255 * c[0]), int(255 * c[1]), int(255 * c[2])))
shape_node.set('stroke-opacity', str(c[3]))
shape_node.set('stroke-linecap', 'round')
shape_node.set('stroke-linejoin', 'round')
with open(filename, "w") as f:
f.write(prettify(root))

172
pydiffvg/shape.py Normal file
View File

@@ -0,0 +1,172 @@
import torch
import svgpathtools
import math
class Circle:
def __init__(self, radius, center, stroke_width = torch.tensor(1.0), id = ''):
self.radius = radius
self.center = center
self.stroke_width = stroke_width
self.id = id
class Ellipse:
def __init__(self, radius, center, stroke_width = torch.tensor(1.0), id = ''):
self.radius = radius
self.center = center
self.stroke_width = stroke_width
self.id = id
class Path:
def __init__(self,
num_control_points,
points,
is_closed,
stroke_width = torch.tensor(1.0),
id = '',
use_distance_approx = False):
self.num_control_points = num_control_points
self.points = points
self.is_closed = is_closed
self.stroke_width = stroke_width
self.id = id
self.use_distance_approx = use_distance_approx
class Polygon:
def __init__(self, points, is_closed, stroke_width = torch.tensor(1.0), id = ''):
self.points = points
self.is_closed = is_closed
self.stroke_width = stroke_width
self.id = id
class Rect:
def __init__(self, p_min, p_max, stroke_width = torch.tensor(1.0), id = ''):
self.p_min = p_min
self.p_max = p_max
self.stroke_width = stroke_width
self.id = id
class ShapeGroup:
def __init__(self,
shape_ids,
fill_color,
use_even_odd_rule = True,
stroke_color = None,
shape_to_canvas = torch.eye(3),
id = ''):
self.shape_ids = shape_ids
self.fill_color = fill_color
self.use_even_odd_rule = use_even_odd_rule
self.stroke_color = stroke_color
self.shape_to_canvas = shape_to_canvas
self.id = id
def from_svg_path(path_str, shape_to_canvas = torch.eye(3), force_close = False):
path = svgpathtools.parse_path(path_str)
if len(path) == 0:
return []
ret_paths = []
subpaths = path.continuous_subpaths()
for subpath in subpaths:
if subpath.isclosed():
if len(subpath) > 1 and isinstance(subpath[-1], svgpathtools.Line) and subpath[-1].length() < 1e-5:
subpath.remove(subpath[-1])
subpath[-1].end = subpath[0].start # Force closing the path
subpath.end = subpath[-1].end
assert(subpath.isclosed())
else:
beg = subpath[0].start
end = subpath[-1].end
if abs(end - beg) < 1e-5:
subpath[-1].end = beg # Force closing the path
subpath.end = subpath[-1].end
assert(subpath.isclosed())
elif force_close:
subpath.append(svgpathtools.Line(end, beg))
subpath.end = subpath[-1].end
assert(subpath.isclosed())
num_control_points = []
points = []
for i, e in enumerate(subpath):
if i == 0:
points.append((e.start.real, e.start.imag))
else:
# Must begin from the end of previous segment
assert(e.start.real == points[-1][0])
assert(e.start.imag == points[-1][1])
if isinstance(e, svgpathtools.Line):
num_control_points.append(0)
elif isinstance(e, svgpathtools.QuadraticBezier):
num_control_points.append(1)
points.append((e.control.real, e.control.imag))
elif isinstance(e, svgpathtools.CubicBezier):
num_control_points.append(2)
points.append((e.control1.real, e.control1.imag))
points.append((e.control2.real, e.control2.imag))
elif isinstance(e, svgpathtools.Arc):
# Convert to Cubic curves
# https://www.joecridge.me/content/pdf/bezier-arcs.pdf
start = e.theta * math.pi / 180.0
stop = (e.theta + e.delta) * math.pi / 180.0
sign = 1.0
if stop < start:
sign = -1.0
epsilon = 0.00001
debug = abs(e.delta) >= 90.0
while (sign * (stop - start) > epsilon):
arc_to_draw = stop - start
if arc_to_draw > 0.0:
arc_to_draw = min(arc_to_draw, 0.5 * math.pi)
else:
arc_to_draw = max(arc_to_draw, -0.5 * math.pi)
alpha = arc_to_draw / 2.0
cos_alpha = math.cos(alpha)
sin_alpha = math.sin(alpha)
cot_alpha = 1.0 / math.tan(alpha)
phi = start + alpha
cos_phi = math.cos(phi)
sin_phi = math.sin(phi)
lambda_ = (4.0 - cos_alpha) / 3.0
mu = sin_alpha + (cos_alpha - lambda_) * cot_alpha
last = sign * (stop - (start + arc_to_draw)) <= epsilon
num_control_points.append(2)
rx = e.radius.real
ry = e.radius.imag
cx = e.center.real
cy = e.center.imag
rot = e.phi * math.pi / 180.0
cos_rot = math.cos(rot)
sin_rot = math.sin(rot)
x = lambda_ * cos_phi + mu * sin_phi
y = lambda_ * sin_phi - mu * cos_phi
xx = x * cos_rot - y * sin_rot
yy = x * sin_rot + y * cos_rot
points.append((cx + rx * xx, cy + ry * yy))
x = lambda_ * cos_phi - mu * sin_phi
y = lambda_ * sin_phi + mu * cos_phi
xx = x * cos_rot - y * sin_rot
yy = x * sin_rot + y * cos_rot
points.append((cx + rx * xx, cy + ry * yy))
if not last:
points.append((cx + rx * math.cos(rot + start + arc_to_draw),
cy + ry * math.sin(rot + start + arc_to_draw)))
start += arc_to_draw
first = False
if i != len(subpath) - 1:
points.append((e.end.real, e.end.imag))
else:
if subpath.isclosed():
# Must end at the beginning of first segment
assert(e.end.real == points[0][0])
assert(e.end.imag == points[0][1])
else:
points.append((e.end.real, e.end.imag))
points = torch.tensor(points)
points = torch.cat((points, torch.ones([points.shape[0], 1])), dim = 1) @ torch.transpose(shape_to_canvas, 0, 1)
points = points / points[:, 2:3]
points = points[:, :2].contiguous()
ret_paths.append(Path(torch.tensor(num_control_points), points, subpath.isclosed()))
return ret_paths