File size: 4,445 Bytes
9d59179 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
import uvicorn
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from recommender import SHLRecommender
from utils.validators import url as is_valid_url
app = FastAPI(
title="SHL Test Recommender API",
description="API for recommending SHL tests based on job descriptions or queries",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# Add CORS middleware to allow requests from any origin
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins
allow_credentials=True,
allow_methods=["*"], # Allow all methods
allow_headers=["*"], # Allow all headers
)
recommender = SHLRecommender()
# Define request and response models
class RecommendRequest(BaseModel):
query: str
max_recommendations: int = 10
class Assessment(BaseModel):
url: str
adaptive_support: str
description: str
duration: int
remote_support: str
test_type: List[str]
class RecommendationResponse(BaseModel):
recommended_assessments: List[Assessment]
# API endpoints
@app.get("/health")
async def health_check():
try:
if not recommender or not hasattr(recommender, 'df') or recommender.df.empty:
return {"status": "unhealthy"}
if not hasattr(recommender, 'embedding_model') or not hasattr(recommender, 'model') or not hasattr(recommender, 'tokenizer'):
return {"status": "unhealthy"}
if not hasattr(recommender, 'product_embeddings') or len(recommender.product_embeddings) == 0:
return {"status": "unhealthy"}
return {"status": "healthy"}
except Exception:
return {"status": "unhealthy"}
@app.get("/")
async def root():
return {"message": "Welcome to the SHL Test Recommender API."}
@app.post("/optimize")
async def optimize_memory():
try:
recommender.optimize_memory()
return {"status": "success", "message": "Memory optimized successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Main recommend endpoint
@app.post("/recommend", response_model=RecommendationResponse)
async def recommend(request: RecommendRequest):
return await process_recommendation(request.query, request.max_recommendations)
async def process_recommendation(query: str, max_recommendations: int):
try:
is_url = is_valid_url(query)
recommendations = recommender.get_recommendations(
query,
is_url=is_url,
max_recommendations=max_recommendations
)
formatted_assessments = []
for rec in recommendations:
duration_str = rec['Duration']
try:
duration_int = int(''.join(filter(str.isdigit, duration_str)))
except:
duration_int = 60
test_type_list = [rec['Test Type']] if rec['Test Type'] and rec['Test Type'] != "Unknown" else ["General Assessment"]
test_description = recommender.generate_test_description(
test_name=rec['Test Name'],
test_type=rec['Test Type'] if rec['Test Type'] and rec['Test Type'] != "Unknown" else "General Assessment"
)
description = test_description
formatted_assessments.append(
Assessment(
url=rec['Link'],
adaptive_support="Yes" if rec['Adaptive/IRT'] == "Yes" else "No",
description=description,
duration=duration_int,
remote_support="Yes" if rec['Remote Testing'] == "Yes" else "No",
test_type=test_type_list
)
)
return RecommendationResponse(
recommended_assessments=formatted_assessments
)
except Exception as e:
try:
recommender.optimize_memory()
except:
pass
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
# Check if running on Hugging Face Spaces
IS_HF_SPACE = os.environ.get('SPACE_ID') is not None
port = 7860 if IS_HF_SPACE else 8000
print(f"Starting FastAPI server on port {port}")
uvicorn.run("app:app", host="0.0.0.0", port=port, reload=True)
|