Resolve crash in 'backward' when a background image with only 3 channels is passed to 'forward'
The existing code adds the fourth channel to the background image directly inside 'forward'. However, this breaks back propagation because Torch's autograd framework records the shapes of all inputs to the 'forward' function and expects shapes passed to 'backward' to match. By adding a channel to the background image inside 'forward' and passing this to 'backward', there is an extra channel that autograd does not expect, and it crashes. The resolution is to instead raise an exception with a useful error message for the end user that they need to add a channel of all ones to the background image.
This commit is contained in:
@@ -384,9 +384,7 @@ class RenderFunction(torch.autograd.Function):
|
||||
if background_image is not None:
|
||||
background_image = background_image.to(pydiffvg.get_device())
|
||||
if background_image.shape[2] == 3:
|
||||
background_image = torch.cat((\
|
||||
background_image, torch.ones(background_image.shape[0], background_image.shape[1], 1,
|
||||
device = background_image.device)), dim = 2)
|
||||
raise NotImplementedError('Background image must have 4 channels, not 3. Add a fourth channel with all ones via torch.ones().')
|
||||
background_image = background_image.contiguous()
|
||||
assert(background_image.shape[0] == rendered_image.shape[0])
|
||||
assert(background_image.shape[1] == rendered_image.shape[1])
|
||||
|
Reference in New Issue
Block a user