File size: 4,133 Bytes
24a4bd7
 
 
 
 
 
03e7772
8fd199a
9549eae
03e7772
8fd199a
24a4bd7
 
 
 
 
03e7772
 
24a4bd7
8fd199a
 
03e7772
24a4bd7
 
03e7772
24a4bd7
 
03e7772
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fd199a
 
03e7772
 
 
 
24a4bd7
03e7772
24a4bd7
 
 
03e7772
24a4bd7
03e7772
24a4bd7
 
 
 
 
 
 
8fd199a
 
24a4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import sqlite3
import threading
import time
import os
import shutil
from gradio_client import Client, handle_file
# import spaces

# Database setup
conn = sqlite3.connect('/tmp/jobs.db', check_same_thread=False)
c = conn.cursor()
c.execute('''CREATE TABLE IF NOT EXISTS jobs
             (id INTEGER PRIMARY KEY, image_path TEXT, job_id TEXT, status TEXT, output_path TEXT)''')
conn.commit()

# TRELLIS API client
trellis_client = Client("jkorstad/TRELLIS")


# @spaces.GPU
# Processing logic with three-step TRELLIS workflow
def process_job(job_id):
    try:
        # Get the uploaded image path
        c.execute("SELECT image_path FROM jobs WHERE id=?", (job_id,))
        image_path = c.fetchone()[0]
        
        # Step 1: Preprocess the image
        c.execute("UPDATE jobs SET status='preprocessing' WHERE id=?", (job_id,))
        conn.commit()
        preprocessed_image = trellis_client.predict(
            image=handle_file(image_path),
            api_name="/preprocess_image"
        )
        
        # Step 2: Generate 3D asset
        c.execute("UPDATE jobs SET status='generating' WHERE id=?", (job_id,))
        conn.commit()
        time.sleep(5)  # Wait between steps; adjust based on observed timing
        result_3d = trellis_client.predict(
            image=handle_file(preprocessed_image),  # Use preprocessed image
            multiimages=[],
            seed=0,  # Default; could make configurable
            ss_guidance_strength=7.5,
            ss_sampling_steps=12,
            slat_guidance_strength=3,
            slat_sampling_steps=12,
            multiimage_algo="stochastic",
            api_name="/image_to_3d"
        )
        video_path = result_3d['video']  # Extract video filepath from dict
        
        # Step 3: Extract GLB
        c.execute("UPDATE jobs SET status='extracting' WHERE id=?", (job_id,))
        conn.commit()
        time.sleep(10)  # Wait for 3D processing; adjust as needed
        glb_result = trellis_client.predict(
            mesh_simplify=0.95,
            texture_size=1024,
            api_name="/extract_glb"
        )
        glb_path = glb_result[0]  # First element is the GLB filepath
        
        # Move GLB to persistent storage
        output_path = f'/tmp/outputs/result_{job_id}.glb'
        os.makedirs('/tmp/outputs', exist_ok=True)
        shutil.move(glb_path, output_path)
        
        # Update job status
        c.execute("UPDATE jobs SET status='completed', output_path=? WHERE id=?", (output_path, job_id))
        conn.commit()
        
    except Exception as e:
        c.execute("UPDATE jobs SET status='failed' WHERE id=?", (job_id,))
        conn.commit()
        print(f"Error processing job {job_id}: {e}")

# Gradio interface
def submit_images(files):
    if not files:
        return "No files uploaded."
    for file in files:
        c.execute("INSERT INTO jobs (status) VALUES ('submitted')")
        job_id = c.lastrowid
        conn.commit()
        image_path = f'/tmp/inputs/input_{job_id}.jpg'
        os.makedirs('/tmp/inputs', exist_ok=True)
        shutil.copy(file.name, image_path)
        c.execute("UPDATE jobs SET image_path=? WHERE id=?", (image_path, job_id))
        conn.commit()
        threading.Thread(target=process_job, args=(job_id,), daemon=True).start()
    return "Jobs submitted. Check the status tab."

def get_status():
    c.execute("SELECT id, image_path, status, output_path FROM jobs")
    return c.fetchall()

with gr.Blocks(title="TRELLIS 3D Generator") as demo:
    with gr.Tab("Upload"):
        files_input = gr.File(file_count="multiple", label="Upload Images (JPG/PNG)")
        submit_btn = gr.Button("Submit")
        output_msg = gr.Textbox(label="Message")
        submit_btn.click(fn=submit_images, inputs=files_input, outputs=output_msg)
    with gr.Tab("Status"):
        status_table = gr.DataFrame(
            headers=["ID", "Image Path", "Status", "Output Path"],
            label="Job Status"
        )
        refresh_btn = gr.Button("Refresh")
        refresh_btn.click(fn=get_status, inputs=None, outputs=status_table)

demo.launch()