Maximofn commited on
Commit
fe46a24
·
1 Parent(s): 494b36f

refactor(app.py): :sparkles: Improve response handling and streamline configuration for model invocation

Browse files
Files changed (1) hide show
  1. app.py +23 -11
app.py CHANGED
@@ -78,13 +78,15 @@ def call_model(state: MessagesState, system_prompt: str):
78
  pad_token_id=tokenizer.eos_token_id
79
  )
80
 
81
- # Decode and clean the response
82
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
83
- # Extract only the assistant's response (after the last user message)
84
- response = response.split("Assistant:")[-1].strip()
 
 
85
 
86
  # Convert the response to LangChain format
87
- ai_message = AIMessage(content=response)
88
  return {"messages": state["messages"] + [ai_message]}
89
 
90
  # Define the graph
@@ -92,7 +94,6 @@ workflow = StateGraph(state_schema=MessagesState)
92
 
93
  # Define the node in the graph
94
  workflow.add_edge(START, "model")
95
- workflow.add_node("model", call_model)
96
 
97
  # Add memory
98
  memory = MemorySaver()
@@ -149,10 +150,16 @@ async def generate(request: QueryRequest):
149
  input_messages = [HumanMessage(content=request.query)]
150
 
151
  # Invoke the graph with custom system prompt
 
 
 
 
 
 
 
152
  output = graph_app.invoke(
153
  {"messages": input_messages},
154
- config,
155
- {"model": {"system_prompt": request.system_prompt}}
156
  )
157
 
158
  # Get the model response
@@ -189,11 +196,16 @@ async def summarize(request: SummaryRequest):
189
  # Create the input message
190
  input_messages = [HumanMessage(content=request.text)]
191
 
192
- # Invoke the graph with summarization system prompt
 
 
 
 
 
 
193
  output = graph_app.invoke(
194
  {"messages": input_messages},
195
- config,
196
- {"model": {"system_prompt": summary_system_prompt}}
197
  )
198
 
199
  # Get the model response
 
78
  pad_token_id=tokenizer.eos_token_id
79
  )
80
 
81
+ # Get just the new tokens (excluding the input prompt tokens)
82
+ input_length = inputs.shape[1]
83
+ generated_tokens = outputs[0][input_length:]
84
+
85
+ # Decode only the new tokens to get just the assistant's response
86
+ assistant_response = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
87
 
88
  # Convert the response to LangChain format
89
+ ai_message = AIMessage(content=assistant_response)
90
  return {"messages": state["messages"] + [ai_message]}
91
 
92
  # Define the graph
 
94
 
95
  # Define the node in the graph
96
  workflow.add_edge(START, "model")
 
97
 
98
  # Add memory
99
  memory = MemorySaver()
 
150
  input_messages = [HumanMessage(content=request.query)]
151
 
152
  # Invoke the graph with custom system prompt
153
+ # Combine config parameters into a single dictionary
154
+ combined_config = {
155
+ **config,
156
+ "model": {"system_prompt": request.system_prompt}
157
+ }
158
+
159
+ # Invoke the graph with proper argument count
160
  output = graph_app.invoke(
161
  {"messages": input_messages},
162
+ combined_config
 
163
  )
164
 
165
  # Get the model response
 
196
  # Create the input message
197
  input_messages = [HumanMessage(content=request.text)]
198
 
199
+ # Combine config parameters into a single dictionary
200
+ combined_config = {
201
+ **config,
202
+ "model": {"system_prompt": summary_system_prompt}
203
+ }
204
+
205
+ # Invoke the graph with proper argument count
206
  output = graph_app.invoke(
207
  {"messages": input_messages},
208
+ combined_config
 
209
  )
210
 
211
  # Get the model response