Spaces:
Sleeping
Sleeping
File size: 4,133 Bytes
24a4bd7 03e7772 8fd199a 9549eae 03e7772 8fd199a 24a4bd7 03e7772 24a4bd7 8fd199a 03e7772 24a4bd7 03e7772 24a4bd7 03e7772 8fd199a 03e7772 24a4bd7 03e7772 24a4bd7 03e7772 24a4bd7 03e7772 24a4bd7 8fd199a 24a4bd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import gradio as gr
import sqlite3
import threading
import time
import os
import shutil
from gradio_client import Client, handle_file
# import spaces
# Database setup
conn = sqlite3.connect('/tmp/jobs.db', check_same_thread=False)
c = conn.cursor()
c.execute('''CREATE TABLE IF NOT EXISTS jobs
(id INTEGER PRIMARY KEY, image_path TEXT, job_id TEXT, status TEXT, output_path TEXT)''')
conn.commit()
# TRELLIS API client
trellis_client = Client("jkorstad/TRELLIS")
# @spaces.GPU
# Processing logic with three-step TRELLIS workflow
def process_job(job_id):
try:
# Get the uploaded image path
c.execute("SELECT image_path FROM jobs WHERE id=?", (job_id,))
image_path = c.fetchone()[0]
# Step 1: Preprocess the image
c.execute("UPDATE jobs SET status='preprocessing' WHERE id=?", (job_id,))
conn.commit()
preprocessed_image = trellis_client.predict(
image=handle_file(image_path),
api_name="/preprocess_image"
)
# Step 2: Generate 3D asset
c.execute("UPDATE jobs SET status='generating' WHERE id=?", (job_id,))
conn.commit()
time.sleep(5) # Wait between steps; adjust based on observed timing
result_3d = trellis_client.predict(
image=handle_file(preprocessed_image), # Use preprocessed image
multiimages=[],
seed=0, # Default; could make configurable
ss_guidance_strength=7.5,
ss_sampling_steps=12,
slat_guidance_strength=3,
slat_sampling_steps=12,
multiimage_algo="stochastic",
api_name="/image_to_3d"
)
video_path = result_3d['video'] # Extract video filepath from dict
# Step 3: Extract GLB
c.execute("UPDATE jobs SET status='extracting' WHERE id=?", (job_id,))
conn.commit()
time.sleep(10) # Wait for 3D processing; adjust as needed
glb_result = trellis_client.predict(
mesh_simplify=0.95,
texture_size=1024,
api_name="/extract_glb"
)
glb_path = glb_result[0] # First element is the GLB filepath
# Move GLB to persistent storage
output_path = f'/tmp/outputs/result_{job_id}.glb'
os.makedirs('/tmp/outputs', exist_ok=True)
shutil.move(glb_path, output_path)
# Update job status
c.execute("UPDATE jobs SET status='completed', output_path=? WHERE id=?", (output_path, job_id))
conn.commit()
except Exception as e:
c.execute("UPDATE jobs SET status='failed' WHERE id=?", (job_id,))
conn.commit()
print(f"Error processing job {job_id}: {e}")
# Gradio interface
def submit_images(files):
if not files:
return "No files uploaded."
for file in files:
c.execute("INSERT INTO jobs (status) VALUES ('submitted')")
job_id = c.lastrowid
conn.commit()
image_path = f'/tmp/inputs/input_{job_id}.jpg'
os.makedirs('/tmp/inputs', exist_ok=True)
shutil.copy(file.name, image_path)
c.execute("UPDATE jobs SET image_path=? WHERE id=?", (image_path, job_id))
conn.commit()
threading.Thread(target=process_job, args=(job_id,), daemon=True).start()
return "Jobs submitted. Check the status tab."
def get_status():
c.execute("SELECT id, image_path, status, output_path FROM jobs")
return c.fetchall()
with gr.Blocks(title="TRELLIS 3D Generator") as demo:
with gr.Tab("Upload"):
files_input = gr.File(file_count="multiple", label="Upload Images (JPG/PNG)")
submit_btn = gr.Button("Submit")
output_msg = gr.Textbox(label="Message")
submit_btn.click(fn=submit_images, inputs=files_input, outputs=output_msg)
with gr.Tab("Status"):
status_table = gr.DataFrame(
headers=["ID", "Image Path", "Status", "Output Path"],
label="Job Status"
)
refresh_btn = gr.Button("Refresh")
refresh_btn.click(fn=get_status, inputs=None, outputs=status_table)
demo.launch() |