hsuwill000 commited on
Commit
604dba1
·
verified ·
1 Parent(s): 9398cd5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -17
app.py CHANGED
@@ -4,33 +4,31 @@ from PIL import Image
4
  import os
5
  import time
6
  import traceback
 
7
 
 
8
  api_key = os.getenv('MY_API_KEY')
9
 
 
10
  repos = [
11
  "hsuwill000/LCM_SoteMix_OpenVINO_CPU_Space_TAESD",
12
  "hsuwill000/LCM_SoteMix_OpenVINO_CPU_Space_TAESD_0"
13
  ]
 
14
  class CustomClient(Client):
15
  def __init__(self, *args, timeout=30, **kwargs):
16
  super().__init__(*args, **kwargs)
17
  self.timeout = timeout
18
 
19
  def _request(self, method, url, **kwargs):
20
- # 设置 timeout 参数
21
  kwargs['timeout'] = self.timeout
22
  return super()._request(method, url, **kwargs)
23
-
24
  # Counter for image filenames to avoid overwriting
25
  count = 0
26
- repo_index = 0 # This will keep track of the current repository
27
 
28
- # Gradio Interface Function to handle image generation
29
- def infer_gradio(prompt: str):
30
- global count, repo_index
31
- # Create a Client instance to communicate with the Hugging Face space
32
- client = CustomClient(repos[repo_index], hf_token=api_key,timeout=300)
33
-
34
  # Prepare the inputs for the prediction
35
  inputs = {
36
  "prompt": prompt,
@@ -39,7 +37,7 @@ def infer_gradio(prompt: str):
39
 
40
  try:
41
  # Send the request to the model and receive the image
42
- result = client.predict(inputs, api_name="/infer")
43
 
44
  # Open the resulting image
45
  image = Image.open(result)
@@ -54,9 +52,6 @@ def infer_gradio(prompt: str):
54
  image.save(filename)
55
  print(f"Saved image as {filename}")
56
 
57
- # Increment the repo_index to choose the next repository in the list
58
- repo_index = (repo_index + 1) % len(repos) # Cycle through repos list
59
-
60
  # Return the image to be displayed in Gradio
61
  return image
62
 
@@ -67,12 +62,25 @@ def infer_gradio(prompt: str):
67
  traceback.print_exc() # Print stack trace for debugging
68
  return None # Return nothing if an error occurs
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  # Define Gradio Interface
71
  with gr.Blocks() as demo:
72
  with gr.Row(): # Use a Row to place the prompt input and the button side by side
73
  prompt_input = gr.Textbox(
74
  label="Enter Your Prompt",
75
- show_label = "False",
76
  placeholder="Type your prompt for image generation here",
77
  lines=1, # Set the input to be only one line tall
78
  interactive=True # Allow user to interact with the textbox
@@ -81,10 +89,11 @@ with gr.Blocks() as demo:
81
  # Change the button text to "RUN:" and align it with the prompt input
82
  run_button = gr.Button("RUN")
83
 
84
- # Output image display area
85
- output_image = gr.Image(label="Generated Image")
86
 
87
  # Connecting the button click to the image generation function
88
- run_button.click(infer_gradio, inputs=prompt_input, outputs=output_image)
89
 
 
90
  demo.launch()
 
4
  import os
5
  import time
6
  import traceback
7
+ import asyncio
8
 
9
+ # Your Hugging Face API key (ensure this is set in your environment or replace directly)
10
  api_key = os.getenv('MY_API_KEY')
11
 
12
+ # List of repos (private spaces)
13
  repos = [
14
  "hsuwill000/LCM_SoteMix_OpenVINO_CPU_Space_TAESD",
15
  "hsuwill000/LCM_SoteMix_OpenVINO_CPU_Space_TAESD_0"
16
  ]
17
+
18
  class CustomClient(Client):
19
  def __init__(self, *args, timeout=30, **kwargs):
20
  super().__init__(*args, **kwargs)
21
  self.timeout = timeout
22
 
23
  def _request(self, method, url, **kwargs):
 
24
  kwargs['timeout'] = self.timeout
25
  return super()._request(method, url, **kwargs)
26
+
27
  # Counter for image filenames to avoid overwriting
28
  count = 0
 
29
 
30
+ async def infer_single_gradio(client, prompt):
31
+ global count
 
 
 
 
32
  # Prepare the inputs for the prediction
33
  inputs = {
34
  "prompt": prompt,
 
37
 
38
  try:
39
  # Send the request to the model and receive the image
40
+ result = await asyncio.to_thread(client.predict, inputs, api_name="/infer")
41
 
42
  # Open the resulting image
43
  image = Image.open(result)
 
52
  image.save(filename)
53
  print(f"Saved image as {filename}")
54
 
 
 
 
55
  # Return the image to be displayed in Gradio
56
  return image
57
 
 
62
  traceback.print_exc() # Print stack trace for debugging
63
  return None # Return nothing if an error occurs
64
 
65
+ async def infer_gradio(prompt: str):
66
+ # Create a list of tasks (one for each repo)
67
+ tasks = []
68
+ for repo in repos:
69
+ # Create a CustomClient instance for each repo
70
+ client = CustomClient(repo, hf_token=api_key, timeout=300)
71
+ task = infer_single_gradio(client, prompt)
72
+ tasks.append(task)
73
+
74
+ # Run all tasks concurrently (i.e., generate images from all repos)
75
+ results = await asyncio.gather(*tasks)
76
+ return results # Return all the images as a list
77
+
78
  # Define Gradio Interface
79
  with gr.Blocks() as demo:
80
  with gr.Row(): # Use a Row to place the prompt input and the button side by side
81
  prompt_input = gr.Textbox(
82
  label="Enter Your Prompt",
83
+ show_label="False",
84
  placeholder="Type your prompt for image generation here",
85
  lines=1, # Set the input to be only one line tall
86
  interactive=True # Allow user to interact with the textbox
 
89
  # Change the button text to "RUN:" and align it with the prompt input
90
  run_button = gr.Button("RUN")
91
 
92
+ # Output image display area (will show multiple images)
93
+ output_images = gr.Gallery(label="Generated Images", elem_id="gallery", show_label=False)
94
 
95
  # Connecting the button click to the image generation function
96
+ run_button.click(infer_gradio, inputs=prompt_input, outputs=output_images)
97
 
98
+ # Launch Gradio app
99
  demo.launch()