Deploy Inference for Gemma
In this tutorial, we will create a model endpoint for Gemma model. Currently, we provide 4 variants: 2B, 2B-it, 7B, and 7B-it. In this tutorial, we will deploy the model endpoint for the 2B-it variant. The same steps will work for other variants.
The tutorial will mainly focus on the following:
- Model Endpoint creation for Gemma 2B-it using prebuilt container
- Brief description about the Supported parameters
Model Endpoint Creation for Gemma 2B-it using Prebuilt Container
When a model endpoint is created in the TIR dashboard, in the background, a model server is launched to serve the inference requests.
The TIR platform supports a variety of model formats through pre-built containers (e.g. PyTorch, Triton, LLaMA, MPT, etc.).
For the scope of this tutorial, we will use the pre-built container (Gemma 2B-it) for the model endpoint, but you may choose to create your own custom container by following this tutorial.
In most cases, the pre-built container would work for your use case. The advantage is you won't have to worry about building an API handler. The API handler will be automatically created for you.
Step 1: Create a Model Endpoint
-
Go to TIR AI Platform
-
Choose a project
-
Go to Model Endpoints section
-
Create a new Endpoint
-
Choose Gemma 2B-IT model card
-
Pick a suitable GPU plan of your choice & set the replicas
-
Setting Environment Variables
Compulsory Environment Variables
Note: Gemma is a Gated Model; you will need permission to access it. The model checkpoint will be downloaded from Kaggle. Follow these steps to get access to the model:-
Visit Kaggle Gemma Model
-
Log in and click on "Request Access."
-
Go to Account Settings > API > Create New API Token once approved.
-
Copy the API token key.
- KAGGLE_KEY: Paste the API token key.
- KAGGLE_USERNAME: Add the Kaggle username you used to request access to Gemma in the previous step.
-
Advanced Environment Variables
We do not recommend modifying these values for general use cases. The value of these parameters largely depends on the GPU you use. These parameters will be used to configure the server model using TensorRT-LLM.
-
MAX_BATCH_SIZE: Denotes the maximum number of input sequences that the model can handle concurrently in a single batch during inference. Smaller values for this parameter mean that fewer input sequences are processed simultaneously. While smaller batch sizes may reduce memory usage and allow for more efficient processing, they may also result in lower throughput and less parallelism, particularly on hardware optimized for larger batch sizes.
-
MAX_INPUT_LEN: Specifies the maximum length of input sequences that the model can accept during inference. For language models, smaller values for this parameter limit the number of tokens (words or subwords) in the input text that the model can process at once. Limiting the input length can help control memory usage and computation complexity, but it may also lead to information loss if important context is truncated.
-
MAX_OUTPUT_LEN: Represents the maximum length of output sequences generated by the model. Smaller values for this parameter restrict the length of generated text, potentially preventing overly verbose or irrelevant output. However, overly small values may limit the model's ability to produce coherent and meaningful responses, particularly for tasks like text generation or completion where longer outputs may be necessary to convey complete thoughts or responses.
-
Complete the endpoint creation.
-
Model creation might take a few minutes. Check logs for updates. Once the inference server starts, the log will resemble the image provided below.
Step 2: Generate Your API_TOKEN
The model endpoint API requires a valid auth token, which you'll need to perform further steps. So, let's generate one.
-
Go to the API Tokens section under the project.
-
Create a new API Token. by clicking on the Create Token button on the top right corner. You can also use an existing token if one is already created.
-
Once created, you'll be able to see the list of API Tokens containing the API Key and Auth Token. You will need this Auth Token in the next step.
Step 3: Inferring Request
-
When your endpoint is ready, visit the Sample API request section to test your endpoint using curl.
Supported Parameters
Field | Description | Shape | Data Type |
---|---|---|---|
text_input | Input text to be used as a prompt for text generation. | [-1] | TYPE_STRING |
max_tokens | The maximum number of tokens to generate in the output text. | [-1] | TYPE_INT32 |
bad_words | A list of words or phrases that should not appear in the generated text. | [-1] | TYPE_STRING |
stop_words | A list of words that are considered stop words and are excluded from the generation. | [-1] | TYPE_STRING |
end_id | The token ID marking the end of a sequence. | [1] | TYPE_INT32 |
pad_id | The token ID used for padding sequences. | [1] | TYPE_INT32 |
top_k | The number of highest probability vocabulary tokens to consider for generation. | [1] | TYPE_INT32 |
top_p | Nucleus sampling parameter, limiting the cumulative probability of tokens. | [1] | TYPE_FP32 |
temperature | Controls the randomness of token selection during generation. | [1] | TYPE_FP32 |
length_penalty | Penalty applied to the length of the generated text. | [1] | TYPE_FP32 |
repetition_penalty | Penalty applied to repeated sequences in the generated text. | [1] | TYPE_FP32 |
min_length | The minimum number of tokens in the generated text. | [1] | TYPE_INT32 |
presence_penalty | Penalty applied based on the presence of specific tokens in the generated text. | [1] | TYPE_FP32 |
frequency_penalty | Penalty applied based on the frequency of tokens in the generated text. | [1] | TYPE_FP32 |
random_seed | Seed for controlling the randomness of generation. | [1] | TYPE_UINT64 |
return_log_probs | Whether to return log probabilities for each token. | [1] | TYPE_BOOL |
return_context_logits | Whether to return logits for each token in the context. | [1] | TYPE_BOOL |
return_generation_logits | Whether to return logits for each token in the generated text. | [1] | TYPE_BOOL |
prompt_embedding_table | Table of embeddings for words in the prompt. | [-1, -1] | TYPE_FP16 |
prompt_vocab_size | Size of the vocabulary for prompt embeddings. | [1] | TYPE_INT32 |
embedding_bias_words | Words to bias the word embeddings. | [-1] | TYPE_STRING |
embedding_bias_weights | Weights for the biasing of word embeddings. | [-1] | TYPE_FP32 |
cum_log_probs | Cumulative log probabilities of generated tokens. | [-1] | TYPE_FP32 |
output_log_probs | Log probabilities of each token in the generated text. | [-1, -1] | TYPE_FP32 |
context_logits | Logits for each token in the context. | [-1, -1] | TYPE_FP32 |
generation_logits | Logits for each token in the generated text. | [-1, -1, -1] | TYPE_FP32 |