examples.infer
Running inference.
Below is a code example of an indference script. Simply plug in your own checkpoint and start running inference. To extend this further, you might want to look into compiling your model to ONNX, TensorRT, OpenVino or other formats.
1from torchvision import datasets
2from torchvision.transforms import v2
3
4from diffusion_models.diffusion_trainer import DiffusionTrainer
5from diffusion_models.gaussian_diffusion.beta_schedulers import (
6 LinearBetaScheduler,
7)
8from diffusion_models.gaussian_diffusion.gaussian_diffuser import (
9 GaussianDiffuser,
10)
11from diffusion_models.models.SimpleUnet import SimpleUnet
12from diffusion_models.utils.schemas import LogConfiguration
13from diffusion_models.utils.schemas import TrainingConfiguration
14
15
16if __name__ == "__main__":
17 image_size = 64
18 image_channels = 3
19
20 training_configuration = TrainingConfiguration(
21 batch_size=256,
22 learning_rate=2 * 10e-4,
23 number_of_epochs=500,
24 training_name="ReworkedFrameworkBase",
25 checkpoint_rate=100,
26 mixed_precision_training=False,
27 # gradient_clip=0.1,
28 )
29 log_configuration = LogConfiguration(
30 log_rate=10,
31 image_rate=635,
32 number_of_images=5,
33 )
34 model = SimpleUnet(
35 image_channels=image_channels,