from fasthtml.common import * from fasthtml.svg import * from monsterui.all import * from pathlib import Path import requests import base64 from PIL import Image import numpy as np import io import json SHOW_DEV_BUTTONS = True DEMO_DIR = Path(__file__).parent ADD_DEV_FORM = True def list_ckpt_files(dir: Path): """List all files in the ckpts directory.""" ckpt_dir = dir.absolute() files = [] if ckpt_dir.exists(): for file_path in ckpt_dir.glob("**/*"): if file_path.is_file() or file_path.is_symlink(): files.append(str(file_path)) return sorted(files) print(f'CKPT files: {list_ckpt_files(Path("ckpts"))}') print(f'Demo assets: {list_ckpt_files(DEMO_DIR / "assets")}') DEMOS = [ { "name": "Dog", "image": DEMO_DIR / "assets" / "dog.jpg", "mask": DEMO_DIR / "assets" / "dog.json", "text": "A brown bulldog.", }, { "name": "Pickup Truck", "image": DEMO_DIR / "assets" / "pickup.jpg", "mask": DEMO_DIR / "assets" / "pickup.json", "text": "A pickup truck.", }, { "name": "Taj Mahal", "image": DEMO_DIR / "assets" / "tajmahal.jpg", "mask": DEMO_DIR / "assets" / "tajmahal.json", "text": "The in .", }, { "name": "Venice", "image": DEMO_DIR / "assets" / "venice.jpg", "mask": DEMO_DIR / "assets" / "venice.json", "text": "A in.", }, { "name": "T2I", "text": "A sits at the counter of an art-deco loungebar, drinking whisky from a tumbler glass.", } ] # Use MonsterUI's theme headers. app, rt = fast_app(hdrs=(Theme.blue.headers(),)) def square_crop(image: Image.Image) -> Image.Image: width, height = image.size side = min(width, height) left = (width - side) // 2 top = (height - side) // 2 right = left + side bottom = top + side return image.crop((left, top, right, bottom)) def process(image: Image.Image, desired_resolution: int = 512) -> Image.Image: cropped_image = square_crop(image.convert("RGB")) return cropped_image.resize((desired_resolution, desired_resolution), Image.LANCZOS) def encode_image(file: Path | io.BytesIO | Image.Image) -> Dict[str, str]: if isinstance(file, Image.Image): buffered = io.BytesIO() file.save(buffered, format="JPEG") base64_str = base64.b64encode(buffered.getvalue()).decode("utf-8") elif isinstance(file, Path): with file.open("rb") as img_file: base64_str = base64.b64encode(img_file.read()).decode("utf-8") else: base64_str = base64.b64encode(file.getvalue()).decode("utf-8") return {"url": f"data:image/jpeg;base64,{base64_str}"} def encode_array_image(array: np.ndarray) -> Dict[str, str]: # Handle boolean masks more efficiently if array.dtype == bool: array = array.astype(np.uint8) * 255 im = Image.fromarray(array) buffered = io.BytesIO() im.save(buffered, format="JPEG", quality=95) base64_str = base64.b64encode(buffered.getvalue()).decode("utf-8") return {"url": f"data:image/jpeg;base64,{base64_str}"} def get_boolean_mask(mask_data: str) -> np.ndarray: """Decode compressed mask data from client""" mask_info = json.loads(mask_data) data = base64.b64decode(mask_info['data']) width, height = mask_info['width'], mask_info['height'] arr = np.frombuffer(data, dtype=np.uint8) bits = np.unpackbits(arr, bitorder='big')[:width * height] return bits.reshape((height, width)).astype(bool) def get_input_card_params(): return dict( header=Div(H4("Input"), Subtitle("You can mask the image, text, or both.")), id="input-card", title="Input", ) def create_input_card_content(text_content=""): """Create the shared input card content structure.""" content = [ Div(id="preview-container", cls="relative flex justify-center items-center mb-4 p-4 empty:p-0 empty:mb-0"), TextArea(text_content, name="user_input", id="user-input-text", cls="resize-none h-12 w-full mb-4"), Input(type="file", name="uploaded_file", id="upload-image-input", cls="mb-4"), Input(type="hidden", name="mask_data", id="mask-data") ] return content @rt("/") def get(session): demo_cards = [] for demo in DEMOS: if 'image' in demo: demo_image_url = encode_image(process(Image.open(demo['image'])))['url'] inner_content = Div( Div( Loading(cls="hidden", htmx_indicator=True), id=f"demo-spinner-{DEMOS.index(demo)}", cls="absolute inset-0 flex items-center justify-center" ), Div( Img(src=demo_image_url, cls="w-32 h-32 object-cover rounded-md transition-opacity hover:opacity-60 cursor-pointer mb-3"), cls="demo-image-container relative flex justify-center" ), P(demo['text'], cls="mt-2 text-sm text-muted-foreground group-hover:text-foreground transition-colors text-center"), cls="flex flex-col items-center p-1" ) demo_card = Card( inner_content, cls="demo-card hover:shadow-md transition-shadow cursor-pointer w-fit mx-auto", title=f"{demo['name']}", hx_post=f"/load_demo/{DEMOS.index(demo)}", hx_target="#input-card", hx_swap="innerHTML", hx_indicator=f"#demo-spinner-{DEMOS.index(demo)}" ) demo_cards.append(demo_card) js_script = fr""" document.body.addEventListener('htmx:beforeRequest', function(ev) {{ const target = ev.detail.elt.querySelector('[hx-indicator]'); if(target) target.querySelector('.loading').classList.remove('hidden'); }}); document.body.addEventListener('htmx:afterRequest', function(ev) {{ const target = ev.detail.elt.querySelector('[hx-indicator]'); if(target) target.querySelector('.loading').classList.add('hidden'); }}); const demoMaskData = {json.dumps(session.get('demo_mask'))} || undefined; if (typeof demoMaskData !== 'undefined' && demoMaskData !== null) {{ const maskInfo = JSON.parse(demoMaskData); const data = atob(maskInfo.data); const arr = new Uint8Array(data.length); for (let i = 0; i < data.length; i++) {{ arr[i] = data.charCodeAt(i); }} const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); for (let i = 0; i < arr.length * 8; i++) {{ const byteIndex = Math.floor(i / 8); const bitIndex = 7 - (i % 8); if (arr[byteIndex] & (1 << bitIndex)) {{ const x = i % canvas.width; const y = Math.floor(i / canvas.width); imageData.data[(y * canvas.width + x) * 4 + 3] = 255; }} }} ctx.putImageData(imageData, 0, 0); updateMaskData(canvas); }} function downloadMaskData() {{ const maskData = document.getElementById('mask-data').value; if (!maskData) return; const blob = new Blob([maskData], {{type: 'application/json'}}); const link = document.createElement('a'); link.href = URL.createObjectURL(blob); link.download = `mask_${{Date.now()}}.json`; link.click(); }} function initializeCanvas(img, wrapper) {{ const DISPLAY_SIZE = 256; // fixed display size in pixels // Set the preview image to the fixed size img.style.width = DISPLAY_SIZE + "px"; img.style.height = DISPLAY_SIZE + "px"; const canvas = document.createElement('canvas'); // Use our fixed display size for the canvas dimensions canvas.width = DISPLAY_SIZE; canvas.height = DISPLAY_SIZE; canvas.style.position = 'absolute'; canvas.style.top = '0'; canvas.style.left = '0'; canvas.style.cursor = 'crosshair'; const ctx = canvas.getContext('2d'); // Compute a scale factor relative to the image's natural dimensions const scaleFactor = DISPLAY_SIZE / Math.max(img.naturalWidth, img.naturalHeight); ctx.strokeStyle = 'black'; ctx.lineWidth = 35 * scaleFactor; // adjust line width proportionally let drawing = false; canvas.addEventListener('mousedown', e => {{ drawing = true; ctx.beginPath(); ctx.moveTo(e.offsetX, e.offsetY); }}); canvas.addEventListener('mousemove', e => {{ if (drawing) {{ ctx.lineTo(e.offsetX, e.offsetY); ctx.stroke(); }} }}); canvas.addEventListener('mouseup', e => {{ drawing = false; updateMaskData(canvas); }}); canvas.addEventListener('mouseleave', e => {{ if (drawing) {{ drawing = false; updateMaskData(canvas); }} }}); wrapper.appendChild(canvas); return canvas; }} function updateMaskData(canvas) {{ const ctx = canvas.getContext('2d'); const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); const data = imageData.data; const buffer = new Uint8Array(Math.ceil((canvas.width * canvas.height) / 8)); for (let i = 0; i < data.length; i += 4) {{ const pixelIndex = i / 4; const byteIndex = Math.floor(pixelIndex / 8); const bitIndex = 7 - (pixelIndex % 8); if (data[i + 3] > 0) {{ buffer[byteIndex] |= (1 << bitIndex); }} }} const base64 = btoa(String.fromCharCode(...buffer)); document.getElementById('mask-data').value = JSON.stringify({{ data: base64, width: canvas.width, height: canvas.height }}); }} function clearMask() {{ const canvas = document.querySelector('#preview-container canvas'); if (canvas) {{ const ctx = canvas.getContext('2d'); ctx.clearRect(0, 0, canvas.width, canvas.height); updateMaskData(canvas); }} document.getElementById('mask-data').value = ''; }} function clearImage() {{ // Clear file input and preview const fileInput = document.getElementById('upload-image-input'); fileInput.value = ''; // Reset file input const previewContainer = document.getElementById('preview-container'); previewContainer.innerHTML = ''; // Clear canvas and image document.getElementById('mask-data').value = ''; // Clear mask data // If there's a demo image, re-initialize it const demoImg = {json.dumps(session.get('demo_image', ''))}; if (demoImg) {{ const img = new Image(); img.onload = function() {{ const wrapper = document.createElement('div'); wrapper.style.position = 'relative'; wrapper.style.display = 'inline-block'; initializeCanvas(img, wrapper); wrapper.appendChild(img); previewContainer.appendChild(wrapper); }}; img.src = demoImg; }} }} // Helper function to square crop an image (crop centered) function squareCropImage(img) {{ const side = Math.min(img.naturalWidth, img.naturalHeight); const left = (img.naturalWidth - side) / 2; const top = (img.naturalHeight - side) / 2; const offCanvas = document.createElement("canvas"); offCanvas.width = side; offCanvas.height = side; const offCtx = offCanvas.getContext("2d"); offCtx.drawImage(img, left, top, side, side, 0, 0, side, side); return offCanvas.toDataURL("image/jpeg"); }} // Listen for file uploads and square crop the image before previewing document.getElementById('upload-image-input').addEventListener('change', function(event) {{ const file = event.target.files[0]; if (file) {{ const img = new Image(); img.onload = function() {{ // Square crop the loaded image const croppedDataUrl = squareCropImage(img); const croppedImg = new Image(); croppedImg.onload = function() {{ const previewContainer = document.getElementById('preview-container'); previewContainer.innerHTML = ''; const wrapper = document.createElement('div'); wrapper.style.position = 'relative'; wrapper.style.display = 'inline-block'; croppedImg.style.display = 'block'; croppedImg.style.maxWidth = '100%'; initializeCanvas(croppedImg, wrapper); wrapper.appendChild(croppedImg); previewContainer.appendChild(wrapper); }}; croppedImg.src = croppedDataUrl; }}; img.src = URL.createObjectURL(file); }} }}); """ main_content = Container( Div( DivFullySpaced( Style(""" .top-left { position: absolute; top: 3%; left: 2%; /* Additional styling as needed */ } .custom_middle { position: relative; top: 0%; left: 50%; transform: translate(-50%, 0%); /* Additional styling as needed */ } """), H1("UniDisc Demo", cls="text-4xl font-light tracking-tight top-left"), Div(*demo_cards, cls="grid grid-cols-3 gap-4 max-w-5xl custom_middle"), cls="flex items-center justify-between mb-8 px-4" ), Form( Grid( Card( Div(*create_input_card_content()), **get_input_card_params() ), Card( Div(id="output-content", cls="space-y-4"), header=Div(H4("Output")), id="output-card", title="Output" ), cls="grid grid-cols-2 gap-6 mb-0" ), CardFooter( Grid( Button( Div( Span("Submit", cls="submit-text"), Loading(cls="hidden h-4 w-4 animate-spin", id='loading', htmx_indicator=True), cls="flex gap-2 items-center justify-center" ), cls=(ButtonT.primary,'w-full'), hx_indicator="this .loading" ), Button("Clear Mask", type="button", cls=(ButtonT.primary,'w-full'), onclick="clearMask()"), Button("Clear Image", type="button", cls=(ButtonT.primary,'w-full'), onclick="clearImage()"), *([Button("Download Mask", type="button", cls=(ButtonT.primary, 'w-full', 'dev-only'), onclick="downloadMaskData()")] if SHOW_DEV_BUTTONS else []), Button( # DivFullySpaced(UkIcon('move-down', 20, 20, 3),"Sampling Configs"), "Sampling Configs", uk_toggle="target: #config-modal", id="config-modal-button", cls=(ButtonT.primary, 'w-full') ), cls="grid grid-cols-4 gap-2" ), ), Card( Grid( Div( LabelInput("Max Tokens", name="max_tokens", type="number", value=32, cls="w-full"), cls="space-y-1.5" ), Div( LabelSelect( *Options(256, 512, 1024, selected_idx=1), name="resolution", label="Resolution", cls="w-full", ), cls="space-y-1.5" ), Div( LabelInput("Sampling Steps", name="sampling_steps", type="number", value=32, cls="w-full"), cls="space-y-1.5" ), Div( LabelInput("Top P", name="top_p", type="number", value=0.95, step="0.01", cls="w-full"), cls="space-y-1.5" ), Div( LabelInput("Temperature", name="temperature", type="number", value=0.9, step="0.1", min_value="0.0", max_value="2.0", cls="w-full"), cls="space-y-1.5" ), Div( LabelInput("MaskGit R Temp", name="maskgit_r_temp", type="number", value=4.5, step="0.1", cls="w-full"), cls="space-y-1.5" ), Div( LabelInput("CFG", name="cfg", type="number", value=2.5, step="0.1", cls="w-full"), cls="space-y-1.5" ), Div( LabelSelect( *Options("maskgit", "maskgit_nucleus", "ddpm_cache", selected_idx=1), name="sampler", label="Sampler", cls="w-full", ), cls="space-y-1.5" ), *([ Div( LabelInput("Port", name="port", type="number", value=8001, step="0.01", cls="w-full"), cls="space-y-1.5" ), Div( LabelSelect(*Options("False", "True", selected_idx=0), name="reward_models",label="Reward Models", cls="w-full",), cls="space-y-1.5" ) ] if ADD_DEV_FORM else []), Hidden(name="save_mask_enabled", value="True"), cls="grid grid-cols-4 gap-4", ), cls="mb-6", title="Configuration", id="config-modal", hidden=True ), hx_swap="innerHTML", hx_target="#output-content", hx_post="/submit", enctype="multipart/form-data", cls="mb-6" ), ), Script(js_script), ) return main_content @rt("/load_demo/{demo_index}") def post(demo_index: int, session): demo = DEMOS[demo_index] if 'image' in demo: session['demo_image'] = encode_image(process(Image.open(demo['image'])))['url'] if 'text' in demo: session['demo_text'] = demo['text'] if 'mask' in demo and demo['mask'] and Path(demo['mask']).exists(): session['demo_mask'] = json.loads(Path(demo['mask']).read_text()) else: session['demo_mask'] = None content = create_input_card_content( text_content=session['demo_text'], ) mask_json = 'undefined' if not session['demo_mask'] else json.dumps(session['demo_mask']) content.append(Script(fr""" img = new Image(); img.onload = async function() {{ const previewContainer = document.getElementById('preview-container'); previewContainer.innerHTML = ''; const wrapper = document.createElement('div'); wrapper.style.position = 'relative'; wrapper.style.display = 'inline-block'; const canvas = initializeCanvas(img, wrapper); const ctx = canvas.getContext('2d'); wrapper.appendChild(img); previewContainer.appendChild(wrapper); const dataUrl = {json.dumps(session.get('demo_image', ''))}; const base64Data = dataUrl.split(',')[1]; const byteCharacters = atob(base64Data); const byteArrays = []; for (let offset = 0; offset < byteCharacters.length; offset += 1024) {{ const slice = byteCharacters.slice(offset, offset + 1024); const byteNumbers = new Array(slice.length); for (let i = 0; i < slice.length; i++) {{ byteNumbers[i] = slice.charCodeAt(i); }} const byteArray = new Uint8Array(byteNumbers); byteArrays.push(byteArray); }} const blob = new Blob(byteArrays, {{ type: 'image/jpeg' }}); const file = new File([blob], "demo_image.jpg", {{ type: 'image/jpeg', lastModified: Date.now() }}); const dataTransfer = new DataTransfer(); dataTransfer.items.add(file); const fileInput = document.getElementById('upload-image-input'); fileInput.files = dataTransfer.files; fileInput.dispatchEvent(new Event('change')); const demoMaskData = {mask_json}; if (typeof demoMaskData !== 'undefined' && demoMaskData !== null) {{ const maskInfo = demoMaskData; const data = atob(maskInfo.data); const arr = new Uint8Array(data.length); for (let i = 0; i < data.length; i++) {{ arr[i] = data.charCodeAt(i); }} const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); for (let i = 0; i < arr.length * 8; i++) {{ const byteIndex = Math.floor(i / 8); const bitIndex = 7 - (i % 8); if (arr[byteIndex] & (1 << bitIndex)) {{ const x = i % canvas.width; const y = Math.floor(i / canvas.width); imageData.data[(y * canvas.width + x) * 4 + 3] = 255; }} }} ctx.putImageData(imageData, 0, 0); updateMaskData(canvas); }} }}; img.src = {json.dumps(session.get('demo_image', ''))}; """)) return Card( Div(*content), **get_input_card_params() ) @rt("/submit") def post( req, temperature: float, top_p: float, maskgit_r_temp: float, cfg: float, max_tokens: int, resolution: int, sampling_steps: int, sampler: str, user_input: str | None = None, mask_data: str | None = None, uploaded_file: UploadFile | None = None, port: int | None = 8001, reward_models: str | None = "False" ): payload_messages = [] if user_input: payload_messages.append({"role": "user", "content": [{"type": "text", "text": user_input}]}) image_message_content = [] current_image = None if uploaded_file is not None and uploaded_file.filename != "No image": current_image = process(Image.open(io.BytesIO(uploaded_file.file.read())), int(resolution)) img_data = encode_image(current_image)["url"] image_message_content.append({ "type": "image_url", "image_url": {"url": img_data}, "is_mask": False }) if mask_data is not None and len(mask_data) > 0: mask_array = get_boolean_mask(mask_data) mask_data_url = encode_array_image(mask_array)["url"] image_message_content.append({ "type": "image_url", "image_url": {"url": mask_data_url}, "is_mask": True }) if image_message_content: payload_messages.append({"role": "assistant", "content": image_message_content}) config_payload = { "max_tokens": int(max_tokens), "resolution": int(resolution), "sampling_steps": int(sampling_steps), "top_p": float(top_p), "temperature": float(temperature), "maskgit_r_temp": float(maskgit_r_temp), "cfg": float(cfg), "sampler": sampler, "use_reward_models": reward_models == "True" } payload = { "messages": payload_messages, "model": "unidisc", **config_payload } API_URL = f"http://localhost:{port}/v1/chat/completions" response = requests.post(API_URL, json=payload) components = [] if response.status_code == 200: response_json = response.json() if "choices" in response_json: content = response_json["choices"][0]["message"]["content"] if isinstance(content, list): for part in content: if part["type"] == "text": components.append(Card( P(part["text"], cls="p-4"), cls="response-card mb-4", title="Response" )) elif part["type"] == "image_url": components.append( Card( Div( Img( src=part["image_url"]["url"], cls="w-64 h-64 object-cover rounded-md" ), cls="flex justify-center items-center p-4" ), cls="response-card mb-4" ) ) else: components.append(Card(P(content, cls="p-4"), cls="response-card mb-4", title="Response")) else: components.append(Card(P(f"API Error: {response.text}"), cls="response-card destructive mb-4", title="Error")) output_content = Div(*components, id="output-content", cls="space-y-4 flex flex-col") return output_content print(f"Before serve...") serve(port=5003)