From a5e26c439566d4d7ead2b05f7fe7a8b205477a99 Mon Sep 17 00:00:00 2001 From: Dan Nissenbaum Date: Sat, 5 Nov 2022 15:07:47 -0400 Subject: [PATCH] Resolve crash in 'backward' when a background image with only 3 channels is passed to 'forward' The existing code adds the fourth channel to the background image directly inside 'forward'. However, this breaks back propagation because Torch's autograd framework records the shapes of all inputs to the 'forward' function and expects shapes passed to 'backward' to match. By adding a channel to the background image inside 'forward' and passing this to 'backward', there is an extra channel that autograd does not expect, and it crashes. The resolution is to instead raise an exception with a useful error message for the end user that they need to add a channel of all ones to the background image. --- pydiffvg/render_pytorch.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pydiffvg/render_pytorch.py b/pydiffvg/render_pytorch.py index b776ce6..a686fb1 100644 --- a/pydiffvg/render_pytorch.py +++ b/pydiffvg/render_pytorch.py @@ -384,9 +384,7 @@ class RenderFunction(torch.autograd.Function): if background_image is not None: background_image = background_image.to(pydiffvg.get_device()) if background_image.shape[2] == 3: - background_image = torch.cat((\ - background_image, torch.ones(background_image.shape[0], background_image.shape[1], 1, - device = background_image.device)), dim = 2) + raise NotImplementedError('Background image must have 4 channels, not 3. Add a fourth channel with all ones via torch.ones().') background_image = background_image.contiguous() assert(background_image.shape[0] == rendered_image.shape[0]) assert(background_image.shape[1] == rendered_image.shape[1])