Fine-tune Stable Diffusion model on TIR

In this tutorial, we will fine-tune Stability AI’s Stable Diffusion (v2.1) model via Dreambooth & Textual Inversion training methods using Hugging Face Diffusers library. By using just 3-5 images, we will be able to teach new concepts to Stable Diffusion and personalize the model on our own images.

About Training methods

The Stable Diffusion model can be fine-tuned via any of the two training methods, i.e., Dreambooth or Textual Inversion. So, before tarining the model, let us get a brief overview of the two training methods.

Dreambooth

DreamBooth is a way to fine tune Stable Diffusion to generate a subject in different environments and styles. It teaches the diffusion model about a specific object or style using approximately three to five example images. After the model is fine-tuned on a specific object, it can produce images containing that object in new settings. You can read the original paper here. Here’s how it works:

  • You provide a set of reference images to teach the model how the subject looks

  • You re-train the model to learn to associate that subject with a specific word

  • You prompt the new model using the special word

With DreamBooth, we’re training an entirely new version of Stable Diffusion. It will be an entirely self-sufficient version of a model that can yield pretty good results to the cost of bigger models.

Textual Inversion

The Textual Inversion training method captures new concepts from a small number of example images and associates the concepts with new words in the textual embedding space of the pre-trained model. Using only 3-5 images of a user-provided concept, like an object or a style, we learn to represent it through new “words” in the textual embedding space. These “words” can be used into natural language sentences in the form of text prompts, guiding personalized image creation in an intuitive way. You can read the original paper here. Here’s how it works:

  • You provide a set of reference images of your subject

  • You figure out what specific embedding the model associates with those images

  • You prompt the new model using the embedding

Textual Inversion is basically akin to finding a “special word” that gets your model to produce the images that you want.

Fine tuning the model

Step-1: Launch a Notebook on TIR

Before beginning, let us launch a Notebook on TIR to train our model. To launch the notebook follow the steps below:

  • Go to TIR AI Platform

  • Choose a Project, navigate to the Notebooks section and click on Create Notebook button

  • Let us name the notebook as stable-diffusion-fine-tune. You can give any name of your choice.

  • Notebook Image & Machine: Launch a new Notebook with Diffusers Image and a GPU Machine plan. To fine-tune the model, we recommend choosing a GPU plan from one of A100, A40 or A30 series (e.g.: GDC3.A10080-16.115GB)

  • Disk Size: A default Disk Size of 30GB would be sufficient for our use-case. In case you need more space, set the disk size accordingly.

  • Datasets: If you have your training data stored in a TIR Dataset, select that dataset to be able to use it during fine tuning.

  • Create the notebook

  • Launch Notebook: Once the notebook comes to Running state, launch the notebook by clicking on the three dots(...) and start the start jupyter labs environment. (The sidebar on the left also displayes quick launch links for your notebooks.)

  • In the jupyter labs, create a new Python3 Notebook from the New Launcher window. A new .ipynb file will open in jupyter labs.

Our notebook is ready. Now, let’s proceed with the model training.

Note

This tutorial covers both DreamBooth and Textual Inversion training methods. Most of the steps and code are the same for both the methods. But, when they are different, we have used tabs to separate the two training methods.

Step-2: Initial Setup

Install the required libraries

!pip install -qq ftfy

Import the required libraries and packages

import argparse
import itertools
import math
import os
from contextlib import nullcontext
import random

import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import Dataset

import PIL
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

import bitsandbytes as bnb

Now, let’s define some helper functions that we will need gradually during the training and inference process.

def image_grid(imgs, rows, cols):
    '''display multiple images in a grid'''
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

Step-3: Settings for teaching the new concept

3.1. Defining the Model

Specify the Stable Diffusion model that we want to train. For this tutorial, we will use Stable Diffusion v2.1 (stabilityai/stable-diffusion-2-1)

# `pretrained_model_name_or_path` defines the Stable Diffusion checkpoint you want to use.
pretrained_model_name_or_path = "stabilityai/stable-diffusion-2-1"  # using Stable Diffusion v2.1
# some other models that you can use: ["stabilityai/stable-diffusion-2-1", "stabilityai/stable-diffusion-2", "stabilityai/stable-diffusion-2-base", "CompVis/stable-diffusion-v1-4", "runwayml/stable-diffusion-v1-5"]

3.2. Input Images:

To teach the model a new concept, we need 3-5 input images with the help of which we will train the model. So let’s gather and setup the images for training purpose.

There can be multiple sources for the input images. Setup the input images using any one of the three sources:

  • Images available on the Internet

    Using the public image urls, we can download the images from the internet and save them locally on Notebook.

    # Add the URLs to the images of the concept you are adding. 3-5 should be fine
    urls = [
        "https://huggingface.co/datasets/valhalla/images/resolve/main/2.jpeg",
        "https://huggingface.co/datasets/valhalla/images/resolve/main/3.jpeg",
        "https://huggingface.co/datasets/valhalla/images/resolve/main/5.jpeg",
        "https://huggingface.co/datasets/valhalla/images/resolve/main/6.jpeg",
        ## You can add additional image urls here
    ]
    
    # Download the images from urls
    import requests
    import glob
    from io import BytesIO
    
    def download_image(url):
        try:
            response = requests.get(url)
        except:
            return None
        return Image.open(BytesIO(response.content)).convert("RGB")
    
    images = list(filter(None, [download_image(url) for url in urls]))
    
    # save the images in a directory
    save_path = "./my_concept"
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    [image.save(f"{save_path}/{i}.jpeg") for i, image in enumerate(images)]
    
  • Images stored in a TIR Dataset

    To use the images stored in a TIR Dataset, ensure that the particular dataset containing your images is mounted on the notebook.

    Note

    How to check if your Dataset is mounted on the notebook ?
    Go to TIR Dashboard >> Notebooks Section >> Select the notebook >> Go to the Associated Datsets Tab
    You will be able to see the list of Datasets mounted on your notebook. If your desired dataset is not in the list, you can mount it now.

    You can also open the terminal and list all the mounted datasets using the ls /datasets/ command.

    If you have updated the mounted dataset, wait for your notebook to be back in Running state, and start afresh from Step-2.

    P.S.: Dataset and Notebook must be in the same project, else you cannot mount it on the notebook.

    Specify your dataset name and the directory path conatining the images in the below code block and run the code cell.

    dataset_name = ""  # enter the dataset name
    images_dir = ""  # enter path to directory containing the training images
    
    images_path = os.path.join("/datasets/", str(dataset_name), str(images_dir))  # "/datasets/{dataset_name}/{images_path}"
    if not dataset_name:
        print("dataset_name must be provided")
    elif not os.path.exists(str(images_path)):
        print('The images_path specified does not exist. Check the path and try again')
    else:
        save_path = images_path
    
  • Images present on your local system

    You can load your own training images by uploading them to the notebook using the Upload Files option.

    ../../../_images/UploadFilesOption.png
    # `images_path` is a path to directory containing the training images.
    images_path = ""  # type: "string"
    if not os.path.exists(str(images_path)):
        print('The images_path specified does not exist. Check the path and try again')
    else:
        save_path = images_path
    

Note

Make sure that your input image directory only contains input images as the code will read all the files from the provided directory, and it will interfere with the training process.

Before proceeding further, let us first check our training images.

images = []
for file_path in os.listdir(save_path):
    try:
        image_path = os.path.join(save_path, file_path)
        images.append(Image.open(image_path).resize((512, 512)))
    except:
        print(f"{image_path} is not a valid image, please make sure to remove this file from the directory otherwise the training could fail.")
image_grid(images, 1, len(images))

Output:

../../../_images/CatToyInputImages.png

3.3. Settings for the new concept

# `instance_prompt` is a prompt that should contain a good description of what your object or style is, together with the initializer word `cat_toy`
instance_prompt = "<cat-toy> toy"  # type: "string"

# Check the `prior_preservation` option if you would like class of the concept (e.g.: toy, dog, painting) is guaranteed to be preserved. This increases the quality and helps with generalization at the cost of training time
prior_preservation = False  # type: bool
prior_preservation_class_prompt = "a photo of a cat clay toy"  # type: "string"

num_class_images = 12
sample_batch_size = 2
prior_loss_weight = 0.5
prior_preservation_class_folder = "./class_images"
class_data_root = prior_preservation_class_folder
class_prompt = prior_preservation_class_prompt

Advanced settings for prior preservation (optional)

num_class_images = 12  # type: int
sample_batch_size = 2

# `prior_preservation_weight` determines how strong the class for prior preservation should be
prior_loss_weight = 1  # type: int

# If the `prior_preservation_class_folder` is empty, images for the class will be generated with the class prompt. Otherwise, fill this folder with images of items on the same class as your concept (but not images of the concept itself)
prior_preservation_class_folder = "./class_images"  # type: string
class_data_root = prior_preservation_class_folder

Step-4: Teach the model the new concept (Fine-tuning with the training method)

Execute the below sequence of cells to run the training process. The whole process may take from 30 min to 3 hours. (Open this block if you are interested in how this process works under the hood or if you want to change advanced training settings or hyperparameters)

4.1. Setup for Training

Setup the classes

from pathlib import Path
from torchvision import transforms

class DreamBoothDataset(Dataset):
    def __init__(
        self,
        instance_data_root,
        instance_prompt,
        tokenizer,
        class_data_root=None,
        class_prompt=None,
        size=512,
        center_crop=False,
    ):
        self.size = size
        self.center_crop = center_crop
        self.tokenizer = tokenizer

        self.instance_data_root = Path(instance_data_root)
        if not self.instance_data_root.exists():
            raise ValueError("Instance images root doesn't exists.")

        self.instance_images_path = list(Path(instance_data_root).iterdir())
        self.num_instance_images = len(self.instance_images_path)
        self.instance_prompt = instance_prompt
        self._length = self.num_instance_images

        if class_data_root is not None:
            self.class_data_root = Path(class_data_root)
            self.class_data_root.mkdir(parents=True, exist_ok=True)
            self.class_images_path = list(Path(class_data_root).iterdir())
            self.num_class_images = len(self.class_images_path)
            self._length = max(self.num_class_images, self.num_instance_images)
            self.class_prompt = class_prompt
        else:
            self.class_data_root = None

        self.image_transforms = transforms.Compose(
            [
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

    def __len__(self):
        return self._length

    def __getitem__(self, index):
        example = {}
        instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
        if not instance_image.mode == "RGB":
            instance_image = instance_image.convert("RGB")
        example["instance_images"] = self.image_transforms(instance_image)
        example["instance_prompt_ids"] = self.tokenizer(
            self.instance_prompt,
            padding="do_not_pad",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
        ).input_ids

        if self.class_data_root:
            class_image = Image.open(self.class_images_path[index % self.num_class_images])
            if not class_image.mode == "RGB":
                class_image = class_image.convert("RGB")
            example["class_images"] = self.image_transforms(class_image)
            example["class_prompt_ids"] = self.tokenizer(
                self.class_prompt,
                padding="do_not_pad",
                truncation=True,
                max_length=self.tokenizer.model_max_length,
            ).input_ids

        return example


class PromptDataset(Dataset):
    def __init__(self, prompt, num_samples):
        self.prompt = prompt
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index):
        example = {}
        example["prompt"] = self.prompt
        example["index"] = index
        return example

Generate Class Images

import gc

if(prior_preservation):
    class_images_dir = Path(class_data_root)
    if not class_images_dir.exists():
        class_images_dir.mkdir(parents=True)
    cur_class_images = len(list(class_images_dir.iterdir()))

    if cur_class_images < num_class_images:
        pipeline = StableDiffusionPipeline.from_pretrained(
            pretrained_model_name_or_path, revision="fp16", torch_dtype=torch.float16
        ).to("cuda")
        pipeline.enable_attention_slicing()
        pipeline.set_progress_bar_config(disable=True)

        num_new_images = num_class_images - cur_class_images
        print(f"Number of class images to sample: {num_new_images}.")

        sample_dataset = PromptDataset(class_prompt, num_new_images)
        sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=sample_batch_size)

        for example in tqdm(sample_dataloader, desc="Generating class images"):
            images = pipeline(example["prompt"]).images

            for i, image in enumerate(images):
                image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg")
        pipeline = None
        gc.collect()
        del pipeline
        with torch.no_grad():
            torch.cuda.empty_cache()

Load the Stable Diffusion Model

# Load models and create wrapper for stable diffusion
text_encoder = CLIPTextModel.from_pretrained(
    pretrained_model_name_or_path, subfolder="text_encoder"
)
vae = AutoencoderKL.from_pretrained(
    pretrained_model_name_or_path, subfolder="vae"
)
unet = UNet2DConditionModel.from_pretrained(
    pretrained_model_name_or_path, subfolder="unet"
)
tokenizer = CLIPTokenizer.from_pretrained(
    pretrained_model_name_or_path,
    subfolder="tokenizer",
)

4.2. Training

Setting up the training args and define hyperparameters for the training. You can also tune the hyperparameters like learning_rate, max_train_steps etc. to play around.

from argparse import Namespace

args = Namespace(
    pretrained_model_name_or_path=pretrained_model_name_or_path,
    resolution=vae.sample_size,
    center_crop=True,
    train_text_encoder=False,
    instance_data_dir=save_path,
    instance_prompt=instance_prompt,
    learning_rate=5e-06,
    max_train_steps=300,
    save_steps=50,
    train_batch_size=2,  # set to 1 if using prior preservation
    gradient_accumulation_steps=2,
    max_grad_norm=1.0,
    mixed_precision="fp16",  # set to "fp16" for mixed-precision training.
    gradient_checkpointing=True,  # set this to True to lower the memory usage.
    use_8bit_adam=True,  # use 8bit optimizer from bitsandbytes
    seed=3434554,
    with_prior_preservation=prior_preservation,
    prior_loss_weight=prior_loss_weight,
    sample_batch_size=2,
    class_data_dir=prior_preservation_class_folder,
    class_prompt=prior_preservation_class_prompt,
    num_class_images=num_class_images,
    lr_scheduler="constant",
    lr_warmup_steps=100,
    output_dir="dreambooth-concept",
)

Define the training function

from accelerate.utils import set_seed

def training_function(text_encoder, vae, unet):
logger = get_logger(__name__)

set_seed(args.seed)

accelerator = Accelerator(
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    mixed_precision=args.mixed_precision,
)

# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
    raise ValueError(
        "Gradient accumulation is not supported when training the text encoder in distributed training. "
        "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
    )

vae.requires_grad_(False)
if not args.train_text_encoder:
    text_encoder.requires_grad_(False)

if args.gradient_checkpointing:
    unet.enable_gradient_checkpointing()
    if args.train_text_encoder:
        text_encoder.gradient_checkpointing_enable()

# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
    optimizer_class = bnb.optim.AdamW8bit
else:
    optimizer_class = torch.optim.AdamW

params_to_optimize = (
    itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
)

optimizer = optimizer_class(
    params_to_optimize,
    lr=args.learning_rate,
)

noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")

train_dataset = DreamBoothDataset(
    instance_data_root=args.instance_data_dir,
    instance_prompt=args.instance_prompt,
    class_data_root=args.class_data_dir if args.with_prior_preservation else None,
    class_prompt=args.class_prompt,
    tokenizer=tokenizer,
    size=args.resolution,
    center_crop=args.center_crop,
)

def collate_fn(examples):
    input_ids = [example["instance_prompt_ids"] for example in examples]
    pixel_values = [example["instance_images"] for example in examples]

    # concat class and instance examples for prior preservation
    if args.with_prior_preservation:
        input_ids += [example["class_prompt_ids"] for example in examples]
        pixel_values += [example["class_images"] for example in examples]

    pixel_values = torch.stack(pixel_values)
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

    input_ids = tokenizer.pad(
        {"input_ids": input_ids},
        padding="max_length",
        return_tensors="pt",
        max_length=tokenizer.model_max_length
    ).input_ids

    batch = {
        "input_ids": input_ids,
        "pixel_values": pixel_values,
    }
    return batch

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn
)

lr_scheduler = get_scheduler(
    args.lr_scheduler,
    optimizer=optimizer,
    num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
    num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)

if args.train_text_encoder:
    unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        unet, text_encoder, optimizer, train_dataloader, lr_scheduler
    )
else:
    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        unet, optimizer, train_dataloader, lr_scheduler
    )

weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
    weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
    weight_dtype = torch.bfloat16

# Move text_encode and vae to gpu.
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
vae.to(accelerator.device, dtype=weight_dtype)
vae.decoder.to("cpu")
if not args.train_text_encoder:
    text_encoder.to(accelerator.device, dtype=weight_dtype)

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

logger.info("***** Running training *****")
logger.info(f"  Num examples = {len(train_dataset)}")
logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f"  Total optimization steps = {args.max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
global_step = 0

for epoch in range(num_train_epochs):
    unet.train()
    for step, batch in enumerate(train_dataloader):
        with accelerator.accumulate(unet):
            # Convert images to latent space
            latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
            latents = latents * 0.18215

            # Sample noise that we'll add to the latents
            noise = torch.randn_like(latents)
            bsz = latents.shape[0]
            # Sample a random timestep for each image
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
            timesteps = timesteps.long()

            # Add noise to the latents according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Get the text embedding for conditioning
            encoder_hidden_states = text_encoder(batch["input_ids"])[0]

            # Predict the noise residual
            noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

            # Get the target for loss depending on the prediction type
            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(latents, noise, timesteps)
            else:
                raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

            if args.with_prior_preservation:
                # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
                noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
                target, target_prior = torch.chunk(target, 2, dim=0)

                # Compute instance loss
                loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()

                # Compute prior loss
                prior_loss = F.mse_loss(noise_pred_prior.float(), target_prior.float(), reduction="mean")

                # Add the prior loss to the instance loss.
                loss = loss + args.prior_loss_weight * prior_loss
            else:
                loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")

            accelerator.backward(loss)

            if accelerator.sync_gradients:
                params_to_clip = (
                    itertools.chain(unet.parameters(), text_encoder.parameters())
                    if args.train_text_encoder
                    else unet.parameters()
                )
                accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
            optimizer.step()
            optimizer.zero_grad()

        # Checks if the accelerator has performed an optimization step behind the scenes
        if accelerator.sync_gradients:
            progress_bar.update(1)
            global_step += 1

            if global_step % args.save_steps == 0:
                if accelerator.is_main_process:
                    pipeline = StableDiffusionPipeline.from_pretrained(
                        args.pretrained_model_name_or_path,
                        unet=accelerator.unwrap_model(unet),
                        text_encoder=accelerator.unwrap_model(text_encoder),
                    )
                    save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                    pipeline.save_pretrained(save_path)

        logs = {"loss": loss.detach().item()}
        progress_bar.set_postfix(**logs)

        if global_step >= args.max_train_steps:
            break

    accelerator.wait_for_everyone()

# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
    pipeline = StableDiffusionPipeline.from_pretrained(
        args.pretrained_model_name_or_path,
        unet=accelerator.unwrap_model(unet),
        text_encoder=accelerator.unwrap_model(text_encoder),
    )
    pipeline.save_pretrained(args.output_dir)

4.3. Run the Training

import accelerate

num_gpus = 1  # specify the number of GPUs you would like to use. This also depends on your machine config
accelerate.notebook_launcher(training_function, args=(text_encoder, vae, unet), num_processes=num_gpus)

for param in itertools.chain(unet.parameters(), text_encoder.parameters()):
    if param.grad is not None:
        del param.grad  # free some memory
    torch.cuda.empty_cache()

Step-5: Run Inference with the newly trained Model

Bravo! Our model training is complete and we have taught our model a new concept, namely, cat-toy. Let us now test our newly trained model by running inference against it and see the results.

5.1. Set up the pipeline

The newly trained model artifacts are available in the output directory specified while setting up the training args. We will use those artifacts to load the model and setup the pipeline.

from diffusers import DPMSolverMultistepScheduler

output_dir = args.output_dir
pipe = StableDiffusionPipeline.from_pretrained(
    output_dir,
    scheduler=DPMSolverMultistepScheduler.from_pretrained(output_dir, subfolder="scheduler"),
    torch_dtype=torch.float16,
).to("cuda")

5.2. Run the Stable Diffusion pipeline

# Don't forget to use the placeholder token in your prompt
prompt = "a <cat-toy> in mad max fury road"  # type: str

num_samples = 2  # type: int : number of samples to generate for the prompt

images = pipe(
    prompt, num_images_per_prompt=num_samples, num_inference_steps=30, guidance_scale=9
).images
# you can play around with the newly trained model by changing the `prompt`, `num_inference_steps` and `guidance_scale` and see the difference in results.

grid = image_grid(images, rows=1, cols=num_samples)
grid

Output:

../../../_images/CatToyInference.png

Step-6: Save the newly created concept

Our model training is complete and we have also tested it by successfully generating some image samples. We shall now save the model artifacts to preserve our newly created concept. This will also enable us to launch an Inference server against it, as described in the next section.

6.1. Save the Model artifacts

model_artifacts_dir = "dreambooth_model"  # specify the directory to save the model artifacts
os.makedirs(model_artifacts_dir, exist_ok=True)
pipe2.save_pretrained(model_artifacts_dir)  # save the model artifacts

6.2. Create Model on TIR

  • Go to TIR AI Platform

  • Choose a project, navigate to the Models section and click on Create Model

  • Enter a model name of your choosing (e.g. stable-diffusion)

  • Select Model Type as Custom & click on CREATE

  • You will now see details of EOS (E2E Object Storage) bucket created for this model.

  • EOS Provides a S3 compatible API to upload or download content. We will be using MinIO CLI in this tutorial.

  • Copy the Setup Host command from Setup MinIO CLI tab. We will use it to setup MinIO CLI in the next step.

    Note

    In case you forget to copy the setup host command for MinIO CLI, don’t worry. You can always go back to model details and get it again.

6.3. Upload the Artifacts to Model Bucket

We have already created the model on TIR. Let’s set it up on our notebook and upload the artifacts to the model.

Let’s again go back to our TIR Notebook named stable-diffusion-fine-tune, that we used to fine-tune the model in the previous steps and follow the steps below:

  • In the jupyter labs, click New Launcher and select Terminal

  • Setup the MinIO CLI by running the Setup Host command copied in Step-6.2

  • Go to the directory model_artifacts_dir, where we saved the model artifacts in step-6.1 using the below command.

    export model_artifacts_dir=dreambooth_model && \
    cd $model_artifacts_dir
    
  • Copy the model artifacts to the model bucket.
    • Go to TIR Dashboard >> Models >> Select your model >> Copy the cp command from Setup MinIO CLI tab

    • The copy command would look like this: mc cp -r <MODEL_NAME> stable-diffusion/stable-diffusion-854588

    • Here we will replace <MODEL_NAME> with ‘*’ to upload all contents of the current folder to the bucket

    • We will append $model_artifacts_dir/ to the bucket path to upload the contents to a folder named $model_artifacts_dir inside the bucket, so that, your copy command would look something like this:

    mc cp -r * stable-diffusion/stable-diffusion-854588/$model_artifacts_dir/
    

    Note

    The above command to copy model artifacts should not be used as it is. It is just a sample command to show what the command would look like.

    You may follow the above steps to generate the copy command that will be specific to your model bucket.

Your model artifacts will be saved successfully upon upload completion.

Step-7: Create Inference Server against our newly trained model

What now remains is to create an Inference Server against our trained model and serve API requests.

Note

If you are not much familiar with Inference creation on TIR, follow this tutorial for a detailed and step-by-step guide on Model Endpoint (Inference) Creation for Stable Diffusion.

7.1. Create a new Model Endpoint in TIR

  • Create a new Model Endpoint in TIR with Stable Diffusion Framework and a GPU Machine Plan.

  • [Important!] In the Model Details Subsection, choose the Model that we created in Step-6.2. Make sure to specify the $model_artifacts_dir path. This is necessary, because our trained model artifacts are present in this directory.

7.2. Generate your API Token

The model endpoint API requires a valid auth token which you’ll need to perform further steps. So, generate a new API Token from the API Tokens Section (or use an existing token, if already created). An api-key and an auth token will be generated. Copy this auth token. You will need it in the next step.

7.3. Make API Request to generate Image output

The final step is to send API requests to the created model endpoint & generate images using text prompts. We will use TIR Notebook to do the same.

  • Once your model is Ready, visit the Sample API Request section of that model and copy the Python code

  • Launch a TIR Notebook with PyTorch or StableDiffusion Image with any basic machine plan. Once it is in Running state, launch it, and start a new notebook untitled.ipynb in the jupyter labs

  • Paste the Sample API Request code (for Python) in the notebook cell

  • Copy the Auth Token generated in Step-7.2 & use it in place of $AUTH_TOKEN in the Sample API Request

  • Replace the prompt string in the payload with the below prompt. This ensures the use of placeholder token in the prompt (in our case, <cat-toy>) to get the desired output.

    "a <cat-toy> in mad max fury road"
    
  • Execute the code and send request. You’ll get a list of tensors as output.
    This is because Stable Diffusion v2.1 model endpoint return the generated images as a list of PyTorch Tensors.

  • To view the generated images, copy the below code, paste it in the notebook cell and execute it. You’ll be able to view the generated images.

    import torch
    import torchvision.transforms as transforms
    
    def display_images(tensor_image_data_list):
        '''convert PyTorch Tensors to PIL Image'''
        for tensor_data in tensor_image_data_list:
            tensor_image = torch.tensor(tensor_data.get("data"))  # initialise the tensor
            pil_img = transforms.ToPILImage()(tensor_image)  # convert to PIL Image
            pil_img.show()
            # to save the generated_images, uncomment the line below
            # image.save(tensor_data.get("name"))
    
    if response.status_code == 200:
        display_images(response.json().get("predictions"))
    

    Output:

    ../../../_images/CatToyInference2.png

That’s it! We have successfully taught our model a new concept, <cat-toy>.

The Stable Diffusion model supports various other parameters for controlling the generation of image output. Refer to Supported parameters for image generation for more details.

Conclusion

Through this tutorial, we fine-tuned the Stable Diffusion v2.1 model using two training methods, namely, Dreambooth and Textual Inversion. By giving just 3-5 images as input, we taught new concept to the model and could personalise the model on our own images.

After fine-tuning, we also saw how we can store the trained model artifacts to TIR Model Storage & launch Inference Servers using the same to serve API Requests.