File size: 4,024 Bytes
9b74ec6 06f0356 9b74ec6 5464450 a3a9074 9b74ec6 a7d6d41 a3a9074 9b74ec6 65a2535 9b74ec6 a3a9074 74c6866 a3a9074 65a2535 9b74ec6 1ba0543 9b74ec6 1ba0543 65a2535 1ba0543 4113730 18416fb 1ba0543 65a2535 1ba0543 18416fb 1ba0543 65a2535 1ba0543 65a2535 9b74ec6 a3a9074 a7d6d41 a3a9074 5464450 ebb0019 5464450 3e23390 5464450 3e23390 5464450 3e23390 5464450 3e23390 5464450 9b74ec6 a3a9074 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline
import numpy as np
app = FastAPI()
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
question_model = "deepset/tinyroberta-squad2"
nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
# Define request models
class ModifyQueryRequest(BaseModel):
query_string: str
class AnswerQuestionRequest(BaseModel):
question: str
context: list
locations: list
class T5QuestionRequest(BaseModel):
context: str
class T5Response(BaseModel):
answer: str
# Define response models (if needed)
class ModifyQueryResponse(BaseModel):
embeddings: list
class AnswerQuestionResponse(BaseModel):
answer: str
locations: list
# Define API endpoints
@app.post("/modify_query", response_model=ModifyQueryResponse)
async def modify_query(request: ModifyQueryRequest):
try:
binary_embeddings = model.encode([request.query_string], precision="binary")
return ModifyQueryResponse(embeddings=binary_embeddings[0].tolist())
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/answer_question", response_model=AnswerQuestionResponse)
async def answer_question(request: AnswerQuestionRequest):
try:
res_locs = []
context_string = ''
corpus_embeddings = model.encode(request.context, convert_to_tensor=True)
query_embeddings = model.encode(request.question, convert_to_tensor=True)
hits = util.semantic_search(query_embeddings, corpus_embeddings)
for hit in hits[0]:
if hit['score'] > .4:
loc = hit['corpus_id']
res_locs.append(request.locations[loc])
context_string += request.context[loc] + ' '
if len(res_locs) == 0:
ans = "Sorry, I couldn't find any results for your query. Please try again!"
else:
QA_input = {
'question': request.question,
'context': context_string.replace('\n',' ')
}
result = nlp(QA_input)
ans = result['answer']
return AnswerQuestionResponse(answer=ans, locations = res_locs)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/t5answer", response_model=T5Response)
async def t5answer(request: T5QuestionRequest):
resp = summarizer(request.context, max_length=130, min_length=30, do_sample=False)
return T5Response(answer = resp[0]["summary_text"])
# Define API endpoints
@app.post("/modify_query2", response_model=ModifyQueryResponse)
async def modify_query2(request: ModifyQueryRequest):
try:
embeddings = optimize_embedding([request.query_string])
return ModifyQueryResponse(embeddings=embeddings[0].tolist())
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def optimize_embedding(texts, precision='uint8'):
# Step 1: Generate embeddings with 384 dimensions
embeddings = model.encode(texts)
# Step 2: Normalize embeddings to [0, 1] range
embeddings_min = embeddings.min(axis=1, keepdims=True)
embeddings_max = embeddings.max(axis=1, keepdims=True)
normalized_embeddings = (embeddings - embeddings_min) / (embeddings_max - embeddings_min + 1e-8)
# Step 3: Scale normalized embeddings to fit within the range of uint8 or uint16
if precision == 'uint8':
scaled_embeddings = (normalized_embeddings * 255).astype('uint8')
elif precision == 'uint16':
scaled_embeddings = (normalized_embeddings * 65535).astype('uint16')
else:
raise ValueError("Unsupported precision. Use 'uint8' or 'uint16'.")
return scaled_embeddings
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|