|
import gradio as gr |
|
from gradio_client import Client |
|
from PIL import Image |
|
import os |
|
import time |
|
import traceback |
|
import asyncio |
|
|
|
|
|
api_key = os.getenv('MY_API_KEY') |
|
|
|
|
|
repos = [ |
|
"hsuwill000/LCM_SoteMix_OpenVINO_CPU_Space_TAESD", |
|
"hsuwill000/LCM_SoteMix_OpenVINO_CPU_Space_TAESD_0" |
|
] |
|
|
|
class CustomClient(Client): |
|
def __init__(self, *args, timeout=30, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.timeout = timeout |
|
|
|
def _request(self, method, url, **kwargs): |
|
kwargs['timeout'] = self.timeout |
|
return super()._request(method, url, **kwargs) |
|
|
|
|
|
count = 0 |
|
|
|
async def infer_single_gradio(client, prompt): |
|
global count |
|
|
|
inputs = { |
|
"prompt": prompt, |
|
"num_inference_steps": 10 |
|
} |
|
|
|
try: |
|
|
|
result = await asyncio.to_thread(client.predict, inputs, api_name="/infer") |
|
|
|
|
|
image = Image.open(result) |
|
|
|
|
|
filename = f"img_{count:08d}.jpg" |
|
while os.path.exists(filename): |
|
count += 1 |
|
filename = f"img_{count:08d}.jpg" |
|
|
|
|
|
image.save(filename) |
|
print(f"Saved image as {filename}") |
|
|
|
|
|
return image |
|
|
|
except Exception as e: |
|
|
|
print(f"An exception occurred: {str(e)}") |
|
print("Stack trace:") |
|
traceback.print_exc() |
|
return None |
|
|
|
async def infer_gradio(prompt: str): |
|
|
|
tasks = [] |
|
for repo in repos: |
|
|
|
client = CustomClient(repo, hf_token=api_key, timeout=300) |
|
task = infer_single_gradio(client, prompt) |
|
tasks.append(task) |
|
|
|
|
|
results = await asyncio.gather(*tasks) |
|
return results |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
prompt_input = gr.Textbox( |
|
label="Enter Your Prompt", |
|
show_label="False", |
|
placeholder="Type your prompt for image generation here", |
|
lines=1, |
|
interactive=True |
|
) |
|
|
|
|
|
run_button = gr.Button("RUN") |
|
|
|
|
|
output_images = gr.Gallery(label="Generated Images", elem_id="gallery", show_label=False) |
|
|
|
|
|
run_button.click(infer_gradio, inputs=prompt_input, outputs=output_images) |
|
|
|
|
|
demo.launch() |
|
|