import torch from typing import Any from typing import Optional from fastapi import FastAPI from pydantic import BaseModel from vllm import LLM, SamplingParams, RequestOutput # Don't forget to set HF_TOKEN in the env during running app = FastAPI() # Initialize the LLM engine # Replace 'your-model-path' with the actual path or name of your model engine_llama_3_2: LLM = LLM( model='meta-llama/Llama-3.2-3B-Instruct', revision="0cb88a4f764b7a12671c53f0838cd831a0843b95", max_num_batched_tokens=512, # Reduced for T4 max_num_seqs=16, # Reduced for T4 gpu_memory_utilization=0.85, # Slightly increased, adjust if needed max_model_len=131072, # Llama-3.2-3B-Instruct context length enforce_eager=True, # Disable CUDA graph dtype='half', # Use 'half' if you want half precision ) @app.get("/") def greet_json(): cuda_info: dict[str, Any] = {} if torch.cuda.is_available(): cuda_current_device: int = torch.cuda.current_device() cuda_info = { "device_count": torch.cuda.device_count(), "cuda_device": torch.cuda.get_device_name(cuda_current_device), "cuda_capability": torch.cuda.get_device_capability(cuda_current_device), "allocated": f"{round(torch.cuda.memory_allocated(cuda_current_device) / 1024 ** 3, 1)} GB", "cached": f"{round(torch.cuda.memory_reserved(cuda_current_device) / 1024 ** 3, 1)} GB", } return { "message": f"CUDA availability is {torch.cuda.is_available()}", "cuda_info": cuda_info, "model": [ { "name": "meta-llama/Llama-3.2-3B-Instruct", "revision": "0cb88a4f764b7a12671c53f0838cd831a0843b95", } ] } class GenerationRequest(BaseModel): prompt: str max_tokens: int = 100 temperature: float = 0.7 logit_bias: Optional[dict[int, float]] = None class GenerationResponse(BaseModel): text: Optional[str] error: Optional[str] @app.post("/generate-llama3-2") def generate_text(request: GenerationRequest) -> list[RequestOutput] | dict[str, str]: try: sampling_params: SamplingParams = SamplingParams( temperature=request.temperature, max_tokens=request.max_tokens, logit_bias=request.logit_bias, ) # Generate text return engine_llama_3_2.generate( prompts=request.prompt, sampling_params=sampling_params ) except Exception as e: return { "error": str(e) }