Spaces:
Build error
Build error
import base64 | |
import json | |
import os | |
import random | |
import re | |
import shutil | |
import sys | |
import tempfile | |
import uuid | |
import requests | |
from datetime import datetime | |
from io import BytesIO | |
from pathlib import Path | |
import gradio as gr | |
from PIL import Image | |
from dotenv import load_dotenv | |
from graphviz import Digraph | |
from huggingface_hub import InferenceClient | |
from together import Together | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
# ENV / API | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
load_dotenv() | |
HF_TOKEN = os.getenv("HF_TOKEN") # <-- add your HF token to .env | |
TOGETHER_TOKEN = os.getenv("TOGETHER_API_KEY", "") | |
together_client = Together(api_key=TOGETHER_TOKEN) | |
image_client = InferenceClient(token=HF_TOKEN) # default model set later | |
# Optional Graphviz path helper (Windows ONLY (RIP Gotham)) | |
# if shutil.which("dot") is None: | |
# gv_path = r"C:\Program Files\Graphviz\bin" | |
# if os.path.exists(gv_path): | |
# os.environ["PATH"] = gv_path + os.pathsep + os.environ["PATH"] | |
# else: | |
# sys.exit("Graphviz not found. Please install Graphviz or remove the check.") | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
# LLM templates | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
LLAMA_JSON_PROMPT = """ | |
Extract every character and any explicit relationship between them. | |
Return pure JSON ONLY in this schema: | |
{ | |
"characters": ["Alice", "Bob"], | |
"relations": [ | |
{"from":"Alice","to":"Bob","type":"friend"} | |
] | |
} | |
TEXT: | |
\"\"\"%s\"\"\" | |
""" | |
IMAGE_PROMPT_TEMPLATE = """ | |
Based on the following story, write %d distinct vivid scene descriptions, one per line. | |
Each line should begin with a dash (-) followed by a detailed image-worthy scene. | |
Include setting, mood, characters, and visual cues. | |
Return ONLY the list of scenes, each on its own line. | |
Story: | |
\"\"\"%s\"\"\" | |
""" | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
# Entity extraction | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
def extract_entities(text: str): | |
try: | |
prompt = LLAMA_JSON_PROMPT % text | |
resp = together_client.chat.completions.create( | |
model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", | |
messages=[{"role": "user", "content": prompt}], | |
max_tokens=1024, | |
) | |
raw = resp.choices[0].message.content.strip() | |
m = re.search(r"\{[\s\S]*\}", raw) | |
if not m: | |
return None, f"โ ๏ธย No JSON block found.\n\n{raw}" | |
data = json.loads(m.group(0)) | |
return data, None | |
except Exception as e: | |
return None, f"โ ๏ธย extractor error: {e}" | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
# Build visual prompt | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
def generate_image_prompts(story_text: str, count=1): | |
try: | |
prompt_msg = IMAGE_PROMPT_TEMPLATE % (count, story_text) | |
resp = together_client.chat.completions.create( | |
model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", | |
messages=[{"role": "user", "content": prompt_msg}], | |
max_tokens=200, | |
) | |
raw_output = resp.choices[0].message.content.strip() | |
prompts = [line.strip("-โข ").strip() for line in raw_output.split("\n") if line.strip()] | |
return prompts[:count] # just in case LLM gives more than needed | |
except Exception as e: | |
print("โ ๏ธ LLM scene prompt generation failed:", e) | |
return [] | |
def generate_images_with_together(story, style, quality, count=1): | |
base_prompt = generate_image_prompts(story) | |
images = [] | |
for i in range(count): | |
full_prompt = f"{style} style, cinematic lighting, quality {quality}, {base_prompt} [Scene {i + 1}]" | |
seed = random.randint(1, 10_000_000) | |
try: | |
resp = together_client.images.generate( | |
model="black-forest-labs/FLUX.1-schnell-Free", | |
prompt=full_prompt, | |
seed=seed, | |
width=768, | |
height=512, | |
steps=4 | |
) | |
except Exception as e: | |
print("๐ฅ Together image API error:", e) | |
break | |
img = None | |
if resp.data: | |
choice = resp.data[0] | |
if getattr(choice, "url", None): | |
try: | |
img_bytes = requests.get(choice.url, timeout=30).content | |
img = Image.open(BytesIO(img_bytes)) | |
except Exception as e: | |
print("โ ๏ธย URL fetch failed:", e) | |
elif getattr(choice, "b64_json", None): | |
try: | |
img_bytes = base64.b64decode(choice.b64_json) | |
img = Image.open(BytesIO(img_bytes)) | |
except Exception as e: | |
print("โ ๏ธย base64 decode failed:", e) | |
if img is not None: | |
images.append(img) | |
else: | |
print(f"โ ๏ธย No image for scene {i+1}") | |
return images | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
# Graph โ PNG (Graphviz) | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
def build_graph_png(data: dict) -> str: | |
dot = Digraph(format="png") | |
dot.attr(rankdir="LR", bgcolor="white", fontsize="11") | |
for c in data["characters"]: | |
dot.node(c, shape="ellipse", style="filled", fillcolor="#8ecae6") | |
for r in data["relations"]: | |
dot.edge(r["from"], r["to"], label=r["type"], fontsize="10") | |
tmpdir = Path(tempfile.mkdtemp()) | |
path = tmpdir / f"graph_{uuid.uuid4().hex}.png" | |
dot.render(path.stem, directory=tmpdir, cleanup=True) | |
return str(path) | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
# Core generation | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
def generate_assets(prompt, style, quality, num_images, state): | |
data, err = extract_entities(prompt) | |
if not data: | |
return [], None, err or "No data.", state | |
graph_path = build_graph_png(data) | |
images = [] | |
if num_images > 0: | |
try: | |
images = generate_images_with_together(prompt, style, quality, int(num_images)) | |
except Exception as e: | |
status = f"โ ๏ธ Image generation failed: {e}" | |
return [], graph_path, status, data | |
status = "โ All assets generated." if images else "โ Graph generated (no images)." | |
return images, graph_path, status, data | |
# Helper to rebuild graph after manual edits | |
def _regen_graph(state): return gr.update(value=build_graph_png(state)) | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
# Manual tweak callbacks | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
def add_character(name, state): | |
if not name: | |
return gr.update(), "Enter a character name.", state | |
if name in state["characters"]: | |
return gr.update(), f"{name} already exists.", state | |
state["characters"].append(name) | |
return _regen_graph(state), "โ ย Character added.", state | |
def add_relation(frm, to, typ, state): | |
if frm not in state["characters"] or to not in state["characters"]: | |
return gr.update(), "Both characters must exist first.", state | |
state["relations"].append({"from": frm, "to": to, "type": typ or "relation"}) | |
return _regen_graph(state), "โ ย Relation added.", state | |
def delete_character(name, state): | |
if name not in state["characters"]: | |
return gr.update(), "Character not found.", state | |
state["characters"].remove(name) | |
state["relations"] = [r for r in state["relations"] if r["from"] != name and r["to"] != name] | |
return _regen_graph(state), f"๐ฎย {name} deleted.", state | |
# Save / Load | |
def save_json(state): | |
fp = Path(tempfile.gettempdir()) / f"story_{datetime.now().isoformat()}.json" | |
fp.write_text(json.dumps(state, indent=2)) | |
return str(fp) | |
def load_json(file_obj, state): | |
if not file_obj or not Path(file_obj).exists(): | |
return gr.update(), "No file uploaded.", state | |
try: | |
data = json.loads(Path(file_obj).read_text()) | |
assert "characters" in data and "relations" in data | |
return _regen_graph(data), "โ ย File loaded.", data | |
except Exception as e: | |
return gr.update(), f"Load error: {e}", state | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
# UI (same tabs you designed) | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan")) as demo: | |
gr.Markdown("## โจ EpicFrame โ Narrative Workbench") | |
state = gr.State({"characters": [], "relations": []}) | |
# Input tab | |
with gr.Tab("Input"): | |
text_input = gr.Textbox(label="Story prompt", lines=6) | |
style_dropdown = gr.Dropdown(["Realistic", "Anime", "Sketch"], value="Realistic", label="Style") | |
quality_slider = gr.Slider(1, 10, value=7, step=1, label="Image Quality") | |
num_images_sl = gr.Slider(0, 4, value=0, step=1, label="Images to generate (0 = skip)") | |
generate_btn = gr.Button("โถ๏ธ Generate Assets") | |
status_box = gr.Textbox(label="Status", lines=2) | |
# Images tab | |
with gr.Tab("Images"): | |
gallery = gr.Gallery(label="๐ผ๏ธ Images", columns=4) | |
# Graph/Edit tab | |
with gr.Tab("Graph / Edit"): | |
graph_img = gr.Image(label="๐ Character Map", interactive=False, height=500) | |
with gr.Row(): | |
add_char_name = gr.Textbox(label="Add Character โ Name") | |
add_char_btn = gr.Button("Add") | |
with gr.Row(): | |
rel_from = gr.Textbox(label="Relation From") | |
rel_to = gr.Textbox(label="To") | |
rel_type = gr.Textbox(label="Type") | |
add_rel_btn = gr.Button("Add Relation") | |
with gr.Row(): | |
del_char_name = gr.Textbox(label="Delete Character โ Name") | |
del_char_btn = gr.Button("Delete") | |
tweak_msg = gr.Textbox(label="โฐ Status", max_lines=2) | |
# Save/Load tab | |
with gr.Tab("Save / Load"): | |
save_btn = gr.Button("๐พ Download JSON") | |
load_file = gr.File(label="Load JSON") | |
load_btn = gr.Button("โคต๏ธ Load into workspace") | |
save_msg = gr.Textbox(label="Status", max_lines=2) | |
# callbacks | |
generate_btn.click( | |
generate_assets, | |
inputs=[text_input, style_dropdown, quality_slider, num_images_sl, state], | |
outputs=[gallery, graph_img, status_box, state] | |
) | |
add_char_btn.click(add_character, | |
inputs=[add_char_name, state], | |
outputs=[graph_img, tweak_msg, state]) | |
add_rel_btn.click(add_relation, | |
inputs=[rel_from, rel_to, rel_type, state], | |
outputs=[graph_img, tweak_msg, state]) | |
del_char_btn.click(delete_character, | |
inputs=[del_char_name, state], | |
outputs=[graph_img, tweak_msg, state]) | |
save_btn.click(save_json, inputs=state, outputs=save_btn, api_name="download") \ | |
.then(lambda p: "โ JSON ready.", outputs=save_msg) | |
load_btn.click(load_json, inputs=[load_file, state], | |
outputs=[graph_img, save_msg, state]) | |
demo.launch() | |