Drag2121 commited on
Commit
0ba5adf
·
1 Parent(s): 15ec193

stream code again

Browse files
Files changed (1) hide show
  1. app.py +29 -15
app.py CHANGED
@@ -1,10 +1,9 @@
1
  import os
2
  from fastapi import FastAPI, HTTPException
 
3
  from pydantic import BaseModel
4
  from langchain_ollama import ChatOllama
5
- from langchain.schema import StrOutputParser
6
- from langchain.prompts import ChatPromptTemplate
7
-
8
  import logging
9
  from functools import lru_cache
10
 
@@ -13,19 +12,12 @@ logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
14
 
15
  app = FastAPI()
16
-
17
  MODEL_NAME = 'phi3:mini'
18
 
19
  @lru_cache()
20
  def get_llm():
21
  return ChatOllama(model=MODEL_NAME)
22
 
23
- @lru_cache()
24
- def get_chain():
25
- llm = get_llm()
26
- prompt = ChatPromptTemplate.from_template("Question: {question}\n\nAnswer:")
27
- return prompt | llm | StrOutputParser()
28
-
29
  class Question(BaseModel):
30
  text: str
31
 
@@ -37,20 +29,42 @@ def read_root():
37
  async def ask_question(question: Question):
38
  try:
39
  logger.info(f"Received question: {question.text}")
40
- chain = get_chain()
41
- response = chain.invoke({"question": question.text})
 
42
  logger.info("Response generated successfully")
43
- return {"answer": response}
44
  except Exception as e:
45
  logger.error(f"Error in /ask endpoint: {str(e)}")
46
  raise HTTPException(status_code=500, detail=str(e))
47
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  @app.on_event("startup")
50
  async def startup_event():
51
  logger.info(f"Starting up with model: {MODEL_NAME}")
52
  # Warm up the cache
53
- get_chain()
54
 
55
  @app.on_event("shutdown")
56
  async def shutdown_event():
 
1
  import os
2
  from fastapi import FastAPI, HTTPException
3
+ from fastapi.responses import StreamingResponse
4
  from pydantic import BaseModel
5
  from langchain_ollama import ChatOllama
6
+ from langchain.schema import HumanMessage
 
 
7
  import logging
8
  from functools import lru_cache
9
 
 
12
  logger = logging.getLogger(__name__)
13
 
14
  app = FastAPI()
 
15
  MODEL_NAME = 'phi3:mini'
16
 
17
  @lru_cache()
18
  def get_llm():
19
  return ChatOllama(model=MODEL_NAME)
20
 
 
 
 
 
 
 
21
  class Question(BaseModel):
22
  text: str
23
 
 
29
  async def ask_question(question: Question):
30
  try:
31
  logger.info(f"Received question: {question.text}")
32
+ llm = get_llm()
33
+ messages = [HumanMessage(content=question.text)]
34
+ response = llm(messages)
35
  logger.info("Response generated successfully")
36
+ return {"answer": response.content}
37
  except Exception as e:
38
  logger.error(f"Error in /ask endpoint: {str(e)}")
39
  raise HTTPException(status_code=500, detail=str(e))
40
+
41
+ @app.post("/ask_stream")
42
+ async def ask_question_stream(question: Question):
43
+ try:
44
+ logger.info(f"Received question for streaming: {question.text}")
45
+ llm = get_llm()
46
+ messages = [HumanMessage(content=question.text)]
47
+
48
+ async def generate():
49
+ full_response = ""
50
+ async for chunk in llm.astream(messages):
51
+ if chunk.content:
52
+ full_response += chunk.content
53
+ yield chunk.content
54
+
55
+ # Log the full response after streaming is complete
56
+ logger.info(f"Full streamed response: {full_response}")
57
+
58
+ return StreamingResponse(generate(), media_type="text/plain")
59
+ except Exception as e:
60
+ logger.error(f"Error in /ask_stream endpoint: {str(e)}")
61
+ raise HTTPException(status_code=500, detail=str(e))
62
 
63
  @app.on_event("startup")
64
  async def startup_event():
65
  logger.info(f"Starting up with model: {MODEL_NAME}")
66
  # Warm up the cache
67
+ get_llm()
68
 
69
  @app.on_event("shutdown")
70
  async def shutdown_event():