File size: 2,449 Bytes
cdb99b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import gradio as gr
from transformers import pipeline,AutoTokenizer,AutoModelForSeq2SeqLM
import re,torch

device = "cuda" if torch.cuda.is_available() else "cpu"

def load_models():
    try:
        enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
        enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
        model_checkpoint = "gokaygokay/Flux-Prompt-Enhance"
        tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).eval().to(device=device)
        enhancer_flux = pipeline('text2text-generation', model=model, tokenizer=tokenizer, repetition_penalty=1.5, device=device)
    except Exception as e:
        print(e)
        enhancer_medium = enhancer_long = enhancer_flux = None
    return enhancer_medium, enhancer_long, enhancer_flux

enhancer_medium, enhancer_long, enhancer_flux = load_models()

def enhance_prompt(input_prompt, model_choice):
    if model_choice == "Medium":
        result = enhancer_medium("Enhance the description: " + input_prompt)
        enhanced_text = result[0]['summary_text']
        
        pattern = r'^.*?of\s+(.*?(?:\.|$))'
        match = re.match(pattern, enhanced_text, re.IGNORECASE | re.DOTALL)
        
        if match:
            remaining_text = enhanced_text[match.end():].strip()
            modified_sentence = match.group(1).capitalize()
            enhanced_text = modified_sentence + ' ' + remaining_text
    elif model_choice == "Flux":
        result = enhancer_flux("enhance prompt: " + input_prompt, max_length=256)
        enhanced_text = result[0]['generated_text']
    else:  # Long
        result = enhancer_long("Enhance the description: " + input_prompt)
        enhanced_text = result[0]['summary_text']
    
    return enhanced_text

def prompt_enhancer(character: str, series: str, general: str, model_choice: str):
    characters = character.split(",") if character else []
    serieses = series.split(",") if series else []
    generals = general.split(",") if general else []
    tags = characters + serieses + generals
    cprompt = ",".join(tags) if tags else ""

    output = enhance_prompt(cprompt, model_choice)
    prompt = cprompt + ", " + output

    return prompt, gr.update(interactive=True), gr.update(interactive=True)