102 lines
2.5 KiB
Python
102 lines
2.5 KiB
Python
from __future__ import annotations
|
|
|
|
from time import sleep
|
|
|
|
import pydiffvg
|
|
import torch
|
|
import random
|
|
from random import uniform
|
|
|
|
from gmtypes import GradientMesh, Quad, PointMapping, join_quads
|
|
|
|
def get_mesh() -> GradientMesh:
|
|
"""Helper function to get a random mesh."""
|
|
a, b, c, d = [Quad.random() for _ in range(4)]
|
|
join_quads(a, b, c, d)
|
|
return GradientMesh(a, b, c, d)
|
|
|
|
|
|
def render_mesh(mesh: PointMapping,
|
|
filename='test_data/mesh.png',
|
|
width=1024,
|
|
height=1024,
|
|
num_control_points=2,
|
|
seed=None):
|
|
|
|
random.seed(seed)
|
|
|
|
pydiffvg.set_use_gpu(torch.cuda.is_available())
|
|
render = pydiffvg.RenderFunction.apply
|
|
ncp = torch.tensor([num_control_points] * len(mesh.points))
|
|
|
|
# Scale
|
|
# TODO non-uniform scaling
|
|
points = [x * width for x in mesh.as_shapes()]
|
|
|
|
shapes = [
|
|
pydiffvg.Path(num_control_points=ncp,
|
|
points=pts,
|
|
is_closed=True)
|
|
for pts in points
|
|
]
|
|
|
|
shape_groups = [
|
|
pydiffvg.ShapeGroup(shape_ids=torch.tensor([i]),
|
|
fill_color=mesh.colors[i])
|
|
for i in range(len(mesh.points))
|
|
]
|
|
|
|
scene_args = pydiffvg.RenderFunction.serialize_scene(
|
|
width,
|
|
height,
|
|
shapes,
|
|
shape_groups
|
|
)
|
|
|
|
img = render(width,
|
|
height,
|
|
2, # num smaples x
|
|
2, # num samples y
|
|
0, # seed
|
|
None,
|
|
*scene_args)
|
|
pydiffvg.imwrite(img.cpu(), filename, gamma=2.2)
|
|
|
|
return img
|
|
|
|
|
|
def test_render(filename='test_data/target.png',
|
|
width=1024,
|
|
height=1024):
|
|
return render_mesh(get_mesh().as_mapping(),
|
|
width=width,
|
|
height=height,
|
|
filename=filename)
|
|
|
|
|
|
def optimize():
|
|
width, height = 256, 256
|
|
target = test_render(width=width, height=height).clone()
|
|
|
|
mesh = get_mesh().as_mapping()
|
|
|
|
optimizer = torch.optim.Adam([mesh.data, mesh.colors], lr=1e-2)
|
|
|
|
for t in range(150):
|
|
print(f"iteration {t}")
|
|
optimizer.zero_grad()
|
|
|
|
img = render_mesh(mesh,
|
|
filename=f"test_data/mesh_optim_{str(t).zfill(3)}.png",
|
|
width=width,
|
|
height=height)
|
|
|
|
loss = (img - target).pow(2).sum()
|
|
|
|
# FIXME no need to retain graph
|
|
loss.backward(retain_graph=True)
|
|
|
|
print(f'loss: {loss}')
|
|
|
|
optimizer.step()
|