SeedOfEvil commited on
Commit
f97ea96
·
verified ·
1 Parent(s): e6033b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -26
app.py CHANGED
@@ -1,36 +1,27 @@
1
  import gradio as gr
2
- import spaces # Import ZeroGPU's helper module
3
  from transformers import pipeline
4
- import torch
5
 
6
- # Global generator variable; load lazily.
7
- generator = None
 
 
8
 
9
- def get_generator():
10
- global generator
11
- if generator is None:
12
- try:
13
- # If GPU is available, load on GPU (device=0)
14
- if torch.cuda.is_available():
15
- generator = pipeline("text-generation", model="EleutherAI/gpt-j-6B", device=0)
16
- else:
17
- generator = pipeline("text-generation", model="EleutherAI/gpt-j-6B", device=-1)
18
- except Exception as e:
19
- print("Error loading model on GPU, falling back to CPU:", e)
20
- generator = pipeline("text-generation", model="EleutherAI/gpt-j-6B", device=-1)
21
- return generator
22
-
23
- @spaces.GPU # This decorator ensures ZeroGPU allocates a GPU when the function is called.
24
  def expand_prompt(prompt, num_variants=5, max_length=100):
25
  """
26
- Given a basic prompt, generate `num_variants` expanded prompts using GPT-J-6B.
27
- The GPU is only engaged during this function call.
28
  """
29
- gen = get_generator()
30
- outputs = gen(prompt, max_length=max_length, num_return_sequences=num_variants, do_sample=True)
 
 
 
31
  expanded = [out["generated_text"].strip() for out in outputs]
32
  return "\n\n".join(expanded)
33
 
 
34
  iface = gr.Interface(
35
  fn=expand_prompt,
36
  inputs=gr.Textbox(lines=2, placeholder="Enter your basic prompt here...", label="Basic Prompt"),
@@ -38,9 +29,8 @@ iface = gr.Interface(
38
  title="Prompt Expansion Generator",
39
  description=(
40
  "Enter a basic prompt and receive 5 creative, expanded prompt variants. "
41
- "This tool leverages the EleutherAI/gpt-j-6B model on an A100 GPU via ZeroGPU. "
42
- "The GPU is only allocated when a prompt is submitted, ensuring proper ZeroGPU initialization. "
43
- "Simply copy the output for use with your downstream image-generation pipeline."
44
  )
45
  )
46
 
 
1
  import gradio as gr
2
+ import spaces # ZeroGPU helper module
3
  from transformers import pipeline
 
4
 
5
+ # Preload the text-generation model on CPU at startup.
6
+ # Model: EleutherAI/gpt-j-6B (https://huggingface.co/EleutherAI/gpt-j-6B)
7
+ # We load on CPU (device=-1) so that initialization is done before the GUI is up.
8
+ generator = pipeline("text-generation", model="EleutherAI/gpt-j-6B", device=-1)
9
 
10
+ @spaces.GPU # This decorator ensures ZeroGPU allocates a GPU only during this function call.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def expand_prompt(prompt, num_variants=5, max_length=100):
12
  """
13
+ Given a basic prompt, generate `num_variants` expanded prompt variants.
14
+ Before generation, the model is moved to GPU (A100), and after generation it's moved back to CPU.
15
  """
16
+ # Move the model to GPU for generation.
17
+ generator.model.to("cuda")
18
+ outputs = generator(prompt, max_length=max_length, num_return_sequences=num_variants, do_sample=True)
19
+ # Move the model back to CPU after generation.
20
+ generator.model.to("cpu")
21
  expanded = [out["generated_text"].strip() for out in outputs]
22
  return "\n\n".join(expanded)
23
 
24
+ # Create a Gradio Interface
25
  iface = gr.Interface(
26
  fn=expand_prompt,
27
  inputs=gr.Textbox(lines=2, placeholder="Enter your basic prompt here...", label="Basic Prompt"),
 
29
  title="Prompt Expansion Generator",
30
  description=(
31
  "Enter a basic prompt and receive 5 creative, expanded prompt variants. "
32
+ "The model (EleutherAI/gpt-j-6B) is preloaded on CPU at startup and then moved to GPU (via ZeroGPU) only "
33
+ "when a prompt is submitted. Simply copy the output for use with your downstream image-generation pipeline."
 
34
  )
35
  )
36