Spaces:
Running
Running
File size: 3,353 Bytes
0dfc187 7a13824 c248483 6298bb9 604dba1 c20305d 604dba1 1c5bbef 604dba1 a5537f1 1c5bbef 33179dc 604dba1 9398cd5 0dfc187 9398cd5 604dba1 0dfc187 6298bb9 604dba1 0dfc187 a5537f1 604dba1 0dfc187 6298bb9 0dfc187 a5537f1 0dfc187 604dba1 0dfc187 d7f02ec 604dba1 d7f02ec a5537f1 7e241d8 5469b0c 604dba1 5469b0c 604dba1 0dfc187 604dba1 7b9aa9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import gradio as gr
from gradio_client import Client
from PIL import Image
import os
import time
import traceback
import asyncio
# Your Hugging Face API key (ensure this is set in your environment or replace directly)
api_key = os.getenv('MY_API_KEY')
# List of repos (private spaces)
repos = [
"hsuwill000/LCM_SoteMix_OpenVINO_CPU_Space_TAESD",
"hsuwill000/LCM_SoteMix_OpenVINO_CPU_Space_TAESD_0"
]
class CustomClient(Client):
def __init__(self, *args, timeout=30, **kwargs):
super().__init__(*args, **kwargs)
self.timeout = timeout
def _request(self, method, url, **kwargs):
kwargs['timeout'] = self.timeout
return super()._request(method, url, **kwargs)
# Counter for image filenames to avoid overwriting
count = 0
async def infer_single_gradio(client, prompt):
global count
# Prepare the inputs for the prediction
inputs = {
"prompt": prompt,
"num_inference_steps": 10 # Number of inference steps for the model
}
try:
# Send the request to the model and receive the image
result = await asyncio.to_thread(client.predict, inputs, api_name="/infer")
# Open the resulting image
image = Image.open(result)
# Create a unique filename to save the image
filename = f"img_{count:08d}.jpg"
while os.path.exists(filename):
count += 1
filename = f"img_{count:08d}.jpg"
# Save the image locally
image.save(filename)
print(f"Saved image as {filename}")
# Return the image to be displayed in Gradio
return image
except Exception as e:
# Handle any errors that occur
print(f"An exception occurred: {str(e)}")
print("Stack trace:")
traceback.print_exc() # Print stack trace for debugging
return None # Return nothing if an error occurs
async def infer_gradio(prompt: str):
# Create a list of tasks (one for each repo)
tasks = []
for repo in repos:
# Create a CustomClient instance for each repo
client = CustomClient(repo, hf_token=api_key, timeout=300)
task = infer_single_gradio(client, prompt)
tasks.append(task)
# Run all tasks concurrently (i.e., generate images from all repos)
results = await asyncio.gather(*tasks)
return results # Return all the images as a list
# Define Gradio Interface
with gr.Blocks() as demo:
with gr.Row(): # Use a Row to place the prompt input and the button side by side
prompt_input = gr.Textbox(
label="Enter Your Prompt",
show_label="False",
placeholder="Type your prompt for image generation here",
lines=1, # Set the input to be only one line tall
interactive=True # Allow user to interact with the textbox
)
# Change the button text to "RUN:" and align it with the prompt input
run_button = gr.Button("RUN")
# Output image display area (will show multiple images)
output_images = gr.Gallery(label="Generated Images", elem_id="gallery", show_label=False)
# Connecting the button click to the image generation function
run_button.click(infer_gradio, inputs=prompt_input, outputs=output_images)
# Launch Gradio app
demo.launch()
|