import streamlit as st import requests from pathlib import Path import base64 from PIL import Image import numpy as np import io import uuid from streamlit_drawable_canvas import st_canvas from demo.api_data_defs import ChatRequest, ChatMessage, ContentPart from typing import Dict import time import json API_URL = "http://localhost:8000/v1/chat/completions" DEMO_DIR = Path("demo") def square_crop(image: Image.Image) -> Image.Image: width, height = image.size side = min(width, height) left = (width - side) // 2 top = (height - side) // 2 right = left + side bottom = top + side return image.crop((left, top, right, bottom)) def process(image: Image.Image, desired_resolution: int = 256) -> Image.Image: cropped_image = square_crop(image.convert("RGB")) return cropped_image.resize( (int(desired_resolution), int(desired_resolution)), Image.LANCZOS ) DEMOS = [ { "name": "Dog", "image": DEMO_DIR / "assets" / "dog.jpg", "mask": DEMO_DIR / "assets" / "dog.json", "text": "A corgi playing in the snow", }, { "name": "Landscape", "image": DEMO_DIR / "assets" / "mountain.jpg", "mask": DEMO_DIR / "assets" / "mountain.json", "text": "Snowy mountain peak.", }, { "name": "Architecture", "image": DEMO_DIR / "assets" / "building.jpg", "mask": DEMO_DIR / "assets" / "building.json", "text": "Modern glass skyscraper", } ] # Custom CSS for animations and layout st.markdown(""" """, unsafe_allow_html=True) def load_demo_assets(demo, config): """Load demo assets with error handling""" try: st.session_state.demo_image = process(Image.open(demo["image"]), config["resolution"]) st.session_state.original_image = np.array(st.session_state.demo_image) st.session_state.demo_text = demo["text"] if demo["mask"].exists(): with demo["mask"].open("r") as f: print(f"Loaded mask from {demo['mask']}") st.session_state.initial_drawing = json.load(f) breakpoint() else: st.warning(f"Mask not found for {demo['name']}") st.session_state.initial_drawing = None except Exception as e: st.error(f"Failed to load {demo['name']} demo: {str(e)}") def encode_image(file: Path | io.BytesIO | Image.Image) -> Dict[str, str]: if isinstance(file, Image.Image): buffered = io.BytesIO() file.save(buffered, format="JPEG") base64_str = base64.b64encode(buffered.getvalue()).decode("utf-8") elif isinstance(file, Path): with file.open("rb") as img_file: base64_str = base64.b64encode(img_file.read()).decode("utf-8") else: base64_str = base64.b64encode(file.getvalue()).decode("utf-8") return {"url": f"data:image/jpeg;base64,{base64_str}"} def encode_array_image(array: np.ndarray) -> Dict[str, str]: im = Image.fromarray(array) if isinstance(array, np.ndarray) else array buffered = io.BytesIO() im.save(buffered, format="JPEG") base64_str = base64.b64encode(buffered.getvalue()).decode("utf-8") return {"url": f"data:image/jpeg;base64,{base64_str}"} def get_boolean_mask(canvas_data): if canvas_data is None or canvas_data.image_data is None: return None, None mask_data = canvas_data.json_data.get("objects", []) if not mask_data: return np.zeros_like(st.session_state.original_image, dtype=np.uint8), None mask = np.zeros(st.session_state.original_image.shape[:2], dtype=np.uint8) for obj in mask_data: if obj.get("type") == "path": path = obj.get("path") # Custom processing of the path could be added here return mask * 255, None # Initialize session state variables if "demo_image" not in st.session_state: st.session_state.demo_image = None if "demo_text" not in st.session_state: st.session_state.demo_text = "" if "initial_drawing" not in st.session_state: st.session_state.initial_drawing = None if "original_image" not in st.session_state: st.session_state.original_image = None if "stroke_image" not in st.session_state: st.session_state.stroke_image = None if "response" not in st.session_state: st.session_state.response = None # Main UI title and demo selection st.title("Image + Text Input Demo") # Add configuration options in sidebar before any processing st.sidebar.header("Configuration") config = { "max_tokens": st.sidebar.number_input("Max Tokens", value=32, min_value=1, key="max_tokens"), "resolution": st.sidebar.number_input("Resolution", value=256, min_value=64, key="resolution"), "sampling_steps": st.sidebar.number_input("Sampling Steps", value=32, min_value=1, key="sampling_steps"), "top_p": st.sidebar.number_input("Top P", value=0.95, min_value=0.0, max_value=1.0, key="top_p"), "temperature": st.sidebar.number_input("Temperature", value=0.9, min_value=0.0, max_value=2.0, key="temperature"), "maskgit_r_temp": st.sidebar.number_input("MaskGit R Temp", value=4.5, min_value=0.0, key="maskgit_r_temp"), "cfg": st.sidebar.number_input("CFG", value=2.5, min_value=0.0, key="cfg"), "sampler": st.sidebar.selectbox( "Sampler", options=["maskgit", "maskgit_nucleus", "ddpm_cache"], index=1, key="sampler" ), "save_mask_enabled": True } st.subheader("Example Inputs") with st.container(): cols = st.columns(len(DEMOS)) for col, demo in zip(cols, DEMOS): with col: try: demo_html = f"""