File size: 3,054 Bytes
9b74ec6 06f0356 9b74ec6 a7d6d41 a3a9074 9b74ec6 a7d6d41 a3a9074 9b74ec6 65a2535 9b74ec6 a3a9074 74c6866 a3a9074 65a2535 9b74ec6 1ba0543 9b74ec6 1ba0543 65a2535 1ba0543 4113730 18416fb 1ba0543 65a2535 1ba0543 18416fb 1ba0543 65a2535 1ba0543 65a2535 9b74ec6 a3a9074 a7d6d41 a3a9074 9b74ec6 a3a9074 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline
#from transformers import T5Tokenizer, T5ForConditionalGeneration
# Initialize FastAPI app
app = FastAPI()
# Load models
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
question_model = "deepset/tinyroberta-squad2"
nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)
#t5tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
#t5model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
# Define request models
class ModifyQueryRequest(BaseModel):
query_string: str
class AnswerQuestionRequest(BaseModel):
question: str
context: list
locations: list
class T5QuestionRequest(BaseModel):
context: str
class T5Response(BaseModel):
answer: str
# Define response models (if needed)
class ModifyQueryResponse(BaseModel):
embeddings: list
class AnswerQuestionResponse(BaseModel):
answer: str
locations: list
# Define API endpoints
@app.post("/modify_query", response_model=ModifyQueryResponse)
async def modify_query(request: ModifyQueryRequest):
try:
binary_embeddings = model.encode([request.query_string], precision="binary")
return ModifyQueryResponse(embeddings=binary_embeddings[0].tolist())
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/answer_question", response_model=AnswerQuestionResponse)
async def answer_question(request: AnswerQuestionRequest):
try:
res_locs = []
context_string = ''
corpus_embeddings = model.encode(request.context, convert_to_tensor=True)
query_embeddings = model.encode(request.question, convert_to_tensor=True)
hits = util.semantic_search(query_embeddings, corpus_embeddings)
for hit in hits[0]:
if hit['score'] > .4:
loc = hit['corpus_id']
res_locs.append(request.locations[loc])
context_string += request.context[loc] + ' '
if len(res_locs) == 0:
ans = "Sorry, I couldn't find any results for your query. Please try again!"
else:
QA_input = {
'question': request.question,
'context': context_string.replace('\n',' ')
}
result = nlp(QA_input)
ans = result['answer']
return AnswerQuestionResponse(answer=ans, locations = res_locs)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/t5answer", response_model=T5Response)
async def t5answer(request: T5QuestionRequest):
resp = summarizer(request.context, max_length=130, min_length=30, do_sample=False)
return T5Response(answer = resp[0]["summary_text"])
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|