vllm-inference / main.py
yusufs's picture
feat(t4-gpu): add t4 gpu capability
4998ce7
raw
history blame
2.62 kB
import torch
from typing import Any
from typing import Optional
from fastapi import FastAPI
from pydantic import BaseModel
from vllm import LLM, SamplingParams, RequestOutput
# Don't forget to set HF_TOKEN in the env during running
app = FastAPI()
# Initialize the LLM engine
# Replace 'your-model-path' with the actual path or name of your model
engine_llama_3_2: LLM = LLM(
model='meta-llama/Llama-3.2-3B-Instruct',
revision="0cb88a4f764b7a12671c53f0838cd831a0843b95",
max_num_batched_tokens=512, # Reduced for T4
max_num_seqs=16, # Reduced for T4
gpu_memory_utilization=0.85, # Slightly increased, adjust if needed
max_model_len=131072, # Llama-3.2-3B-Instruct context length
enforce_eager=True, # Disable CUDA graph
dtype='half', # Use 'half' if you want half precision
)
@app.get("/")
def greet_json():
cuda_info: dict[str, Any] = {}
if torch.cuda.is_available():
cuda_current_device: int = torch.cuda.current_device()
cuda_info = {
"device_count": torch.cuda.device_count(),
"cuda_device": torch.cuda.get_device_name(cuda_current_device),
"cuda_capability": torch.cuda.get_device_capability(cuda_current_device),
"allocated": f"{round(torch.cuda.memory_allocated(cuda_current_device) / 1024 ** 3, 1)} GB",
"cached": f"{round(torch.cuda.memory_reserved(cuda_current_device) / 1024 ** 3, 1)} GB",
}
return {
"message": f"CUDA availability is {torch.cuda.is_available()}",
"cuda_info": cuda_info,
"model": [
{
"name": "meta-llama/Llama-3.2-3B-Instruct",
"revision": "0cb88a4f764b7a12671c53f0838cd831a0843b95",
}
]
}
class GenerationRequest(BaseModel):
prompt: str
max_tokens: int = 100
temperature: float = 0.7
logit_bias: Optional[dict[int, float]] = None
class GenerationResponse(BaseModel):
text: Optional[str]
error: Optional[str]
@app.post("/generate-llama3-2")
def generate_text(request: GenerationRequest) -> list[RequestOutput] | dict[str, str]:
try:
sampling_params: SamplingParams = SamplingParams(
temperature=request.temperature,
max_tokens=request.max_tokens,
logit_bias=request.logit_bias,
)
# Generate text
return engine_llama_3_2.generate(
prompts=request.prompt,
sampling_params=sampling_params
)
except Exception as e:
return {
"error": str(e)
}