File size: 2,274 Bytes
55567e0
4e963c4
55567e0
0ba5adf
55567e0
0af890e
 
55567e0
 
fe312ba
 
55567e0
 
 
 
 
f3b5bcb
55567e0
 
 
fe312ba
 
55567e0
 
 
 
 
 
 
 
 
 
 
 
0ba5adf
0af890e
55567e0
0af890e
55567e0
 
 
0ba5adf
 
 
 
 
 
 
 
 
0af890e
fe312ba
 
0ba5adf
 
 
 
 
 
 
 
fe312ba
55567e0
 
 
 
0ba5adf
55567e0
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import re
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from langchain_community.llms import Ollama
from langchain_core.messages import HumanMessage
import logging
from functools import lru_cache
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()
MODEL_NAME = 'gemma3:1b'

@lru_cache()
def get_llm():
    callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
    return Ollama(model=MODEL_NAME, callback_manager=callback_manager)

class Question(BaseModel):
    text: str

@app.get("/")
def read_root():
    return {"Hello": f"Welcome to {MODEL_NAME} FastAPI"}

@app.post("/ask")
async def ask_question(question: Question):
    try:
        logger.info(f"Received question: {question.text}")
        llm = get_llm()
        response = llm.invoke(question.text)
        logger.info("Response generated successfully")
        return {"answer": response}
    except Exception as e:
        logger.error(f"Error in /ask endpoint: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/ask_stream")
async def ask_question_stream(question: Question):
    try:
        logger.info(f"Received question for streaming: {question.text}")
        llm = get_llm()

        async def generate():
            full_response = ""
            async for chunk in llm.astream(question.text):
                full_response += chunk
                yield chunk
            
            # Log the full response after streaming is complete
            logger.info(f"Full streamed response: {full_response}")

        return StreamingResponse(generate(), media_type="text/plain")
    except Exception as e:
        logger.error(f"Error in /ask_stream endpoint: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))
    
@app.on_event("startup")
async def startup_event():
    logger.info(f"Starting up with model: {MODEL_NAME}")
    # Warm up the cache
    get_llm()

@app.on_event("shutdown")
async def shutdown_event():
    logger.info("Shutting down")