SeedOfEvil commited on
Commit
e4e5d71
·
verified ·
1 Parent(s): 204fc4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -28
app.py CHANGED
@@ -1,53 +1,51 @@
1
  import gradio as gr
2
- import spaces # ZeroGPU helper module
3
- from transformers import pipeline
 
4
 
5
- # Preload the text-generation model on CPU at startup.
6
- # We load EleutherAI/gpt-j-6B on CPU (device=-1).
7
- generator = pipeline("text-generation", model="EleutherAI/gpt-j-6B", device=-1)
 
8
 
9
- @spaces.GPU # ZeroGPU will allocate the GPU only during this function call.
10
  def expand_prompt(prompt, num_variants=5, max_length=100):
11
  """
12
- Given a basic prompt, generate `num_variants` expanded prompts using GPT-J-6B.
13
- This function explicitly tokenizes the input with truncation (strategy 'longest_first'),
14
- moves the input to GPU, generates output using the GPU, and then moves the model back to CPU.
15
  """
16
- # Move model to GPU for generation.
17
- generator.model.to("cuda")
18
 
19
- # Explicitly tokenize the input with truncation.
20
- inputs = generator.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_length)
21
- # Move inputs to GPU.
22
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
23
 
24
- # Generate text, explicitly setting pad_token_id to eos_token_id.
25
- outputs = generator.model.generate(
26
  **inputs,
27
  max_length=max_length,
28
- num_return_sequences=num_variants,
29
  do_sample=True,
30
- pad_token_id=generator.tokenizer.eos_token_id
 
31
  )
32
 
33
- # Decode outputs.
34
- expanded = [generator.tokenizer.decode(output, skip_special_tokens=True).strip() for output in outputs]
35
 
36
- # Move model back to CPU.
37
- generator.model.to("cpu")
38
 
39
- return "\n\n".join(expanded)
40
 
 
41
  iface = gr.Interface(
42
  fn=expand_prompt,
43
  inputs=gr.Textbox(lines=2, placeholder="Enter your basic prompt here...", label="Basic Prompt"),
44
  outputs=gr.Textbox(lines=10, label="Expanded Prompts"),
45
  title="Prompt Expansion Generator",
46
  description=(
47
- "Enter a basic prompt to receive 5 creative, expanded prompt variants. "
48
- "The model (EleutherAI/gpt-j-6B) is preloaded on CPU at startup and moved to GPU (via ZeroGPU) for generation. "
49
- "Input is tokenized with truncation enabled. Once generation is complete, the model is moved back to CPU. "
50
- "Simply copy the output for use in your downstream image-generation pipeline."
51
  )
52
  )
53
 
 
1
  import gradio as gr
2
+ import spaces
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import torch
5
 
6
+ # Load the MagicPrompt-Stable-Diffusion model and tokenizer
7
+ model_name = "Gustavosta/MagicPrompt-Stable-Diffusion"
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForCausalLM.from_pretrained(model_name).to("cpu")
10
 
11
+ @spaces.GPU
12
  def expand_prompt(prompt, num_variants=5, max_length=100):
13
  """
14
+ Generate expanded prompts using a specialized model fine-tuned for Stable Diffusion.
 
 
15
  """
16
+ # Move model to GPU
17
+ model.to("cuda")
18
 
19
+ # Tokenize input prompt
20
+ inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
 
 
21
 
22
+ # Generate multiple prompt variants
23
+ outputs = model.generate(
24
  **inputs,
25
  max_length=max_length,
 
26
  do_sample=True,
27
+ num_return_sequences=num_variants,
28
+ pad_token_id=tokenizer.eos_token_id
29
  )
30
 
31
+ # Decode generated prompts
32
+ expanded_prompts = [tokenizer.decode(output, skip_special_tokens=True).strip() for output in outputs]
33
 
34
+ # Move model back to CPU
35
+ model.to("cpu")
36
 
37
+ return "\n\n".join(expanded_prompts)
38
 
39
+ # Create a Gradio Interface
40
  iface = gr.Interface(
41
  fn=expand_prompt,
42
  inputs=gr.Textbox(lines=2, placeholder="Enter your basic prompt here...", label="Basic Prompt"),
43
  outputs=gr.Textbox(lines=10, label="Expanded Prompts"),
44
  title="Prompt Expansion Generator",
45
  description=(
46
+ "Enter a basic prompt and receive multiple expanded prompt variants optimized for Stable Diffusion. "
47
+ "This tool uses a specialized model fine-tuned on Stable Diffusion prompts. "
48
+ "Simply copy the output for use with your image-generation pipeline."
 
49
  )
50
  )
51