File size: 4,445 Bytes
9d59179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
import uvicorn
import os
import sys

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from recommender import SHLRecommender
from utils.validators import url as is_valid_url

app = FastAPI(
    title="SHL Test Recommender API",
    description="API for recommending SHL tests based on job descriptions or queries",
    version="1.0.0",
    docs_url="/docs",
    redoc_url="/redoc"
)

# Add CORS middleware to allow requests from any origin
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allow all origins
    allow_credentials=True,
    allow_methods=["*"],  # Allow all methods
    allow_headers=["*"],  # Allow all headers
)

recommender = SHLRecommender()

# Define request and response models
class RecommendRequest(BaseModel):
    query: str
    max_recommendations: int = 10

class Assessment(BaseModel):
    url: str
    adaptive_support: str
    description: str
    duration: int
    remote_support: str
    test_type: List[str]

class RecommendationResponse(BaseModel):
    recommended_assessments: List[Assessment]

# API endpoints
@app.get("/health")
async def health_check():
    try:
        if not recommender or not hasattr(recommender, 'df') or recommender.df.empty:
            return {"status": "unhealthy"}

        if not hasattr(recommender, 'embedding_model') or not hasattr(recommender, 'model') or not hasattr(recommender, 'tokenizer'):
            return {"status": "unhealthy"}

        if not hasattr(recommender, 'product_embeddings') or len(recommender.product_embeddings) == 0:
            return {"status": "unhealthy"}

        return {"status": "healthy"}
    except Exception:
        return {"status": "unhealthy"}

@app.get("/")
async def root():
    return {"message": "Welcome to the SHL Test Recommender API."}

@app.post("/optimize")
async def optimize_memory():
    try:
        recommender.optimize_memory()
        return {"status": "success", "message": "Memory optimized successfully"}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# Main recommend endpoint
@app.post("/recommend", response_model=RecommendationResponse)
async def recommend(request: RecommendRequest):
    return await process_recommendation(request.query, request.max_recommendations)


async def process_recommendation(query: str, max_recommendations: int):
    try:
        is_url = is_valid_url(query)

        recommendations = recommender.get_recommendations(
            query,
            is_url=is_url,
            max_recommendations=max_recommendations
        )

        formatted_assessments = []
        for rec in recommendations:
            duration_str = rec['Duration']
            try:
                duration_int = int(''.join(filter(str.isdigit, duration_str)))
            except:
                duration_int = 60

            test_type_list = [rec['Test Type']] if rec['Test Type'] and rec['Test Type'] != "Unknown" else ["General Assessment"]

            test_description = recommender.generate_test_description(
                test_name=rec['Test Name'],
                test_type=rec['Test Type'] if rec['Test Type'] and rec['Test Type'] != "Unknown" else "General Assessment"
            )

            description = test_description

            formatted_assessments.append(
                Assessment(
                    url=rec['Link'],
                    adaptive_support="Yes" if rec['Adaptive/IRT'] == "Yes" else "No",
                    description=description,
                    duration=duration_int,
                    remote_support="Yes" if rec['Remote Testing'] == "Yes" else "No",
                    test_type=test_type_list
                )
            )

        return RecommendationResponse(
            recommended_assessments=formatted_assessments
        )
    except Exception as e:
        try:
            recommender.optimize_memory()
        except:
            pass
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    # Check if running on Hugging Face Spaces
    IS_HF_SPACE = os.environ.get('SPACE_ID') is not None
    port = 7860 if IS_HF_SPACE else 8000
    
    print(f"Starting FastAPI server on port {port}")
    uvicorn.run("app:app", host="0.0.0.0", port=port, reload=True)