Custom Containers in TIR

TIR platform supports a variety of pre-built containers that can launch API handlers for you. But, sometimes you may want to handle the API requests differently or introduce additional steps in the flow. This is when a custom container image can help.

Additionally, you may also have your own containers that you want to launch with a GPU plan.

In this tutorial, we will -

Step 1: Write a API handler for model inference

By default, each Model Endpoint (in TIR) follows Kserve Open inference protocol for handling inference requests. We recommend using the same format for your REST API endpoints, but you may choose to do things differently.

In this tutorial, we will use Kserve Model Server to wrap our model inference calls so we don’t have to deal with liveness and readiness probes.

Lets walk through a simple template of API handler. If you intend to use kserve model server, your code must include methods such as load, predict and extend kserve.Model as shown below:

from kserve import Model, ModelServer

class MyCustomModel(Model):
  def __init__(self, name: str):
      super().__init__(name)
      self.name = name
      self.ready = False
      self.load()

  def load(self):
      # fetch your model from disk or remote
      self.model = ...

  def predict(self, payload: Dict, headers: Dict[str, str] = None) -> Dict:
      # read request input from payload dict
      # for example
      # inputs = payload["instances"]
      # source_text = inputs[0]["text"]

      # call inference
      result = ...

      return {"predictions": result}

  if __name__ == "__main__":

    # here we have named the model meta-llama-2-7b-chat but you may choose any name
    # of your choice. This is important because it impacts your REST endpoint.
    # lets say you define a model name as 'mnist' then your rest endpoints will end with # https://infer.e2enetworks.net/project/<project-id>/endpoint/is-<endpoint-id>v1/models/mnist

    model = MyCustomModel("meta-llama-2-7b-chat")
    ModelServer().start([model])

To take this further, create a project directory on your local or TIR notebook and create model_server.py with the following contents.

# filename: model_server.py
from kserve import Model, ModelServer
from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
import torch
from typing import List, Dict

class MetaLLMA2Model(Model):
    def __init__(self, name: str):
       super().__init__(name)
       self.name = name
       self.ready = False
       self.tokenizer = None
       #
       self.model_id = 'meta-llama/Llama-2-7b-chat-hf'
       self.load()

    def load(self):
        # this step fetches the model from huggingface directly. the downloads may take longer and be slow depending on upstream link. We recommend using TIR Models
        # instead
        self.model = AutoModelForCausalLM.from_pretrained(self.model_id,
                                                          trust_remote_code=True,
                                                          device_map='auto')

        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
        self.pipeline  = transformers.pipeline(
            "text-generation",
            model=self.model,
            torch_dtype=torch.float16,
            tokenizer=self.tokenizer,
            device_map="auto",
        )
        self.ready = True

    def predict(self, payload: Dict, headers: Dict[str, str] = None) -> Dict:
        inputs = payload["instances"]
        source_text = inputs[0]["text"]

        sequences = self.pipeline(source_text,
                                  do_sample=True,
                                  top_k=10,
                                  num_return_sequences=1,
                                  eos_token_id=self.tokenizer.eos_token_id,
                                  max_length=200,
                                  )

        results = []
        for seq in sequences:
            results.append(seq['generated_text'])

        return {"predictions": results}

if __name__ == "__main__":
    model = MetaLLMA2Model("meta-llama-2-7b-chat")
    ModelServer().start([model])

Note

The LLMA 2 model weights need to be downloaded from huggingface following the licensing terms. Once you have the weights on your local machine or TIR notebook, you can upload them to Model bucket (in EOS).

Step 2: Package the API handler in a container image

Now, Lets package our API handler (from step 1) using the docker file below:

# Dockerfile
FROM pytorch/torchserve-kfs:0.8.1-gpu

ENV APP_HOME /app
WORKDIR $APP_HOME

# Install production dependencies.
COPY requirements.txt ./
RUN pip install --no-cache-dir -r ./requirements.txt

# Copy local code to container image
COPY model_server.py ./

CMD ["python", "model_server.py"]

We need to create and push the container image to docker hub now. You may also choose to use your private repo.

docker build -t <your-docker-handle-here>/meta-llm2-server .
docker push <your-docker-handle-here>/meta-llm2-server

Note

You may run the docker to test the API provided your hardware can support the LLMA2 model. If you are on TIR Notebook with A10080 or your local machine can support the model then do go ahead and test the api locally.

Step 3: Configure a model endpoint in TIR to serve the model over REST API

Now have a container image in docker hub, we can define a model endpoint in TIR.

  • Go to TIR Dashboard

  • Select a Project

  • Go to Model Endpoints

  • Click Create Endpoint

  • Select Custom Container and Press Continue

  • Select a GPU plan - GDC3.A10080

  • Set Disk Size to 15G or higher depending on the model size

  • Click Next

  • Enter an appropriate name for the endpoint.

  • Click Next

  • In Container Details, enter image as <your-docker-handle-here>/meta-llm2-server and select other parameters as necessary.

  • In Environment details, enter these key-vals - - HUGGING_FACE_HUB_TOKEN: get the token from huggingface website - TRANSFORMERS_CACHE: /mnt/models

  • In Model Details, Do not select a model. In above example, we are fetching the model from huggingface directly so we don’t need to fetch model from EOS.

  • Click Finish and create the endpoint

If all goes well, you will see the endpoint come to a ready state. When it does, you can test the model using the curl commands from sample API request tab.

Sample API request to see readiness of the endpoint:

curl -H "Authorization: Bearer $token" https://infer.e2enetworks.net/project/<project>/endpoint/<endpoint-id>/v1/models/meta-llama-2-7b-chat

# If the model is ready, you will see a response.
# Response:{"name": "meta-llama-2-7b-chat", "ready": true/false}

Sample API request to test the model: .. code:

# Request format:{"instances": []}
# Response format:{"predictions": []}

curl  -H "Authorization: Bearer $token" -X POST https://infer.e2enetworks.net/project/<project>/endpoint/<endpoint-id>/v1/models/meta-llama-2-7b-chat:predict -d {'instances':[{"text": "Life is such that "]}

Step 4: Use TIR Models to improve launch time of containers

You will notice that the model endpoints takes a while to be deployed or may time out as well in some cases. This is because our model_server.py is trying to download the model directly from huggingface hub.

To fix this, we can define a TIR Model and host the model weights in in EOS bucket.

  • Go to TIR Dashboard

  • Go to Models

  • Create a new Model with name (e.g. my-model) with format custom

  • Once the TIR model is created, you will get EOS bucket details

  • Use the instructions from Setup Minio CLI tab to configure Minio Host on your local or TIR Notebook

  • Download the target model (e.g. meta-llama/Llama-2-7b-chat-hf) from huggingface hub

  • Upload the model code and weights(from $HOME/.cache/huggingface/hub/<model>/snapshot directory) to EOS bucket using minio cp command. You can use the cp command template from Setup Minio CLI tab.

  • Now, go ahead with the step 3 (above) but this time choose model (e.g. my-model) in model details section

  • The endpoint created will now ensure that the model weights are downloaded to /mnt/models directory before starting the API handler. You may need to also change the model_server.py to load weights from /mnt/models and not huggingface hub.