Yanqing0327 commited on
Commit
330b634
·
verified ·
1 Parent(s): 662f179

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -25
app.py CHANGED
@@ -1,34 +1,120 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoProcessor, LlavaForConditionalGeneration
4
  from PIL import Image
 
 
 
 
5
 
6
- # 加载模型
7
- model_id = "Yanqing0327/LLaVA-project"
8
- processor = AutoProcessor.from_pretrained(model_id)
9
- model = LlavaForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16)
 
 
10
 
11
- def llava_infer(image, text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  if image is None or text.strip() == "":
13
  return "请提供图片和文本输入"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- # 处理输入
16
- inputs = processor(text=text, images=image, return_tensors="pt").to("cuda")
17
-
18
- # 生成输出
19
- with torch.no_grad():
20
- output = model.generate(**inputs, max_new_tokens=100)
21
-
22
- result = processor.batch_decode(output, skip_special_tokens=True)[0]
23
- return result
24
-
25
- # 创建 Gradio 界面
26
- iface = gr.Interface(
27
- fn=llava_infer,
28
- inputs=[gr.Image(type="pil"), gr.Textbox(placeholder="输入文本...")],
29
- outputs="text",
30
- title="LLaVA Web UI",
31
- description="上传图片并输入文本,LLaVA 将返回回答"
32
- )
 
 
 
33
 
34
- iface.launch()
 
 
1
  import gradio as gr
2
  import torch
3
+ from transformers.generation.streamers import TextIteratorStreamer
4
  from PIL import Image
5
+ import requests
6
+ from io import BytesIO
7
+ from threading import Thread
8
+ import os
9
 
10
+ # 导入 LLaVA 相关模块
11
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
12
+ from llava.conversation import conv_templates, SeparatorStyle
13
+ from llava.model.builder import load_pretrained_model
14
+ from llava.utils import disable_torch_init
15
+ from llava.mm_utils import tokenizer_image_token
16
 
17
+ # 确保 Hugging Face 缓存目录设置正确
18
+ os.environ["HUGGINGFACE_HUB_CACHE"] = os.getcwd() + "/weights"
19
+
20
+ # **加载 LLaVA-1.5-13B**
21
+ disable_torch_init()
22
+ model_id = "liuhaotian/llava-v1.5-13b"
23
+
24
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
25
+ model_id, model_name="llava-v1.5-13b", model_base=None, load_8bit=False, load_4bit=False
26
+ )
27
+
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ model = model.to(device)
30
+
31
+
32
+ def load_image(image_file):
33
+ """加载本地图片或 URL 图片"""
34
+ if isinstance(image_file, str) and (image_file.startswith('http') or image_file.startswith('https')):
35
+ response = requests.get(image_file)
36
+ image = Image.open(BytesIO(response.content)).convert('RGB')
37
+ else:
38
+ image = Image.open(image_file).convert('RGB')
39
+ return image
40
+
41
+
42
+ def llava_infer(image, text, temperature, top_p, max_tokens):
43
+ """LLaVA 模型推理"""
44
  if image is None or text.strip() == "":
45
  return "请提供图片和文本输入"
46
+
47
+ # 预处理图像
48
+ image_data = load_image(image)
49
+ image_tensor = image_processor.preprocess(image_data, return_tensors='pt')['pixel_values'].half().to(device)
50
+
51
+ # **处理对话**
52
+ conv_mode = "llava_v1"
53
+ conv = conv_templates[conv_mode].copy()
54
+
55
+ # 生成输入文本,添加特殊 token
56
+ inp = DEFAULT_IMAGE_TOKEN + '\n' + text
57
+ conv.append_message(conv.roles[0], inp)
58
+ conv.append_message(conv.roles[1], None)
59
+
60
+ prompt = conv.get_prompt()
61
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device)
62
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
63
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=20.0)
64
+
65
+ # **执行推理**
66
+ with torch.inference_mode():
67
+ thread = Thread(target=model.generate, kwargs=dict(
68
+ inputs=input_ids,
69
+ images=image_tensor,
70
+ do_sample=True,
71
+ temperature=temperature,
72
+ top_p=top_p,
73
+ max_new_tokens=max_tokens,
74
+ streamer=streamer,
75
+ use_cache=True
76
+ ))
77
+ thread.start()
78
+
79
+ response = ""
80
+ prepend_space = False
81
+ for new_text in streamer:
82
+ if new_text == " ":
83
+ prepend_space = True
84
+ continue
85
+ if new_text.endswith(stop_str):
86
+ new_text = new_text[:-len(stop_str)].strip()
87
+ prepend_space = False
88
+ elif prepend_space:
89
+ new_text = " " + new_text
90
+ prepend_space = False
91
+ response += new_text
92
+ if prepend_space:
93
+ response += " "
94
+
95
+ thread.join()
96
 
97
+ return response
98
+
99
+
100
+ # **创建 Gradio Web 界面**
101
+ with gr.Blocks(title="LLaVA 1.5-13B Web UI") as demo:
102
+ gr.Markdown("# 🌋 LLaVA-1.5-13B Web Interface")
103
+ gr.Markdown("上传图片并输入文本,LLaVA 将返回回答")
104
+
105
+ with gr.Row():
106
+ with gr.Column(scale=3):
107
+ image_input = gr.Image(type="pil", label="上传图片")
108
+ text_input = gr.Textbox(placeholder="输入文本...", label="输入文本")
109
+ temperature = gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Temperature")
110
+ top_p = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Top P")
111
+ max_tokens = gr.Slider(10, 1024, value=512, step=10, label="Max Tokens")
112
+ submit_button = gr.Button("提交")
113
+
114
+ with gr.Column(scale=7):
115
+ chatbot_output = gr.Textbox(label="LLaVA 输出", interactive=False)
116
+
117
+ submit_button.click(fn=llava_infer, inputs=[image_input, text_input, temperature, top_p, max_tokens], outputs=chatbot_output)
118
 
119
+ # **启动 Gradio Web 界面**
120
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)