Skip to main content

Deploy Inference for Gemma

In this tutorial, we will create a model endpoint for Gemma model. Currently, we provide 4 variants: 2B, 2B-it, 7B, and 7B-it. In this tutorial, we will deploy the model endpoint for the 2B-it variant. The same steps will work for other variants.

The tutorial will mainly focus on the following:

Model Endpoint Creation for Gemma 2B-it using Prebuilt Container

When a model endpoint is created in the TIR dashboard, in the background, a model server is launched to serve the inference requests.

The TIR platform supports a variety of model formats through pre-built containers (e.g. PyTorch, Triton, LLaMA, MPT, etc.).

For the scope of this tutorial, we will use the pre-built container (Gemma 2B-it) for the model endpoint, but you may choose to create your own custom container by following this tutorial.

In most cases, the pre-built container would work for your use case. The advantage is you won't have to worry about building an API handler. The API handler will be automatically created for you.

Step 1: Create a Model Endpoint

  • Go to TIR AI Platform

  • Choose a project

  • Go to Model Endpoints section

    Model Endpoints

  • Create a new Endpoint

    Create Endpoint

  • Choose Gemma 2B-IT model card

    Gemma 2B-IT Model Card

  • Pick a suitable GPU plan of your choice & set the replicas

  • Setting Environment Variables

    Environment Variables

    Compulsory Environment Variables
    Note: Gemma is a Gated Model; you will need permission to access it. The model checkpoint will be downloaded from Kaggle. Follow these steps to get access to the model:

    • Visit Kaggle Gemma Model

    • Log in and click on "Request Access."

      Request Access

    • Go to Account Settings > API > Create New API Token once approved.

      API Token

    • Copy the API token key.

      • KAGGLE_KEY: Paste the API token key.
      • KAGGLE_USERNAME: Add the Kaggle username you used to request access to Gemma in the previous step.

Advanced Environment Variables

We do not recommend modifying these values for general use cases. The value of these parameters largely depends on the GPU you use. These parameters will be used to configure the server model using TensorRT-LLM.

  • MAX_BATCH_SIZE: Denotes the maximum number of input sequences that the model can handle concurrently in a single batch during inference. Smaller values for this parameter mean that fewer input sequences are processed simultaneously. While smaller batch sizes may reduce memory usage and allow for more efficient processing, they may also result in lower throughput and less parallelism, particularly on hardware optimized for larger batch sizes.

  • MAX_INPUT_LEN: Specifies the maximum length of input sequences that the model can accept during inference. For language models, smaller values for this parameter limit the number of tokens (words or subwords) in the input text that the model can process at once. Limiting the input length can help control memory usage and computation complexity, but it may also lead to information loss if important context is truncated.

  • MAX_OUTPUT_LEN: Represents the maximum length of output sequences generated by the model. Smaller values for this parameter restrict the length of generated text, potentially preventing overly verbose or irrelevant output. However, overly small values may limit the model's ability to produce coherent and meaningful responses, particularly for tasks like text generation or completion where longer outputs may be necessary to convey complete thoughts or responses.

  • Complete the endpoint creation.

  • Model creation might take a few minutes. Check logs for updates. Once the inference server starts, the log will resemble the image provided below.

    Inference Server Logs

Step 2: Generate Your API_TOKEN

The model endpoint API requires a valid auth token, which you'll need to perform further steps. So, let's generate one.

  • Go to the API Tokens section under the project.

  • Create a new API Token. by clicking on the Create Token button on the top right corner. You can also use an existing token if one is already created.

  • Once created, you'll be able to see the list of API Tokens containing the API Key and Auth Token. You will need this Auth Token in the next step.

    API Token

Step 3: Inferring Request

  • When your endpoint is ready, visit the Sample API request section to test your endpoint using curl.

    Sample API Request

Supported Parameters

FieldDescriptionShapeData Type
text_inputInput text to be used as a prompt for text generation.[-1]TYPE_STRING
max_tokensThe maximum number of tokens to generate in the output text.[-1]TYPE_INT32
bad_wordsA list of words or phrases that should not appear in the generated text.[-1]TYPE_STRING
stop_wordsA list of words that are considered stop words and are excluded from the generation.[-1]TYPE_STRING
end_idThe token ID marking the end of a sequence.[1]TYPE_INT32
pad_idThe token ID used for padding sequences.[1]TYPE_INT32
top_kThe number of highest probability vocabulary tokens to consider for generation.[1]TYPE_INT32
top_pNucleus sampling parameter, limiting the cumulative probability of tokens.[1]TYPE_FP32
temperatureControls the randomness of token selection during generation.[1]TYPE_FP32
length_penaltyPenalty applied to the length of the generated text.[1]TYPE_FP32
repetition_penaltyPenalty applied to repeated sequences in the generated text.[1]TYPE_FP32
min_lengthThe minimum number of tokens in the generated text.[1]TYPE_INT32
presence_penaltyPenalty applied based on the presence of specific tokens in the generated text.[1]TYPE_FP32
frequency_penaltyPenalty applied based on the frequency of tokens in the generated text.[1]TYPE_FP32
random_seedSeed for controlling the randomness of generation.[1]TYPE_UINT64
return_log_probsWhether to return log probabilities for each token.[1]TYPE_BOOL
return_context_logitsWhether to return logits for each token in the context.[1]TYPE_BOOL
return_generation_logitsWhether to return logits for each token in the generated text.[1]TYPE_BOOL
prompt_embedding_tableTable of embeddings for words in the prompt.[-1, -1]TYPE_FP16
prompt_vocab_sizeSize of the vocabulary for prompt embeddings.[1]TYPE_INT32
embedding_bias_wordsWords to bias the word embeddings.[-1]TYPE_STRING
embedding_bias_weightsWeights for the biasing of word embeddings.[-1]TYPE_FP32
cum_log_probsCumulative log probabilities of generated tokens.[-1]TYPE_FP32
output_log_probsLog probabilities of each token in the generated text.[-1, -1]TYPE_FP32
context_logitsLogits for each token in the context.[-1, -1]TYPE_FP32
generation_logitsLogits for each token in the generated text.[-1, -1, -1]TYPE_FP32