File size: 3,274 Bytes
df7a4c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from huggingface_hub import InferenceClient

from langchain_core.messages import HumanMessage, AIMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, MessagesState, StateGraph

import os
from dotenv import load_dotenv
load_dotenv()

# HuggingFace token
HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", os.getenv("HUGGINGFACE_TOKEN"))

# Initialize the HuggingFace model
model = InferenceClient(
    model="Qwen/Qwen2.5-72B-Instruct",
    api_key=HUGGINGFACE_TOKEN
)

# Define the function that calls the model
def call_model(state: MessagesState):
    """
    Call the model with the given messages

    Args:
        state: MessagesState

    Returns:
        dict: A dictionary containing the generated text and the thread ID
    """
    # Convert LangChain messages to HuggingFace format
    hf_messages = []
    for msg in state["messages"]:
        if isinstance(msg, HumanMessage):
            hf_messages.append({"role": "user", "content": msg.content})
        elif isinstance(msg, AIMessage):
            hf_messages.append({"role": "assistant", "content": msg.content})

    # Call the API
    response = model.chat_completion(
        messages=hf_messages,
        temperature=0.5,
        max_tokens=64,
        top_p=0.7
    )

    # Convert the response to LangChain format
    ai_message = AIMessage(content=response.choices[0].message.content)
    return {"messages": state["messages"] + [ai_message]}

# Define the graph
workflow = StateGraph(state_schema=MessagesState)

# Define the node in the graph
workflow.add_edge(START, "model")
workflow.add_node("model", call_model)

# Add memory
memory = MemorySaver()
graph_app = workflow.compile(checkpointer=memory)

# Define the data model for the request
class QueryRequest(BaseModel):
    query: str
    thread_id: str = "default"

# Create the FastAPI application
app = FastAPI(title="LangChain FastAPI", description="API to generate text using LangChain and LangGraph")

# Welcome endpoint
@app.get("/")
async def api_home():
    """Welcome endpoint"""
    return {"detail": "Welcome to FastAPI, Langchain, Docker tutorial"}

# Generate endpoint
@app.post("/generate")
async def generate(request: QueryRequest):
    """
    Endpoint to generate text using the language model

    Args:
        request: QueryRequest
        query: str
        thread_id: str = "default"

    Returns:
        dict: A dictionary containing the generated text and the thread ID
    """
    try:
        # Configure the thread ID
        config = {"configurable": {"thread_id": request.thread_id}}

        # Create the input message
        input_messages = [HumanMessage(content=request.query)]

        # Invoke the graph
        output = graph_app.invoke({"messages": input_messages}, config)

        # Get the model response
        response = output["messages"][-1].content

        return {
            "generated_text": response,
            "thread_id": request.thread_id
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error al generar texto: {str(e)}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)