vllm-inference / main.py
yusufs's picture
feat(reduce-max-length): reduce maximum length
2425953
raw
history blame
2.9 kB
import torch
from typing import Any
from typing import Optional
from fastapi import FastAPI
from pydantic import BaseModel
from vllm import LLM, SamplingParams, RequestOutput
# Don't forget to set HF_TOKEN in the env during running
app = FastAPI()
# Initialize the LLM engine
# Replace 'your-model-path' with the actual path or name of your model
engine_llama_3_2: LLM = LLM(
model='meta-llama/Llama-3.2-3B-Instruct',
revision="0cb88a4f764b7a12671c53f0838cd831a0843b95",
max_num_batched_tokens=512, # Reduced for T4
max_num_seqs=16, # Reduced for T4
gpu_memory_utilization=0.85, # Slightly increased, adjust if needed
# Llama-3.2-3B-Instruct max context length is 131072, but we reduce it to 32k.
# 32k tokens, 3/4 of 32k is 24k words, each page average is 500 or 0.5k words,
# so that's basically 24k / .5k = 24 x 2 =~48 pages.
# Because when we use maximum token length, it will be slower and the memory is not enough for T4.
max_model_len=32768,
enforce_eager=True, # Disable CUDA graph
dtype='auto', # Use 'half' if you want half precision
)
@app.get("/")
def greet_json():
cuda_info: dict[str, Any] = {}
if torch.cuda.is_available():
cuda_current_device: int = torch.cuda.current_device()
cuda_info = {
"device_count": torch.cuda.device_count(),
"cuda_device": torch.cuda.get_device_name(cuda_current_device),
"cuda_capability": torch.cuda.get_device_capability(cuda_current_device),
"allocated": f"{round(torch.cuda.memory_allocated(cuda_current_device) / 1024 ** 3, 1)} GB",
"cached": f"{round(torch.cuda.memory_reserved(cuda_current_device) / 1024 ** 3, 1)} GB",
}
return {
"message": f"CUDA availability is {torch.cuda.is_available()}",
"cuda_info": cuda_info,
"model": [
{
"name": "meta-llama/Llama-3.2-3B-Instruct",
"revision": "0cb88a4f764b7a12671c53f0838cd831a0843b95",
}
]
}
class GenerationRequest(BaseModel):
prompt: str
max_tokens: int = 100
temperature: float = 0.7
logit_bias: Optional[dict[int, float]] = None
class GenerationResponse(BaseModel):
text: Optional[str]
error: Optional[str]
@app.post("/generate-llama3-2")
def generate_text(request: GenerationRequest) -> list[RequestOutput] | dict[str, str]:
try:
sampling_params: SamplingParams = SamplingParams(
temperature=request.temperature,
max_tokens=request.max_tokens,
logit_bias=request.logit_bias,
)
# Generate text
return engine_llama_3_2.generate(
prompts=request.prompt,
sampling_params=sampling_params
)
except Exception as e:
return {
"error": str(e)
}