Spaces:
Running
Running
File size: 2,449 Bytes
cdb99b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import gradio as gr
from transformers import pipeline,AutoTokenizer,AutoModelForSeq2SeqLM
import re,torch
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_models():
try:
enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
model_checkpoint = "gokaygokay/Flux-Prompt-Enhance"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).eval().to(device=device)
enhancer_flux = pipeline('text2text-generation', model=model, tokenizer=tokenizer, repetition_penalty=1.5, device=device)
except Exception as e:
print(e)
enhancer_medium = enhancer_long = enhancer_flux = None
return enhancer_medium, enhancer_long, enhancer_flux
enhancer_medium, enhancer_long, enhancer_flux = load_models()
def enhance_prompt(input_prompt, model_choice):
if model_choice == "Medium":
result = enhancer_medium("Enhance the description: " + input_prompt)
enhanced_text = result[0]['summary_text']
pattern = r'^.*?of\s+(.*?(?:\.|$))'
match = re.match(pattern, enhanced_text, re.IGNORECASE | re.DOTALL)
if match:
remaining_text = enhanced_text[match.end():].strip()
modified_sentence = match.group(1).capitalize()
enhanced_text = modified_sentence + ' ' + remaining_text
elif model_choice == "Flux":
result = enhancer_flux("enhance prompt: " + input_prompt, max_length=256)
enhanced_text = result[0]['generated_text']
else: # Long
result = enhancer_long("Enhance the description: " + input_prompt)
enhanced_text = result[0]['summary_text']
return enhanced_text
def prompt_enhancer(character: str, series: str, general: str, model_choice: str):
characters = character.split(",") if character else []
serieses = series.split(",") if series else []
generals = general.split(",") if general else []
tags = characters + serieses + generals
cprompt = ",".join(tags) if tags else ""
output = enhance_prompt(cprompt, model_choice)
prompt = cprompt + ", " + output
return prompt, gr.update(interactive=True), gr.update(interactive=True) |