# Fine-tune Stable Diffusion Model on TIR In this tutorial, we will fine-tune Stability AI's Stable Diffusion (v2.1) model via Dreambooth and Textual Inversion training methods using the Hugging Face Diffusers library. By using just 3-5 images, we will be able to teach new concepts to Stable Diffusion and personalize the model on our own images. ## About Training Methods The Stable Diffusion model can be fine-tuned using either of the two training methods: Dreambooth or Textual Inversion. Before training the model, let us get a brief overview of the two methods. ### Dreambooth DreamBooth is a way to fine-tune Stable Diffusion to generate a subject in different environments and styles. It teaches the diffusion model about a specific object or style using approximately three to five example images. After fine-tuning, the model can produce images containing that object in new settings. You can read the original paper [here](https://arxiv.org/pdf/2208.12242.pdf). Here's how it works: - You provide a set of reference images to teach the model how the subject looks. - You re-train the model to learn to associate that subject with a specific word. - You prompt the new model using the special word. With DreamBooth, we're training an entirely new version of Stable Diffusion. It will be a self-sufficient model capable of yielding impressive results compared to larger models. ### Textual Inversion The Textual Inversion training method captures new concepts from a small number of example images and associates the concepts with new words in the textual embedding space of the pre-trained model. Using only 3-5 images of a user-provided concept, like an object or style, we learn to represent it through new “words” in the textual embedding space. These “words” can be used in natural language sentences as text prompts, guiding personalized image creation intuitively. You can read the original paper [here](https://arxiv.org/pdf/2208.01618.pdf). Here’s how it works: - You provide a set of reference images of your subject. - You figure out what specific embedding the model associates with those images. - You prompt the new model using the embedding. Textual Inversion is akin to finding a “special word” that prompts your model to produce the images you want. ## Fine-Tuning the Model ### Step 1: Launch a Notebook on TIR Before beginning, let us launch a Notebook on TIR to train our model. To launch the notebook, follow the steps below: 1. Go to the [TIR AI Platform](https://tir.e2enetworks.com). 2. Choose a Project, navigate to the Notebooks section, and click on the **Create Notebook** button. 3. Name the notebook *stable-diffusion-fine-tune* (you can choose any name). 4. **Notebook Image & Machine:** Launch a new Notebook with the *Diffusers Image* and a *GPU Machine plan*. For fine-tuning, we recommend choosing a GPU plan from one of the *A100, A40, or A30* series (e.g., GDC3.A10080-16.115GB). 5. **Disk Size:** A default *Disk Size* of 30GB is sufficient for our use case. Adjust the disk size if you need more space. 6. **Datasets:** If your training data is stored in a *TIR Dataset*, select that dataset for use during fine-tuning. 7. Create the notebook. 8. **Launch Notebook:** Once the notebook is in a *Running* state, launch it by clicking on the three dots (`...`) and start the Jupyter Labs environment. (The sidebar on the left also displays quick launch links for your notebooks.) 9. In Jupyter Labs, create a new *Python3 Notebook* from the *New Launcher* window. A new *.ipynb* file will open. Our notebook is ready. Now, let's proceed with the model training. :::info Note This tutorial covers both *DreamBooth* and *Textual Inversion* training methods. Most steps and code are the same for both methods, but when they differ, we have used **tabs** to separate the two training methods. ::: ## Step 2: Initial Setup ### Install the Required Libraries ```bash !pip install -qq ftfy ``` Import the required libraries and packages import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; ```python import argparse import itertools import math import os from contextlib import nullcontext import random import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint from torch.utils.data import Dataset import PIL from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from PIL import Image from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer import bitsandbytes as bnb ``` ```python import argparse import itertools import math import os import random import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint from torch.utils.data import Dataset import PIL from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from PIL import Image from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer ``` Now, let's define some helper functions that we will need gradually during the training and inference process. ```python def image_grid(imgs, rows, cols): '''display multiple images in a grid''' assert len(imgs) == rows*cols w, h = imgs[0].size grid = Image.new('RGB', size=(cols*w, rows*h)) grid_w, grid_h = grid.size for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h)) return grid ``` ## Step 3: Settings for Teaching the New Concept ### 3.1. Defining the Model Specify the *Stable Diffusion model* that we want to train. For this tutorial, we will use Stable Diffusion v2.1 (`stabilityai/stable-diffusion-2-1`). ```python # `pretrained_model_name_or_path` defines the Stable Diffusion checkpoint you want to use. pretrained_model_name_or_path = "stabilityai/stable-diffusion-2-1" # using Stable Diffusion v2.1 # some other models that you can use: ["stabilityai/stable-diffusion-2-1", "stabilityai/stable-diffusion-2", "stabilityai/stable-diffusion-2-base", "CompVis/stable-diffusion-v1-4", "runwayml/stable-diffusion-v1-5"] ``` ### 3.2. Input Images To teach the model a new concept, we need 3-5 input images with which we will train the model. So let's gather and set up the images for training purposes. There can be multiple sources for the input images. Set up the input images using any one of the three sources:
Images available on the Internet Using the public image urls, we can download the images from the internet and save them locally on Notebook. ```python # Add the URLs to the images of the concept you are adding. 3-5 should be fine urls = [ "https://huggingface.co/datasets/valhalla/images/resolve/main/2.jpeg", "https://huggingface.co/datasets/valhalla/images/resolve/main/3.jpeg", "https://huggingface.co/datasets/valhalla/images/resolve/main/5.jpeg", "https://huggingface.co/datasets/valhalla/images/resolve/main/6.jpeg", ## You can add additional image urls here ] # Download the images from urls import requests import glob from io import BytesIO def download_image(url): try: response = requests.get(url) except: return None return Image.open(BytesIO(response.content)).convert("RGB") images = list(filter(None, [download_image(url) for url in urls])) # save the images in a directory save_path = "./my_concept" if not os.path.exists(save_path): os.mkdir(save_path) [image.save(f"{save_path}/{i}.jpeg") for i, image in enumerate(images)] ```
Images stored in a *TIR Dataset* To use the images stored in a *TIR Dataset*, ensure that the particular dataset containing your images is mounted on the notebook. :::info Note **How to check if your Dataset is mounted on the notebook ?** |br| Go to TIR Dashboard >> Notebooks Section >> Select the notebook >> Go to the **Associated Datasets Tab** |br| You will be able to see the list of Datasets mounted on your notebook. If your desired dataset is not in the list, you can mount it now. You can also open the terminal and list all the mounted datasets using the ``ls /datasets/`` command. If you have updated the mounted dataset, wait for your notebook to be back in Running state, and start afresh from `Step-2`. **P.S.:** Dataset and Notebook must be in the same project, else you cannot mount it on the notebook. ::: Specify your dataset name and the directory path containing the images in the below code block and run the code cell. ```python dataset_name = "" # enter the dataset name images_dir = "" # enter path to directory containing the training images images_path = os.path.join("/datasets/", str(dataset_name), str(images_dir)) # "/datasets/{dataset_name}/{images_path}" if not dataset_name: print("dataset_name must be provided") elif not os.path.exists(str(images_path)): print('The images_path specified does not exist. Check the path and try again') else: save_path = images_path ```
Images present on your local system You can load your own training images by uploading them to the notebook using the **Upload Files** option. ![Upload Files Option](images/UploadFilesOption.png) ```python # `images_path` is a path to directory containing the training images. images_path = "" # type: "string" if not os.path.exists(str(images_path)): print('The images_path specified does not exist. Check the path and try again') else: save_path = images_path ```
:::info Note Make sure that your input image directory only contains input images as the code will read all the files from the provided directory, and it will interfere with the training process. ::: Before proceeding further, let us first check our training images. ```python images = [] for file_path in os.listdir(save_path): try: image_path = os.path.join(save_path, file_path) images.append(Image.open(image_path).resize((512, 512))) except: print(f"{image_path} is not a valid image, please make sure to remove this file from the directory otherwise the training could fail.") image_grid(images, 1, len(images)) ``` Output: ![Cat Toy Input Images](images/CatToyInputImages.png) ## 3.3. Settings for the new concept ```python # `instance_prompt` is a prompt that should contain a good description of what your object or style is, together with the initializer word `cat_toy` instance_prompt = " toy" # type: "string" # Check the `prior_preservation` option if you would like class of the concept (e.g.: toy, dog, painting) is guaranteed to be preserved. This increases the quality and helps with generalization at the cost of training time prior_preservation = False # type: bool prior_preservation_class_prompt = "a photo of a cat clay toy" # type: "string" num_class_images = 12 sample_batch_size = 2 prior_loss_weight = 0.5 prior_preservation_class_folder = "./class_images" class_data_root = prior_preservation_class_folder class_prompt = prior_preservation_class_prompt ``` Advanced settings for prior preservation (optional) ```python num_class_images = 12 # type: int sample_batch_size = 2 # `prior_preservation_weight` determines how strong the class for prior preservation should be prior_loss_weight = 1 # type: int # If the `prior_preservation_class_folder` is empty, images for the class will be generated with the class prompt. Otherwise, fill this folder with images of items on the same class as your concept (but not images of the concept itself) prior_preservation_class_folder = "./class_images" # type: string class_data_root = prior_preservation_class_folder ``` ```python # what_to_teach: what is it that you are teaching? # object: enables you to teach the model a new object to be used # style: allows you to teach the model a new style one can use. what_to_teach = "object" # allowed values: ["object", "style"] # placeholder_token: the token you are going to use to represent your new concept (so when you prompt the model, you will say "A in an amusement park"). We use angle brackets to differentiate a token from other words/tokens, to avoid collision. placeholder_token = "" # type: string # initializer_token: a word that can summarize what your new concept is, to be used as a starting point initializer_token = "toy" # type: string ``` ## Step-4: Teach the model the new concept (Fine-tuning with the training method) Execute the below sequence of cells to run the training process. The whole process may take from 30 min to 3 hours. (Open this block if you are interested in how this process works under the hood or if you want to change advanced training settings or hyperparameters) ## 4.1. Setup for Training Setup the classes ```python from pathlib import Path from torchvision import transforms class DreamBoothDataset(Dataset): def __init__( self, instance_data_root, instance_prompt, tokenizer, class_data_root=None, class_prompt=None, size=512, center_crop=False, ): self.size = size self.center_crop = center_crop self.tokenizer = tokenizer self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): raise ValueError("Instance images root doesn't exists.") self.instance_images_path = list(Path(instance_data_root).iterdir()) self.num_instance_images = len(self.instance_images_path) self.instance_prompt = instance_prompt self._length = self.num_instance_images if class_data_root is not None: self.class_data_root = Path(class_data_root) self.class_data_root.mkdir(parents=True, exist_ok=True) self.class_images_path = list(Path(class_data_root).iterdir()) self.num_class_images = len(self.class_images_path) self._length = max(self.num_class_images, self.num_instance_images) self.class_prompt = class_prompt else: self.class_data_root = None self.image_transforms = transforms.Compose( [ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) def __len__(self): return self._length def __getitem__(self, index): example = {} instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") example["instance_images"] = self.image_transforms(instance_image) example["instance_prompt_ids"] = self.tokenizer( self.instance_prompt, padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids if self.class_data_root: class_image = Image.open(self.class_images_path[index % self.num_class_images]) if not class_image.mode == "RGB": class_image = class_image.convert("RGB") example["class_images"] = self.image_transforms(class_image) example["class_prompt_ids"] = self.tokenizer( self.class_prompt, padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids return example class PromptDataset(Dataset): def __init__(self, prompt, num_samples): self.prompt = prompt self.num_samples = num_samples def __len__(self): return self.num_samples def __getitem__(self, index): example = {} example["prompt"] = self.prompt example["index"] = index return example ``` Generate Class Images ```python import gc if(prior_preservation): class_images_dir = Path(class_data_root) if not class_images_dir.exists(): class_images_dir.mkdir(parents=True) cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < num_class_images: pipeline = StableDiffusionPipeline.from_pretrained( pretrained_model_name_or_path, revision="fp16", torch_dtype=torch.float16 ).to("cuda") pipeline.enable_attention_slicing() pipeline.set_progress_bar_config(disable=True) num_new_images = num_class_images - cur_class_images print(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(class_prompt, num_new_images) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=sample_batch_size) for example in tqdm(sample_dataloader, desc="Generating class images"): images = pipeline(example["prompt"]).images for i, image in enumerate(images): image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg") pipeline = None gc.collect() del pipeline with torch.no_grad(): torch.cuda.empty_cache() ``` Load the Stable Diffusion Model ```python # Load models and create wrapper for stable diffusion text_encoder = CLIPTextModel.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder" ) vae = AutoencoderKL.from_pretrained( pretrained_model_name_or_path, subfolder="vae" ) unet = UNet2DConditionModel.from_pretrained( pretrained_model_name_or_path, subfolder="unet" ) tokenizer = CLIPTokenizer.from_pretrained( pretrained_model_name_or_path, subfolder="tokenizer", ) ``` Create Dataset ```python # Setup the prompt templates for training imagenet_templates_small = [ "a photo of a {}", "a rendering of a {}", "a cropped photo of the {}", "the photo of a {}", "a photo of a clean {}", "a photo of a dirty {}", "a dark photo of the {}", "a photo of my {}", "a photo of the cool {}", "a close-up photo of a {}", "a bright photo of the {}", "a cropped photo of a {}", "a photo of the {}", "a good photo of the {}", "a photo of one {}", "a close-up photo of the {}", "a rendition of the {}", "a photo of the clean {}", "a rendition of a {}", "a photo of a nice {}", "a good photo of a {}", "a photo of the nice {}", "a photo of the small {}", "a photo of the weird {}", "a photo of the large {}", "a photo of a cool {}", "a photo of a small {}", ] imagenet_style_templates_small = [ "a painting in the style of {}", "a rendering in the style of {}", "a cropped painting in the style of {}", "the painting in the style of {}", "a clean painting in the style of {}", "a dirty painting in the style of {}", "a dark painting in the style of {}", "a picture in the style of {}", "a cool painting in the style of {}", "a close-up painting in the style of {}", "a bright painting in the style of {}", "a cropped painting in the style of {}", "a good painting in the style of {}", "a close-up painting in the style of {}", "a rendition in the style of {}", "a nice painting in the style of {}", "a small painting in the style of {}", "a weird painting in the style of {}", "a large painting in the style of {}", ] ``` Setup the Dataset ```python # Setup the dataset class TextualInversionDataset(Dataset): def __init__( self, data_root, tokenizer, learnable_property="object", # [object, style] size=512, repeats=100, interpolation="bicubic", flip_p=0.5, set="train", placeholder_token="*", center_crop=False, ): self.data_root = data_root self.tokenizer = tokenizer self.learnable_property = learnable_property self.size = size self.placeholder_token = placeholder_token self.center_crop = center_crop self.flip_p = flip_p self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] self.num_images = len(self.image_paths) self._length = self.num_images if set == "train": self._length = self.num_images * repeats self.interpolation = { # "linear": PIL.Image.Resampling.LINEAR, # Removed in version 10.0.0. use bilinear instead "bilinear": PIL.Image.Resampling.BILINEAR, "bicubic": PIL.Image.Resampling.BICUBIC, "lanczos": PIL.Image.Resampling.LANCZOS, }[interpolation] self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p) def __len__(self): return self._length def __getitem__(self, i): example = {} image = Image.open(self.image_paths[i % self.num_images]) if not image.mode == "RGB": image = image.convert("RGB") placeholder_string = self.placeholder_token text = random.choice(self.templates).format(placeholder_string) example["input_ids"] = self.tokenizer( text, padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length, return_tensors="pt", ).input_ids[0] # default to score-sde preprocessing img = np.array(image).astype(np.uint8) if self.center_crop: crop = min(img.shape[0], img.shape[1]) h, w, = ( img.shape[0], img.shape[1], ) img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] image = Image.fromarray(img) image = image.resize((self.size, self.size), resample=self.interpolation) image = self.flip_transform(image) image = np.array(image).astype(np.uint8) image = (image / 127.5 - 1.0).astype(np.float32) example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) return example ``` Set up the Model ```python # Load the tokenizer and add the placeholder token as a additional special token. tokenizer = CLIPTokenizer.from_pretrained( pretrained_model_name_or_path, subfolder="tokenizer", ) # Add the placeholder token in tokenizer num_added_tokens = tokenizer.add_tokens(placeholder_token) if num_added_tokens == 0: raise ValueError( f"The tokenizer already contains the token {placeholder_token}. Please pass a different" " `placeholder_token` that is not already in the tokenizer." ) ``` ```python # Get token ids for our placeholder and initializer token. # This code block will complain if initializer string is not a single token # Convert the initializer_token, placeholder_token to ids token_ids = tokenizer.encode(initializer_token, add_special_tokens=False) # Check if initializer_token is a single token or a sequence of tokens if len(token_ids) > 1: raise ValueError("The initializer token must be a single token.") initializer_token_id = token_ids[0] placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token) ``` ```python # Load the Stable Diffusion model # Load models and create wrapper for stable diffusion text_encoder = CLIPTextModel.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder" ) vae = AutoencoderKL.from_pretrained( pretrained_model_name_or_path, subfolder="vae" ) unet = UNet2DConditionModel.from_pretrained( pretrained_model_name_or_path, subfolder="unet" ) ``` We have added the `placeholder_token` in the `tokenizer` so we resize the token embeddings here, this will a new embedding vector in the token embeddings for our `placeholder_token` ```python text_encoder.resize_token_embeddings(len(tokenizer)) ``` Initialize the newly added placeholder token with the embeddings of the initializer token ```python token_embeds = text_encoder.get_input_embeddings().weight.data token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] ``` In Textual-Inversion we only train the newly added embedding vector, so lets freeze rest of the model parameters here ```python def freeze_params(params): for param in params: param.requires_grad = False # Freeze vae and unet freeze_params(vae.parameters()) freeze_params(unet.parameters()) # Freeze all parameters except for the token embeddings in text encoder params_to_freeze = itertools.chain( text_encoder.text_model.encoder.parameters(), text_encoder.text_model.final_layer_norm.parameters(), text_encoder.text_model.embeddings.position_embedding.parameters(), ) freeze_params(params_to_freeze) ``` Creating our training data: Creation of Dataset, Dataloader and noise_scheduler ```python # create dataset for training train_dataset = TextualInversionDataset( data_root=save_path, tokenizer=tokenizer, size=vae.sample_size, placeholder_token=placeholder_token, repeats=100, learnable_property=what_to_teach, #Option selected above between object and style center_crop=False, set="train", ) # create dataloader for training def create_dataloader(train_batch_size=1): return torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True) # Create noise_scheduler for training noise_scheduler = DDPMScheduler.from_config(pretrained_model_name_or_path, subfolder="scheduler") ``` ## 4.2. Training Setting up the training args and define hyperparameters for the training. You can also tune the hyperparameters like ``learning_rate``, ``max_train_steps`` etc. to play around. ```python from argparse import Namespace args = Namespace( pretrained_model_name_or_path=pretrained_model_name_or_path, resolution=vae.sample_size, center_crop=True, train_text_encoder=False, instance_data_dir=save_path, instance_prompt=instance_prompt, learning_rate=5e-06, max_train_steps=300, save_steps=50, train_batch_size=2, # set to 1 if using prior preservation gradient_accumulation_steps=2, max_grad_norm=1.0, mixed_precision="fp16", # set to "fp16" for mixed-precision training. gradient_checkpointing=True, # set this to True to lower the memory usage. use_8bit_adam=True, # use 8bit optimizer from bitsandbytes seed=3434554, with_prior_preservation=prior_preservation, prior_loss_weight=prior_loss_weight, sample_batch_size=2, class_data_dir=prior_preservation_class_folder, class_prompt=prior_preservation_class_prompt, num_class_images=num_class_images, lr_scheduler="constant", lr_warmup_steps=100, output_dir="dreambooth-concept", ) ``` ```python hyperparameters = { "learning_rate": 5e-04, "scale_lr": True, "max_train_steps": 2000, "save_steps": 250, "train_batch_size": 4, "gradient_accumulation_steps": 1, "gradient_checkpointing": True, "mixed_precision": "fp16", "seed": 42, "output_dir": "textual-inversion-concept", } !mkdir -p textual-inversion-concept ``` Define the training function ```python from accelerate.utils import set_seed def training_function(text_encoder, vae, unet): logger = get_logger(__name__) set_seed(args.seed) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, ) # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: raise ValueError( "Gradient accumulation is not supported when training the text encoder in distributed training. " "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." ) vae.requires_grad_(False) if not args.train_text_encoder: text_encoder.requires_grad_(False) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() if args.train_text_encoder: text_encoder.gradient_checkpointing_enable() # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW params_to_optimize = ( itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() ) optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, ) noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler") train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, instance_prompt=args.instance_prompt, class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_prompt=args.class_prompt, tokenizer=tokenizer, size=args.resolution, center_crop=args.center_crop, ) def collate_fn(examples): input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] # concat class and instance examples for prior preservation if args.with_prior_preservation: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = tokenizer.pad( {"input_ids": input_ids}, padding="max_length", return_tensors="pt", max_length=tokenizer.model_max_length ).input_ids batch = { "input_ids": input_ids, "pixel_values": pixel_values, } return batch train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn ) lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) if args.train_text_encoder: unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, text_encoder, optimizer, train_dataloader, lr_scheduler ) else: unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler ) weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 # Move text_encode and vae to gpu. # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. vae.to(accelerator.device, dtype=weight_dtype) vae.decoder.to("cpu") if not args.train_text_encoder: text_encoder.to(accelerator.device, dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") # Only show the progress bar once on each machine. progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar.set_description("Steps") global_step = 0 for epoch in range(num_train_epochs): unet.train() for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # Convert images to latent space latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * 0.18215 # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") if args.with_prior_preservation: # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() # Compute prior loss prior_loss = F.mse_loss(noise_pred_prior.float(), target_prior.float(), reduction="mean") # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean") accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = ( itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() ) accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) optimizer.step() optimizer.zero_grad() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 if global_step % args.save_steps == 0: if accelerator.is_main_process: pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), text_encoder=accelerator.unwrap_model(text_encoder), ) save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") pipeline.save_pretrained(save_path) logs = {"loss": loss.detach().item()} progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: break accelerator.wait_for_everyone() # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), text_encoder=accelerator.unwrap_model(text_encoder), ) pipeline.save_pretrained(args.output_dir) ``` ```python logger = get_logger(__name__) def save_progress(text_encoder, placeholder_token_id, accelerator, save_path): logger.info("Saving embeddings") learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id] learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} torch.save(learned_embeds_dict, save_path) def training_function(text_encoder, vae, unet): train_batch_size = hyperparameters["train_batch_size"] gradient_accumulation_steps = hyperparameters["gradient_accumulation_steps"] learning_rate = hyperparameters["learning_rate"] max_train_steps = hyperparameters["max_train_steps"] output_dir = hyperparameters["output_dir"] gradient_checkpointing = hyperparameters["gradient_checkpointing"] accelerator = Accelerator( gradient_accumulation_steps=gradient_accumulation_steps, mixed_precision=hyperparameters["mixed_precision"] ) if gradient_checkpointing: text_encoder.gradient_checkpointing_enable() unet.enable_gradient_checkpointing() train_dataloader = create_dataloader(train_batch_size) if hyperparameters["scale_lr"]: learning_rate = ( learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes ) # Initialize the optimizer optimizer = torch.optim.AdamW( text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings lr=learning_rate, ) text_encoder, optimizer, train_dataloader = accelerator.prepare( text_encoder, optimizer, train_dataloader ) weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 # Move vae and unet to device vae.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype) # Keep vae in eval mode as we don't train it vae.eval() # Keep unet in train mode to enable gradient checkpointing unet.train() # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) # Train! total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Instantaneous batch size per device = {train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") logger.info(f" Total optimization steps = {max_train_steps}") # Only show the progress bar once on each machine. progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process) progress_bar.set_description("Steps") global_step = 0 for epoch in range(num_train_epochs): text_encoder.train() for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): # Convert images to latent space latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach() latents = latents * 0.18215 # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states.to(weight_dtype)).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss = F.mse_loss(noise_pred, target, reduction="none").mean([1, 2, 3]).mean() accelerator.backward(loss) # Zero out the gradients for all token embeddings except the newly added # embeddings for the concept, as we only want to optimize the concept embeddings if accelerator.num_processes > 1: grads = text_encoder.module.get_input_embeddings().weight.grad else: grads = text_encoder.get_input_embeddings().weight.grad # Get the index for tokens that we want to zero the grads for index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) optimizer.step() optimizer.zero_grad() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 if global_step % hyperparameters["save_steps"] == 0: save_path = os.path.join(output_dir, f"learned_embeds-step-{global_step}.bin") save_progress(text_encoder, placeholder_token_id, accelerator, save_path) logs = {"loss": loss.detach().item()} progress_bar.set_postfix(**logs) if global_step >= max_train_steps: break accelerator.wait_for_everyone() # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: pipeline = StableDiffusionPipeline.from_pretrained( pretrained_model_name_or_path, text_encoder=accelerator.unwrap_model(text_encoder), tokenizer=tokenizer, vae=vae, unet=unet, ) pipeline.save_pretrained(output_dir) # Also save the newly trained embeddings save_path = os.path.join(output_dir, f"learned_embeds.bin") save_progress(text_encoder, placeholder_token_id, accelerator, save_path) ``` ## 4.3. Run the Training ```python import accelerate num_gpus = 1 # specify the number of GPUs you would like to use. This also depends on your machine config accelerate.notebook_launcher(training_function, args=(text_encoder, vae, unet), num_processes=num_gpus) for param in itertools.chain(unet.parameters(), text_encoder.parameters()): if param.grad is not None: del param.grad # free some memory torch.cuda.empty_cache() ``` ## Step-5: Run Inference with the newly trained Model Bravo! Our model training is complete and we have taught our model a new concept, namely, ``cat-toy``. Let us now test our newly trained model by running inference against it and see the results. ## 5.1. Set up the pipeline The newly trained model artifacts are available in the output directory specified while setting up the training args. We will use those artifacts to load the model and setup the pipeline. ```python from diffusers import DPMSolverMultistepScheduler output_dir = args.output_dir pipe = StableDiffusionPipeline.from_pretrained( output_dir, scheduler=DPMSolverMultistepScheduler.from_pretrained(output_dir, subfolder="scheduler"), torch_dtype=torch.float16, ).to("cuda") ``` ```python from diffusers import DPMSolverMultistepScheduler output_dir = hyperparameters["output_dir"] pipe = StableDiffusionPipeline.from_pretrained( output_dir, scheduler=DPMSolverMultistepScheduler.from_pretrained(output_dir, subfolder="scheduler"), torch_dtype=torch.float16, ).to("cuda") ``` ## 5.2. Run the Stable Diffusion pipeline ```python # Don't forget to use the placeholder token in your prompt prompt = "a in mad max fury road" # type: str num_samples = 2 # type: int : number of samples to generate for the prompt images = pipe( prompt, num_images_per_prompt=num_samples, num_inference_steps=30, guidance_scale=9 ).images # you can play around with the newly trained model by changing the `prompt`, `num_inference_steps` and `guidance_scale` and see the difference in results. grid = image_grid(images, rows=1, cols=num_samples) grid ``` Output: ![Cat Toy Input Images](images/CatToyInputImages.png) ## Step-6: Save the newly created concept Our model training is complete and we have also tested it by successfully generating some image samples. We shall now save the model artifacts to preserve our newly created concept. This will also enable us to launch an Inference server against it, as described in the next section. ## 6.1. Save the Model artifacts ```python model_artifacts_dir = "dreambooth_model" # specify the directory to save the model artifacts os.makedirs(model_artifacts_dir, exist_ok=True) pipe2.save_pretrained(model_artifacts_dir) # save the model artifacts ``` ```python model_artifacts_dir = "textual_inversion_model" # specify the directory to save the model artifacts os.makedirs(model_artifacts_dir, exist_ok=True) pipe1.save_pretrained(model_artifacts_dir) # save the model artifacts # In case of Textual Inversion, the learned concept is stored in the `learned_embeds.bin` file. # Hence this file will be required every time we need to load this concept. !cp $output_dir/learned_embeds.bin $model_artifacts_dir/ ``` ## 6.2. Create Model on TIR * Go to `TIR `_ * Choose **a project**, navigate to the **Models** section and click on **Create Model** * Enter a model name of your choosing (e.g. stable-diffusion) * Select Model Type as **Custom** & click on **CREATE** * You will now see details of EOS (E2E Object Storage) bucket created for this model. * EOS provides a S3 compatible API to upload or download content. We will be using MinIO CLI in this tutorial. * Copy the **Setup Host** command from **Setup MinIO CLI tab**. We will use it to setup MinIO CLI in the next step. :::info Note In case you forget to copy the **setup host** command for MinIO CLI, don't worry. You can always go back to **model details** and get it again. ::: ## 6.3. Upload the Artifacts to Model Bucket We have already created the model on TIR. Let's set it up on our notebook and upload the artifacts to the model. Let's again go back to our TIR Notebook named ``stable-diffusion-fine-tune``, that we used to fine-tune the model in the previous steps and follow the steps below: * In the Jupyter labs, click **New Launcher** and select **Terminal** * Set up the MinIO CLI by running the **Setup Host** command copied in *Step-6.2* * Go to the directory ``model_artifacts_dir``, where we saved the model artifacts in Step 6.1 using the below command. ```python export model_artifacts_dir=dreambooth_model && \ cd $model_artifacts_dir ``` ```python export model_artifacts_dir=textual_inversion_model && \ cd $model_artifacts_dir ``` * Copy the model artifacts to the model bucket. * Go to TIR Dashboard >> Models >> Select your model >> Copy the cp command from **Setup MinIO CLI** tab * The copy command would look like this: ``mc cp -r stable-diffusion/stable-diffusion-854588`` * Here we will replace MODEL_NAME with '*' to upload all contents of the current folder to the bucket * We will append ``$model_artifacts_dir/`` to the bucket path to upload the contents to a folder named ``$model_artifacts_dir`` inside the bucket, so that, your copy command would look something like this: ```python mc cp -r * stable-diffusion/stable-diffusion-854588/$model_artifacts_dir/ ``` :::info Note The above command to copy model artifacts should not be used as it is. It is just a sample command to show what the command would look like. ::: You may follow the above steps to generate the copy command that will be **specific to your model bucket**. Your model artifacts will be saved successfully upon upload completion. ## Step-7: Create Inference Server against our newly trained model What now remains is to create an Inference Server against our trained model and serve API requests. :::info Note If you are not much familiar with Inference creation on TIR, follow `this tutorial `__ for a detailed and step-by-step guide on Model Endpoint (Inference) Creation for Stable Diffusion. ::: ## 7.1. Create a new Model Endpoint in TIR * Create a new Model Endpoint in TIR with Stable Diffusion Framework and a GPU Machine Plan. * **[Important!]** In the Model Details Subsection, choose the Model that we created in *Step-6.2*. Make sure to specify the ``$model_artifacts_dir`` path. This is necessary, because our trained model artifacts are present in this directory. ## 7.2. Generate your API Token The model endpoint API requires a valid **auth token** which you'll need to perform further steps. So, generate a new API Token from the *API Tokens* Section (or use an existing token, if already created). An api-key and an auth token will be generated. Copy this *auth token*. You will need it in the next step. ## 7.3. Make API Request to generate Image output The final step is to send API requests to the created model endpoint & generate images using text prompts. We will use TIR Notebook to do the same. * Once your model is *Ready*, visit the **Sample API Request** section of that model and copy the Python code * Launch a TIR Notebook with *PyTorch* or *StableDiffusion* Image with any basic machine plan. Once it is in *Running* state, launch it, and start a new notebook `untitled.ipynb` in the jupyter labs * Paste the **Sample API Request** code (for Python) in the notebook cell * Copy the Auth Token generated in *Step-7.2* & use it in place of *$AUTH_TOKEN* in the Sample API Request * Replace the **prompt** string in the payload with the below prompt. This ensures the use of placeholder token in the prompt (in our case, ````) to get the desired output. ```python "a in mad max fury road" ``` * Execute the code and send request. You'll get a list of tensors as output. |br| This is because Stable Diffusion v2.1 model endpoint return the generated images as a list of PyTorch Tensors. * To view the generated images, copy the below code, paste it in the notebook cell and execute it. You'll be able to view the generated images. ```python import torch import torchvision.transforms as transforms def display_images(tensor_image_data_list): '''convert PyTorch Tensors to PIL Image''' for tensor_data in tensor_image_data_list: tensor_image = torch.tensor(tensor_data.get("data")) # initialize the tensor pil_img = transforms.ToPILImage()(tensor_image) # convert to PIL Image pil_img.show() # to save the generated_images, uncomment the line below # image.save(tensor_data.get("name")) if response.status_code == 200: display_images(response.json().get("predictions")) ``` Output: ![Cat Toy Input Images](images/CatToyInference2.png) ## 3.3. Settings for the new concept That's it! We have successfully taught our model a new concept, ````. The Stable Diffusion model supports various other parameters for controlling the generation of image output. Refer to `Supported parameters for image generation `__ for more details. ## Conclusion Through this tutorial, we fine-tuned the Stable Diffusion v2.1 model using two training methods, namely, Dreambooth and Textual Inversion. By giving just 3-5 images as input, we taught new concept to the model and could personalize the model on our own images. After fine-tuning, we also saw how we can store the trained model artifacts to TIR Model Storage & launch Inference Servers using the same to serve API Requests. ---