File size: 4,024 Bytes
9b74ec6
 
06f0356
9b74ec6
5464450
a3a9074
9b74ec6
 
 
 
 
 
 
a7d6d41
a3a9074
9b74ec6
 
 
 
 
 
65a2535
 
9b74ec6
a3a9074
74c6866
a3a9074
 
 
 
65a2535
9b74ec6
 
 
 
 
1ba0543
9b74ec6
 
 
 
 
 
 
 
 
 
 
 
 
1ba0543
 
65a2535
1ba0543
 
4113730
18416fb
1ba0543
65a2535
 
1ba0543
18416fb
1ba0543
 
 
65a2535
1ba0543
 
 
65a2535
9b74ec6
 
 
a3a9074
 
a7d6d41
 
a3a9074
5464450
 
 
ebb0019
5464450
 
 
 
 
 
 
 
 
 
 
3e23390
 
 
 
 
 
5464450
3e23390
5464450
3e23390
5464450
 
 
3e23390
5464450
9b74ec6
 
 
a3a9074
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline
import numpy as np


app = FastAPI()

model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
question_model = "deepset/tinyroberta-squad2"
nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)

summarizer = pipeline("summarization", model="facebook/bart-large-cnn")

# Define request models
class ModifyQueryRequest(BaseModel):
    query_string: str

class AnswerQuestionRequest(BaseModel):
    question: str
    context: list
    locations: list

class T5QuestionRequest(BaseModel):
    context: str

class T5Response(BaseModel):
    answer: str

# Define response models (if needed)
class ModifyQueryResponse(BaseModel):
    embeddings: list

class AnswerQuestionResponse(BaseModel):
    answer: str
    locations: list

# Define API endpoints
@app.post("/modify_query", response_model=ModifyQueryResponse)
async def modify_query(request: ModifyQueryRequest):
    try:
        binary_embeddings = model.encode([request.query_string], precision="binary")
        return ModifyQueryResponse(embeddings=binary_embeddings[0].tolist())
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/answer_question", response_model=AnswerQuestionResponse)
async def answer_question(request: AnswerQuestionRequest):
    try:
        res_locs = []
        context_string = ''
        corpus_embeddings = model.encode(request.context, convert_to_tensor=True)
        query_embeddings = model.encode(request.question, convert_to_tensor=True)
        hits = util.semantic_search(query_embeddings, corpus_embeddings)
        for hit in hits[0]:
            if hit['score'] > .4:
                loc = hit['corpus_id']
                res_locs.append(request.locations[loc])
                context_string += request.context[loc] + ' '
        if len(res_locs) == 0:
            ans = "Sorry, I couldn't find any results for your query. Please try again!"
        else:
            QA_input = {
                'question': request.question,
                'context': context_string.replace('\n',' ')
            }
            result = nlp(QA_input)
            ans = result['answer']
        return AnswerQuestionResponse(answer=ans, locations = res_locs)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/t5answer", response_model=T5Response)
async def t5answer(request: T5QuestionRequest):
    resp = summarizer(request.context, max_length=130, min_length=30, do_sample=False)
    return T5Response(answer = resp[0]["summary_text"])


# Define API endpoints
@app.post("/modify_query2", response_model=ModifyQueryResponse)
async def modify_query2(request: ModifyQueryRequest):
    try:
        embeddings = optimize_embedding([request.query_string])
        return ModifyQueryResponse(embeddings=embeddings[0].tolist())
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


def optimize_embedding(texts, precision='uint8'):
    # Step 1: Generate embeddings with 384 dimensions
    embeddings = model.encode(texts)

    # Step 2: Normalize embeddings to [0, 1] range
    embeddings_min = embeddings.min(axis=1, keepdims=True)
    embeddings_max = embeddings.max(axis=1, keepdims=True)
    normalized_embeddings = (embeddings - embeddings_min) / (embeddings_max - embeddings_min + 1e-8)

    # Step 3: Scale normalized embeddings to fit within the range of uint8 or uint16
    if precision == 'uint8':
        scaled_embeddings = (normalized_embeddings * 255).astype('uint8')
    elif precision == 'uint16':
        scaled_embeddings = (normalized_embeddings * 65535).astype('uint16')
    else:
        raise ValueError("Unsupported precision. Use 'uint8' or 'uint16'.")

    return scaled_embeddings

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)