1import logging
2from typing import Callable, Tuple
3
4import torch
5from PIL.Image import Image
6from torchvision import transforms
7from torchvision.utils import make_grid
8
9from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
10
11
[docs]
12class DiffusionInference:
13 def __init__(
14 self,
15 model: BaseDiffusionModel,
16 reverse_transforms: Callable = lambda x: x,
17 image_shape: Tuple[int, int] = (3, 64),
18 device: str = "cuda",
19 ):
20 """A diffusion inference framework.
21
22 This is a simplified inference framework to easily start using your
23 diffusion model.
24
25 Args:
26 model: The trained diffusion model.
27 reverse_transforms: A set of reverse transforms.
28 image_shape: The shape of the image to produce. This is a tuple where the
29 first value is the number of channels and the second is the size of the
30 image. Images are expected to be square.
31 device: The device to run the inference on.
32 """
[docs]
33 self.image_channels = image_shape[0]
34 """The number of channels of the image."""
[docs]
35 self.image_size = image_shape[1]
36 """The size of the image."""
37
[docs]
38 self.model = model.to(device)
39 """The trained diffusion model."""
40
42 """The set of reverse transforms."""
[docs]
43 self.device = device
44 """The device to run the inference on."""
45
46 def _visualise_images(self, denoised_images: torch.Tensor):
47 reverse_transformed_images = self.reverse_transforms(denoised_images)
48 image_grid = make_grid(reverse_transformed_images, nrow=5)
49 pil_images = transforms.ToPILImage()(image_grid)
50 return pil_images
51
[docs]
52 def generate(
53 self, number_of_images: int, save_gif: bool = False
54 ) -> Image:
55 """Generate a batch of images.
56
57 Args:
58 number_of_images: The number of images to generate.
59 save_gif: Whether to save the generation process as a GIF.
60
61 Returns:
62 A PIL image of the generated images stacked together.
63 """
64 images = torch.randn(
65 (
66 number_of_images,
67 self.image_channels,
68 self.image_size,
69 self.image_size,
70 ),
71 device=self.device,
72 )
73
74 denoised_images = self.model.denoise(images)
75
76 if save_gif:
77 pil_images = []
78 for timestep, images in enumerate(denoised_images):
79 pil_image_grid = self._visualise_images(images)
80 pil_images.append(pil_image_grid)
81
82 images = pil_images[1::2]
83 frame_one = pil_images[0]
84
85 logging.info("Saving GIF")
86 frame_one.save(
87 "generated.gif",
88 format="GIF",
89 append_images=images,
90 save_all=True,
91 duration=[10 for i in range(len(images))] + [2000],
92 loop=0,
93 )
94
95 else:
96 pil_image_grid = self._visualise_images(denoised_images[-1])
97 return pil_image_grid
98
[docs]
99 def get_generator(self, number_of_images: int = 1):
100 """An image generator.
101
102 This method is a generator that will generate a batch of images. At
103 every call, the generator denoises the images by one more step until the
104 image is fully generated. This can be particularly useful for running the
105 image generation step by step or for a streaming API.
106
107 Args:
108 number_of_images: The number of images the generator should generate.
109 """
110 images = torch.randn(
111 (
112 number_of_images,
113 self.image_channels,
114 self.image_size,
115 self.image_size,
116 ),
117 device=self.device,
118 )
119
120 for i in range(self.model.diffuser.beta_scheduler.steps)[::-1]:
121 timestep = torch.full((images.shape[0],), i, device=self.device)
122 images = self.model.diffuser._denoise_step(
123 images, model=self.model, timestep=timestep
124 )
125 images = torch.clamp(images, -1.0, 1.0)
126
127 yield self._visualise_images(images)