Fine-tune Stable Diffusion Model on TIR
In this tutorial, we will fine-tune Stability AI's Stable Diffusion (v2.1) model via Dreambooth and Textual Inversion training methods using the Hugging Face Diffusers library. By using just 3-5 images, we will be able to teach new concepts to Stable Diffusion and personalize the model on our own images.
About Training Methods
The Stable Diffusion model can be fine-tuned using either of the two training methods: Dreambooth or Textual Inversion. Before training the model, let us get a brief overview of the two methods.
Dreambooth
DreamBooth is a way to fine-tune Stable Diffusion to generate a subject in different environments and styles. It teaches the diffusion model about a specific object or style using approximately three to five example images. After fine-tuning, the model can produce images containing that object in new settings. You can read the original paper here. Here's how it works:
- You provide a set of reference images to teach the model how the subject looks.
- You re-train the model to learn to associate that subject with a specific word.
- You prompt the new model using the special word.
With DreamBooth, we're training an entirely new version of Stable Diffusion. It will be a self-sufficient model capable of yielding impressive results compared to larger models.
Textual Inversion
The Textual Inversion training method captures new concepts from a small number of example images and associates the concepts with new words in the textual embedding space of the pre-trained model. Using only 3-5 images of a user-provided concept, like an object or style, we learn to represent it through new “words” in the textual embedding space. These “words” can be used in natural language sentences as text prompts, guiding personalized image creation intuitively. You can read the original paper here. Here’s how it works:
- You provide a set of reference images of your subject.
- You figure out what specific embedding the model associates with those images.
- You prompt the new model using the embedding.
Textual Inversion is akin to finding a “special word” that prompts your model to produce the images you want.
Fine-Tuning the Model
Step 1: Launch a Notebook on TIR
Before beginning, let us launch a Notebook on TIR to train our model. To launch the notebook, follow the steps below:
- Go to the TIR AI Platform.
- Choose a Project, navigate to the Notebooks section, and click on the Create Notebook button.
- Name the notebook stable-diffusion-fine-tune (you can choose any name).
- Notebook Image & Machine: Launch a new Notebook with the Diffusers Image and a GPU Machine plan. For fine-tuning, we recommend choosing a GPU plan from one of the A100, A40, or A30 series (e.g., GDC3.A10080-16.115GB).
- Disk Size: A default Disk Size of 30GB is sufficient for our use case. Adjust the disk size if you need more space.
- Datasets: If your training data is stored in a TIR Dataset, select that dataset for use during fine-tuning.
- Create the notebook.
- Launch Notebook: Once the notebook is in a Running state, launch it by clicking on the three dots (
...
) and start the Jupyter Labs environment. (The sidebar on the left also displays quick launch links for your notebooks.) - In Jupyter Labs, create a new Python3 Notebook from the New Launcher window. A new .ipynb file will open.
Our notebook is ready. Now, let's proceed with the model training.
This tutorial covers both DreamBooth and Textual Inversion training methods. Most steps and code are the same for both methods, but when they differ, we have used tabs to separate the two training methods.
Step 2: Initial Setup
Install the Required Libraries
!pip install -qq ftfy
Import the required libraries and packages
- Dreambooth
- Textual Inversion
import argparse
import itertools
import math
import os
from contextlib import nullcontext
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import Dataset
import PIL
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
import bitsandbytes as bnb
import argparse
import itertools
import math
import os
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import Dataset
import PIL
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
Now, let's define some helper functions that we will need gradually during the training and inference process.
def image_grid(imgs, rows, cols):
'''display multiple images in a grid'''
assert len(imgs) == rows*cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
Step 3: Settings for Teaching the New Concept
3.1. Defining the Model
Specify the Stable Diffusion model that we want to train. For this tutorial, we will use Stable Diffusion v2.1 (stabilityai/stable-diffusion-2-1
).
# `pretrained_model_name_or_path` defines the Stable Diffusion checkpoint you want to use.
pretrained_model_name_or_path = "stabilityai/stable-diffusion-2-1" # using Stable Diffusion v2.1
# some other models that you can use: ["stabilityai/stable-diffusion-2-1", "stabilityai/stable-diffusion-2", "stabilityai/stable-diffusion-2-base", "CompVis/stable-diffusion-v1-4", "runwayml/stable-diffusion-v1-5"]
3.2. Input Images
To teach the model a new concept, we need 3-5 input images with which we will train the model. So let's gather and set up the images for training purposes.
There can be multiple sources for the input images. Set up the input images using any one of the three sources:
Images available on the Internet
Using the public image urls, we can download the images from the internet and save them locally on Notebook.
# Add the URLs to the images of the concept you are adding. 3-5 should be fine
urls = [
"https://huggingface.co/datasets/valhalla/images/resolve/main/2.jpeg",
"https://huggingface.co/datasets/valhalla/images/resolve/main/3.jpeg",
"https://huggingface.co/datasets/valhalla/images/resolve/main/5.jpeg",
"https://huggingface.co/datasets/valhalla/images/resolve/main/6.jpeg",
## You can add additional image urls here
]
# Download the images from urls
import requests
import glob
from io import BytesIO
def download_image(url):
try:
response = requests.get(url)
except:
return None
return Image.open(BytesIO(response.content)).convert("RGB")
images = list(filter(None, [download_image(url) for url in urls]))
# save the images in a directory
save_path = "./my_concept"
if not os.path.exists(save_path):
os.mkdir(save_path)
[image.save(f"{save_path}/{i}.jpeg") for i, image in enumerate(images)]
Images stored in a TIR Dataset
To use the images stored in a TIR Dataset, ensure that the particular dataset containing your images is mounted on the notebook.
How to check if your Dataset is mounted on the notebook ? |br| Go to TIR Dashboard >> Notebooks Section >> Select the notebook >> Go to the Associated Datsets Tab |br| You will be able to see the list of Datasets mounted on your notebook. If your desired dataset is not in the list, you can mount it now.
You can also open the terminal and list all the mounted datasets using the ls /datasets/
command.
If you have updated the mounted dataset, wait for your notebook to be back in Running state, and start afresh from Step-2
.
P.S.: Dataset and Notebook must be in the same project, else you cannot mount it on the notebook.
Specify your dataset name and the directory path conatining the images in the below code block and run the code cell.
dataset_name = "" # enter the dataset name
images_dir = "" # enter path to directory containing the training images
images_path = os.path.join("/datasets/", str(dataset_name), str(images_dir)) # "/datasets/{dataset_name}/{images_path}"
if not dataset_name:
print("dataset_name must be provided")
elif not os.path.exists(str(images_path)):
print('The images_path specified does not exist. Check the path and try again')
else:
save_path = images_path
Images present on your local system
You can load your own training images by uploading them to the notebook using the Upload Files option.
# `images_path` is a path to directory containing the training images.
images_path = "" # type: "string"
if not os.path.exists(str(images_path)):
print('The images_path specified does not exist. Check the path and try again')
else:
save_path = images_path
Make sure that your input image directory only contains input images as the code will read all the files from the provided directory, and it will interfere with the training process.
Before proceeding further, let us first check our training images.
images = []
for file_path in os.listdir(save_path):
try:
image_path = os.path.join(save_path, file_path)
images.append(Image.open(image_path).resize((512, 512)))
except:
print(f"{image_path} is not a valid image, please make sure to remove this file from the directory otherwise the training could fail.")
image_grid(images, 1, len(images))
Output:
3.3. Settings for the new concept
- Dreambooth
- Textual Inversion
# `instance_prompt` is a prompt that should contain a good description of what your object or style is, together with the initializer word `cat_toy`
instance_prompt = "<cat-toy> toy" # type: "string"
# Check the `prior_preservation` option if you would like class of the concept (e.g.: toy, dog, painting) is guaranteed to be preserved. This increases the quality and helps with generalization at the cost of training time
prior_preservation = False # type: bool
prior_preservation_class_prompt = "a photo of a cat clay toy" # type: "string"
num_class_images = 12
sample_batch_size = 2
prior_loss_weight = 0.5
prior_preservation_class_folder = "./class_images"
class_data_root = prior_preservation_class_folder
class_prompt = prior_preservation_class_prompt
Advanced settings for prior preservation (optional)
num_class_images = 12 # type: int
sample_batch_size = 2
# `prior_preservation_weight` determines how strong the class for prior preservation should be
prior_loss_weight = 1 # type: int
# If the `prior_preservation_class_folder` is empty, images for the class will be generated with the class prompt. Otherwise, fill this folder with images of items on the same class as your concept (but not images of the concept itself)
prior_preservation_class_folder = "./class_images" # type: string
class_data_root = prior_preservation_class_folder
# what_to_teach: what is it that you are teaching?
# object: enables you to teach the model a new object to be used
# style: allows you to teach the model a new style one can use.
what_to_teach = "object" # allowed values: ["object", "style"]
# placeholder_token: the token you are going to use to represent your new concept (so when you prompt the model, you will say "A <my-placeholder-token> in an amusement park"). We use angle brackets to differentiate a token from other words/tokens, to avoid collision.
placeholder_token = "<cat-toy>" # type: string
# initializer_token: a word that can summarise what your new concept is, to be used as a starting point
initializer_token = "toy" # type: string
Step-4: Teach the model the new concept (Fine-tuning with the training method)
Execute the below sequence of cells to run the training process. The whole process may take from 30 min to 3 hours. (Open this block if you are interested in how this process works under the hood or if you want to change advanced training settings or hyperparameters)
4.1. Setup for Training
- Dreambooth
- Textual Inversion
Setup the classes
from pathlib import Path
from torchvision import transforms
class DreamBoothDataset(Dataset):
def __init__(
self,
instance_data_root,
instance_prompt,
tokenizer,
class_data_root=None,
class_prompt=None,
size=512,
center_crop=False,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
raise ValueError("Instance images root doesn't exists.")
self.instance_images_path = list(Path(instance_data_root).iterdir())
self.num_instance_images = len(self.instance_images_path)
self.instance_prompt = instance_prompt
self._length = self.num_instance_images
if class_data_root is not None:
self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True)
self.class_images_path = list(Path(class_data_root).iterdir())
self.num_class_images = len(self.class_images_path)
self._length = max(self.num_class_images, self.num_instance_images)
self.class_prompt = class_prompt
else:
self.class_data_root = None
self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __len__(self):
return self._length
def __getitem__(self, index):
example = {}
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
example["instance_images"] = self.image_transforms(instance_image)
example["instance_prompt_ids"] = self.tokenizer(
self.instance_prompt,
padding="do_not_pad",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids
if self.class_data_root:
class_image = Image.open(self.class_images_path[index % self.num_class_images])
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
example["class_prompt_ids"] = self.tokenizer(
self.class_prompt,
padding="do_not_pad",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids
return example
class PromptDataset(Dataset):
def __init__(self, prompt, num_samples):
self.prompt = prompt
self.num_samples = num_samples
def __len__(self):
return self.num_samples
def __getitem__(self, index):
example = {}
example["prompt"] = self.prompt
example["index"] = index
return example
Generate Class Images
import gc
if(prior_preservation):
class_images_dir = Path(class_data_root)
if not class_images_dir.exists():
class_images_dir.mkdir(parents=True)
cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < num_class_images:
pipeline = StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path, revision="fp16", torch_dtype=torch.float16
).to("cuda")
pipeline.enable_attention_slicing()
pipeline.set_progress_bar_config(disable=True)
num_new_images = num_class_images - cur_class_images
print(f"Number of class images to sample: {num_new_images}.")
sample_dataset = PromptDataset(class_prompt, num_new_images)
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=sample_batch_size)
for example in tqdm(sample_dataloader, desc="Generating class images"):
images = pipeline(example["prompt"]).images
for i, image in enumerate(images):
image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg")
pipeline = None
gc.collect()
del pipeline
with torch.no_grad():
torch.cuda.empty_cache()
Load the Stable Diffusion Model
# Load models and create wrapper for stable diffusion
text_encoder = CLIPTextModel.from_pretrained(
pretrained_model_name_or_path, subfolder="text_encoder"
)
vae = AutoencoderKL.from_pretrained(
pretrained_model_name_or_path, subfolder="vae"
)
unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path, subfolder="unet"
)
tokenizer = CLIPTokenizer.from_pretrained(
pretrained_model_name_or_path,
subfolder="tokenizer",
)
Create Dataset
# Setup the prompt templates for training
imagenet_templates_small = [
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
]
imagenet_style_templates_small = [
"a painting in the style of {}",
"a rendering in the style of {}",
"a cropped painting in the style of {}",
"the painting in the style of {}",
"a clean painting in the style of {}",
"a dirty painting in the style of {}",
"a dark painting in the style of {}",
"a picture in the style of {}",
"a cool painting in the style of {}",
"a close-up painting in the style of {}",
"a bright painting in the style of {}",
"a cropped painting in the style of {}",
"a good painting in the style of {}",
"a close-up painting in the style of {}",
"a rendition in the style of {}",
"a nice painting in the style of {}",
"a small painting in the style of {}",
"a weird painting in the style of {}",
"a large painting in the style of {}",
]
Setup the Dataset
# Setup the dataset
class TextualInversionDataset(Dataset):
def __init__(
self,
data_root,
tokenizer,
learnable_property="object", # [object, style]
size=512,
repeats=100,
interpolation="bicubic",
flip_p=0.5,
set="train",
placeholder_token="*",
center_crop=False,
):
self.data_root = data_root
self.tokenizer = tokenizer
self.learnable_property = learnable_property
self.size = size
self.placeholder_token = placeholder_token
self.center_crop = center_crop
self.flip_p = flip_p
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
self.num_images = len(self.image_paths)
self._length = self.num_images
if set == "train":
self._length = self.num_images * repeats
self.interpolation = {
# "linear": PIL.Image.Resampling.LINEAR, # Removed in version 10.0.0. use bilinear instead
"bilinear": PIL.Image.Resampling.BILINEAR,
"bicubic": PIL.Image.Resampling.BICUBIC,
"lanczos": PIL.Image.Resampling.LANCZOS,
}[interpolation]
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
def __len__(self):
return self._length
def __getitem__(self, i):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
placeholder_string = self.placeholder_token
text = random.choice(self.templates).format(placeholder_string)
example["input_ids"] = self.tokenizer(
text,
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids[0]
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = (
img.shape[0],
img.shape[1],
)
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
image = Image.fromarray(img)
image = image.resize((self.size, self.size), resample=self.interpolation)
image = self.flip_transform(image)
image = np.array(image).astype(np.uint8)
image = (image / 127.5 - 1.0).astype(np.float32)
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
return example
Set up the Model
# Load the tokenizer and add the placeholder token as a additional special token.
tokenizer = CLIPTokenizer.from_pretrained(
pretrained_model_name_or_path,
subfolder="tokenizer",
)
# Add the placeholder token in tokenizer
num_added_tokens = tokenizer.add_tokens(placeholder_token)
if num_added_tokens == 0:
raise ValueError(
f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer."
)
# Get token ids for our placeholder and initializer token.
# This code block will complain if initializer string is not a single token
# Convert the initializer_token, placeholder_token to ids
token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)
# Check if initializer_token is a single token or a sequence of tokens
if len(token_ids) > 1:
raise ValueError("The initializer token must be a single token.")
initializer_token_id = token_ids[0]
placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)
# Load the Stable Diffusion model
# Load models and create wrapper for stable diffusion
text_encoder = CLIPTextModel.from_pretrained(
pretrained_model_name_or_path, subfolder="text_encoder"
)
vae = AutoencoderKL.from_pretrained(
pretrained_model_name_or_path, subfolder="vae"
)
unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path, subfolder="unet"
)
We have added the placeholder_token
in the tokenizer
so we resize the token embeddings here, this will a new embedding vector in the token embeddings for our placeholder_token
text_encoder.resize_token_embeddings(len(tokenizer))
Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder.get_input_embeddings().weight.data
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
In Textual-Inversion we only train the newly added embedding vector, so lets freeze rest of the model parameters here
def freeze_params(params):
for param in params:
param.requires_grad = False
# Freeze vae and unet
freeze_params(vae.parameters())
freeze_params(unet.parameters())
# Freeze all parameters except for the token embeddings in text encoder
params_to_freeze = itertools.chain(
text_encoder.text_model.encoder.parameters(),
text_encoder.text_model.final_layer_norm.parameters(),
text_encoder.text_model.embeddings.position_embedding.parameters(),
)
freeze_params(params_to_freeze)
Creating our training data: Creation of Dataset, Dataloader and noise_scheduler
# create dataset for training
train_dataset = TextualInversionDataset(
data_root=save_path,
tokenizer=tokenizer,
size=vae.sample_size,
placeholder_token=placeholder_token,
repeats=100,
learnable_property=what_to_teach, #Option selected above between object and style
center_crop=False,
set="train",
)
# create dataloader for training
def create_dataloader(train_batch_size=1):
return torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
# Create noise_scheduler for training
noise_scheduler = DDPMScheduler.from_config(pretrained_model_name_or_path, subfolder="scheduler")
4.2. Training
Setting up the training args and define hyperparameters for the training.
You can also tune the hyperparameters like learning_rate
, max_train_steps
etc. to play around.
- Dreambooth
- Textual Inversion
from argparse import Namespace
args = Namespace(
pretrained_model_name_or_path=pretrained_model_name_or_path,
resolution=vae.sample_size,
center_crop=True,
train_text_encoder=False,
instance_data_dir=save_path,
instance_prompt=instance_prompt,
learning_rate=5e-06,
max_train_steps=300,
save_steps=50,
train_batch_size=2, # set to 1 if using prior preservation
gradient_accumulation_steps=2,
max_grad_norm=1.0,
mixed_precision="fp16", # set to "fp16" for mixed-precision training.
gradient_checkpointing=True, # set this to True to lower the memory usage.
use_8bit_adam=True, # use 8bit optimizer from bitsandbytes
seed=3434554,
with_prior_preservation=prior_preservation,
prior_loss_weight=prior_loss_weight,
sample_batch_size=2,
class_data_dir=prior_preservation_class_folder,
class_prompt=prior_preservation_class_prompt,
num_class_images=num_class_images,
lr_scheduler="constant",
lr_warmup_steps=100,
output_dir="dreambooth-concept",
)
hyperparameters = {
"learning_rate": 5e-04,
"scale_lr": True,
"max_train_steps": 2000,
"save_steps": 250,
"train_batch_size": 4,
"gradient_accumulation_steps": 1,
"gradient_checkpointing": True,
"mixed_precision": "fp16",
"seed": 42,
"output_dir": "textual-inversion-concept",
}
!mkdir -p textual-inversion-concept
Define the training function
- Dreambooth
- Textual Inversion
from accelerate.utils import set_seed
def training_function(text_encoder, vae, unet):
logger = get_logger(__name__)
set_seed(args.seed)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
)
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
raise ValueError(
"Gradient accumulation is not supported when training the text encoder in distributed training. "
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
)
vae.requires_grad_(False)
if not args.train_text_encoder:
text_encoder.requires_grad_(False)
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.train_text_encoder:
text_encoder.gradient_checkpointing_enable()
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
params_to_optimize = (
itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
)
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
)
noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_prompt=args.class_prompt,
tokenizer=tokenizer,
size=args.resolution,
center_crop=args.center_crop,
)
def collate_fn(examples):
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
# concat class and instance examples for prior preservation
if args.with_prior_preservation:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = tokenizer.pad(
{"input_ids": input_ids},
padding="max_length",
return_tensors="pt",
max_length=tokenizer.model_max_length
).input_ids
batch = {
"input_ids": input_ids,
"pixel_values": pixel_values,
}
return batch
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn
)
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move text_encode and vae to gpu.
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
vae.to(accelerator.device, dtype=weight_dtype)
vae.decoder.to("cpu")
if not args.train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
global_step = 0
for epoch in range(num_train_epochs):
unet.train()
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
if args.with_prior_preservation:
# Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
# Compute instance loss
loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
# Compute prior loss
prior_loss = F.mse_loss(noise_pred_prior.float(), target_prior.float(), reduction="mean")
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = (
itertools.chain(unet.parameters(), text_encoder.parameters())
if args.train_text_encoder
else unet.parameters()
)
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
optimizer.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if global_step % args.save_steps == 0:
if accelerator.is_main_process:
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
pipeline.save_pretrained(save_path)
logs = {"loss": loss.detach().item()}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
accelerator.wait_for_everyone()
# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
)
pipeline.save_pretrained(args.output_dir)
logger = get_logger(__name__)
def save_progress(text_encoder, placeholder_token_id, accelerator, save_path):
logger.info("Saving embeddings")
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
torch.save(learned_embeds_dict, save_path)
def training_function(text_encoder, vae, unet):
train_batch_size = hyperparameters["train_batch_size"]
gradient_accumulation_steps = hyperparameters["gradient_accumulation_steps"]
learning_rate = hyperparameters["learning_rate"]
max_train_steps = hyperparameters["max_train_steps"]
output_dir = hyperparameters["output_dir"]
gradient_checkpointing = hyperparameters["gradient_checkpointing"]
accelerator = Accelerator(
gradient_accumulation_steps=gradient_accumulation_steps,
mixed_precision=hyperparameters["mixed_precision"]
)
if gradient_checkpointing:
text_encoder.gradient_checkpointing_enable()
unet.enable_gradient_checkpointing()
train_dataloader = create_dataloader(train_batch_size)
if hyperparameters["scale_lr"]:
learning_rate = (
learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes
)
# Initialize the optimizer
optimizer = torch.optim.AdamW(
text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
lr=learning_rate,
)
text_encoder, optimizer, train_dataloader = accelerator.prepare(
text_encoder, optimizer, train_dataloader
)
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move vae and unet to device
vae.to(accelerator.device, dtype=weight_dtype)
unet.to(accelerator.device, dtype=weight_dtype)
# Keep vae in eval mode as we don't train it
vae.eval()
# Keep unet in train mode to enable gradient checkpointing
unet.train()
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
# Train!
total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Instantaneous batch size per device = {train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
global_step = 0
for epoch in range(num_train_epochs):
text_encoder.train()
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(text_encoder):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states.to(weight_dtype)).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = F.mse_loss(noise_pred, target, reduction="none").mean([1, 2, 3]).mean()
accelerator.backward(loss)
# Zero out the gradients for all token embeddings except the newly added
# embeddings for the concept, as we only want to optimize the concept embeddings
if accelerator.num_processes > 1:
grads = text_encoder.module.get_input_embeddings().weight.grad
else:
grads = text_encoder.get_input_embeddings().weight.grad
# Get the index for tokens that we want to zero the grads for
index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
optimizer.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if global_step % hyperparameters["save_steps"] == 0:
save_path = os.path.join(output_dir, f"learned_embeds-step-{global_step}.bin")
save_progress(text_encoder, placeholder_token_id, accelerator, save_path)
logs = {"loss": loss.detach().item()}
progress_bar.set_postfix(**logs)
if global_step >= max_train_steps:
break
accelerator.wait_for_everyone()
# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
pipeline = StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder),
tokenizer=tokenizer,
vae=vae,
unet=unet,
)
pipeline.save_pretrained(output_dir)
# Also save the newly trained embeddings
save_path = os.path.join(output_dir, f"learned_embeds.bin")
save_progress(text_encoder, placeholder_token_id, accelerator, save_path)
4.3. Run the Training
import accelerate
num_gpus = 1 # specify the number of GPUs you would like to use. This also depends on your machine config
accelerate.notebook_launcher(training_function, args=(text_encoder, vae, unet), num_processes=num_gpus)
for param in itertools.chain(unet.parameters(), text_encoder.parameters()):
if param.grad is not None:
del param.grad # free some memory
torch.cuda.empty_cache()
Step-5: Run Inference with the newly trained Model
Bravo! Our model training is complete and we have taught our model a new concept, namely, cat-toy
.
Let us now test our newly trained model by running inference against it and see the results.
5.1. Set up the pipeline
The newly trained model artifacts are available in the output directory specified while setting up the training args. We will use those artifacts to load the model and setup the pipeline.
- Dreambooth
- Textual Inversion
from diffusers import DPMSolverMultistepScheduler
output_dir = args.output_dir
pipe = StableDiffusionPipeline.from_pretrained(
output_dir,
scheduler=DPMSolverMultistepScheduler.from_pretrained(output_dir, subfolder="scheduler"),
torch_dtype=torch.float16,
).to("cuda")
from diffusers import DPMSolverMultistepScheduler
output_dir = hyperparameters["output_dir"]
pipe = StableDiffusionPipeline.from_pretrained(
output_dir,
scheduler=DPMSolverMultistepScheduler.from_pretrained(output_dir, subfolder="scheduler"),
torch_dtype=torch.float16,
).to("cuda")
5.2. Run the Stable Diffusion pipeline
# Don't forget to use the placeholder token in your prompt
prompt = "a <cat-toy> in mad max fury road" # type: str
num_samples = 2 # type: int : number of samples to generate for the prompt
images = pipe(
prompt, num_images_per_prompt=num_samples, num_inference_steps=30, guidance_scale=9
).images
# you can play around with the newly trained model by changing the `prompt`, `num_inference_steps` and `guidance_scale` and see the difference in results.
grid = image_grid(images, rows=1, cols=num_samples)
grid
Output:
Step-6: Save the newly created concept
Our model training is complete and we have also tested it by successfully generating some image samples. We shall now save the model artifacts to preserve our newly created concept. This will also enable us to launch an Inference server against it, as described in the next section.
6.1. Save the Model artifacts
- Dreambooth
- Textual Inversion
model_artifacts_dir = "dreambooth_model" # specify the directory to save the model artifacts
os.makedirs(model_artifacts_dir, exist_ok=True)
pipe2.save_pretrained(model_artifacts_dir) # save the model artifacts
model_artifacts_dir = "textual_inversion_model" # specify the directory to save the model artifacts
os.makedirs(model_artifacts_dir, exist_ok=True)
pipe1.save_pretrained(model_artifacts_dir) # save the model artifacts
# In case of Textual Inversion, the learned concept is stored in the `learned_embeds.bin` file.
# Hence this file will be required every time we need to load this concept.
!cp $output_dir/learned_embeds.bin $model_artifacts_dir/
6.2. Create Model on TIR
- Go to
TIR AI Platform <https://tir.e2enetworks.com>
_ - Choose a project, navigate to the Models section and click on Create Model
- Enter a model name of your choosing (e.g. stable-diffusion)
- Select Model Type as Custom & click on CREATE
- You will now see details of EOS (E2E Object Storage) bucket created for this model.
- EOS provides a S3 compatible API to upload or download content. We will be using MinIO CLI in this tutorial.
- Copy the Setup Host command from Setup MinIO CLI tab. We will use it to setup MinIO CLI in the next step.
In case you forget to copy the setup host command for MinIO CLI, don't worry. You can always go back to model details and get it again.
6.3. Upload the Artifacts to Model Bucket
We have already created the model on TIR. Let's set it up on our notebook and upload the artifacts to the model.
Let's again go back to our TIR Notebook named stable-diffusion-fine-tune
, that we used to fine-tune the model in the previous steps and follow the steps below:
- In the Jupyter labs, click New Launcher and select Terminal
- Set up the MinIO CLI by running the Setup Host command copied in Step-6.2
- Go to the directory
model_artifacts_dir
, where we saved the model artifacts in Step 6.1 using the below command.
- Dreambooth
- Textual Inversion
export model_artifacts_dir=dreambooth_model && \
cd $model_artifacts_dir
export model_artifacts_dir=textual_inversion_model && \
cd $model_artifacts_dir
-
Copy the model artifacts to the model bucket.
- Go to TIR Dashboard >> Models >> Select your model >> Copy the cp command from Setup MinIO CLI tab
- The copy command would look like this:
mc cp -r <MODEL_NAME> stable-diffusion/stable-diffusion-854588
- Here we will replace MODEL_NAME with '*' to upload all contents of the current folder to the bucket
- We will append
$model_artifacts_dir/
to the bucket path to upload the contents to a folder named$model_artifacts_dir
inside the bucket, so that, your copy command would look something like this:
mc cp -r * stable-diffusion/stable-diffusion-854588/$model_artifacts_dir/
NoteThe above command to copy model artifacts should not be used as it is. It is just a sample command to show what the command would look like.
You may follow the above steps to generate the copy command that will be specific to your model bucket.
Your model artifacts will be saved successfully upon upload completion.
Step-7: Create Inference Server against our newly trained model
What now remains is to create an Inference Server against our trained model and serve API requests.
If you are not much familiar with Inference creation on TIR, follow this tutorial <https://docs.e2enetworks.com/tir/inference/Tutorials/stable_diffusion_inference.html#a-guide-on-model-endpoint-creation-image-generation>
__ for a detailed and step-by-step guide on Model Endpoint (Inference) Creation for Stable Diffusion.
7.1. Create a new Model Endpoint in TIR
- Create a new Model Endpoint in TIR with Stable Diffusion Framework and a GPU Machine Plan.
- [Important!] In the Model Details Subsection, choose the Model that we created in Step-6.2. Make sure to specify the
$model_artifacts_dir
path. This is necessary, because our trained model artifacts are present in this directory.
7.2. Generate your API Token
The model endpoint API requires a valid auth token which you'll need to perform further steps. So, generate a new API Token from the API Tokens Section (or use an existing token, if already created). An api-key and an auth token will be generated. Copy this auth token. You will need it in the next step.
7.3. Make API Request to generate Image output
The final step is to send API requests to the created model endpoint & generate images using text prompts. We will use TIR Notebook to do the same.
-
Once your model is Ready, visit the Sample API Request section of that model and copy the Python code
-
Launch a TIR Notebook with PyTorch or StableDiffusion Image with any basic machine plan. Once it is in Running state, launch it, and start a new notebook
untitled.ipynb
in the jupyter labs -
Paste the Sample API Request code (for Python) in the notebook cell
-
Copy the Auth Token generated in Step-7.2 & use it in place of $AUTH_TOKEN in the Sample API Request
-
Replace the prompt string in the payload with the below prompt. This ensures the use of placeholder token in the prompt (in our case,
<cat-toy>
) to get the desired output."a <cat-toy> in mad max fury road"
-
Execute the code and send request. You'll get a list of tensors as output. |br| This is because Stable Diffusion v2.1 model endpoint return the generated images as a list of PyTorch Tensors.
-
To view the generated images, copy the below code, paste it in the notebook cell and execute it. You'll be able to view the generated images.
import torch
import torchvision.transforms as transforms
def display_images(tensor_image_data_list):
'''convert PyTorch Tensors to PIL Image'''
for tensor_data in tensor_image_data_list:
tensor_image = torch.tensor(tensor_data.get("data")) # initialise the tensor
pil_img = transforms.ToPILImage()(tensor_image) # convert to PIL Image
pil_img.show()
# to save the generated_images, uncomment the line below
# image.save(tensor_data.get("name"))
if response.status_code == 200:
display_images(response.json().get("predictions"))
Output:
3.3. Settings for the new concept
That's it! We have successfully taught our model a new concept, <cat-toy>
.
The Stable Diffusion model supports various other parameters for controlling the generation of image output. Refer to Supported parameters for image generation <https://docs.e2enetworks.com/tir/inference/Tutorials/stable_diffusion_inference.html#supported-parameters-for-image-generation>
__ for more details.
Conclusion
Through this tutorial, we fine-tuned the Stable Diffusion v2.1 model using two training methods, namely, Dreambooth and Textual Inversion. By giving just 3-5 images as input, we taught new concept to the model and could personalise the model on our own images.
After fine-tuning, we also saw how we can store the trained model artifacts to TIR Model Storage & launch Inference Servers using the same to serve API Requests.