Fine-tune Google Flan UL2 with Multiple GPUs

In this tutorial, we will go through fine-tune Flan-UL2 combining Peft, LoRA, and Deepspeed on multiple GPUs (single machine). With A100 80GB cards, we can expect the training to finish in 24 hours for 3 epochs.

Flan-UL2 is an encoder-decoder model based on the T5 architecture. It is a 20B parameter model fine-tuned using the “Flan” prompt tuning and dataset collection.

The model is was initialized using UL2 checkpoints. For more information on UL2, please take a look at the original paper.

While focus of this article is to cover use of multiple GPUs on TIR platform, you may run through the same with a single GPU at the cost of training speed.

We can also try to fit the model across GPUs using Deepspeed’s pipeline parallelism but it does require playing with device map mainly because T5 blocks in UL2 have residual connections which can not be split across GPUs. There are workarounds but for the scope of this tutorials we will use Deepspeed’s Zero Offloading feature which will allow us to fine-tune one model copy per GPU through model parallelism.


[1] Start a GPU Notebook (with 4xA10080 Plan and atleast 100GB disk) from TIR Dashbaord.

[2] Install Requirements. .. code:

!pip install accelerate transformers peft deepspeed

[3] Initiate setup of Accelerate Config. .. code:

accelerate config --config_file launcher_config.yaml

[4] Choose following parameters: .. code:

In which environment are you running?
This machine
Which type of machine are you using?
How many different machines will you use? 1
Do you wish to optimise your script with torch dynamo? No
Do you want to use Deepspeed? Yes
Do you want to specify json for deepspeed config? No
What should be your Deepspeed's Zero Optimization stage? 3
Where to offload CPU optimization stages? cpu
where to offload parameters? cpu
How many gradient accumulation steps your are passing to the script? 8
Do you want to use gradient clipping? no
Do you want to save 16-bit model.. ? no
Do you want to enable '' ... ? no
How many gpus should be used for training? 4

Do you wish to use FP16 or BF16? no

[5] Prepare dataset.

The fine-tuning script can work with any csv file. here, we will use alpaca dataset. Create a new file in jypter labs with name Copy the following contents to this file. .. code:

import json

import pandas as pd

with open('alpaca_data.json') as f:
    data = json.load(f)

new_format = []
for i, point in enumerate(data):
    # no input
    if len(point['input']) == 0:
        inputt = "Below is an instruction that describes a task.\n "
        inputt += "Write a response that appropriately completes the request.\n\n"
        inputt += f"### Instruction:\n{point['instruction']}\n\n### Response:"
        inputt = "Below is an instruction that describes a task.\n "
        inputt += "Write a response that appropriately completes the request.\n\n"
        inputt += f"### Instruction:\n{point['instruction']}\n\n### Input:\n{point['input']}\n\n### Response:"

    item = {'input': inputt, 'output': str(point['output'])}

df = pd.DataFrame(new_format)
df = df.dropna()

[6] Open terminal in Jypyter labs. Run the following commands to download alpaca set and prepare a csv from it. .. code:

# download the json alpaca dataset

# run in the terminal shell

[7] Prepare fine-tuning script.

Create a file named in jupyter labs (from file browser) and copy the following contents to the file.

# Modified from
import argparse
import gc
import logging
import os
import threading

import psutil
import torch
from accelerate import Accelerator
from datasets import load_dataset
from deepspeed.accelerator import get_accelerator
from peft import LoraConfig, TaskType, get_peft_model
from import DataLoader
from tqdm import tqdm
from transformers import (AutoModelForSeq2SeqLM, AutoTokenizer,
                        get_linear_schedule_with_warmup, set_seed)

def b2mb(x):
    Converting Bytes to Megabytes
    return int(x / 2**20)

class TorchTracemalloc:
    # Context manager is used to track the peak memory usage of the process

    def __enter__(self):
        torch.cuda.reset_max_memory_allocated()  # reset the peak gauge to zero
        self.begin = torch.cuda.memory_allocated()
        self.process = psutil.Process()

        self.cpu_begin = self.cpu_mem_used()
        self.peak_monitoring = True
        peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
        peak_monitor_thread.daemon = True
        return self

    def cpu_mem_used(self):
        """get resident set size memory for the current process"""
        return self.process.memory_info().rss

    def peak_monitor_func(self):
        self.cpu_peak = -1
        while True:
            self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)
            if not self.peak_monitoring:

    def __exit__(self, *exc):
        self.peak_monitoring = False
        self.end = torch.cuda.memory_allocated()
        self.peak = torch.cuda.max_memory_allocated()
        self.used = b2mb(self.end - self.begin)
        self.peaked = b2mb(self.peak - self.begin)
        self.cpu_end = self.cpu_mem_used()
        self.cpu_used = b2mb(self.cpu_end - self.cpu_begin)
        self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin)

# Handle argument parsing
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str,
                        help="Model path. Supports T5/UL2 models")
    parser.add_argument("--datafile_path", type=str, default="sample.csv",
                        help="Path to the already processed dataset.")
    parser.add_argument("--num_epochs", type=int, default=1,
                        help="Number of epochs to train for.")
    parser.add_argument("--per_device_batch_size", type=int,
                        default=2, help="Batch size to use for training.")
    parser.add_argument("--input_max_length", type=int, default=128,
                        help="Maximum input length to use for generation")
    parser.add_argument("--target_max_length", type=int, default=128,
                        help="Maximum target length to use for generation")
    parser.add_argument("--lr", type=float, default=3e-4,
                        help="Learning rate to use for training.")
    parser.add_argument("--seed", type=int, default=42,
                        help="Seed to use for training.")
    parser.add_argument("--input_column", type=str,
                        default='input', help='csv input text column')
    parser.add_argument("--target_column", type=str,
                        default='output', help='csv target text column')
    parser.add_argument("--save_path", type=str,
                        default='peft_ckpt', help="Save path")

    args = parser.parse_known_args()
    return args

# Main function
def main():
    args, _ = parse_args()
    text_column = args.input_column
    label_column = args.target_column
    lr =
    num_epochs = args.num_epochs
    batch_size = args.per_device_batch_size
    seed = args.seed
    model_name_or_path = args.model_path
    data_file = args.datafile_path
    save_path = args.save_path
    target_max_length = args.target_max_length
    source_max_length = args.input_max_length

    # Create  dir if it doesn't exist
    os.makedirs(save_path, exist_ok=True)

    # Setup logging
        format="%(asctime)s [%(levelname)s] %(message)s",
            logging.FileHandler(os.path.join(save_path, "training_log.log")),
    )'Args:\n {args}')

    # launch configs
    accelerator = Accelerator()
    peft_config = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8,
        lora_alpha=32, lora_dropout=0.1

    # Save logs only on the main process
    def log_info(logging, s):

    # load dataset
    dataset = load_dataset('csv', data_files={'train': data_file})
    log_info(logging, f"Dataset length :{len(dataset['train'])}")
    # load model
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
    # load peft model
    model = get_peft_model(model, peft_config)
    # load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

    def preprocess_function(sample, padding="max_length"):
        # created prompted input
        inputs = sample[text_column]
        # tokenize inputs
        model_inputs = tokenizer(
            inputs, max_length=source_max_length,
            padding=padding, truncation=True)
        # Tokenize targets with the `text_target` keyword argument
        labels = tokenizer(text_target=sample[label_column],
                        padding=padding, truncation=True)
        # If we are padding here, replace all tokenizer.pad_token_id
        # in the labels by -100 when we want to ignore padding in the loss.
        if padding == "max_length":
            labels["input_ids"] = [
                [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    # Prepare and preprocess the dataset
    with accelerator.main_process_first():
        # preventing string conversion errors

        def str_convert(example):
            example[label_column] = str(example[label_column])
            return example

        dataset['train'] = dataset['train'].map(str_convert)
        processed_datasets =
            desc="Running tokenizer on dataset",
    train_dataset = processed_datasets["train"]

    def collate_fn(examples):
        return tokenizer.pad(examples, padding="longest", return_tensors="pt")

    train_dataloader = DataLoader(
        train_dataset, shuffle=True, collate_fn=collate_fn,
        batch_size=batch_size, pin_memory=True

    # optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    # lr scheduler
    lr_scheduler = get_linear_schedule_with_warmup(
        num_training_steps=(len(train_dataloader) * num_epochs),

    # accelerator prepapre
    model, train_dataloader, optimizer, lr_scheduler = accelerator.prepare(
        model, train_dataloader, optimizer, lr_scheduler

    # Train the model
    for epoch in range(num_epochs):
        with TorchTracemalloc() as tracemalloc:
            total_loss = 0
            for step, batch in enumerate(tqdm(train_dataloader)):
                # using accelerator accumulate to perform gradient accumulation
                with accelerator.accumulate(model):
                    outputs = model(**batch)
                    loss = outputs.loss
                    total_loss += loss.detach().float()


        # Printing the GPU memory usage details
                "GPU Peak Memory consumed during train: {}".format(tracemalloc.peaked))
                "GPU Total Peak Memory consumed during the train: {}".format(
                    tracemalloc.peaked + b2mb(tracemalloc.begin)
            logging, "CPU Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.cpu_peaked))
                "CPU Total Peak Memory consumed during the train (max): {}".format(
                    tracemalloc.cpu_peaked + b2mb(tracemalloc.cpu_begin)

        train_epoch_loss = total_loss / len(train_dataloader)
        train_ppl = torch.exp(train_epoch_loss)
            logging, "........................ : TRAINING DETAILS : .......................")
        log_info(logging, f"{epoch=}: {train_ppl=} {train_epoch_loss=}")

        # save intermediate checkpoint
        log_info(logging, "Saving intermediate ckpt")
        success = model.save_checkpoint(f'{save_path}', f'{epoch}')
        # save peft config
        peft_config.save_pretrained(os.path.join(f'{save_path}', f'{epoch}'))

        status_msg = f"checkpointing: checkpoint_folder={save_path}"
        if success:
            log_info(logging, f"Success {status_msg}")
            log_info(logging, f"Failure {status_msg}")

    log_info(logging, "Training complete ......")

if __name__ == "__main__":

[8] Run the training with accelerate. .. code:

accelerate launch --config_file launcher_config.yaml \
--model_path google/flan-ul2 \
--datafile_path alpaca_data.csv \
--save_path checkpoint
--num_epochs 1 \
--lr 1e-4\
--per_device_batch_size 2 \
--input_max_length 256 \
--target_max_length 256

[9] Prepare checkpoint for inference

Deepspeed will create sharded model files and a in the checkpoint folder. Run the following command to convert the checkpoints to Pytorch bin file. .. code:

# replace 0 with the latest checkpoints number
python ./checkpoint/ ./checkpoint/ ./checkpoint/0/adapter_model.bin

[10] Load fine-tuned model for inference. .. code:

# Start a new notebook to run this code from a notebook cell.
from transformers import AutoModelForSeq2SeqLM
from peft import PeftModel, PeftConfig

peft_model_id = 'checkpoint\0'
base_model_id = 'Google\flan-ul2'

model = AutoModelForSeq2SeqLM.from_pretrained(base_model_id) # The original model path
model = PeftModel.from_pretrained(model, peft_model_id) # The fine-tuned model path

You may also further merge the base weights with LoRA checkpoints.


In this tutorial, we have fine-tuned a large language model with multiple GPUs using Deepspeed and LoRA. A similar apporach can be followed for fine-tuning other models.