Thouph commited on
Commit
6675f35
·
verified ·
1 Parent(s): 4defcef

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -0
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import Qwen2ForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+
6
+
7
+ torch.set_grad_enabled(False)
8
+ model = Qwen2ForCausalLM.from_pretrained("Thouph/tag2prompt-qwen2-0.5b-v0.1")
9
+ model.generation_config.max_new_tokens = None
10
+ """
11
+ Otherwise you will get this warning
12
+ Both `max_new_tokens` (=2048) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
13
+ """
14
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
15
+
16
+ def post(
17
+ input_text,
18
+ temperature,
19
+ top_p,
20
+ top_k,
21
+ output_num_words,
22
+ ):
23
+ global model, processor
24
+
25
+ prompt = f"{input_text}\n{output_num_words}\n\n"
26
+ inputs = tokenizer(
27
+ prompt,
28
+ padding="do_not_pad",
29
+ max_length=512,
30
+ truncation=True,
31
+ return_tensors="pt",
32
+ )
33
+
34
+ generate_ids = model.generate(
35
+ **inputs,
36
+ max_length=512,
37
+ do_sample=True,
38
+ temperature=temperature,
39
+ top_p=top_p,
40
+ top_k=top_k
41
+ )
42
+ generated_text = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
43
+ generated_text = generated_text[len(prompt):]
44
+ return generated_text
45
+
46
+ def main():
47
+
48
+ with gr.Blocks() as iface:
49
+
50
+ with gr.Row():
51
+ with gr.Column(scale=1):
52
+ text_input = gr.TextArea(label="Tags (in Underscore Format)",)
53
+
54
+ temperature = gr.Slider(maximum=1., value=0.5, minimum=0., label='Temperature')
55
+ top_p = gr.Slider(maximum=1., value=0.8, minimum=0.1, label='Top P')
56
+ top_k = gr.Slider(maximum=100, value=20, minimum=1, step=1, label='Top K')
57
+ output_num_words = gr.Slider(maximum=512, value=100, minimum=1, step=1, label='Output Num Words')
58
+ with gr.Column(scale=1):
59
+ with gr.Column():
60
+ caption_output = gr.Textbox(lines=1, label="Output")
61
+ caption_button = gr.Button(
62
+ value="Run tag2prompt", interactive=True, variant="primary"
63
+ )
64
+ caption_button.click(
65
+ post,
66
+ [
67
+ text_input,
68
+ temperature,
69
+ top_p,
70
+ top_k,
71
+ output_num_words
72
+ ],
73
+ [caption_output],
74
+ )
75
+
76
+ iface.launch()
77
+
78
+
79
+
80
+
81
+ if __name__ == "__main__":
82
+ main()