1import logging
2from typing import Callable, Tuple
3
4import torch
5from PIL.Image import Image
6from torchvision import transforms
7from torchvision.utils import make_grid
8
9from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
10
11
[docs]
12class DiffusionInference:
13 def __init__(
14 self,
15 model: BaseDiffusionModel,
16 reverse_transforms: Callable = lambda x: x,
17 image_shape: Tuple[int, int] = (3, 64),
18 device: str = "cuda",
19 ):
20 """A diffusion inference framework.
21
22 This is a simplified inference framework to easily start using your
23 diffusion model.
24
25 Args:
26 model: The trained diffusion model.
27 reverse_transforms: A set of reverse transforms.
28 image_shape: The shape of the image to produce. This is a tuple where the
29 first value is the number of channels and the second is the size of the
30 image. Images are expected to be square.
31 device: The device to run the inference on.
32 """
[docs]
33 self.image_channels = image_shape[0]
34 """The number of channels of the image."""
[docs]
35 self.image_size = image_shape[1]
36 """The size of the image."""
37
[docs]
38 self.model = model.to(device)
39 """The trained diffusion model."""
40
42 """The set of reverse transforms."""
[docs]
43 self.device = device
44 """The device to run the inference on."""
45
46 def _visualise_images(self, denoised_images: torch.Tensor):
47 reverse_transformed_images = self.reverse_transforms(denoised_images)
48 image_grid = make_grid(reverse_transformed_images, nrow=5)
49 pil_images = transforms.ToPILImage()(image_grid)
50 return pil_images
51
[docs]
52 def generate(self, number_of_images: int, save_gif: bool = False) -> Image:
53 """Generate a batch of images.
54
55 Args:
56 number_of_images: The number of images to generate.
57 save_gif: Whether to save the generation process as a GIF.
58
59 Returns:
60 A PIL image of the generated images stacked together.
61 """
62 images = torch.randn(
63 (
64 number_of_images,
65 self.image_channels,
66 self.image_size,
67 self.image_size,
68 ),
69 device=self.device,
70 )
71
72 denoised_images = self.model.denoise(images)
73
74 if save_gif:
75 pil_images = []
76 for timestep, images in enumerate(denoised_images):
77 pil_image_grid = self._visualise_images(images)
78 pil_images.append(pil_image_grid)
79
80 images = pil_images[1::2]
81 frame_one = pil_images[0]
82
83 logging.info("Saving GIF")
84 frame_one.save(
85 "generated.gif",
86 format="GIF",
87 append_images=images,
88 save_all=True,
89 duration=[10 for i in range(len(images))] + [2000],
90 loop=0,
91 )
92
93 else:
94 pil_image_grid = self._visualise_images(denoised_images[-1])
95 return pil_image_grid
96
[docs]
97 def get_generator(self, number_of_images: int = 1):
98 """An image generator.
99
100 This method is a generator that will generate a batch of images. At
101 every call, the generator denoises the images by one more step until the
102 image is fully generated. This can be particularly useful for running the
103 image generation step by step or for a streaming API.
104
105 Args:
106 number_of_images: The number of images the generator should generate.
107 """
108 images = torch.randn(
109 (
110 number_of_images,
111 self.image_channels,
112 self.image_size,
113 self.image_size,
114 ),
115 device=self.device,
116 )
117
118 for i in self.model.diffuser.steps:
119 timestep = self.model.diffuser.get_timestep(images.shape[0], i)
120 images = self.model.diffuser._denoise_step(
121 images, model=self.model, timestep=timestep
122 )
123 images = torch.clamp(images, -1.0, 1.0)
124
125 yield self._visualise_images(images)