from fastapi import FastAPI, HTTPException from pydantic import BaseModel from sentence_transformers import SentenceTransformer, util from transformers import pipeline #from transformers import T5Tokenizer, T5ForConditionalGeneration # Initialize FastAPI app app = FastAPI() # Load models model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") question_model = "deepset/tinyroberta-squad2" nlp = pipeline('question-answering', model=question_model, tokenizer=question_model) #t5tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large") #t5model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large") summarizer = pipeline("summarization", model="facebook/bart-large-cnn") # Define request models class ModifyQueryRequest(BaseModel): query_string: str class AnswerQuestionRequest(BaseModel): question: str context: list locations: list class T5QuestionRequest(BaseModel): context: str class T5Response(BaseModel): answer: str # Define response models (if needed) class ModifyQueryResponse(BaseModel): embeddings: list class AnswerQuestionResponse(BaseModel): answer: str locations: list # Define API endpoints @app.post("/modify_query", response_model=ModifyQueryResponse) async def modify_query(request: ModifyQueryRequest): try: binary_embeddings = model.encode([request.query_string], precision="binary") return ModifyQueryResponse(embeddings=binary_embeddings[0].tolist()) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/answer_question", response_model=AnswerQuestionResponse) async def answer_question(request: AnswerQuestionRequest): try: res_locs = [] context_string = '' corpus_embeddings = model.encode(request.context, convert_to_tensor=True) query_embeddings = model.encode(request.question, convert_to_tensor=True) hits = util.semantic_search(query_embeddings, corpus_embeddings) for hit in hits[0]: if hit['score'] > .4: loc = hit['corpus_id'] res_locs.append(request.locations[loc]) context_string += request.context[loc] + ' ' if len(res_locs) == 0: ans = "Sorry, I couldn't find any results for your query. Please try again!" else: QA_input = { 'question': request.question, 'context': context_string.replace('\n',' ') } result = nlp(QA_input) ans = result['answer'] return AnswerQuestionResponse(answer=ans, locations = res_locs) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/t5answer", response_model=T5Response) async def t5answer(request: T5QuestionRequest): resp = summarizer(request.context, max_length=130, min_length=30, do_sample=False) return T5Response(answer = resp[0]["summary_text"]) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)