Spaces:
Paused
Paused
from fastapi import APIRouter, HTTPException | |
from fastapi.responses import StreamingResponse | |
from pydantic import BaseModel | |
from typing import Optional, Union, AsyncGenerator | |
import torch | |
import logging | |
from pathlib import Path | |
from litgpt.api import LLM | |
import json | |
import asyncio | |
# Set up logging | |
logger = logging.getLogger(__name__) | |
# Create router instance | |
router = APIRouter() | |
# Global variable to store the LLM instance | |
llm_instance = None | |
class InitializeRequest(BaseModel): | |
""" | |
Configuration for model initialization including model path | |
""" | |
mode: str = "cpu" | |
precision: Optional[str] = None | |
quantize: Optional[str] = None | |
gpu_count: Union[str, int] = "auto" | |
model_path: str | |
class GenerateRequest(BaseModel): | |
prompt: str | |
max_new_tokens: int = 50 | |
temperature: float = 1.0 | |
top_k: Optional[int] = None | |
top_p: float = 1.0 | |
return_as_token_ids: bool = False | |
stream: bool = False | |
# A Pydantic model for the streaming generation request | |
class StreamGenerateRequest(BaseModel): | |
prompt: str | |
max_new_tokens: int = 50 | |
temperature: float = 1.0 | |
top_k: Optional[int] = None | |
top_p: float = 1.0 | |
class InitializeCustomRequest(BaseModel): | |
""" | |
Configuration for custom model initialization using from_pretrained | |
""" | |
mode: str = "cpu" | |
precision: Optional[str] = None | |
quantize: Optional[str] = None | |
gpu_count: Union[str, int] = "auto" | |
folder_path: str # Path to the model folder relative to checkpoints | |
model_filename: str # Name of the model file (e.g., "lit_model.pth") | |
config_filename: str = "config.json" # Default config filename | |
tokenizer_filename: Optional[str] = "tokenizer.json" # Optional tokenizer filename | |
async def initialize_custom_model(request: InitializeCustomRequest): | |
""" | |
Initialize a custom model using from_pretrained method. | |
This is for models that are already downloaded and stored in the checkpoints directory. | |
""" | |
global llm_instance | |
try: | |
# Get the project root directory and construct paths | |
project_root = Path(__file__).parent | |
checkpoints_dir = project_root / "checkpoints" | |
model_dir = checkpoints_dir / request.folder_path | |
logger.info(f"Loading custom model from directory: {model_dir}") | |
# Verify that all required files exist | |
model_path = model_dir / request.model_filename | |
config_path = model_dir / request.config_filename | |
if not model_path.exists(): | |
raise HTTPException( | |
status_code=400, | |
detail=f"Model file not found: {request.model_filename}" | |
) | |
if not config_path.exists(): | |
raise HTTPException( | |
status_code=400, | |
detail=f"Config file not found: {request.config_filename}" | |
) | |
# Check for tokenizer if specified | |
tokenizer_path = None | |
if request.tokenizer_filename: | |
tokenizer_path = model_dir / request.tokenizer_filename | |
if not tokenizer_path.exists(): | |
raise HTTPException( | |
status_code=400, | |
detail=f"Tokenizer file not found: {request.tokenizer_filename}" | |
) | |
# Load the model using from_pretrained | |
llm_instance = LLM.from_pretrained( | |
path=str(model_dir), | |
model_file=request.model_filename, | |
config_file=request.config_filename, | |
tokenizer_file=request.tokenizer_filename if request.tokenizer_filename else None, | |
distribute=None if request.precision or request.quantize else "auto" | |
) | |
# If manual distribution is needed | |
if request.precision or request.quantize: | |
llm_instance.distribute( | |
accelerator="cuda" if request.mode == "gpu" else "cpu", | |
devices=request.gpu_count, | |
precision=request.precision, | |
quantize=request.quantize | |
) | |
# Log success and memory stats | |
logger.info( | |
f"Custom model initialized successfully with config:\n" | |
f"Mode: {request.mode}\n" | |
f"Precision: {request.precision}\n" | |
f"Quantize: {request.quantize}\n" | |
f"GPU Count: {request.gpu_count}\n" | |
f"Model Directory: {model_dir}\n" | |
f"Model File: {request.model_filename}\n" | |
f"Config File: {request.config_filename}\n" | |
f"Tokenizer File: {request.tokenizer_filename}\n" | |
f"Current GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, " | |
f"{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved" | |
) | |
return { | |
"success": True, | |
"message": "Custom model initialized successfully", | |
"model_info": { | |
"folder": str(model_dir), | |
"model_file": request.model_filename, | |
"config_file": request.config_filename, | |
"tokenizer_file": request.tokenizer_filename | |
} | |
} | |
except Exception as e: | |
logger.error(f"Error initializing custom model: {str(e)}") | |
# Print detailed memory statistics on failure | |
logger.error(f"GPU Memory Stats:\n" | |
f"Allocated: {torch.cuda.memory_allocated()/1024**3:.2f}GB\n" | |
f"Reserved: {torch.cuda.memory_reserved()/1024**3:.2f}GB\n" | |
f"Max Allocated: {torch.cuda.max_memory_allocated()/1024**3:.2f}GB") | |
raise HTTPException(status_code=500, detail=f"Error initializing custom model: {str(e)}") | |
# Endpoint for streaming generation | |
async def generate_stream(request: StreamGenerateRequest): | |
""" | |
Generate text using the initialized model with streaming response. | |
Returns a StreamingResponse that yields JSON-formatted chunks of text. | |
""" | |
global llm_instance | |
if llm_instance is None: | |
raise HTTPException( | |
status_code=400, | |
detail="Model not initialized. Call /initialize first." | |
) | |
async def event_generator() -> AsyncGenerator[str, None]: | |
try: | |
# Start the generation with streaming enabled | |
async for token in llm_instance.generate( | |
prompt=request.prompt, | |
max_new_tokens=request.max_new_tokens, | |
temperature=request.temperature, | |
top_k=request.top_k, | |
top_p=request.top_p, | |
stream=True # Enable streaming | |
): | |
# Create a JSON response for each token | |
chunk = { | |
"token": token, | |
"metadata": { | |
"prompt": request.prompt, | |
"is_finished": False | |
} | |
} | |
# Format as SSE data | |
yield f"data: {json.dumps(chunk)}\n\n" | |
# Small delay to prevent overwhelming the client | |
await asyncio.sleep(0.01) | |
# Send final message indicating completion | |
final_chunk = { | |
"token": "", | |
"metadata": { | |
"prompt": request.prompt, | |
"is_finished": True | |
} | |
} | |
yield f"data: {json.dumps(final_chunk)}\n\n" | |
except Exception as e: | |
logger.error(f"Error in stream generation: {str(e)}") | |
error_chunk = { | |
"error": str(e), | |
"metadata": { | |
"prompt": request.prompt, | |
"is_finished": True | |
} | |
} | |
yield f"data: {json.dumps(error_chunk)}\n\n" | |
return StreamingResponse( | |
event_generator(), | |
media_type="text/event-stream", | |
headers={ | |
'Cache-Control': 'no-cache', | |
'Connection': 'keep-alive', | |
} | |
) | |
async def root(): | |
"""Root endpoint to verify service is running""" | |
return { | |
"status": "running", | |
"service": "LLM Engine", | |
"endpoints": { | |
"initialize": "/initialize", | |
"generate": "/generate", | |
"health": "/health" | |
} | |
} | |
async def initialize_model(request: InitializeRequest): | |
""" | |
Initialize the LLM model with specified configuration. | |
""" | |
global llm_instance | |
try: | |
# Get the project root directory (where main.py is located) | |
project_root = Path(__file__).parent | |
checkpoints_dir = project_root / "checkpoints" | |
logger.info(f"Checkpoint dir is: {checkpoints_dir}") | |
# For LitGPT downloaded models, path includes organization | |
if "/" in request.model_path: | |
# e.g., "mistralai/Mistral-7B-Instruct-v0.3" | |
org, model_name = request.model_path.split("/") | |
model_path = str(checkpoints_dir / org / model_name) | |
else: | |
# Fallback for direct model paths | |
model_path = str(checkpoints_dir / request.model_path) | |
logger.info(f"Using model path: {model_path}") | |
# Load the model | |
logger.info("Loading model") | |
llm_instance = LLM.load( | |
model=model_path, | |
distribute=None if request.precision or request.quantize else "auto" | |
) | |
logger.info("Done loading model") | |
# If manual distribution is needed | |
logger.info("Distributing model") | |
if request.precision or request.quantize: | |
llm_instance.distribute( | |
accelerator="cuda" if request.mode == "gpu" else "cpu", | |
devices=request.gpu_count, | |
precision=request.precision, | |
quantize=request.quantize | |
) | |
logger.info("Done distributing model") | |
logger.info( | |
f"Model initialized successfully with config:\n" | |
f"Mode: {request.mode}\n" | |
f"Precision: {request.precision}\n" | |
f"Quantize: {request.quantize}\n" | |
f"GPU Count: {request.gpu_count}\n" | |
f"Model Path: {model_path}\n" | |
f"Current GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, " | |
f"{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved" | |
) | |
return {"success": True, "message": "Model initialized successfully"} | |
except Exception as e: | |
logger.error(f"Error initializing model: {str(e)}") | |
# Print detailed memory statistics on failure | |
logger.error(f"GPU Memory Stats:\n" | |
f"Allocated: {torch.cuda.memory_allocated()/1024**3:.2f}GB\n" | |
f"Reserved: {torch.cuda.memory_reserved()/1024**3:.2f}GB\n" | |
f"Max Allocated: {torch.cuda.max_memory_allocated()/1024**3:.2f}GB") | |
raise HTTPException(status_code=500, detail=f"Error initializing model: {str(e)}") | |
async def generate(request: GenerateRequest): | |
""" | |
Generate text using the initialized model. | |
""" | |
global llm_instance | |
if llm_instance is None: | |
raise HTTPException(status_code=400, detail="Model not initialized. Call /initialize first.") | |
try: | |
if request.stream: | |
raise HTTPException( | |
status_code=400, | |
detail="Streaming is not currently supported through the API" | |
) | |
generated_text = llm_instance.generate( | |
prompt=request.prompt, | |
max_new_tokens=request.max_new_tokens, | |
temperature=request.temperature, | |
top_k=request.top_k, | |
top_p=request.top_p, | |
return_as_token_ids=request.return_as_token_ids, | |
stream=False # Force stream to False for now | |
) | |
response = { | |
"generated_text": generated_text if not request.return_as_token_ids else generated_text.tolist(), | |
"metadata": { | |
"prompt": request.prompt, | |
"max_new_tokens": request.max_new_tokens, | |
"temperature": request.temperature, | |
"top_k": request.top_k, | |
"top_p": request.top_p | |
} | |
} | |
return response | |
except Exception as e: | |
logger.error(f"Error generating text: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}") | |
async def health_check(): | |
""" | |
Check if the service is running and model is loaded. | |
""" | |
global llm_instance | |
status = { | |
"status": "healthy", | |
"model_loaded": llm_instance is not None, | |
} | |
if llm_instance is not None: | |
logger.info(f"llm_instance is: {llm_instance}") | |
status["model_info"] = { | |
"model_path": llm_instance.config.name, | |
"device": str(next(llm_instance.model.parameters()).device) | |
} | |
return status |