File size: 4,234 Bytes
0dfc187
 
 
7a13824
c248483
6298bb9
604dba1
5e09053
c20305d
604dba1
1c5bbef
5e09053
1c5bbef
604dba1
a5537f1
3b8336d
 
 
 
 
acc2a99
604dba1
9398cd5
 
 
 
0dfc187
9398cd5
 
 
604dba1
0dfc187
 
6298bb9
604dba1
 
0dfc187
 
 
3b8336d
0dfc187
 
 
a5537f1
604dba1
0dfc187
6298bb9
 
0dfc187
 
 
 
 
 
 
 
 
 
 
a5537f1
0dfc187
 
 
 
 
 
 
 
 
06513b0
fbed3be
 
06513b0
 
 
 
604dba1
 
 
5e09053
2c42c67
 
5e09053
 
a244123
2c42c67
604dba1
 
 
 
 
 
 
06513b0
 
 
 
 
 
 
604dba1
 
0dfc187
d7f02ec
 
 
 
604dba1
d7f02ec
 
 
 
 
a5537f1
7e241d8
5469b0c
604dba1
 
5469b0c
 
86d4979
0dfc187
604dba1
7b9aa9d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
from gradio_client import Client
from PIL import Image
import os
import time
import traceback
import asyncio
from huggingface_hub import HfApi

# Your Hugging Face API key (ensure this is set in your environment or replace directly)
api_key = os.getenv('MY_API_KEY')
api = HfApi(token=api_key)

# List of repos (private spaces)
repos = [
    "hsuwill000/Fluently-v4-LCM-openvino_0",
    "hsuwill000/Fluently-v4-LCM-openvino_1",
    "hsuwill000/Fluently-v4-LCM-openvino_2",
    "hsuwill000/Fluently-v4-LCM-openvino_3",
    "hsuwill000/Fluently-v4-LCM-openvino_4",
]

class CustomClient(Client):
    def __init__(self, *args, timeout=30, **kwargs):
        super().__init__(*args, **kwargs)
        self.timeout = timeout

    def _request(self, method, url, **kwargs):
        kwargs['timeout'] = self.timeout
        return super()._request(method, url, **kwargs)

# Counter for image filenames to avoid overwriting
count = 0

async def infer_single_gradio(client, prompt):
    global count
    # Prepare the inputs for the prediction
    inputs = {
        "prompt": prompt,
        #"num_inference_steps": 10  # Number of inference steps for the model
    }

    try:
        # Send the request to the model and receive the image
        result = await asyncio.to_thread(client.predict, inputs, api_name="/infer")
        
        # Open the resulting image
        image = Image.open(result)
        
        # Create a unique filename to save the image
        filename = f"img_{count:08d}.jpg"
        while os.path.exists(filename):
            count += 1
            filename = f"img_{count:08d}.jpg"
        
        # Save the image locally
        image.save(filename)
        print(f"Saved image as {filename}")
        
        # Return the image to be displayed in Gradio
        return image
    
    except Exception as e:
        # Handle any errors that occur
        print(f"An exception occurred: {str(e)}")
        print("Stack trace:")
        traceback.print_exc()  # Print stack trace for debugging
        return None  # Return nothing if an error occurs

async def infer_gradio(prompt: str, output_images: gr.Gallery):
    #can't work Clear previous images from the gallery
    #output_images.update([])

    # Record the start time
    start_time = time.time()

    # Create a list of tasks (one for each repo)
    tasks = []
    for repo in repos:
        runtime_info = api.get_space_runtime(repo)
        if runtime_info.stage == 'APP_STARTING':
            continue
        if runtime_info.stage in ['SLEEPING', 'PAUSED']:
            print(f"{repo} is now SLEEPING or PAUSED.")
            api.restart_space(repo_id=repo)
            continue   
        # Create a CustomClient instance for each repo
        client = CustomClient(repo, hf_token=api_key, timeout=300)
        task = infer_single_gradio(client, prompt)
        tasks.append(task)
    
    # Run all tasks concurrently (i.e., generate images from all repos)
    results = await asyncio.gather(*tasks)

    # Calculate the time taken for image generation
    end_time = time.time()
    time_taken = end_time - start_time
    print(f"Time taken to generate the image(s): {time_taken:.2f} seconds")

    # Return the results (images)
    return results  # Return all the images as a list

# Define Gradio Interface
with gr.Blocks() as demo:
    with gr.Row():  # Use a Row to place the prompt input and the button side by side
        prompt_input = gr.Textbox(
            label="Enter Your Prompt", 
            show_label="False",
            placeholder="Type your prompt for image generation here",
            lines=1,  # Set the input to be only one line tall
            interactive=True  # Allow user to interact with the textbox
        )
        
        # Change the button text to "RUN:" and align it with the prompt input
        run_button = gr.Button("RUN")
    
    # Output image display area (will show multiple images)
    output_images = gr.Gallery(label="Generated Images", elem_id="gallery", show_label=False)
    
    # Connecting the button click to the image generation function
    run_button.click(infer_gradio, inputs=[prompt_input], outputs=output_images)

# Launch Gradio app
demo.launch()