diff --git a/pydiffvg/render_pytorch.py b/pydiffvg/render_pytorch.py index b776ce6..a686fb1 100644 --- a/pydiffvg/render_pytorch.py +++ b/pydiffvg/render_pytorch.py @@ -384,9 +384,7 @@ class RenderFunction(torch.autograd.Function): if background_image is not None: background_image = background_image.to(pydiffvg.get_device()) if background_image.shape[2] == 3: - background_image = torch.cat((\ - background_image, torch.ones(background_image.shape[0], background_image.shape[1], 1, - device = background_image.device)), dim = 2) + raise NotImplementedError('Background image must have 4 channels, not 3. Add a fourth channel with all ones via torch.ones().') background_image = background_image.contiguous() assert(background_image.shape[0] == rendered_image.shape[0]) assert(background_image.shape[1] == rendered_image.shape[1])