Tonyivan commited on
Commit
e0452b3
·
verified ·
1 Parent(s): 06f0356

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -4
app.py CHANGED
@@ -1,15 +1,23 @@
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from sentence_transformers import SentenceTransformer, util
4
  from transformers import pipeline
5
 
 
 
 
 
6
  # Initialize FastAPI app
7
  app = FastAPI()
8
 
 
 
9
  # Load models
10
  model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
11
  question_model = "deepset/tinyroberta-squad2"
12
  nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)
 
13
 
14
  # Define request models
15
  class ModifyQueryRequest(BaseModel):
@@ -19,7 +27,7 @@ class AnswerQuestionRequest(BaseModel):
19
  question: str
20
  context: dict
21
 
22
- # Define response models (if needed)
23
  class ModifyQueryResponse(BaseModel):
24
  embeddings: list
25
 
@@ -30,38 +38,50 @@ class AnswerQuestionResponse(BaseModel):
30
  # Define API endpoints
31
  @app.post("/modify_query", response_model=ModifyQueryResponse)
32
  async def modify_query(request: ModifyQueryRequest):
 
33
  try:
34
  binary_embeddings = model.encode([request.query_string], precision="binary")
 
35
  return ModifyQueryResponse(embeddings=binary_embeddings[0].tolist())
36
  except Exception as e:
 
37
  raise HTTPException(status_code=500, detail=str(e))
38
 
39
  @app.post("/answer_question", response_model=AnswerQuestionResponse)
40
  async def answer_question(request: AnswerQuestionRequest):
 
41
  try:
42
  res_locs = []
43
  context_string = ''
 
44
  corpus_embeddings = model.encode(request.context['context'], convert_to_tensor=True)
45
  query_embeddings = model.encode(request.question, convert_to_tensor=True)
46
  hits = util.semantic_search(query_embeddings, corpus_embeddings)
 
47
  for hit in hits:
48
- if hit['score'] > .5:
49
  loc = hit['corpus_id']
50
  res_locs.append(request.context['locations'][loc])
51
  context_string += request.context['context'][loc] + ' '
 
52
  if len(res_locs) == 0:
53
  ans = "Sorry, I couldn't find any results for your query."
 
54
  else:
55
  QA_input = {
56
  'question': request.question,
57
- 'context': context_string.replace('\n',' ')
58
  }
59
  result = nlp(QA_input)
60
  ans = result['answer']
61
- return AnswerQuestionResponse(answer=ans, locations = res_locs)
 
 
62
  except Exception as e:
 
63
  raise HTTPException(status_code=500, detail=str(e))
64
 
65
  if __name__ == "__main__":
66
  import uvicorn
 
67
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ import logging
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
  from sentence_transformers import SentenceTransformer, util
5
  from transformers import pipeline
6
 
7
+ # Set up logging
8
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
9
+ logger = logging.getLogger(__name__)
10
+
11
  # Initialize FastAPI app
12
  app = FastAPI()
13
 
14
+ # Log model loading
15
+ logger.info("Loading models...")
16
  # Load models
17
  model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
18
  question_model = "deepset/tinyroberta-squad2"
19
  nlp = pipeline('question-answering', model=question_model, tokenizer=question_model)
20
+ logger.info("Models loaded successfully.")
21
 
22
  # Define request models
23
  class ModifyQueryRequest(BaseModel):
 
27
  question: str
28
  context: dict
29
 
30
+ # Define response models
31
  class ModifyQueryResponse(BaseModel):
32
  embeddings: list
33
 
 
38
  # Define API endpoints
39
  @app.post("/modify_query", response_model=ModifyQueryResponse)
40
  async def modify_query(request: ModifyQueryRequest):
41
+ logger.info(f"Received /modify_query request: {request.query_string}")
42
  try:
43
  binary_embeddings = model.encode([request.query_string], precision="binary")
44
+ logger.info("Embeddings generated successfully.")
45
  return ModifyQueryResponse(embeddings=binary_embeddings[0].tolist())
46
  except Exception as e:
47
+ logger.error(f"Error generating embeddings: {str(e)}")
48
  raise HTTPException(status_code=500, detail=str(e))
49
 
50
  @app.post("/answer_question", response_model=AnswerQuestionResponse)
51
  async def answer_question(request: AnswerQuestionRequest):
52
+ logger.info(f"Received /answer_question request: {request.question}")
53
  try:
54
  res_locs = []
55
  context_string = ''
56
+
57
  corpus_embeddings = model.encode(request.context['context'], convert_to_tensor=True)
58
  query_embeddings = model.encode(request.question, convert_to_tensor=True)
59
  hits = util.semantic_search(query_embeddings, corpus_embeddings)
60
+
61
  for hit in hits:
62
+ if hit['score'] > 0.5:
63
  loc = hit['corpus_id']
64
  res_locs.append(request.context['locations'][loc])
65
  context_string += request.context['context'][loc] + ' '
66
+
67
  if len(res_locs) == 0:
68
  ans = "Sorry, I couldn't find any results for your query."
69
+ logger.info("No relevant context found.")
70
  else:
71
  QA_input = {
72
  'question': request.question,
73
+ 'context': context_string.replace('\n', ' ')
74
  }
75
  result = nlp(QA_input)
76
  ans = result['answer']
77
+ logger.info("Answer generated successfully.")
78
+
79
+ return AnswerQuestionResponse(answer=ans, locations=res_locs)
80
  except Exception as e:
81
+ logger.error(f"Error answering question: {str(e)}")
82
  raise HTTPException(status_code=500, detail=str(e))
83
 
84
  if __name__ == "__main__":
85
  import uvicorn
86
+ logger.info("Starting FastAPI server...")
87
  uvicorn.run(app, host="0.0.0.0", port=8000)