Major refactors
This commit is contained in:
@@ -7,171 +7,95 @@ import torch
|
||||
import random
|
||||
from random import uniform
|
||||
|
||||
from gmtypes import GradientMesh, Quad, Patch, Point, join_quads
|
||||
from gmtypes import GradientMesh, Quad, PointMapping, join_quads
|
||||
|
||||
def get_mesh() -> GradientMesh:
|
||||
"""Helper function to get a random mesh."""
|
||||
a, b, c, d = [Quad.random() for _ in range(4)]
|
||||
join_quads(a, b, c, d)
|
||||
return GradientMesh(a, b, c, d)
|
||||
|
||||
|
||||
def quads():
|
||||
return [
|
||||
Quad.random(),
|
||||
Quad.random(),
|
||||
Quad.random(),
|
||||
Quad.random(),
|
||||
]
|
||||
def render_mesh(mesh: PointMapping,
|
||||
filename='test_data/mesh.png',
|
||||
width=1024,
|
||||
height=1024,
|
||||
num_control_points=2,
|
||||
seed=None):
|
||||
|
||||
random.seed(seed)
|
||||
|
||||
def rand_quad_test(filename='random_quad.png', width=256, height=256,
|
||||
degree=4, num_control_points=2):
|
||||
pydiffvg.set_use_gpu(torch.cuda.is_available())
|
||||
render = pydiffvg.RenderFunction.apply
|
||||
ncp = torch.tensor([num_control_points] * len(mesh.points))
|
||||
|
||||
patch = Patch.random()
|
||||
# Scale
|
||||
# TODO non-uniform scaling
|
||||
points = [x * width for x in mesh.as_shapes()]
|
||||
|
||||
shape_groups = [patch.as_shape_group()]
|
||||
shapes = [patch.as_path(width, height)]
|
||||
shapes = [
|
||||
pydiffvg.Path(num_control_points=ncp,
|
||||
points=pts,
|
||||
is_closed=True)
|
||||
for pts in points
|
||||
]
|
||||
|
||||
scene_args = pydiffvg.RenderFunction.serialize_scene(width, height,
|
||||
shapes, shape_groups)
|
||||
shape_groups = [
|
||||
pydiffvg.ShapeGroup(shape_ids=torch.tensor([i]),
|
||||
fill_color=mesh.colors[i])
|
||||
for i in range(len(mesh.points))
|
||||
]
|
||||
|
||||
img = render(width, height, 2, 2, 0, None, *scene_args)
|
||||
pydiffvg.imwrite(img.cpu(), f"test_data/{filename}", gamma=2.2)
|
||||
scene_args = pydiffvg.RenderFunction.serialize_scene(
|
||||
width,
|
||||
height,
|
||||
shapes,
|
||||
shape_groups
|
||||
)
|
||||
|
||||
img = render(width,
|
||||
height,
|
||||
2, # num smaples x
|
||||
2, # num samples y
|
||||
0, # seed
|
||||
None,
|
||||
*scene_args)
|
||||
pydiffvg.imwrite(img.cpu(), filename, gamma=2.2)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def mult_quad_test(filename='multiple_quads.png', width=1024,
|
||||
height=1024, num_control_points=None, mask=None, seed=None):
|
||||
random.seed(seed)
|
||||
mask = mask or [1, 1, 1, 1]
|
||||
pydiffvg.set_use_gpu(torch.cuda.is_available())
|
||||
render = pydiffvg.RenderFunction.apply
|
||||
|
||||
a, b, c, d = quads()
|
||||
join_quads(a, b, c, d)
|
||||
|
||||
to_render = [a, b, c, d]
|
||||
to_render = [x for x in to_render if mask[to_render.index(x)]]
|
||||
|
||||
shape_groups = [patch.as_shape_group(color=(
|
||||
uniform(0, 1),
|
||||
uniform(0, 1),
|
||||
uniform(0, 1),
|
||||
0.8
|
||||
)) for patch in to_render]
|
||||
|
||||
for i in range(len(to_render)):
|
||||
shape_groups[i].shape_ids = torch.tensor([i])
|
||||
shapes = [patch.as_path(width, height) for patch in to_render]
|
||||
|
||||
scene_args = pydiffvg.RenderFunction.serialize_scene(width, height,
|
||||
shapes, shape_groups)
|
||||
|
||||
img = render(width, height, 2, 2, 0, None, *scene_args)
|
||||
pydiffvg.imwrite(img.cpu(), f"test_data/{filename}", gamma=2.2)
|
||||
return img.clone()
|
||||
def test_render(filename='test_data/target.png',
|
||||
width=1024,
|
||||
height=1024):
|
||||
return render_mesh(get_mesh().as_mapping(),
|
||||
width=width,
|
||||
height=height,
|
||||
filename=filename)
|
||||
|
||||
|
||||
def om():
|
||||
filename = 'optimize_test.png'
|
||||
pydiffvg.set_use_gpu(torch.cuda.is_available())
|
||||
render = pydiffvg.RenderFunction.apply
|
||||
def optimize():
|
||||
width, height = 256, 256
|
||||
target = test_render(width=width, height=height).clone()
|
||||
|
||||
target = mult_quad_test(width=256, height=256)
|
||||
mesh = get_mesh().as_mapping()
|
||||
|
||||
squad = quads()
|
||||
optimizer = torch.optim.Adam([mesh.data, mesh.colors], lr=1e-2)
|
||||
|
||||
join_quads(*squad)
|
||||
|
||||
gm = GradientMesh(*squad)
|
||||
|
||||
points_n = []
|
||||
for s in squad:
|
||||
out = []
|
||||
for pt in s.points:
|
||||
out.append([pt.x, pt.y])
|
||||
for cpt in pt.controls:
|
||||
out.append([cpt.x, cpt.y])
|
||||
points_n.append(out)
|
||||
|
||||
points_n = torch.tensor(points_n, requires_grad=True)
|
||||
color = torch.tensor([s.color for s in squad], requires_grad=True)
|
||||
|
||||
paths = [s.as_path() for s in squad]
|
||||
path_groups = [pydiffvg.ShapeGroup(shape_ids=torch.tensor([i]),
|
||||
fill_color=torch.tensor(squad[i].color))
|
||||
for i in range(len(squad))]
|
||||
scene_args = pydiffvg.RenderFunction.serialize_scene(
|
||||
256, 256, paths, path_groups
|
||||
)
|
||||
img = render(256, # width
|
||||
256, # height
|
||||
2, # num_samples_x
|
||||
2, # num_samples_y
|
||||
1, # seed
|
||||
None,
|
||||
*scene_args)
|
||||
|
||||
points, controls, color = [torch.tensor(x, requires_grad=True)
|
||||
for x in gm.to_numbers()]
|
||||
|
||||
optimizer = torch.optim.Adam([points, color, points_n], lr=1e-2)
|
||||
|
||||
for t in range(180):
|
||||
for t in range(150):
|
||||
print(f"iteration {t}")
|
||||
optimizer.zero_grad()
|
||||
|
||||
points_n.data = torch.tensor(
|
||||
GradientMesh.from_path_points(points_n, color).to_path_points()
|
||||
)
|
||||
|
||||
for i in range(len(paths)):
|
||||
paths[i].points = points_n[i] * 256
|
||||
|
||||
for i in range(len(path_groups)):
|
||||
path_groups[i].fill_color = color[i]
|
||||
|
||||
scene_args = pydiffvg.RenderFunction.serialize_scene(
|
||||
256, 256, paths, path_groups)
|
||||
|
||||
img = render(256, # width
|
||||
256, # height
|
||||
2, # num_samples_x
|
||||
2, # num_samples_y
|
||||
t+1, # seed
|
||||
None,
|
||||
*scene_args)
|
||||
|
||||
pydiffvg.imwrite(img.cpu(),
|
||||
f'test_data/test_curve/iter_{filename}_'
|
||||
f'{str(t).zfill(5)}.png',
|
||||
gamma=2.2)
|
||||
img = render_mesh(mesh,
|
||||
filename=f"test_data/mesh_optim_{str(t).zfill(3)}.png",
|
||||
width=width,
|
||||
height=height)
|
||||
|
||||
loss = (img - target).pow(2).sum()
|
||||
|
||||
loss.backward()
|
||||
# FIXME no need to retain graph
|
||||
loss.backward(retain_graph=True)
|
||||
|
||||
print(f'loss: {loss}')
|
||||
print(f'points.grad {points.grad}')
|
||||
print(f'color.grad {color.grad}')
|
||||
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def slideshow(n=30, s=1, do_mask=False):
|
||||
mask = None
|
||||
for i in range(n):
|
||||
if do_mask:
|
||||
mask = [1] * 4
|
||||
print(i % n)
|
||||
mask[i % 4] = 0
|
||||
print(mask)
|
||||
|
||||
mult_quad_test(mask=mask)
|
||||
sleep(s)
|
||||
|
||||
|
||||
def get_mesh():
|
||||
a, b, c, d = quads()
|
||||
join_quads(a,b,c,d)
|
||||
|
||||
gm = GradientMesh(a, b, c, d)
|
||||
return gm
|
||||
|
Reference in New Issue
Block a user