Source code for diffusion_models.diffusion_inference

  1import logging
  2from typing import Callable, Tuple
  3
  4import torch
  5from PIL.Image import Image
  6from torchvision import transforms
  7from torchvision.utils import make_grid
  8
  9from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
 10
 11
[docs] 12class DiffusionInference: 13 def __init__( 14 self, 15 model: BaseDiffusionModel, 16 reverse_transforms: Callable = lambda x: x, 17 image_shape: Tuple[int, int] = (3, 64), 18 device: str = "cuda", 19 ): 20 """A diffusion inference framework. 21 22 This is a simplified inference framework to easily start using your 23 diffusion model. 24 25 Args: 26 model: The trained diffusion model. 27 reverse_transforms: A set of reverse transforms. 28 image_shape: The shape of the image to produce. This is a tuple where the 29 first value is the number of channels and the second is the size of the 30 image. Images are expected to be square. 31 device: The device to run the inference on. 32 """
[docs] 33 self.image_channels = image_shape[0]
34 """The number of channels of the image."""
[docs] 35 self.image_size = image_shape[1]
36 """The size of the image.""" 37
[docs] 38 self.model = model.to(device)
39 """The trained diffusion model.""" 40
[docs] 41 self.reverse_transforms = reverse_transforms
42 """The set of reverse transforms."""
[docs] 43 self.device = device
44 """The device to run the inference on.""" 45 46 def _visualise_images(self, denoised_images: torch.Tensor): 47 reverse_transformed_images = self.reverse_transforms(denoised_images) 48 image_grid = make_grid(reverse_transformed_images, nrow=5) 49 pil_images = transforms.ToPILImage()(image_grid) 50 return pil_images 51
[docs] 52 def generate( 53 self, number_of_images: int, save_gif: bool = False 54 ) -> Image: 55 """Generate a batch of images. 56 57 Args: 58 number_of_images: The number of images to generate. 59 save_gif: Whether to save the generation process as a GIF. 60 61 Returns: 62 A PIL image of the generated images stacked together. 63 """ 64 images = torch.randn( 65 ( 66 number_of_images, 67 self.image_channels, 68 self.image_size, 69 self.image_size, 70 ), 71 device=self.device, 72 ) 73 74 denoised_images = self.model.denoise(images) 75 76 if save_gif: 77 pil_images = [] 78 for timestep, images in enumerate(denoised_images): 79 pil_image_grid = self._visualise_images(images) 80 pil_images.append(pil_image_grid) 81 82 images = pil_images[1::2] 83 frame_one = pil_images[0] 84 85 logging.info("Saving GIF") 86 frame_one.save( 87 "generated.gif", 88 format="GIF", 89 append_images=images, 90 save_all=True, 91 duration=[10 for i in range(len(images))] + [2000], 92 loop=0, 93 ) 94 95 else: 96 pil_image_grid = self._visualise_images(denoised_images[-1]) 97 return pil_image_grid
98
[docs] 99 def get_generator(self, number_of_images: int = 1): 100 """An image generator. 101 102 This method is a generator that will generate a batch of images. At 103 every call, the generator denoises the images by one more step until the 104 image is fully generated. This can be particularly useful for running the 105 image generation step by step or for a streaming API. 106 107 Args: 108 number_of_images: The number of images the generator should generate. 109 """ 110 images = torch.randn( 111 ( 112 number_of_images, 113 self.image_channels, 114 self.image_size, 115 self.image_size, 116 ), 117 device=self.device, 118 ) 119 120 for i in range(self.model.diffuser.beta_scheduler.steps)[::-1]: 121 timestep = torch.full((images.shape[0],), i, device=self.device) 122 images = self.model.diffuser._denoise_step( 123 images, model=self.model, timestep=timestep 124 ) 125 images = torch.clamp(images, -1.0, 1.0) 126 127 yield self._visualise_images(images)