examples.infer
Running inference.
Below is a code example of an inference script. Simply plug in your own checkpoint and start running inference. To extend this further, you might want to look into compiling your model to ONNX, TensorRT, OpenVino or other formats.
1from torchvision.transforms import v2
2
3from diffusion_models.diffusion_inference import DiffusionInference
4from diffusion_models.gaussian_diffusion.gaussian_diffuser import (
5 GaussianDiffuser,
6)
7from diffusion_models.models.SimpleUnet import SimpleUnet
8from diffusion_models.utils.schemas import Checkpoint
9
10
11if __name__ == "__main__":
12 checkpoint_file_path = "your_checkpoint.pt"
13
14 checkpoint = Checkpoint.from_file(checkpoint_file_path)
15 gaussian_diffuser = GaussianDiffuser.from_checkpoint(
16 checkpoint
17 ) # Switch to DdimDiffuser for faster inference
18
19 model = SimpleUnet(
20 image_channels=checkpoint.image_channels, diffuser=gaussian_diffuser
21 )
22 model.load_state_dict(checkpoint.model_state_dict)
23 # model = model.compile(mode="reduce-overhead", fullgraph=True)
24
25 reverse_transforms = v2.Compose(
26 [
27 v2.Lambda(lambda x: (x + 1) / 2),
28 v2.Resize((128, 128)),
29 ]
30 )
31
32 inference = DiffusionInference(
33 model=model,
34 device="cuda",
35 reverse_transforms=reverse_transforms,
36 )
37 inference.generate(number_of_images=25, save_gif=True)