Source code for diffusion_models.diffusion_inference

  1import logging
  2from typing import Callable, Tuple
  3
  4import torch
  5from PIL.Image import Image
  6from torchvision import transforms
  7from torchvision.utils import make_grid
  8
  9from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
 10
 11
[docs] 12class DiffusionInference: 13 def __init__( 14 self, 15 model: BaseDiffusionModel, 16 reverse_transforms: Callable = lambda x: x, 17 image_shape: Tuple[int, int] = (3, 64), 18 device: str = "cuda", 19 ): 20 """A diffusion inference framework. 21 22 This is a simplified inference framework to easily start using your 23 diffusion model. 24 25 Args: 26 model: The trained diffusion model. 27 reverse_transforms: A set of reverse transforms. 28 image_shape: The shape of the image to produce. This is a tuple where the 29 first value is the number of channels and the second is the size of the 30 image. Images are expected to be square. 31 device: The device to run the inference on. 32 """
[docs] 33 self.image_channels = image_shape[0]
34 """The number of channels of the image."""
[docs] 35 self.image_size = image_shape[1]
36 """The size of the image.""" 37
[docs] 38 self.model = model.to(device)
39 """The trained diffusion model.""" 40
[docs] 41 self.reverse_transforms = reverse_transforms
42 """The set of reverse transforms."""
[docs] 43 self.device = device
44 """The device to run the inference on.""" 45 46 def _visualise_images(self, denoised_images: torch.Tensor): 47 reverse_transformed_images = self.reverse_transforms(denoised_images) 48 image_grid = make_grid(reverse_transformed_images, nrow=5) 49 pil_images = transforms.ToPILImage()(image_grid) 50 return pil_images 51
[docs] 52 def generate(self, number_of_images: int, save_gif: bool = False) -> Image: 53 """Generate a batch of images. 54 55 Args: 56 number_of_images: The number of images to generate. 57 save_gif: Whether to save the generation process as a GIF. 58 59 Returns: 60 A PIL image of the generated images stacked together. 61 """ 62 images = torch.randn( 63 ( 64 number_of_images, 65 self.image_channels, 66 self.image_size, 67 self.image_size, 68 ), 69 device=self.device, 70 ) 71 72 denoised_images = self.model.denoise(images) 73 74 if save_gif: 75 pil_images = [] 76 for timestep, images in enumerate(denoised_images): 77 pil_image_grid = self._visualise_images(images) 78 pil_images.append(pil_image_grid) 79 80 images = pil_images[1::2] 81 frame_one = pil_images[0] 82 83 logging.info("Saving GIF") 84 frame_one.save( 85 "generated.gif", 86 format="GIF", 87 append_images=images, 88 save_all=True, 89 duration=[10 for i in range(len(images))] + [2000], 90 loop=0, 91 ) 92 93 else: 94 pil_image_grid = self._visualise_images(denoised_images[-1]) 95 return pil_image_grid
96
[docs] 97 def get_generator(self, number_of_images: int = 1): 98 """An image generator. 99 100 This method is a generator that will generate a batch of images. At 101 every call, the generator denoises the images by one more step until the 102 image is fully generated. This can be particularly useful for running the 103 image generation step by step or for a streaming API. 104 105 Args: 106 number_of_images: The number of images the generator should generate. 107 """ 108 images = torch.randn( 109 ( 110 number_of_images, 111 self.image_channels, 112 self.image_size, 113 self.image_size, 114 ), 115 device=self.device, 116 ) 117 118 for i in self.model.diffuser.steps: 119 timestep = self.model.diffuser.get_timestep(images.shape[0], i) 120 images = self.model.diffuser._denoise_step( 121 images, model=self.model, timestep=timestep 122 ) 123 images = torch.clamp(images, -1.0, 1.0) 124 125 yield self._visualise_images(images)