File size: 7,357 Bytes
3ddd5b6
 
b839d79
3ddd5b6
 
b839d79
3ddd5b6
 
 
b839d79
 
3ddd5b6
f023e65
b839d79
 
3ddd5b6
 
 
f023e65
3ddd5b6
 
f023e65
3ddd5b6
 
 
 
 
 
 
 
 
 
 
 
 
f023e65
 
 
b839d79
 
f023e65
3ddd5b6
 
 
 
 
 
 
 
 
b839d79
 
3ddd5b6
 
 
b839d79
 
 
f023e65
b839d79
3ddd5b6
 
b839d79
3ddd5b6
 
 
b839d79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f023e65
 
 
 
 
b839d79
f023e65
 
 
b839d79
f023e65
 
 
 
 
b839d79
f023e65
 
 
 
 
 
 
 
 
b839d79
 
 
 
 
 
 
 
 
 
f023e65
 
 
 
 
 
b839d79
f023e65
b839d79
f023e65
 
 
b839d79
f023e65
 
 
 
 
 
 
b839d79
3ddd5b6
 
 
 
 
 
b839d79
 
 
 
 
 
 
 
 
 
 
 
 
3ddd5b6
 
 
 
 
 
 
 
b839d79
 
 
 
 
 
3ddd5b6
 
 
 
 
 
 
b839d79
3ddd5b6
 
 
 
b839d79
 
3ddd5b6
 
 
 
b839d79
 
 
 
 
 
 
 
 
 
3ddd5b6
 
b839d79
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import os
import requests
import time
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, HTMLResponse
from llama_cpp import Llama
from pydantic import BaseModel
import uvicorn
from typing import Generator
import threading

# Configuration
MODEL_URL = "https://huggingface.co/unsloth/DeepSeek-R1-Distill-Qwen-1.5B-GGUF/resolve/main/DeepSeek-R1-Distill-Qwen-1.5B-Q4_K_M.gguf"  # Changed to Q4 for faster inference
MODEL_NAME = "DeepSeek-R1-Distill-Qwen-1.5B-Q4_K_M.gguf"
MODEL_DIR = "model"
MODEL_PATH = os.path.join(MODEL_DIR, MODEL_NAME)

# Create model directory if it doesn't exist
os.makedirs(MODEL_DIR, exist_ok=True)

# Download the model if it doesn't exist
if not os.path.exists(MODEL_PATH):
    print(f"Downloading model from {MODEL_URL}...")
    response = requests.get(MODEL_URL, stream=True)
    if response.status_code == 200:
        with open(MODEL_PATH, "wb") as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
        print("Model downloaded successfully!")
    else:
        raise RuntimeError(f"Failed to download model: HTTP {response.status_code}")
else:
    print("Model already exists. Skipping download.")

# Initialize FastAPI
app = FastAPI(
    title="DeepSeek-R1 OpenAI-Compatible API",
    description="Optimized OpenAI-compatible API with streaming support",
    version="2.0.0"
)

# CORS Configuration
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global model loader with optimized settings
print("Loading model with optimized settings...")
try:
    llm = Llama(
        model_path=MODEL_PATH,
        n_ctx=1024,  # Reduced context window for faster processing
        n_threads=8,  # Increased threads for better CPU utilization
        n_batch=512,  # Larger batch size for improved throughput
        n_gpu_layers=0,
        use_mlock=True,  # Prevent swapping to disk
        verbose=False
    )
    print("Model loaded with optimized settings!")
except Exception as e:
    raise RuntimeError(f"Failed to load model: {str(e)}")

# Streaming generator
def generate_stream(prompt: str, max_tokens: int, temperature: float, top_p: float) -> Generator[str, None, None]:
    start_time = time.time()
    stream = llm.create_completion(
        prompt=prompt,
        max_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        stop=["</s>"],
        stream=True
    )
    
    for chunk in stream:
        delta = chunk['choices'][0]['text']
        yield f"data: {delta}\n\n"
        
        # Early stopping if taking too long
        if time.time() - start_time > 30:  # 30s timeout
            break

# OpenAI-Compatible Request Schema
class ChatCompletionRequest(BaseModel):
    model: str = "DeepSeek-R1-Distill-Qwen-1.5B"
    messages: list[dict]
    max_tokens: int = 256
    temperature: float = 0.7
    top_p: float = 0.9
    stream: bool = False

# Enhanced root endpoint with performance info
@app.get("/", response_class=HTMLResponse)
async def root():
    return f"""
    <html>
        <head>
            <title>DeepSeek-R1 Optimized API</title>
            <style>
                body {{ font-family: Arial, sans-serif; max-width: 800px; margin: 20px auto; padding: 0 20px; }}
                .warning {{ color: #dc3545; background: #ffeef0; padding: 15px; border-radius: 5px; }}
                .info {{ color: #0c5460; background: #d1ecf1; padding: 15px; border-radius: 5px; }}
                a {{ color: #007bff; text-decoration: none; }}
                code {{ background: #f8f9fa; padding: 2px 4px; border-radius: 4px; }}
            </style>
        </head>
        <body>
            <h1>DeepSeek-R1 Optimized API</h1>
            
            <div class="warning">
                <h3>⚠️ Important Notice</h3>
                <p>For private use, please duplicate this space:<br>
                1. Click your profile picture in the top-right<br>
                2. Select "Duplicate Space"<br>
                3. Set visibility to Private</p>
            </div>

            <div class="info">
                <h3>⚡ Performance Optimizations</h3>
                <ul>
                    <li>Quantization: Q4_K_M (optimized speed/quality balance)</li>
                    <li>Batch processing: 512 tokens/chunk</li>
                    <li>Streaming support with 30s timeout</li>
                    <li>8 CPU threads utilization</li>
                </ul>
            </div>

            <h2>API Documentation</h2>
            <ul>
                <li><a href="/docs">Interactive Swagger Documentation</a></li>
                <li><a href="/redoc">ReDoc Documentation</a></li>
            </ul>

            <h2>Example Streaming Request</h2>
            <pre>
curl -N -X POST "{os.environ.get('SPACE_HOST', 'http://localhost:7860')}/v1/chat/completions" \\
-H "Content-Type: application/json" \\
-d '{{
  "messages": [{{"role": "user", "content": "Explain quantum computing"}}],
  "stream": true,
  "max_tokens": 150
}}'
            </pre>
        </body>
    </html>
    """

# Async endpoint handler
@app.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest):
    try:
        prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in request.messages])
        prompt += "\nassistant:"

        if request.stream:
            return StreamingResponse(
                generate_stream(
                    prompt=prompt,
                    max_tokens=request.max_tokens,
                    temperature=request.temperature,
                    top_p=request.top_p
                ),
                media_type="text/event-stream"
            )
            
        # Non-streaming response
        start_time = time.time()
        response = llm(
            prompt=prompt,
            max_tokens=request.max_tokens,
            temperature=request.temperature,
            top_p=request.top_p,
            stop=["</s>"]
        )

        return {
            "id": f"chatcmpl-{int(time.time())}",
            "object": "chat.completion",
            "created": int(time.time()),
            "model": request.model,
            "choices": [{
                "index": 0,
                "message": {
                    "role": "assistant",
                    "content": response['choices'][0]['text'].strip()
                },
                "finish_reason": "stop"
            }],
            "usage": {
                "prompt_tokens": len(prompt),
                "completion_tokens": len(response['choices'][0]['text']),
                "total_tokens": len(prompt) + len(response['choices'][0]['text'])
            }
        }
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    return {
        "status": "healthy",
        "model_loaded": True,
        "performance_settings": {
            "n_threads": llm.params.n_threads,
            "n_ctx": llm.params.n_ctx,
            "n_batch": llm.params.n_batch
        }
    }

if __name__ == "__main__":
    uvicorn.run(
        app,
        host="0.0.0.0",
        port=7860,
        timeout_keep_alive=300  # Keep alive for streaming connections
    )