Maximofn commited on
Commit
df7a4c3
·
verified ·
1 Parent(s): f46cacd

First commit

Browse files
Files changed (1) hide show
  1. app.py +114 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from huggingface_hub import InferenceClient
4
+
5
+ from langchain_core.messages import HumanMessage, AIMessage
6
+ from langgraph.checkpoint.memory import MemorySaver
7
+ from langgraph.graph import START, MessagesState, StateGraph
8
+
9
+ import os
10
+ from dotenv import load_dotenv
11
+ load_dotenv()
12
+
13
+ # HuggingFace token
14
+ HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", os.getenv("HUGGINGFACE_TOKEN"))
15
+
16
+ # Initialize the HuggingFace model
17
+ model = InferenceClient(
18
+ model="Qwen/Qwen2.5-72B-Instruct",
19
+ api_key=HUGGINGFACE_TOKEN
20
+ )
21
+
22
+ # Define the function that calls the model
23
+ def call_model(state: MessagesState):
24
+ """
25
+ Call the model with the given messages
26
+
27
+ Args:
28
+ state: MessagesState
29
+
30
+ Returns:
31
+ dict: A dictionary containing the generated text and the thread ID
32
+ """
33
+ # Convert LangChain messages to HuggingFace format
34
+ hf_messages = []
35
+ for msg in state["messages"]:
36
+ if isinstance(msg, HumanMessage):
37
+ hf_messages.append({"role": "user", "content": msg.content})
38
+ elif isinstance(msg, AIMessage):
39
+ hf_messages.append({"role": "assistant", "content": msg.content})
40
+
41
+ # Call the API
42
+ response = model.chat_completion(
43
+ messages=hf_messages,
44
+ temperature=0.5,
45
+ max_tokens=64,
46
+ top_p=0.7
47
+ )
48
+
49
+ # Convert the response to LangChain format
50
+ ai_message = AIMessage(content=response.choices[0].message.content)
51
+ return {"messages": state["messages"] + [ai_message]}
52
+
53
+ # Define the graph
54
+ workflow = StateGraph(state_schema=MessagesState)
55
+
56
+ # Define the node in the graph
57
+ workflow.add_edge(START, "model")
58
+ workflow.add_node("model", call_model)
59
+
60
+ # Add memory
61
+ memory = MemorySaver()
62
+ graph_app = workflow.compile(checkpointer=memory)
63
+
64
+ # Define the data model for the request
65
+ class QueryRequest(BaseModel):
66
+ query: str
67
+ thread_id: str = "default"
68
+
69
+ # Create the FastAPI application
70
+ app = FastAPI(title="LangChain FastAPI", description="API to generate text using LangChain and LangGraph")
71
+
72
+ # Welcome endpoint
73
+ @app.get("/")
74
+ async def api_home():
75
+ """Welcome endpoint"""
76
+ return {"detail": "Welcome to FastAPI, Langchain, Docker tutorial"}
77
+
78
+ # Generate endpoint
79
+ @app.post("/generate")
80
+ async def generate(request: QueryRequest):
81
+ """
82
+ Endpoint to generate text using the language model
83
+
84
+ Args:
85
+ request: QueryRequest
86
+ query: str
87
+ thread_id: str = "default"
88
+
89
+ Returns:
90
+ dict: A dictionary containing the generated text and the thread ID
91
+ """
92
+ try:
93
+ # Configure the thread ID
94
+ config = {"configurable": {"thread_id": request.thread_id}}
95
+
96
+ # Create the input message
97
+ input_messages = [HumanMessage(content=request.query)]
98
+
99
+ # Invoke the graph
100
+ output = graph_app.invoke({"messages": input_messages}, config)
101
+
102
+ # Get the model response
103
+ response = output["messages"][-1].content
104
+
105
+ return {
106
+ "generated_text": response,
107
+ "thread_id": request.thread_id
108
+ }
109
+ except Exception as e:
110
+ raise HTTPException(status_code=500, detail=f"Error al generar texto: {str(e)}")
111
+
112
+ if __name__ == "__main__":
113
+ import uvicorn
114
+ uvicorn.run(app, host="0.0.0.0", port=7860)