MajorProject / app.py
aniruddh1907's picture
Update app.py (#1)
44d8ee2 verified
import base64
import json
import os
import random
import re
import shutil
import sys
import tempfile
import uuid
import requests
from datetime import datetime
from io import BytesIO
from pathlib import Path
import gradio as gr
from PIL import Image
from dotenv import load_dotenv
from graphviz import Digraph
from huggingface_hub import InferenceClient
from together import Together
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ENV / API
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN") # <-- add your HF token to .env
TOGETHER_TOKEN = os.getenv("TOGETHER_API_KEY", "")
together_client = Together(api_key=TOGETHER_TOKEN)
image_client = InferenceClient(token=HF_TOKEN) # default model set later
# Optional Graphviz path helper (Windows ONLY (RIP Gotham))
# if shutil.which("dot") is None:
# gv_path = r"C:\Program Files\Graphviz\bin"
# if os.path.exists(gv_path):
# os.environ["PATH"] = gv_path + os.pathsep + os.environ["PATH"]
# else:
# sys.exit("Graphviz not found. Please install Graphviz or remove the check.")
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# LLM templates
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
LLAMA_JSON_PROMPT = """
Extract every character and any explicit relationship between them.
Return pure JSON ONLY in this schema:
{
"characters": ["Alice", "Bob"],
"relations": [
{"from":"Alice","to":"Bob","type":"friend"}
]
}
TEXT:
\"\"\"%s\"\"\"
"""
IMAGE_PROMPT_TEMPLATE = """
Based on the following story, write %d distinct vivid scene descriptions, one per line.
Each line should begin with a dash (-) followed by a detailed image-worthy scene.
Include setting, mood, characters, and visual cues.
Return ONLY the list of scenes, each on its own line.
Story:
\"\"\"%s\"\"\"
"""
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Entity extraction
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def extract_entities(text: str):
try:
prompt = LLAMA_JSON_PROMPT % text
resp = together_client.chat.completions.create(
model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
messages=[{"role": "user", "content": prompt}],
max_tokens=1024,
)
raw = resp.choices[0].message.content.strip()
m = re.search(r"\{[\s\S]*\}", raw)
if not m:
return None, f"โš ๏ธย No JSON block found.\n\n{raw}"
data = json.loads(m.group(0))
return data, None
except Exception as e:
return None, f"โš ๏ธย extractor error: {e}"
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Build visual prompt
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def generate_image_prompts(story_text: str, count=1):
try:
prompt_msg = IMAGE_PROMPT_TEMPLATE % (count, story_text)
resp = together_client.chat.completions.create(
model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
messages=[{"role": "user", "content": prompt_msg}],
max_tokens=200,
)
raw_output = resp.choices[0].message.content.strip()
prompts = [line.strip("-โ€ข ").strip() for line in raw_output.split("\n") if line.strip()]
return prompts[:count] # just in case LLM gives more than needed
except Exception as e:
print("โš ๏ธ LLM scene prompt generation failed:", e)
return []
def generate_images_with_together(story, style, quality, count=1):
base_prompt = generate_image_prompts(story)
images = []
for i in range(count):
full_prompt = f"{style} style, cinematic lighting, quality {quality}, {base_prompt} [Scene {i + 1}]"
seed = random.randint(1, 10_000_000)
try:
resp = together_client.images.generate(
model="black-forest-labs/FLUX.1-schnell-Free",
prompt=full_prompt,
seed=seed,
width=768,
height=512,
steps=4
)
except Exception as e:
print("๐Ÿ”ฅ Together image API error:", e)
break
img = None
if resp.data:
choice = resp.data[0]
if getattr(choice, "url", None):
try:
img_bytes = requests.get(choice.url, timeout=30).content
img = Image.open(BytesIO(img_bytes))
except Exception as e:
print("โš ๏ธย URL fetch failed:", e)
elif getattr(choice, "b64_json", None):
try:
img_bytes = base64.b64decode(choice.b64_json)
img = Image.open(BytesIO(img_bytes))
except Exception as e:
print("โš ๏ธย base64 decode failed:", e)
if img is not None:
images.append(img)
else:
print(f"โš ๏ธย No image for scene {i+1}")
return images
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Graph โ†’ PNG (Graphviz)
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def build_graph_png(data: dict) -> str:
dot = Digraph(format="png")
dot.attr(rankdir="LR", bgcolor="white", fontsize="11")
for c in data["characters"]:
dot.node(c, shape="ellipse", style="filled", fillcolor="#8ecae6")
for r in data["relations"]:
dot.edge(r["from"], r["to"], label=r["type"], fontsize="10")
tmpdir = Path(tempfile.mkdtemp())
path = tmpdir / f"graph_{uuid.uuid4().hex}.png"
dot.render(path.stem, directory=tmpdir, cleanup=True)
return str(path)
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Core generation
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def generate_assets(prompt, style, quality, num_images, state):
data, err = extract_entities(prompt)
if not data:
return [], None, err or "No data.", state
graph_path = build_graph_png(data)
images = []
if num_images > 0:
try:
images = generate_images_with_together(prompt, style, quality, int(num_images))
except Exception as e:
status = f"โš ๏ธ Image generation failed: {e}"
return [], graph_path, status, data
status = "โœ… All assets generated." if images else "โœ… Graph generated (no images)."
return images, graph_path, status, data
# Helper to rebuild graph after manual edits
def _regen_graph(state): return gr.update(value=build_graph_png(state))
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Manual tweak callbacks
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def add_character(name, state):
if not name:
return gr.update(), "Enter a character name.", state
if name in state["characters"]:
return gr.update(), f"{name} already exists.", state
state["characters"].append(name)
return _regen_graph(state), "โœ…ย Character added.", state
def add_relation(frm, to, typ, state):
if frm not in state["characters"] or to not in state["characters"]:
return gr.update(), "Both characters must exist first.", state
state["relations"].append({"from": frm, "to": to, "type": typ or "relation"})
return _regen_graph(state), "โœ…ย Relation added.", state
def delete_character(name, state):
if name not in state["characters"]:
return gr.update(), "Character not found.", state
state["characters"].remove(name)
state["relations"] = [r for r in state["relations"] if r["from"] != name and r["to"] != name]
return _regen_graph(state), f"๐Ÿšฎย {name} deleted.", state
# Save / Load
def save_json(state):
fp = Path(tempfile.gettempdir()) / f"story_{datetime.now().isoformat()}.json"
fp.write_text(json.dumps(state, indent=2))
return str(fp)
def load_json(file_obj, state):
if not file_obj or not Path(file_obj).exists():
return gr.update(), "No file uploaded.", state
try:
data = json.loads(Path(file_obj).read_text())
assert "characters" in data and "relations" in data
return _regen_graph(data), "โœ…ย File loaded.", data
except Exception as e:
return gr.update(), f"Load error: {e}", state
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# UI (same tabs you designed)
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan")) as demo:
gr.Markdown("## โœจ EpicFrame โ€“ Narrative Workbench")
state = gr.State({"characters": [], "relations": []})
# Input tab
with gr.Tab("Input"):
text_input = gr.Textbox(label="Story prompt", lines=6)
style_dropdown = gr.Dropdown(["Realistic", "Anime", "Sketch"], value="Realistic", label="Style")
quality_slider = gr.Slider(1, 10, value=7, step=1, label="Image Quality")
num_images_sl = gr.Slider(0, 4, value=0, step=1, label="Images to generate (0 = skip)")
generate_btn = gr.Button("โ–ถ๏ธ Generate Assets")
status_box = gr.Textbox(label="Status", lines=2)
# Images tab
with gr.Tab("Images"):
gallery = gr.Gallery(label="๐Ÿ–ผ๏ธ Images", columns=4)
# Graph/Edit tab
with gr.Tab("Graph / Edit"):
graph_img = gr.Image(label="๐Ÿ“Œ Character Map", interactive=False, height=500)
with gr.Row():
add_char_name = gr.Textbox(label="Add Character โ†’ Name")
add_char_btn = gr.Button("Add")
with gr.Row():
rel_from = gr.Textbox(label="Relation From")
rel_to = gr.Textbox(label="To")
rel_type = gr.Textbox(label="Type")
add_rel_btn = gr.Button("Add Relation")
with gr.Row():
del_char_name = gr.Textbox(label="Delete Character โ†’ Name")
del_char_btn = gr.Button("Delete")
tweak_msg = gr.Textbox(label="โžฐ Status", max_lines=2)
# Save/Load tab
with gr.Tab("Save / Load"):
save_btn = gr.Button("๐Ÿ’พ Download JSON")
load_file = gr.File(label="Load JSON")
load_btn = gr.Button("โคต๏ธ Load into workspace")
save_msg = gr.Textbox(label="Status", max_lines=2)
# callbacks
generate_btn.click(
generate_assets,
inputs=[text_input, style_dropdown, quality_slider, num_images_sl, state],
outputs=[gallery, graph_img, status_box, state]
)
add_char_btn.click(add_character,
inputs=[add_char_name, state],
outputs=[graph_img, tweak_msg, state])
add_rel_btn.click(add_relation,
inputs=[rel_from, rel_to, rel_type, state],
outputs=[graph_img, tweak_msg, state])
del_char_btn.click(delete_character,
inputs=[del_char_name, state],
outputs=[graph_img, tweak_msg, state])
save_btn.click(save_json, inputs=state, outputs=save_btn, api_name="download") \
.then(lambda p: "โœ… JSON ready.", outputs=save_msg)
load_btn.click(load_json, inputs=[load_file, state],
outputs=[graph_img, save_msg, state])
demo.launch()