Spaces:
Paused
Paused
File size: 8,835 Bytes
4279593 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 |
import json
import asyncio
from typing import List
from typing_extensions import TypedDict
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langgraph.graph import StateGraph, END
from src.utils.api_key_manager import with_api_manager
from src.helpers.helper import remove_markdown
# Define the Graph State
class GraphState(TypedDict):
initial_prompt: str
plan: str
write_steps: List[dict]
final_json: str
@with_api_manager(temperature=0.0, top_p=1.0)
def planning_node(state: GraphState, *, llm) -> GraphState:
print("\n---PLANNING---\n")
initial_prompt = state['initial_prompt']
plan_template = \
f"""You need to create a structured JSON based on the following instructions:
{initial_prompt}
Rules:
1. Outline a multi-step plan (one step per line) that will guide the creation of the final JSON.
2. You must create the entire plan yourself without asking others to create it for you.
2. The steps should be as follows:
- Each step should be a high-level task or section of the JSON.
- Check if breaking down each step into smaller, low-level sub-tasks or sections is required
- If yes, ONLY include the sub-steps (one sub-step per line).
3. The plan should be concise and clear, and each step and sub-step should be distinct.
4. The plan should be unformatted and in plain text. DO NOT even use bullet points or new lines.
4. The number of steps should be as less as possible, but still enough to cover ALL sections.
5. If the user request contains any specific details, include them in the plan.
6. DO NOT create the final content, just the plan/outline.
7. DO NOT include any markdown or formatting in the plan."""
chat_template = ChatPromptTemplate.from_messages([
HumanMessagePromptTemplate.from_template("{text}"),
]
)
prompt = chat_template.invoke({"text": plan_template})
response = llm.invoke(prompt)
plan = response.content.strip()
# Store plan text in state
state['plan'] = remove_markdown(plan)
print(plan)
return state
@with_api_manager(temperature=0.0, top_p=1.0)
def writing_node_sync(state: GraphState, *, llm) -> GraphState:
print("\n---WRITING THE JSON---\n")
initial_prompt = state['initial_prompt']
plan = state['plan']
plan = plan.strip()
# Split the plan by lines
plan_lines = plan.split('\n')
# Our final partial JSON objects
partial_jsons: List[dict] = []
# Return partial JSON.
for idx, step_line in enumerate(plan_lines):
if len(step_line.strip()) > 0:
step_prompt_text = \
f"""You are creating part {idx+1} of the final JSON document.
User request:
{initial_prompt}
Plan step (outline):
{step_line.strip()}
Rules:
1. You need to write the JSON data for this step.
2. The JSON should be structured and valid.
3. If the user request contains any specific details, include them in the JSON.
4. If the user request contains the format of the JSON, follow it. If not, create a generic JSON as you see fit.
5. Respond ONLY with valid JSON for this step without any markdown or formatting."""
chat_template = ChatPromptTemplate.from_messages([
HumanMessagePromptTemplate.from_template("{text}"),
]
)
prompt = chat_template.invoke({"text": step_prompt_text})
response = llm.invoke(prompt)
step_result = response.content.strip()
# Attempt to parse the partial JSON
try:
cleaned_result = remove_markdown(step_result)
partial_obj = json.loads(cleaned_result)
except json.JSONDecodeError:
# If the model didn't produce valid JSON, throw an error
raise Exception(f"Failed to parse JSON data for step {idx+1}")
# print(f"Step {idx+1} JSON:\n{json.dumps(partial_obj, indent=2)}\n")
# Add the partial JSON to the list
partial_jsons.append(partial_obj)
# Save all partial JSON in the state
state['write_steps'] = partial_jsons
return state
@with_api_manager(temperature=0.0, top_p=1.0)
async def writing_node_async(state: GraphState, *, llm) -> GraphState:
async def get_partial_json(idx: int, step_line: str) -> dict:
step_prompt_text = \
f"""You are creating part {idx+1} of the final JSON document.
User request:
{initial_prompt}
Plan step (outline):
{step_line.strip()}
Rules:
1. You need to write the JSON data for this step.
2. The JSON should be structured and valid.
3. If the user request contains any specific details, include them in the JSON.
4. If the user request contains the format of the JSON, follow it. If not, create a generic JSON as you see fit.
5. Respond ONLY with valid JSON for this step without any markdown or formatting."""
chat_template = ChatPromptTemplate.from_messages([
HumanMessagePromptTemplate.from_template("{text}"),
]
)
prompt = chat_template.invoke({"text": step_prompt_text})
response = await llm.ainvoke(prompt)
step_result = response.content.strip()
cleaned_result = remove_markdown(step_result)
try:
partial_obj = json.loads(cleaned_result)
except json.JSONDecodeError as e:
raise Exception(f"Failed to parse JSON data for step {idx+1}: {e}")
# print(f"Step {idx+1} JSON:\n{json.dumps(partial_obj, indent=2)}\n")
return partial_obj
print("\n---WRITING THE JSON---\n")
initial_prompt = state['initial_prompt']
plan = state['plan'].strip()
plan_lines = plan.split('\n')
partial_jsons: List[dict] = []
# Build tasks for each step
tasks = []
for idx, line in enumerate(plan_lines):
if len(line.strip()) > 0:
tasks.append(asyncio.create_task(get_partial_json(idx, line)))
# Run them concurrently
partial_jsons = await asyncio.gather(*tasks)
# Store results
state['write_steps'] = list(partial_jsons)
return state
def consolidation_node(state: GraphState) -> GraphState:
print("\n---CONSOLIDATING THE JSON---\n")
plan = state['plan']
partial_jsons = state['write_steps']
final_obj = {
"plan": plan,
"steps": partial_jsons
}
# Convert to string
final_json_str = json.dumps(final_obj, ensure_ascii=False, indent=2)
# Store it in the state
state['final_json'] = final_json_str
return state
def create_workflow_sync() -> StateGraph:
workflow = StateGraph(GraphState)
# Add nodes
workflow.add_node("planning_node", planning_node)
workflow.add_node("writing_node", writing_node_sync)
workflow.add_node("consolidation_node", consolidation_node)
# Set entry point
workflow.set_entry_point("planning_node")
# Add edges
workflow.add_edge("planning_node", "writing_node")
workflow.add_edge("writing_node", "consolidation_node")
# Finally, consolidation_node leads to END
workflow.add_edge("consolidation_node", END)
return workflow.compile()
def create_workflow_async() -> StateGraph:
workflow = StateGraph(GraphState)
# Add nodes
workflow.add_node("planning_node", planning_node)
workflow.add_node("writing_node", writing_node_async)
workflow.add_node("consolidation_node", consolidation_node)
# Set entry point
workflow.set_entry_point("planning_node")
# Add edges
workflow.add_edge("planning_node", "writing_node")
workflow.add_edge("writing_node", "consolidation_node")
# Finally, consolidation_node leads to END
workflow.add_edge("consolidation_node", END)
return workflow.compile()
if __name__ == "__main__":
import time
test_instruction = "Write a 1500-word piece on the HBO TV show Westworld, covering major characters, \
themes of AI and consciousness, and how the story might have continued had it not been cancelled. \
Include specific details, quotes, and references to the show and its creators.\
Do not include any spoilers for the climax of the show's final season."
app = create_workflow_async()
# We supply an initial state.
# (We only need 'initial_prompt' here; the other fields will be set by nodes.)
state_input: GraphState = {
"initial_prompt": test_instruction,
"plan": "",
"write_steps": [],
"final_json": ""
}
start = time.time()
final_state = asyncio.run(app.ainvoke(state_input))
end = time.time()
# The final JSON is in final_state['final_json']
print("\n===== FINAL JSON OUTPUT =====\n")
print(final_state['final_json'])
print("=============================\n")
print("\n===== PERFOMANCE =====\n")
print(f"Time taken: {end-start:.2f} seconds")
print("======================\n")
|