Spaces:
Running
Running
import gradio as gr | |
from transformers import pipeline,AutoTokenizer,AutoModelForSeq2SeqLM | |
import re,torch | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def load_models(): | |
try: | |
enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device) | |
enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device) | |
model_checkpoint = "gokaygokay/Flux-Prompt-Enhance" | |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).eval().to(device=device) | |
enhancer_flux = pipeline('text2text-generation', model=model, tokenizer=tokenizer, repetition_penalty=1.5, device=device) | |
except Exception as e: | |
print(e) | |
enhancer_medium = enhancer_long = enhancer_flux = None | |
return enhancer_medium, enhancer_long, enhancer_flux | |
enhancer_medium, enhancer_long, enhancer_flux = load_models() | |
def enhance_prompt(input_prompt, model_choice): | |
if model_choice == "Medium": | |
result = enhancer_medium("Enhance the description: " + input_prompt) | |
enhanced_text = result[0]['summary_text'] | |
pattern = r'^.*?of\s+(.*?(?:\.|$))' | |
match = re.match(pattern, enhanced_text, re.IGNORECASE | re.DOTALL) | |
if match: | |
remaining_text = enhanced_text[match.end():].strip() | |
modified_sentence = match.group(1).capitalize() | |
enhanced_text = modified_sentence + ' ' + remaining_text | |
elif model_choice == "Flux": | |
result = enhancer_flux("enhance prompt: " + input_prompt, max_length=256) | |
enhanced_text = result[0]['generated_text'] | |
else: # Long | |
result = enhancer_long("Enhance the description: " + input_prompt) | |
enhanced_text = result[0]['summary_text'] | |
return enhanced_text | |
def prompt_enhancer(character: str, series: str, general: str, model_choice: str): | |
characters = character.split(",") if character else [] | |
serieses = series.split(",") if series else [] | |
generals = general.split(",") if general else [] | |
tags = characters + serieses + generals | |
cprompt = ",".join(tags) if tags else "" | |
output = enhance_prompt(cprompt, model_choice) | |
prompt = cprompt + ", " + output | |
return prompt, gr.update(interactive=True), gr.update(interactive=True) |