File size: 2,392 Bytes
318285f
 
6e9d355
 
d3eff8a
f655296
5d27647
65afda8
b820b0a
65afda8
d394f04
 
5d27647
b0cd906
f655296
65afda8
5d27647
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65afda8
 
 
 
 
d394f04
 
 
 
 
 
 
 
 
 
5d27647
d394f04
 
f655296
b0cd906
 
 
 
 
f655296
b0cd906
5d27647
b820b0a
 
 
5d27647
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os

os.environ['HF_HOME'] = '/tmp/.cache/huggingface'  # Use /tmp in Spaces
os.makedirs(os.environ['HF_HOME'], exist_ok=True)  # Ensure directory exists

from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from qwen_classifier.predict import predict_single  # Your existing function
from qwen_classifier.evaluate import evaluate_batch  # Your existing function
import torch
from huggingface_hub import login
from qwen_classifier.model import QwenClassifier
from qwen_classifier.config import HF_REPO, SPACE_URL
from pydantic import BaseModel

app = FastAPI(title="Qwen Classifier")

# Add this endpoint
@app.get("/", response_class=HTMLResponse)
def home():
    return f"""
    <html>
        <head>
            <title>Qwen Classifier</title>
        </head>
        <body>
            <h1>Qwen Classifier API</h1>
            <p>Available endpoints:</p>
            <ul>
                <li><strong>POST /predict</strong> - Classify text</li>
                <li><strong>POST /evaluate</strong> - Evaluate batch text prediction from zip file</li>
                <li><strong>GET /health</strong> - Check API status</li>
            </ul>
            <p>Try it: <code>curl -X POST {SPACE_URL}/predict -H "Content-Type: application/json" -d '{"text":"your text"}'</code></p>
        </body>
    </html>
    """

@app.on_event("startup")
async def load_model():
    # Warm up GPU
    torch.zeros(1).cuda() 
    # Read HF_TOKEN from Hugging Face Space secrets
    hf_token = os.getenv("HF_TOKEN")
    if not hf_token:
        raise ValueError("HF_TOKEN not found in environment variables")

    # Authenticate
    login(token=hf_token)
    
    # Load model (will cache in /home/user/.cache/huggingface)
    app.state.model = QwenClassifier.from_pretrained(
        HF_REPO,
    )
    print("Model loaded successfully!")



class PredictionRequest(BaseModel):
    text: str  # ← Enforces that 'text' must be a non-empty string

@app.post("/predict")
async def predict(request: PredictionRequest):  # ← Validates input automatically
    return predict_single(request.text, HF_REPO, backend="local")

@app.post("/evaluate")
async def evaluate(request: PredictionRequest):  # ← Validates input automatically
    return evaluate_batch(request.text, HF_REPO, backend="local")

@app.get("/health")
def health_check():
    return {"status": "healthy", "model": "loaded"}