seekr / src /evaluation /writer /agent_write.py
Hemang Thakur
Deploy project on Hugging Face Spaces
4279593
import json
import asyncio
from typing import List
from typing_extensions import TypedDict
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langgraph.graph import StateGraph, END
from src.utils.api_key_manager import with_api_manager
from src.helpers.helper import remove_markdown
# Define the Graph State
class GraphState(TypedDict):
initial_prompt: str
plan: str
write_steps: List[dict]
final_json: str
@with_api_manager(temperature=0.0, top_p=1.0)
def planning_node(state: GraphState, *, llm) -> GraphState:
print("\n---PLANNING---\n")
initial_prompt = state['initial_prompt']
plan_template = \
f"""You need to create a structured JSON based on the following instructions:
{initial_prompt}
Rules:
1. Outline a multi-step plan (one step per line) that will guide the creation of the final JSON.
2. You must create the entire plan yourself without asking others to create it for you.
2. The steps should be as follows:
- Each step should be a high-level task or section of the JSON.
- Check if breaking down each step into smaller, low-level sub-tasks or sections is required
- If yes, ONLY include the sub-steps (one sub-step per line).
3. The plan should be concise and clear, and each step and sub-step should be distinct.
4. The plan should be unformatted and in plain text. DO NOT even use bullet points or new lines.
4. The number of steps should be as less as possible, but still enough to cover ALL sections.
5. If the user request contains any specific details, include them in the plan.
6. DO NOT create the final content, just the plan/outline.
7. DO NOT include any markdown or formatting in the plan."""
chat_template = ChatPromptTemplate.from_messages([
HumanMessagePromptTemplate.from_template("{text}"),
]
)
prompt = chat_template.invoke({"text": plan_template})
response = llm.invoke(prompt)
plan = response.content.strip()
# Store plan text in state
state['plan'] = remove_markdown(plan)
print(plan)
return state
@with_api_manager(temperature=0.0, top_p=1.0)
def writing_node_sync(state: GraphState, *, llm) -> GraphState:
print("\n---WRITING THE JSON---\n")
initial_prompt = state['initial_prompt']
plan = state['plan']
plan = plan.strip()
# Split the plan by lines
plan_lines = plan.split('\n')
# Our final partial JSON objects
partial_jsons: List[dict] = []
# Return partial JSON.
for idx, step_line in enumerate(plan_lines):
if len(step_line.strip()) > 0:
step_prompt_text = \
f"""You are creating part {idx+1} of the final JSON document.
User request:
{initial_prompt}
Plan step (outline):
{step_line.strip()}
Rules:
1. You need to write the JSON data for this step.
2. The JSON should be structured and valid.
3. If the user request contains any specific details, include them in the JSON.
4. If the user request contains the format of the JSON, follow it. If not, create a generic JSON as you see fit.
5. Respond ONLY with valid JSON for this step without any markdown or formatting."""
chat_template = ChatPromptTemplate.from_messages([
HumanMessagePromptTemplate.from_template("{text}"),
]
)
prompt = chat_template.invoke({"text": step_prompt_text})
response = llm.invoke(prompt)
step_result = response.content.strip()
# Attempt to parse the partial JSON
try:
cleaned_result = remove_markdown(step_result)
partial_obj = json.loads(cleaned_result)
except json.JSONDecodeError:
# If the model didn't produce valid JSON, throw an error
raise Exception(f"Failed to parse JSON data for step {idx+1}")
# print(f"Step {idx+1} JSON:\n{json.dumps(partial_obj, indent=2)}\n")
# Add the partial JSON to the list
partial_jsons.append(partial_obj)
# Save all partial JSON in the state
state['write_steps'] = partial_jsons
return state
@with_api_manager(temperature=0.0, top_p=1.0)
async def writing_node_async(state: GraphState, *, llm) -> GraphState:
async def get_partial_json(idx: int, step_line: str) -> dict:
step_prompt_text = \
f"""You are creating part {idx+1} of the final JSON document.
User request:
{initial_prompt}
Plan step (outline):
{step_line.strip()}
Rules:
1. You need to write the JSON data for this step.
2. The JSON should be structured and valid.
3. If the user request contains any specific details, include them in the JSON.
4. If the user request contains the format of the JSON, follow it. If not, create a generic JSON as you see fit.
5. Respond ONLY with valid JSON for this step without any markdown or formatting."""
chat_template = ChatPromptTemplate.from_messages([
HumanMessagePromptTemplate.from_template("{text}"),
]
)
prompt = chat_template.invoke({"text": step_prompt_text})
response = await llm.ainvoke(prompt)
step_result = response.content.strip()
cleaned_result = remove_markdown(step_result)
try:
partial_obj = json.loads(cleaned_result)
except json.JSONDecodeError as e:
raise Exception(f"Failed to parse JSON data for step {idx+1}: {e}")
# print(f"Step {idx+1} JSON:\n{json.dumps(partial_obj, indent=2)}\n")
return partial_obj
print("\n---WRITING THE JSON---\n")
initial_prompt = state['initial_prompt']
plan = state['plan'].strip()
plan_lines = plan.split('\n')
partial_jsons: List[dict] = []
# Build tasks for each step
tasks = []
for idx, line in enumerate(plan_lines):
if len(line.strip()) > 0:
tasks.append(asyncio.create_task(get_partial_json(idx, line)))
# Run them concurrently
partial_jsons = await asyncio.gather(*tasks)
# Store results
state['write_steps'] = list(partial_jsons)
return state
def consolidation_node(state: GraphState) -> GraphState:
print("\n---CONSOLIDATING THE JSON---\n")
plan = state['plan']
partial_jsons = state['write_steps']
final_obj = {
"plan": plan,
"steps": partial_jsons
}
# Convert to string
final_json_str = json.dumps(final_obj, ensure_ascii=False, indent=2)
# Store it in the state
state['final_json'] = final_json_str
return state
def create_workflow_sync() -> StateGraph:
workflow = StateGraph(GraphState)
# Add nodes
workflow.add_node("planning_node", planning_node)
workflow.add_node("writing_node", writing_node_sync)
workflow.add_node("consolidation_node", consolidation_node)
# Set entry point
workflow.set_entry_point("planning_node")
# Add edges
workflow.add_edge("planning_node", "writing_node")
workflow.add_edge("writing_node", "consolidation_node")
# Finally, consolidation_node leads to END
workflow.add_edge("consolidation_node", END)
return workflow.compile()
def create_workflow_async() -> StateGraph:
workflow = StateGraph(GraphState)
# Add nodes
workflow.add_node("planning_node", planning_node)
workflow.add_node("writing_node", writing_node_async)
workflow.add_node("consolidation_node", consolidation_node)
# Set entry point
workflow.set_entry_point("planning_node")
# Add edges
workflow.add_edge("planning_node", "writing_node")
workflow.add_edge("writing_node", "consolidation_node")
# Finally, consolidation_node leads to END
workflow.add_edge("consolidation_node", END)
return workflow.compile()
if __name__ == "__main__":
import time
test_instruction = "Write a 1500-word piece on the HBO TV show Westworld, covering major characters, \
themes of AI and consciousness, and how the story might have continued had it not been cancelled. \
Include specific details, quotes, and references to the show and its creators.\
Do not include any spoilers for the climax of the show's final season."
app = create_workflow_async()
# We supply an initial state.
# (We only need 'initial_prompt' here; the other fields will be set by nodes.)
state_input: GraphState = {
"initial_prompt": test_instruction,
"plan": "",
"write_steps": [],
"final_json": ""
}
start = time.time()
final_state = asyncio.run(app.ainvoke(state_input))
end = time.time()
# The final JSON is in final_state['final_json']
print("\n===== FINAL JSON OUTPUT =====\n")
print(final_state['final_json'])
print("=============================\n")
print("\n===== PERFOMANCE =====\n")
print(f"Time taken: {end-start:.2f} seconds")
print("======================\n")