809 lines
37 KiB
Python
809 lines
37 KiB
Python
import torch
|
|
import diffvg
|
|
import pydiffvg
|
|
import time
|
|
from enum import IntEnum
|
|
import warnings
|
|
|
|
print_timing = False
|
|
|
|
def popmult(lst, n):
|
|
return (lst[:n], lst[n:])
|
|
|
|
def set_print_timing(val):
|
|
global print_timing
|
|
print_timing=val
|
|
|
|
class OutputType(IntEnum):
|
|
color = 1
|
|
sdf = 2
|
|
|
|
class RenderFunction(torch.autograd.Function):
|
|
"""
|
|
The PyTorch interface of diffvg.
|
|
"""
|
|
@staticmethod
|
|
def serialize_scene(canvas_width,
|
|
canvas_height,
|
|
shapes,
|
|
shape_groups,
|
|
filter = pydiffvg.PixelFilter(type = diffvg.FilterType.box,
|
|
radius = torch.tensor(0.5)),
|
|
output_type = OutputType.color,
|
|
use_prefiltering = False,
|
|
eval_positions = torch.tensor([])):
|
|
"""
|
|
Given a list of shapes, convert them to a linear list of argument,
|
|
so that we can use it in PyTorch.
|
|
"""
|
|
num_shapes = len(shapes)
|
|
num_shape_groups = len(shape_groups)
|
|
args = [canvas_width,
|
|
canvas_height,
|
|
num_shapes,
|
|
num_shape_groups,
|
|
output_type,
|
|
use_prefiltering,
|
|
eval_positions.to(pydiffvg.get_device())]
|
|
|
|
for shape in shapes:
|
|
use_thickness = False
|
|
if isinstance(shape, pydiffvg.Circle):
|
|
assert shape.center.is_contiguous()
|
|
args += [
|
|
diffvg.ShapeType.circle,
|
|
shape.radius.cpu(),
|
|
shape.center.cpu()
|
|
]
|
|
elif isinstance(shape, pydiffvg.Ellipse):
|
|
assert shape.radius.is_contiguous()
|
|
assert shape.center.is_contiguous()
|
|
args += [
|
|
diffvg.ShapeType.ellipse,
|
|
shape.radius.cpu(),
|
|
shape.center.cpu()
|
|
]
|
|
elif isinstance(shape, pydiffvg.Path):
|
|
assert shape.num_control_points.is_contiguous()
|
|
assert shape.points.is_contiguous()
|
|
assert shape.points.shape[1] == 2
|
|
assert torch.isfinite(shape.points).all()
|
|
|
|
args += [
|
|
diffvg.ShapeType.path,
|
|
shape.num_control_points.to(torch.int32).cpu(),
|
|
shape.points.cpu()
|
|
]
|
|
|
|
if len(shape.stroke_width.shape) > 0 and shape.stroke_width.shape[0] > 1:
|
|
assert torch.isfinite(shape.stroke_width).all()
|
|
use_thickness = True
|
|
args.append(shape.stroke_width.cpu())
|
|
else:
|
|
args.append(None)
|
|
|
|
args += [shape.is_closed, shape.use_distance_approx]
|
|
elif isinstance(shape, pydiffvg.Polygon):
|
|
assert shape.points.is_contiguous()
|
|
assert shape.points.shape[1] == 2
|
|
|
|
args.append(diffvg.ShapeType.path)
|
|
|
|
if shape.is_closed:
|
|
args.append(torch.zeros(shape.points.shape[0], dtype = torch.int32))
|
|
else:
|
|
args.append(torch.zeros(shape.points.shape[0] - 1, dtype = torch.int32))
|
|
|
|
args += [
|
|
shape.points.cpu(),
|
|
None,
|
|
shape.is_closed(),
|
|
False # use_distance_approx
|
|
]
|
|
|
|
elif isinstance(shape, pydiffvg.Rect):
|
|
assert shape.p_min.is_contiguous()
|
|
assert shape.p_max.is_contiguous()
|
|
|
|
args += [
|
|
diffvg.ShapeType.rect,
|
|
shape.p_min.cpu(),
|
|
shape.p_max.cpu()
|
|
]
|
|
else:
|
|
assert False
|
|
|
|
if use_thickness:
|
|
args.append(torch.tensor(0.0))
|
|
else:
|
|
args.append(shape.stroke_width.cpu())
|
|
|
|
for shape_group in shape_groups:
|
|
assert shape_group.shape_ids.is_contiguous()
|
|
args.append(shape_group.shape_ids.to(torch.int32).cpu())
|
|
# Fill color
|
|
if shape_group.fill_color is None:
|
|
args.append(None)
|
|
|
|
elif isinstance(shape_group.fill_color, torch.Tensor):
|
|
assert shape_group.fill_color.is_contiguous()
|
|
|
|
args += [
|
|
diffvg.ColorType.constant,
|
|
shape_group.fill_color.cpu()
|
|
]
|
|
|
|
elif isinstance(shape_group.fill_color, pydiffvg.LinearGradient):
|
|
assert shape_group.fill_color.begin.is_contiguous()
|
|
assert shape_group.fill_color.end.is_contiguous()
|
|
assert shape_group.fill_color.offsets.is_contiguous()
|
|
assert shape_group.fill_color.stop_colors.is_contiguous()
|
|
|
|
args += [
|
|
diffvg.ColorType.linear_gradient,
|
|
shape_group.fill_color.begin.cpu(),
|
|
shape_group.fill_color.end.cpu(),
|
|
shape_group.fill_color.offsets.cpu(),
|
|
shape_group.fill_color.stop_colors.cpu(),
|
|
]
|
|
|
|
elif isinstance(shape_group.fill_color, pydiffvg.RadialGradient):
|
|
assert shape_group.fill_color.center.is_contiguous()
|
|
assert shape_group.fill_color.radius.is_contiguous()
|
|
assert shape_group.fill_color.offsets.is_contiguous()
|
|
assert shape_group.fill_color.stop_colors.is_contiguous()
|
|
|
|
args += [
|
|
diffvg.ColorType.radial_gradient,
|
|
shape_group.fill_color.center.cpu(),
|
|
shape_group.fill_color.radius.cpu(),
|
|
shape_group.fill_color.offsets.cpu(),
|
|
shape_group.fill_color.stop_colors.cpu()
|
|
]
|
|
|
|
if shape_group.fill_color is not None:
|
|
# go through the underlying shapes and check if they are all closed
|
|
for shape_id in shape_group.shape_ids:
|
|
if isinstance(shapes[shape_id], pydiffvg.Path):
|
|
if not shapes[shape_id].is_closed:
|
|
warnings.warn("Detected non-closed paths with fill color. This might causes unexpected results.", Warning)
|
|
|
|
# Stroke color
|
|
if shape_group.stroke_color is None:
|
|
args.append(None)
|
|
|
|
elif isinstance(shape_group.stroke_color, torch.Tensor):
|
|
assert shape_group.stroke_color.is_contiguous()
|
|
args += [ diffvg.ColorType.constant, shape_group.stroke_color.cpu() ]
|
|
|
|
elif isinstance(shape_group.stroke_color, pydiffvg.LinearGradient):
|
|
assert shape_group.stroke_color.begin.is_contiguous()
|
|
assert shape_group.stroke_color.end.is_contiguous()
|
|
assert shape_group.stroke_color.offsets.is_contiguous()
|
|
assert shape_group.stroke_color.stop_colors.is_contiguous()
|
|
assert torch.isfinite(shape_group.stroke_color.stop_colors).all()
|
|
|
|
args += [
|
|
diffvg.ColorType.linear_gradient,
|
|
shape_group.stroke_color.begin.cpu(),
|
|
shape_group.stroke_color.end.cpu(),
|
|
shape_group.stroke_color.offsets.cpu(),
|
|
shape_group.stroke_color.stop_colors.cpu()
|
|
]
|
|
|
|
elif isinstance(shape_group.stroke_color, pydiffvg.RadialGradient):
|
|
assert shape_group.stroke_color.center.is_contiguous()
|
|
assert shape_group.stroke_color.radius.is_contiguous()
|
|
assert shape_group.stroke_color.offsets.is_contiguous()
|
|
assert shape_group.stroke_color.stop_colors.is_contiguous()
|
|
assert torch.isfinite(shape_group.stroke_color.stop_colors).all()
|
|
|
|
args += [
|
|
diffvg.ColorType.radial_gradient,
|
|
shape_group.stroke_color.center.cpu(),
|
|
shape_group.stroke_color.radius.cpu(),
|
|
shape_group.stroke_color.offsets.cpu(),
|
|
shape_group.stroke_color.stop_colors.cpu()
|
|
]
|
|
|
|
args += [ shape_group.use_even_odd_rule,
|
|
shape_group.shape_to_canvas.contiguous().cpu() ]
|
|
|
|
args.append(filter.type)
|
|
args.append(filter.radius.cpu())
|
|
return args
|
|
|
|
@staticmethod
|
|
def forward(ctx,
|
|
width,
|
|
height,
|
|
num_samples_x,
|
|
num_samples_y,
|
|
seed,
|
|
background_image,
|
|
*args):
|
|
"""
|
|
Forward rendering pass.
|
|
"""
|
|
# Unpack arguments
|
|
args = list(args)
|
|
|
|
(canvas_width, canvas_height, num_shapes,
|
|
num_shape_groups, output_type, use_prefiltering, eval_positions), args = popmult(args, 7)
|
|
|
|
shapes = []
|
|
shape_groups = []
|
|
shape_contents = [] # Important to avoid GC deleting the shapes
|
|
color_contents = [] # Same as above
|
|
for shape_id in range(num_shapes):
|
|
shape_type = args.pop(0)
|
|
if shape_type == diffvg.ShapeType.circle:
|
|
(radius, center), args = popmult(args, 2)
|
|
|
|
shape = diffvg.Circle(radius, diffvg.Vector2f(center[0], center[1]))
|
|
elif shape_type == diffvg.ShapeType.ellipse:
|
|
(radius, center), args = popmult(args, 2)
|
|
|
|
shape = diffvg.Ellipse(diffvg.Vector2f(radius[0], radius[1]),
|
|
diffvg.Vector2f(center[0], center[1]))
|
|
elif shape_type == diffvg.ShapeType.path:
|
|
|
|
(num_control_points, points, thickness,
|
|
is_closed, use_distance_approx), args = popmult(args, 5)
|
|
|
|
shape = diffvg.Path(diffvg.int_ptr(num_control_points.data_ptr()),
|
|
diffvg.float_ptr(points.data_ptr()),
|
|
diffvg.float_ptr(thickness.data_ptr() if thickness is not None else 0),
|
|
num_control_points.shape[0],
|
|
points.shape[0],
|
|
is_closed,
|
|
use_distance_approx)
|
|
elif shape_type == diffvg.ShapeType.rect:
|
|
(p_min, p_max), args = popmult(args, 2)
|
|
shape = diffvg.Rect(diffvg.Vector2f(p_min[0], p_min[1]),
|
|
diffvg.Vector2f(p_max[0], p_max[1]))
|
|
else:
|
|
assert False
|
|
|
|
stroke_width = args.pop(0)
|
|
shapes.append(diffvg.Shape(\
|
|
shape_type, shape.get_ptr(), stroke_width.item()))
|
|
shape_contents.append(shape)
|
|
|
|
for shape_group_id in range(num_shape_groups):
|
|
(shape_ids, fill_color_type), args = popmult(args, 2)
|
|
if fill_color_type == diffvg.ColorType.constant:
|
|
color = args.pop(0)
|
|
fill_color = diffvg.Constant(\
|
|
diffvg.Vector4f(color[0], color[1], color[2], color[3]))
|
|
elif fill_color_type == diffvg.ColorType.linear_gradient:
|
|
(beg, end, offsets, stop_colors), args = popmult(args, 4)
|
|
|
|
assert offsets.shape[0] == stop_colors.shape[0]
|
|
fill_color = diffvg.LinearGradient(diffvg.Vector2f(beg[0], beg[1]),
|
|
diffvg.Vector2f(end[0], end[1]),
|
|
offsets.shape[0],
|
|
diffvg.float_ptr(offsets.data_ptr()),
|
|
diffvg.float_ptr(stop_colors.data_ptr()))
|
|
elif fill_color_type == diffvg.ColorType.radial_gradient:
|
|
(center, radius, offsets, stop_colors), args = popmult(args, 4)
|
|
|
|
assert offsets.shape[0] == stop_colors.shape[0]
|
|
fill_color = diffvg.RadialGradient(diffvg.Vector2f(center[0], center[1]),
|
|
diffvg.Vector2f(radius[0], radius[1]),
|
|
offsets.shape[0],
|
|
diffvg.float_ptr(offsets.data_ptr()),
|
|
diffvg.float_ptr(stop_colors.data_ptr()))
|
|
elif fill_color_type is None:
|
|
fill_color = None
|
|
else:
|
|
assert False
|
|
|
|
stroke_color_type = args.pop(0)
|
|
if stroke_color_type == diffvg.ColorType.constant:
|
|
color = args.pop(0)
|
|
stroke_color = diffvg.Constant(\
|
|
diffvg.Vector4f(color[0], color[1], color[2], color[3]))
|
|
|
|
elif stroke_color_type == diffvg.ColorType.linear_gradient:
|
|
(beg, end, offsets, stop_colors), args = popmult(args, 4)
|
|
|
|
assert offsets.shape[0] == stop_colors.shape[0]
|
|
stroke_color = diffvg.LinearGradient(diffvg.Vector2f(beg[0], beg[1]),
|
|
diffvg.Vector2f(end[0], end[1]),
|
|
offsets.shape[0],
|
|
diffvg.float_ptr(offsets.data_ptr()),
|
|
diffvg.float_ptr(stop_colors.data_ptr()))
|
|
|
|
elif stroke_color_type == diffvg.ColorType.radial_gradient:
|
|
(center, radius, offsets, stop_colors), args = popmult(args, 4)
|
|
|
|
assert offsets.shape[0] == stop_colors.shape[0]
|
|
stroke_color = diffvg.RadialGradient(diffvg.Vector2f(center[0], center[1]),
|
|
diffvg.Vector2f(radius[0], radius[1]),
|
|
offsets.shape[0],
|
|
diffvg.float_ptr(offsets.data_ptr()),
|
|
diffvg.float_ptr(stop_colors.data_ptr()))
|
|
elif stroke_color_type is None:
|
|
stroke_color = None
|
|
|
|
else:
|
|
assert False
|
|
|
|
(use_even_odd_rule, shape_to_canvas), args = popmult(args, 2)
|
|
|
|
if fill_color is not None:
|
|
color_contents.append(fill_color)
|
|
|
|
if stroke_color is not None:
|
|
color_contents.append(stroke_color)
|
|
|
|
shape_groups.append(diffvg.ShapeGroup(\
|
|
diffvg.int_ptr(shape_ids.data_ptr()),
|
|
shape_ids.shape[0],
|
|
diffvg.ColorType.constant if fill_color_type is None else fill_color_type,
|
|
diffvg.void_ptr(0) if fill_color is None else fill_color.get_ptr(),
|
|
diffvg.ColorType.constant if stroke_color_type is None else stroke_color_type,
|
|
diffvg.void_ptr(0) if stroke_color is None else stroke_color.get_ptr(),
|
|
use_even_odd_rule,
|
|
diffvg.float_ptr(shape_to_canvas.data_ptr())))
|
|
|
|
(filter_type, filter_radius), args = popmult(args, 2)
|
|
filt = diffvg.Filter(filter_type, filter_radius)
|
|
|
|
start = time.time()
|
|
scene = diffvg.Scene(canvas_width, canvas_height,
|
|
shapes, shape_groups, filt, pydiffvg.get_use_gpu(),
|
|
pydiffvg.get_device().index if pydiffvg.get_device().index is not None else -1)
|
|
time_elapsed = time.time() - start
|
|
|
|
global print_timing
|
|
if print_timing:
|
|
print('Scene construction, time: %.5f s' % time_elapsed)
|
|
|
|
if output_type == OutputType.color:
|
|
assert eval_positions.shape[0] == 0
|
|
rendered_image = torch.zeros(height, width, 4, device = pydiffvg.get_device())
|
|
else:
|
|
assert output_type == OutputType.sdf
|
|
if eval_positions.shape[0] == 0:
|
|
rendered_image = torch.zeros(height, width, 1, device = pydiffvg.get_device())
|
|
else:
|
|
rendered_image = torch.zeros(eval_positions.shape[0], 1, device = pydiffvg.get_device())
|
|
|
|
if background_image is not None:
|
|
background_image = background_image.to(pydiffvg.get_device())
|
|
if background_image.shape[2] == 3:
|
|
raise NotImplementedError('Background image must have 4 channels, not 3. Add a fourth channel with all ones via torch.ones().')
|
|
background_image = background_image.contiguous()
|
|
|
|
assert background_image.shape[0] == rendered_image.shape[0]
|
|
assert background_image.shape[1] == rendered_image.shape[1]
|
|
assert background_image.shape[2] == 4
|
|
|
|
start = time.time()
|
|
diffvg.render(scene,
|
|
diffvg.float_ptr(background_image.data_ptr() if background_image is not None else 0),
|
|
diffvg.float_ptr(rendered_image.data_ptr() if output_type == OutputType.color else 0),
|
|
diffvg.float_ptr(rendered_image.data_ptr() if output_type == OutputType.sdf else 0),
|
|
width,
|
|
height,
|
|
num_samples_x,
|
|
num_samples_y,
|
|
seed,
|
|
diffvg.float_ptr(0), # d_background_image
|
|
diffvg.float_ptr(0), # d_render_image
|
|
diffvg.float_ptr(0), # d_render_sdf
|
|
diffvg.float_ptr(0), # d_translation
|
|
use_prefiltering,
|
|
diffvg.float_ptr(eval_positions.data_ptr()),
|
|
eval_positions.shape[0])
|
|
assert torch.isfinite(rendered_image).all()
|
|
time_elapsed = time.time() - start
|
|
if print_timing:
|
|
print('Forward pass, time: %.5f s' % time_elapsed)
|
|
|
|
ctx.scene = scene
|
|
ctx.background_image = background_image
|
|
ctx.shape_contents = shape_contents
|
|
ctx.color_contents = color_contents
|
|
ctx.filter = filt
|
|
ctx.width = width
|
|
ctx.height = height
|
|
ctx.num_samples_x = num_samples_x
|
|
ctx.num_samples_y = num_samples_y
|
|
ctx.seed = seed
|
|
ctx.output_type = output_type
|
|
ctx.use_prefiltering = use_prefiltering
|
|
ctx.eval_positions = eval_positions
|
|
return rendered_image
|
|
|
|
@staticmethod
|
|
def render_grad(grad_img,
|
|
width,
|
|
height,
|
|
num_samples_x,
|
|
num_samples_y,
|
|
seed,
|
|
background_image,
|
|
*args):
|
|
if not grad_img.is_contiguous():
|
|
grad_img = grad_img.contiguous()
|
|
assert torch.isfinite(grad_img).all()
|
|
|
|
args = list(args)
|
|
# Unpack arguments
|
|
(canvas_width, canvas_height, num_shapes,
|
|
num_shape_groups, output_type, use_prefiltering,
|
|
eval_positions), args = popmult(args, 7)
|
|
|
|
shapes = []
|
|
shape_groups = []
|
|
shape_contents = [] # Important to avoid GC deleting the shapes
|
|
color_contents = [] # Same as above
|
|
for shape_id in range(num_shapes):
|
|
shape_type = args.pop(0)
|
|
if shape_type == diffvg.ShapeType.circle:
|
|
(radius, center), args = popmult(args, 2)
|
|
|
|
shape = diffvg.Circle(radius, diffvg.Vector2f(center[0], center[1]))
|
|
elif shape_type == diffvg.ShapeType.ellipse:
|
|
(radius, center), args = popmult(args, 2)
|
|
|
|
shape = diffvg.Ellipse(diffvg.Vector2f(radius[0], radius[1]),
|
|
diffvg.Vector2f(center[0], center[1]))
|
|
elif shape_type == diffvg.ShapeType.path:
|
|
(num_control_points, points, thickness,
|
|
is_closed, use_distance_approx), args = popmult(args, 5)
|
|
|
|
shape = diffvg.Path(diffvg.int_ptr(num_control_points.data_ptr()),
|
|
diffvg.float_ptr(points.data_ptr()),
|
|
diffvg.float_ptr(thickness.data_ptr() if thickness is not None else 0),
|
|
num_control_points.shape[0],
|
|
points.shape[0],
|
|
is_closed,
|
|
use_distance_approx)
|
|
elif shape_type == diffvg.ShapeType.rect:
|
|
(p_min, p_max), args = popmult(args, 2)
|
|
|
|
shape = diffvg.Rect(diffvg.Vector2f(p_min[0], p_min[1]),
|
|
diffvg.Vector2f(p_max[0], p_max[1]))
|
|
else:
|
|
assert False
|
|
|
|
stroke_width = args.pop(0)
|
|
|
|
shapes.append(diffvg.Shape(\
|
|
shape_type, shape.get_ptr(), stroke_width.item()))
|
|
shape_contents.append(shape)
|
|
|
|
for shape_group_id in range(num_shape_groups):
|
|
(shape_ids, fill_color_type), args = popmult(args, 2)
|
|
|
|
if fill_color_type == diffvg.ColorType.constant:
|
|
color = args.pop(0)
|
|
|
|
fill_color = diffvg.Constant(\
|
|
diffvg.Vector4f(color[0], color[1], color[2], color[3]))
|
|
|
|
elif fill_color_type == diffvg.ColorType.linear_gradient:
|
|
(beg, end, offsets, stop_colors), args = popmult(args, 4)
|
|
|
|
assert offsets.shape[0] == stop_colors.shape[0]
|
|
fill_color = diffvg.LinearGradient(diffvg.Vector2f(beg[0], beg[1]),
|
|
diffvg.Vector2f(end[0], end[1]),
|
|
offsets.shape[0],
|
|
diffvg.float_ptr(offsets.data_ptr()),
|
|
diffvg.float_ptr(stop_colors.data_ptr()))
|
|
|
|
elif fill_color_type == diffvg.ColorType.radial_gradient:
|
|
(center, radius, offsets, stop_colors), args = popmult(args, 4)
|
|
|
|
assert offsets.shape[0] == stop_colors.shape[0]
|
|
fill_color = diffvg.RadialGradient(diffvg.Vector2f(center[0], center[1]),
|
|
diffvg.Vector2f(radius[0], radius[1]),
|
|
offsets.shape[0],
|
|
diffvg.float_ptr(offsets.data_ptr()),
|
|
diffvg.float_ptr(stop_colors.data_ptr()))
|
|
|
|
elif fill_color_type is None:
|
|
fill_color = None
|
|
|
|
else:
|
|
assert False
|
|
|
|
stroke_color_type = args.pop(0)
|
|
|
|
if stroke_color_type == diffvg.ColorType.constant:
|
|
color = args.pop(0)
|
|
stroke_color = diffvg.Constant(\
|
|
diffvg.Vector4f(color[0], color[1], color[2], color[3]))
|
|
elif stroke_color_type == diffvg.ColorType.linear_gradient:
|
|
(beg, end, offsets, stop_colors) = popmult(args, 4)
|
|
|
|
assert offsets.shape[0] == stop_colors.shape[0]
|
|
stroke_color = diffvg.LinearGradient(diffvg.Vector2f(beg[0], beg[1]),
|
|
diffvg.Vector2f(end[0], end[1]),
|
|
offsets.shape[0],
|
|
diffvg.float_ptr(offsets.data_ptr()),
|
|
diffvg.float_ptr(stop_colors.data_ptr()))
|
|
elif stroke_color_type == diffvg.ColorType.radial_gradient:
|
|
(center, radius, offsets, stop_colors), args = popmult(args, 4)
|
|
|
|
assert offsets.shape[0] == stop_colors.shape[0]
|
|
stroke_color = diffvg.RadialGradient(diffvg.Vector2f(center[0], center[1]),
|
|
diffvg.Vector2f(radius[0], radius[1]),
|
|
offsets.shape[0],
|
|
diffvg.float_ptr(offsets.data_ptr()),
|
|
diffvg.float_ptr(stop_colors.data_ptr()))
|
|
|
|
elif stroke_color_type is None:
|
|
stroke_color = None
|
|
|
|
else:
|
|
assert False
|
|
|
|
(use_even_odd_rule, shape_to_canvas), args = popmult(args, 2)
|
|
|
|
if fill_color is not None:
|
|
color_contents.append(fill_color)
|
|
|
|
if stroke_color is not None:
|
|
color_contents.append(stroke_color)
|
|
|
|
shape_groups.append(diffvg.ShapeGroup(\
|
|
diffvg.int_ptr(shape_ids.data_ptr()),
|
|
shape_ids.shape[0],
|
|
diffvg.ColorType.constant if fill_color_type is None else fill_color_type,
|
|
diffvg.void_ptr(0) if fill_color is None else fill_color.get_ptr(),
|
|
diffvg.ColorType.constant if stroke_color_type is None else stroke_color_type,
|
|
diffvg.void_ptr(0) if stroke_color is None else stroke_color.get_ptr(),
|
|
use_even_odd_rule,
|
|
diffvg.float_ptr(shape_to_canvas.data_ptr())))
|
|
|
|
(filter_type, filter_radius), args = popmult(args, 2)
|
|
|
|
filt = diffvg.Filter(filter_type, filter_radius)
|
|
|
|
scene = diffvg.Scene(canvas_width, canvas_height,
|
|
shapes, shape_groups, filt, pydiffvg.get_use_gpu(),
|
|
pydiffvg.get_device().index if pydiffvg.get_device().index is not None else -1)
|
|
|
|
if output_type == OutputType.color:
|
|
assert grad_img.shape[2] == 4
|
|
else:
|
|
assert grad_img.shape[2] == 1
|
|
|
|
if background_image is not None:
|
|
background_image = background_image.to(pydiffvg.get_device())
|
|
if background_image.shape[2] == 3:
|
|
background_image = torch.cat((
|
|
background_image, torch.ones(background_image.shape[0],
|
|
background_image.shape[1],
|
|
1,
|
|
device=background_image.device)),
|
|
dim=2)
|
|
background_image = background_image.contiguous()
|
|
assert background_image.shape[0] == rendered_image.shape[0]
|
|
assert background_image.shape[1] == rendered_image.shape[1]
|
|
assert background_image.shape[2] == 4
|
|
|
|
translation_grad_image = \
|
|
torch.zeros(height, width, 2, device = pydiffvg.get_device())
|
|
start = time.time()
|
|
diffvg.render(scene,
|
|
diffvg.float_ptr(background_image.data_ptr() if background_image is not None else 0),
|
|
diffvg.float_ptr(0), # render_image
|
|
diffvg.float_ptr(0), # render_sdf
|
|
width,
|
|
height,
|
|
num_samples_x,
|
|
num_samples_y,
|
|
seed,
|
|
diffvg.float_ptr(0), # d_background_image
|
|
diffvg.float_ptr(grad_img.data_ptr() if output_type == OutputType.color else 0),
|
|
diffvg.float_ptr(grad_img.data_ptr() if output_type == OutputType.sdf else 0),
|
|
diffvg.float_ptr(translation_grad_image.data_ptr()),
|
|
use_prefiltering,
|
|
diffvg.float_ptr(eval_positions.data_ptr()),
|
|
eval_positions.shape[0])
|
|
time_elapsed = time.time() - start
|
|
if print_timing:
|
|
print('Gradient pass, time: %.5f s' % time_elapsed)
|
|
assert torch.isfinite(translation_grad_image).all()
|
|
|
|
return translation_grad_image
|
|
|
|
@staticmethod
|
|
def backward(ctx,
|
|
grad_img):
|
|
if not grad_img.is_contiguous():
|
|
grad_img = grad_img.contiguous()
|
|
assert torch.isfinite(grad_img).all()
|
|
|
|
scene = ctx.scene
|
|
width = ctx.width
|
|
height = ctx.height
|
|
num_samples_x = ctx.num_samples_x
|
|
num_samples_y = ctx.num_samples_y
|
|
seed = ctx.seed
|
|
output_type = ctx.output_type
|
|
use_prefiltering = ctx.use_prefiltering
|
|
eval_positions = ctx.eval_positions
|
|
background_image = ctx.background_image
|
|
|
|
if background_image is not None:
|
|
d_background_image = torch.zeros_like(background_image)
|
|
else:
|
|
d_background_image = None
|
|
|
|
start = time.time()
|
|
diffvg.render(scene,
|
|
diffvg.float_ptr(background_image.data_ptr() if background_image is not None else 0),
|
|
diffvg.float_ptr(0), # render_image
|
|
diffvg.float_ptr(0), # render_sdf
|
|
width,
|
|
height,
|
|
num_samples_x,
|
|
num_samples_y,
|
|
seed,
|
|
diffvg.float_ptr(d_background_image.data_ptr() if background_image is not None else 0),
|
|
diffvg.float_ptr(grad_img.data_ptr() if output_type == OutputType.color else 0),
|
|
diffvg.float_ptr(grad_img.data_ptr() if output_type == OutputType.sdf else 0),
|
|
diffvg.float_ptr(0), # d_translation
|
|
use_prefiltering,
|
|
diffvg.float_ptr(eval_positions.data_ptr()),
|
|
eval_positions.shape[0])
|
|
time_elapsed = time.time() - start
|
|
global print_timing
|
|
if print_timing:
|
|
print('Backward pass, time: %.5f s' % time_elapsed)
|
|
|
|
# [width, height, num_samples_x, num_samples_y, seed,
|
|
# d_background_image, canvas_width, canvas_height, num_shapes,
|
|
# num_shape_groups, output_type, use_prefiltering, _eval_positions]
|
|
d_args = [None] * 5 + [d_background_image] + [None] * 7
|
|
|
|
for shape_id in range(scene.num_shapes):
|
|
d_args.append(None) # type
|
|
d_shape = scene.get_d_shape(shape_id)
|
|
use_thickness = False
|
|
if d_shape.type == diffvg.ShapeType.circle:
|
|
d_circle = d_shape.as_circle()
|
|
radius = torch.tensor(d_circle.radius)
|
|
assert torch.isfinite(radius).all()
|
|
d_args.append(radius)
|
|
c = d_circle.center
|
|
c = torch.tensor((c.x, c.y))
|
|
assert torch.isfinite(c).all()
|
|
d_args.append(c)
|
|
elif d_shape.type == diffvg.ShapeType.ellipse:
|
|
d_ellipse = d_shape.as_ellipse()
|
|
r = d_ellipse.radius
|
|
r = torch.tensor((d_ellipse.radius.x, d_ellipse.radius.y))
|
|
assert torch.isfinite(r).all()
|
|
d_args.append(r)
|
|
c = d_ellipse.center
|
|
c = torch.tensor((c.x, c.y))
|
|
assert torch.isfinite(c).all()
|
|
d_args.append(c)
|
|
elif d_shape.type == diffvg.ShapeType.path:
|
|
d_path = d_shape.as_path()
|
|
points = torch.zeros((d_path.num_points, 2))
|
|
thickness = None
|
|
if d_path.has_thickness():
|
|
use_thickness = True
|
|
thickness = torch.zeros(d_path.num_points)
|
|
d_path.copy_to(diffvg.float_ptr(points.data_ptr()), diffvg.float_ptr(thickness.data_ptr()))
|
|
else:
|
|
d_path.copy_to(diffvg.float_ptr(points.data_ptr()), diffvg.float_ptr(0))
|
|
assert torch.isfinite(points).all()
|
|
if thickness is not None:
|
|
assert torch.isfinite(thickness).all()
|
|
d_args.append(None) # num_control_points
|
|
d_args.append(points)
|
|
d_args.append(thickness)
|
|
d_args.append(None) # is_closed
|
|
d_args.append(None) # use_distance_approx
|
|
elif d_shape.type == diffvg.ShapeType.rect:
|
|
d_rect = d_shape.as_rect()
|
|
p_min = torch.tensor((d_rect.p_min.x, d_rect.p_min.y))
|
|
p_max = torch.tensor((d_rect.p_max.x, d_rect.p_max.y))
|
|
assert torch.isfinite(p_min).all()
|
|
assert torch.isfinite(p_max).all()
|
|
d_args.append(p_min)
|
|
d_args.append(p_max)
|
|
else:
|
|
assert False
|
|
if use_thickness:
|
|
d_args.append(None)
|
|
else:
|
|
w = torch.tensor((d_shape.stroke_width))
|
|
assert torch.isfinite(w).all()
|
|
d_args.append(w)
|
|
|
|
for group_id in range(scene.num_shape_groups):
|
|
d_shape_group = scene.get_d_shape_group(group_id)
|
|
d_args.append(None) # shape_ids
|
|
d_args.append(None) # fill_color_type
|
|
if d_shape_group.has_fill_color():
|
|
if d_shape_group.fill_color_type == diffvg.ColorType.constant:
|
|
d_constant = d_shape_group.fill_color_as_constant()
|
|
c = d_constant.color
|
|
d_args.append(torch.tensor((c.x, c.y, c.z, c.w)))
|
|
elif d_shape_group.fill_color_type == diffvg.ColorType.linear_gradient:
|
|
d_linear_gradient = d_shape_group.fill_color_as_linear_gradient()
|
|
beg = d_linear_gradient.begin
|
|
d_args.append(torch.tensor((beg.x, beg.y)))
|
|
end = d_linear_gradient.end
|
|
d_args.append(torch.tensor((end.x, end.y)))
|
|
offsets = torch.zeros((d_linear_gradient.num_stops))
|
|
stop_colors = torch.zeros((d_linear_gradient.num_stops, 4))
|
|
d_linear_gradient.copy_to(\
|
|
diffvg.float_ptr(offsets.data_ptr()),
|
|
diffvg.float_ptr(stop_colors.data_ptr()))
|
|
assert torch.isfinite(stop_colors).all()
|
|
d_args.append(offsets)
|
|
d_args.append(stop_colors)
|
|
elif d_shape_group.fill_color_type == diffvg.ColorType.radial_gradient:
|
|
d_radial_gradient = d_shape_group.fill_color_as_radial_gradient()
|
|
center = d_radial_gradient.center
|
|
d_args.append(torch.tensor((center.x, center.y)))
|
|
radius = d_radial_gradient.radius
|
|
d_args.append(torch.tensor((radius.x, radius.y)))
|
|
offsets = torch.zeros((d_radial_gradient.num_stops))
|
|
stop_colors = torch.zeros((d_radial_gradient.num_stops, 4))
|
|
d_radial_gradient.copy_to(\
|
|
diffvg.float_ptr(offsets.data_ptr()),
|
|
diffvg.float_ptr(stop_colors.data_ptr()))
|
|
assert torch.isfinite(stop_colors).all()
|
|
d_args.append(offsets)
|
|
d_args.append(stop_colors)
|
|
else:
|
|
assert False
|
|
d_args.append(None) # stroke_color_type
|
|
if d_shape_group.has_stroke_color():
|
|
if d_shape_group.stroke_color_type == diffvg.ColorType.constant:
|
|
d_constant = d_shape_group.stroke_color_as_constant()
|
|
c = d_constant.color
|
|
d_args.append(torch.tensor((c.x, c.y, c.z, c.w)))
|
|
elif d_shape_group.stroke_color_type == diffvg.ColorType.linear_gradient:
|
|
d_linear_gradient = d_shape_group.stroke_color_as_linear_gradient()
|
|
beg = d_linear_gradient.begin
|
|
d_args.append(torch.tensor((beg.x, beg.y)))
|
|
end = d_linear_gradient.end
|
|
d_args.append(torch.tensor((end.x, end.y)))
|
|
offsets = torch.zeros((d_linear_gradient.num_stops))
|
|
stop_colors = torch.zeros((d_linear_gradient.num_stops, 4))
|
|
d_linear_gradient.copy_to(\
|
|
diffvg.float_ptr(offsets.data_ptr()),
|
|
diffvg.float_ptr(stop_colors.data_ptr()))
|
|
assert torch.isfinite(stop_colors).all()
|
|
d_args.append(offsets)
|
|
d_args.append(stop_colors)
|
|
elif d_shape_group.fill_color_type == diffvg.ColorType.radial_gradient:
|
|
d_radial_gradient = d_shape_group.stroke_color_as_radial_gradient()
|
|
center = d_radial_gradient.center
|
|
d_args.append(torch.tensor((center.x, center.y)))
|
|
radius = d_radial_gradient.radius
|
|
d_args.append(torch.tensor((radius.x, radius.y)))
|
|
offsets = torch.zeros((d_radial_gradient.num_stops))
|
|
stop_colors = torch.zeros((d_radial_gradient.num_stops, 4))
|
|
d_radial_gradient.copy_to(\
|
|
diffvg.float_ptr(offsets.data_ptr()),
|
|
diffvg.float_ptr(stop_colors.data_ptr()))
|
|
assert torch.isfinite(stop_colors).all()
|
|
d_args.append(offsets)
|
|
d_args.append(stop_colors)
|
|
else:
|
|
assert False
|
|
d_args.append(None) # use_even_odd_rule
|
|
d_shape_to_canvas = torch.zeros((3, 3))
|
|
d_shape_group.copy_to(diffvg.float_ptr(d_shape_to_canvas.data_ptr()))
|
|
assert torch.isfinite(d_shape_to_canvas).all()
|
|
d_args.append(d_shape_to_canvas)
|
|
d_args.append(None) # filter_type
|
|
d_args.append(torch.tensor(scene.get_d_filter_radius()))
|
|
|
|
return tuple(d_args)
|