Spaces:
Sleeping
Sleeping
import os | |
import requests | |
import time | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import StreamingResponse, HTMLResponse | |
from llama_cpp import Llama | |
from pydantic import BaseModel | |
import uvicorn | |
from typing import Generator | |
import threading | |
# Configuration | |
MODEL_URL = "https://huggingface.co/unsloth/DeepSeek-R1-Distill-Qwen-1.5B-GGUF/resolve/main/DeepSeek-R1-Distill-Qwen-1.5B-Q4_K_M.gguf" # Changed to Q4 for faster inference | |
MODEL_NAME = "DeepSeek-R1-Distill-Qwen-1.5B-Q4_K_M.gguf" | |
MODEL_DIR = "model" | |
MODEL_PATH = os.path.join(MODEL_DIR, MODEL_NAME) | |
# Create model directory if it doesn't exist | |
os.makedirs(MODEL_DIR, exist_ok=True) | |
# Download the model if it doesn't exist | |
if not os.path.exists(MODEL_PATH): | |
print(f"Downloading model from {MODEL_URL}...") | |
response = requests.get(MODEL_URL, stream=True) | |
if response.status_code == 200: | |
with open(MODEL_PATH, "wb") as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
f.write(chunk) | |
print("Model downloaded successfully!") | |
else: | |
raise RuntimeError(f"Failed to download model: HTTP {response.status_code}") | |
else: | |
print("Model already exists. Skipping download.") | |
# Initialize FastAPI | |
app = FastAPI( | |
title="DeepSeek-R1 OpenAI-Compatible API", | |
description="Optimized OpenAI-compatible API with streaming support", | |
version="2.0.0" | |
) | |
# CORS Configuration | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Global model loader with optimized settings | |
print("Loading model with optimized settings...") | |
try: | |
llm = Llama( | |
model_path=MODEL_PATH, | |
n_ctx=1024, # Reduced context window for faster processing | |
n_threads=8, # Increased threads for better CPU utilization | |
n_batch=512, # Larger batch size for improved throughput | |
n_gpu_layers=0, | |
use_mlock=True, # Prevent swapping to disk | |
verbose=False | |
) | |
print("Model loaded with optimized settings!") | |
except Exception as e: | |
raise RuntimeError(f"Failed to load model: {str(e)}") | |
# Streaming generator | |
def generate_stream(prompt: str, max_tokens: int, temperature: float, top_p: float) -> Generator[str, None, None]: | |
start_time = time.time() | |
stream = llm.create_completion( | |
prompt=prompt, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
stop=["</s>"], | |
stream=True | |
) | |
for chunk in stream: | |
delta = chunk['choices'][0]['text'] | |
yield f"data: {delta}\n\n" | |
# Early stopping if taking too long | |
if time.time() - start_time > 30: # 30s timeout | |
break | |
# OpenAI-Compatible Request Schema | |
class ChatCompletionRequest(BaseModel): | |
model: str = "DeepSeek-R1-Distill-Qwen-1.5B" | |
messages: list[dict] | |
max_tokens: int = 256 | |
temperature: float = 0.7 | |
top_p: float = 0.9 | |
stream: bool = False | |
# Enhanced root endpoint with performance info | |
async def root(): | |
return f""" | |
<html> | |
<head> | |
<title>DeepSeek-R1 Optimized API</title> | |
<style> | |
body {{ font-family: Arial, sans-serif; max-width: 800px; margin: 20px auto; padding: 0 20px; }} | |
.warning {{ color: #dc3545; background: #ffeef0; padding: 15px; border-radius: 5px; }} | |
.info {{ color: #0c5460; background: #d1ecf1; padding: 15px; border-radius: 5px; }} | |
a {{ color: #007bff; text-decoration: none; }} | |
code {{ background: #f8f9fa; padding: 2px 4px; border-radius: 4px; }} | |
</style> | |
</head> | |
<body> | |
<h1>DeepSeek-R1 Optimized API</h1> | |
<div class="warning"> | |
<h3>⚠️ Important Notice</h3> | |
<p>For private use, please duplicate this space:<br> | |
1. Click your profile picture in the top-right<br> | |
2. Select "Duplicate Space"<br> | |
3. Set visibility to Private</p> | |
</div> | |
<div class="info"> | |
<h3>⚡ Performance Optimizations</h3> | |
<ul> | |
<li>Quantization: Q4_K_M (optimized speed/quality balance)</li> | |
<li>Batch processing: 512 tokens/chunk</li> | |
<li>Streaming support with 30s timeout</li> | |
<li>8 CPU threads utilization</li> | |
</ul> | |
</div> | |
<h2>API Documentation</h2> | |
<ul> | |
<li><a href="/docs">Interactive Swagger Documentation</a></li> | |
<li><a href="/redoc">ReDoc Documentation</a></li> | |
</ul> | |
<h2>Example Streaming Request</h2> | |
<pre> | |
curl -N -X POST "{os.environ.get('SPACE_HOST', 'http://localhost:7860')}/v1/chat/completions" \\ | |
-H "Content-Type: application/json" \\ | |
-d '{{ | |
"messages": [{{"role": "user", "content": "Explain quantum computing"}}], | |
"stream": true, | |
"max_tokens": 150 | |
}}' | |
</pre> | |
</body> | |
</html> | |
""" | |
# Async endpoint handler | |
async def chat_completion(request: ChatCompletionRequest): | |
try: | |
prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in request.messages]) | |
prompt += "\nassistant:" | |
if request.stream: | |
return StreamingResponse( | |
generate_stream( | |
prompt=prompt, | |
max_tokens=request.max_tokens, | |
temperature=request.temperature, | |
top_p=request.top_p | |
), | |
media_type="text/event-stream" | |
) | |
# Non-streaming response | |
start_time = time.time() | |
response = llm( | |
prompt=prompt, | |
max_tokens=request.max_tokens, | |
temperature=request.temperature, | |
top_p=request.top_p, | |
stop=["</s>"] | |
) | |
return { | |
"id": f"chatcmpl-{int(time.time())}", | |
"object": "chat.completion", | |
"created": int(time.time()), | |
"model": request.model, | |
"choices": [{ | |
"index": 0, | |
"message": { | |
"role": "assistant", | |
"content": response['choices'][0]['text'].strip() | |
}, | |
"finish_reason": "stop" | |
}], | |
"usage": { | |
"prompt_tokens": len(prompt), | |
"completion_tokens": len(response['choices'][0]['text']), | |
"total_tokens": len(prompt) + len(response['choices'][0]['text']) | |
} | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
return { | |
"status": "healthy", | |
"model_loaded": True, | |
"performance_settings": { | |
"n_threads": llm.params.n_threads, | |
"n_ctx": llm.params.n_ctx, | |
"n_batch": llm.params.n_batch | |
} | |
} | |
if __name__ == "__main__": | |
uvicorn.run( | |
app, | |
host="0.0.0.0", | |
port=7860, | |
timeout_keep_alive=300 # Keep alive for streaming connections | |
) |