From af48db06cd2dd16141e7b0eb23c88b5361d966cc Mon Sep 17 00:00:00 2001 From: Dan Nissenbaum Date: Sat, 5 Nov 2022 15:07:47 -0400 Subject: [PATCH] Resolve crash in 'backward' when a background image with only 3 channels is passed to 'forward' The existing code adds the fourth channel to the background image directly inside 'forward'. However, this breaks back propagation because Torch's autograd framework records the shapes of all inputs to the 'forward' function and expects shapes passed to 'backward' to match. By adding a channel to the background image inside 'forward' and passing this to 'backward', there is an extra channel that autograd does not expect, and it crashes. The resolution is to instead raise an exception with a useful error message for the end user that they need to add a channel of all ones to the background image. --- pydiffvg/render_pytorch.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pydiffvg/render_pytorch.py b/pydiffvg/render_pytorch.py index b776ce6..a686fb1 100644 --- a/pydiffvg/render_pytorch.py +++ b/pydiffvg/render_pytorch.py @@ -384,9 +384,7 @@ class RenderFunction(torch.autograd.Function): if background_image is not None: background_image = background_image.to(pydiffvg.get_device()) if background_image.shape[2] == 3: - background_image = torch.cat((\ - background_image, torch.ones(background_image.shape[0], background_image.shape[1], 1, - device = background_image.device)), dim = 2) + raise NotImplementedError('Background image must have 4 channels, not 3. Add a fourth channel with all ones via torch.ones().') background_image = background_image.contiguous() assert(background_image.shape[0] == rendered_image.shape[0]) assert(background_image.shape[1] == rendered_image.shape[1])