File size: 3,054 Bytes
9b74ec6
 
06f0356
9b74ec6
a7d6d41
a3a9074
9b74ec6
 
 
 
 
 
 
 
 
a7d6d41
 
 
a3a9074
9b74ec6
 
 
 
 
 
65a2535
 
9b74ec6
a3a9074
74c6866
a3a9074
 
 
 
65a2535
9b74ec6
 
 
 
 
1ba0543
9b74ec6
 
 
 
 
 
 
 
 
 
 
 
 
1ba0543
 
65a2535
1ba0543
 
4113730
18416fb
1ba0543
65a2535
 
1ba0543
18416fb
1ba0543
 
 
65a2535
1ba0543
 
 
65a2535
9b74ec6
 
 
a3a9074
 
a7d6d41
 
a3a9074
9b74ec6
 
 
a3a9074
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline
#from transformers import T5Tokenizer, T5ForConditionalGeneration


# Initialize FastAPI app
app = FastAPI()

# Load models
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
question_model = "deepset/tinyroberta-squad2"
nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)

#t5tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
#t5model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")

# Define request models
class ModifyQueryRequest(BaseModel):
    query_string: str

class AnswerQuestionRequest(BaseModel):
    question: str
    context: list
    locations: list

class T5QuestionRequest(BaseModel):
    context: str

class T5Response(BaseModel):
    answer: str

# Define response models (if needed)
class ModifyQueryResponse(BaseModel):
    embeddings: list

class AnswerQuestionResponse(BaseModel):
    answer: str
    locations: list

# Define API endpoints
@app.post("/modify_query", response_model=ModifyQueryResponse)
async def modify_query(request: ModifyQueryRequest):
    try:
        binary_embeddings = model.encode([request.query_string], precision="binary")
        return ModifyQueryResponse(embeddings=binary_embeddings[0].tolist())
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/answer_question", response_model=AnswerQuestionResponse)
async def answer_question(request: AnswerQuestionRequest):
    try:
        res_locs = []
        context_string = ''
        corpus_embeddings = model.encode(request.context, convert_to_tensor=True)
        query_embeddings = model.encode(request.question, convert_to_tensor=True)
        hits = util.semantic_search(query_embeddings, corpus_embeddings)
        for hit in hits[0]:
            if hit['score'] > .4:
                loc = hit['corpus_id']
                res_locs.append(request.locations[loc])
                context_string += request.context[loc] + ' '
        if len(res_locs) == 0:
            ans = "Sorry, I couldn't find any results for your query. Please try again!"
        else:
            QA_input = {
                'question': request.question,
                'context': context_string.replace('\n',' ')
            }
            result = nlp(QA_input)
            ans = result['answer']
        return AnswerQuestionResponse(answer=ans, locations = res_locs)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/t5answer", response_model=T5Response)
async def t5answer(request: T5QuestionRequest):
    resp = summarizer(request.context, max_length=130, min_length=30, do_sample=False)
    return T5Response(answer = resp[0]["summary_text"])

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)