Spaces:
Paused
Paused
import json | |
import asyncio | |
from typing import List | |
from typing_extensions import TypedDict | |
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate | |
from langgraph.graph import StateGraph, END | |
from src.utils.api_key_manager import with_api_manager | |
from src.helpers.helper import remove_markdown | |
# Define the Graph State | |
class GraphState(TypedDict): | |
initial_prompt: str | |
plan: str | |
write_steps: List[dict] | |
final_json: str | |
def planning_node(state: GraphState, *, llm) -> GraphState: | |
print("\n---PLANNING---\n") | |
initial_prompt = state['initial_prompt'] | |
plan_template = \ | |
f"""You need to create a structured JSON based on the following instructions: | |
{initial_prompt} | |
Rules: | |
1. Outline a multi-step plan (one step per line) that will guide the creation of the final JSON. | |
2. You must create the entire plan yourself without asking others to create it for you. | |
2. The steps should be as follows: | |
- Each step should be a high-level task or section of the JSON. | |
- Check if breaking down each step into smaller, low-level sub-tasks or sections is required | |
- If yes, ONLY include the sub-steps (one sub-step per line). | |
3. The plan should be concise and clear, and each step and sub-step should be distinct. | |
4. The plan should be unformatted and in plain text. DO NOT even use bullet points or new lines. | |
4. The number of steps should be as less as possible, but still enough to cover ALL sections. | |
5. If the user request contains any specific details, include them in the plan. | |
6. DO NOT create the final content, just the plan/outline. | |
7. DO NOT include any markdown or formatting in the plan.""" | |
chat_template = ChatPromptTemplate.from_messages([ | |
HumanMessagePromptTemplate.from_template("{text}"), | |
] | |
) | |
prompt = chat_template.invoke({"text": plan_template}) | |
response = llm.invoke(prompt) | |
plan = response.content.strip() | |
# Store plan text in state | |
state['plan'] = remove_markdown(plan) | |
print(plan) | |
return state | |
def writing_node_sync(state: GraphState, *, llm) -> GraphState: | |
print("\n---WRITING THE JSON---\n") | |
initial_prompt = state['initial_prompt'] | |
plan = state['plan'] | |
plan = plan.strip() | |
# Split the plan by lines | |
plan_lines = plan.split('\n') | |
# Our final partial JSON objects | |
partial_jsons: List[dict] = [] | |
# Return partial JSON. | |
for idx, step_line in enumerate(plan_lines): | |
if len(step_line.strip()) > 0: | |
step_prompt_text = \ | |
f"""You are creating part {idx+1} of the final JSON document. | |
User request: | |
{initial_prompt} | |
Plan step (outline): | |
{step_line.strip()} | |
Rules: | |
1. You need to write the JSON data for this step. | |
2. The JSON should be structured and valid. | |
3. If the user request contains any specific details, include them in the JSON. | |
4. If the user request contains the format of the JSON, follow it. If not, create a generic JSON as you see fit. | |
5. Respond ONLY with valid JSON for this step without any markdown or formatting.""" | |
chat_template = ChatPromptTemplate.from_messages([ | |
HumanMessagePromptTemplate.from_template("{text}"), | |
] | |
) | |
prompt = chat_template.invoke({"text": step_prompt_text}) | |
response = llm.invoke(prompt) | |
step_result = response.content.strip() | |
# Attempt to parse the partial JSON | |
try: | |
cleaned_result = remove_markdown(step_result) | |
partial_obj = json.loads(cleaned_result) | |
except json.JSONDecodeError: | |
# If the model didn't produce valid JSON, throw an error | |
raise Exception(f"Failed to parse JSON data for step {idx+1}") | |
# print(f"Step {idx+1} JSON:\n{json.dumps(partial_obj, indent=2)}\n") | |
# Add the partial JSON to the list | |
partial_jsons.append(partial_obj) | |
# Save all partial JSON in the state | |
state['write_steps'] = partial_jsons | |
return state | |
async def writing_node_async(state: GraphState, *, llm) -> GraphState: | |
async def get_partial_json(idx: int, step_line: str) -> dict: | |
step_prompt_text = \ | |
f"""You are creating part {idx+1} of the final JSON document. | |
User request: | |
{initial_prompt} | |
Plan step (outline): | |
{step_line.strip()} | |
Rules: | |
1. You need to write the JSON data for this step. | |
2. The JSON should be structured and valid. | |
3. If the user request contains any specific details, include them in the JSON. | |
4. If the user request contains the format of the JSON, follow it. If not, create a generic JSON as you see fit. | |
5. Respond ONLY with valid JSON for this step without any markdown or formatting.""" | |
chat_template = ChatPromptTemplate.from_messages([ | |
HumanMessagePromptTemplate.from_template("{text}"), | |
] | |
) | |
prompt = chat_template.invoke({"text": step_prompt_text}) | |
response = await llm.ainvoke(prompt) | |
step_result = response.content.strip() | |
cleaned_result = remove_markdown(step_result) | |
try: | |
partial_obj = json.loads(cleaned_result) | |
except json.JSONDecodeError as e: | |
raise Exception(f"Failed to parse JSON data for step {idx+1}: {e}") | |
# print(f"Step {idx+1} JSON:\n{json.dumps(partial_obj, indent=2)}\n") | |
return partial_obj | |
print("\n---WRITING THE JSON---\n") | |
initial_prompt = state['initial_prompt'] | |
plan = state['plan'].strip() | |
plan_lines = plan.split('\n') | |
partial_jsons: List[dict] = [] | |
# Build tasks for each step | |
tasks = [] | |
for idx, line in enumerate(plan_lines): | |
if len(line.strip()) > 0: | |
tasks.append(asyncio.create_task(get_partial_json(idx, line))) | |
# Run them concurrently | |
partial_jsons = await asyncio.gather(*tasks) | |
# Store results | |
state['write_steps'] = list(partial_jsons) | |
return state | |
def consolidation_node(state: GraphState) -> GraphState: | |
print("\n---CONSOLIDATING THE JSON---\n") | |
plan = state['plan'] | |
partial_jsons = state['write_steps'] | |
final_obj = { | |
"plan": plan, | |
"steps": partial_jsons | |
} | |
# Convert to string | |
final_json_str = json.dumps(final_obj, ensure_ascii=False, indent=2) | |
# Store it in the state | |
state['final_json'] = final_json_str | |
return state | |
def create_workflow_sync() -> StateGraph: | |
workflow = StateGraph(GraphState) | |
# Add nodes | |
workflow.add_node("planning_node", planning_node) | |
workflow.add_node("writing_node", writing_node_sync) | |
workflow.add_node("consolidation_node", consolidation_node) | |
# Set entry point | |
workflow.set_entry_point("planning_node") | |
# Add edges | |
workflow.add_edge("planning_node", "writing_node") | |
workflow.add_edge("writing_node", "consolidation_node") | |
# Finally, consolidation_node leads to END | |
workflow.add_edge("consolidation_node", END) | |
return workflow.compile() | |
def create_workflow_async() -> StateGraph: | |
workflow = StateGraph(GraphState) | |
# Add nodes | |
workflow.add_node("planning_node", planning_node) | |
workflow.add_node("writing_node", writing_node_async) | |
workflow.add_node("consolidation_node", consolidation_node) | |
# Set entry point | |
workflow.set_entry_point("planning_node") | |
# Add edges | |
workflow.add_edge("planning_node", "writing_node") | |
workflow.add_edge("writing_node", "consolidation_node") | |
# Finally, consolidation_node leads to END | |
workflow.add_edge("consolidation_node", END) | |
return workflow.compile() | |
if __name__ == "__main__": | |
import time | |
test_instruction = "Write a 1500-word piece on the HBO TV show Westworld, covering major characters, \ | |
themes of AI and consciousness, and how the story might have continued had it not been cancelled. \ | |
Include specific details, quotes, and references to the show and its creators.\ | |
Do not include any spoilers for the climax of the show's final season." | |
app = create_workflow_async() | |
# We supply an initial state. | |
# (We only need 'initial_prompt' here; the other fields will be set by nodes.) | |
state_input: GraphState = { | |
"initial_prompt": test_instruction, | |
"plan": "", | |
"write_steps": [], | |
"final_json": "" | |
} | |
start = time.time() | |
final_state = asyncio.run(app.ainvoke(state_input)) | |
end = time.time() | |
# The final JSON is in final_state['final_json'] | |
print("\n===== FINAL JSON OUTPUT =====\n") | |
print(final_state['final_json']) | |
print("=============================\n") | |
print("\n===== PERFOMANCE =====\n") | |
print(f"Time taken: {end-start:.2f} seconds") | |
print("======================\n") | |