Spaces:
Sleeping
Sleeping
File size: 2,392 Bytes
318285f 6e9d355 d3eff8a f655296 5d27647 65afda8 b820b0a 65afda8 d394f04 5d27647 b0cd906 f655296 65afda8 5d27647 65afda8 d394f04 5d27647 d394f04 f655296 b0cd906 f655296 b0cd906 5d27647 b820b0a 5d27647 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import os
os.environ['HF_HOME'] = '/tmp/.cache/huggingface' # Use /tmp in Spaces
os.makedirs(os.environ['HF_HOME'], exist_ok=True) # Ensure directory exists
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from qwen_classifier.predict import predict_single # Your existing function
from qwen_classifier.evaluate import evaluate_batch # Your existing function
import torch
from huggingface_hub import login
from qwen_classifier.model import QwenClassifier
from qwen_classifier.config import HF_REPO, SPACE_URL
from pydantic import BaseModel
app = FastAPI(title="Qwen Classifier")
# Add this endpoint
@app.get("/", response_class=HTMLResponse)
def home():
return f"""
<html>
<head>
<title>Qwen Classifier</title>
</head>
<body>
<h1>Qwen Classifier API</h1>
<p>Available endpoints:</p>
<ul>
<li><strong>POST /predict</strong> - Classify text</li>
<li><strong>POST /evaluate</strong> - Evaluate batch text prediction from zip file</li>
<li><strong>GET /health</strong> - Check API status</li>
</ul>
<p>Try it: <code>curl -X POST {SPACE_URL}/predict -H "Content-Type: application/json" -d '{"text":"your text"}'</code></p>
</body>
</html>
"""
@app.on_event("startup")
async def load_model():
# Warm up GPU
torch.zeros(1).cuda()
# Read HF_TOKEN from Hugging Face Space secrets
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
raise ValueError("HF_TOKEN not found in environment variables")
# Authenticate
login(token=hf_token)
# Load model (will cache in /home/user/.cache/huggingface)
app.state.model = QwenClassifier.from_pretrained(
HF_REPO,
)
print("Model loaded successfully!")
class PredictionRequest(BaseModel):
text: str # ← Enforces that 'text' must be a non-empty string
@app.post("/predict")
async def predict(request: PredictionRequest): # ← Validates input automatically
return predict_single(request.text, HF_REPO, backend="local")
@app.post("/evaluate")
async def evaluate(request: PredictionRequest): # ← Validates input automatically
return evaluate_batch(request.text, HF_REPO, backend="local")
@app.get("/health")
def health_check():
return {"status": "healthy", "model": "loaded"} |