Yanqing0327 commited on
Commit
e7b2b9e
·
verified ·
1 Parent(s): 2de12eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -106
app.py CHANGED
@@ -1,128 +1,48 @@
1
  import gradio as gr
2
- import torch
3
- from transformers.generation.streamers import TextIteratorStreamer
4
- from PIL import Image
5
  import requests
6
- from io import BytesIO
7
- from threading import Thread
8
- import os
9
-
10
- # 导入 LLaVA 相关模块
11
- from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
12
- from llava.conversation import conv_templates, SeparatorStyle
13
- from llava.model.builder import load_pretrained_model
14
- from llava.utils import disable_torch_init
15
- from llava.mm_utils import tokenizer_image_token
16
-
17
- # **确保 Hugging Face 缓存目录正确**
18
- os.environ["HUGGINGFACE_HUB_CACHE"] = os.getcwd() + "/weights"
19
-
20
- # **加载 LLaVA-1.5-13B**
21
- disable_torch_init()
22
- model_id = "Yanqing0327/LLaVA-project" # 替换为你的 Hugging Face 模型仓库
23
- tokenizer, model, image_processor, context_len = load_pretrained_model(
24
- model_id, model_name="llava-v1.5-13b", model_base=None, load_8bit=False, load_4bit=False, device_map="auto"
25
- )
26
-
27
- device = "cuda" if torch.cuda.is_available() else "cpu"
28
- model = model.to(device)
29
-
30
- def load_image(image_file):
31
- """确保 image 是 `PIL.Image`"""
32
- if isinstance(image_file, Image.Image):
33
- return image_file.convert("RGB") # 直接返回 `PIL.Image`
34
-
35
- elif isinstance(image_file, str) and (image_file.startswith('http') or image_file.startswith('https')):
36
- response = requests.get(image_file)
37
- return Image.open(BytesIO(response.content)).convert('RGB')
38
 
39
- else: # 这里如果 `image_file` 是路径
40
- return Image.open(image_file).convert("RGB")
41
 
 
 
 
 
 
42
 
43
- def llava_infer(image, text, temperature, top_p, max_tokens):
44
- """LLaVA 模型推理"""
45
  if image is None or text.strip() == "":
46
  return "请提供图片和文本输入"
47
 
48
- # 预处理图像
49
- image_data = load_image(image)
50
- image_tensor = image_processor.preprocess(image_data, return_tensors='pt')['pixel_values']
51
-
52
- # **确保数据在正确的设备上**
53
- image_tensor = image_tensor.to(device)
54
- if torch.cuda.is_available():
55
- image_tensor = image_tensor.half() # GPU: 用 float16
56
- else:
57
- image_tensor = image_tensor.float() # CPU: 用 float32
58
-
59
- # **处理对话**
60
- conv_mode = "llava_v1"
61
- conv = conv_templates[conv_mode].copy()
62
-
63
- # 生成输入文本,添加特殊 token
64
- inp = DEFAULT_IMAGE_TOKEN + '\n' + text
65
- conv.append_message(conv.roles[0], inp)
66
- conv.append_message(conv.roles[1], None)
67
-
68
- prompt = conv.get_prompt()
69
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device)
70
- stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
71
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=20.0)
72
-
73
- # **执行推理**
74
- with torch.inference_mode():
75
- thread = Thread(target=model.generate, kwargs=dict(
76
- inputs=input_ids,
77
- images=image_tensor,
78
- do_sample=True,
79
- temperature=temperature,
80
- top_p=top_p,
81
- max_new_tokens=max_tokens,
82
- streamer=streamer,
83
- use_cache=True
84
- ))
85
- thread.start()
86
-
87
- response = ""
88
- prepend_space = False
89
- for new_text in streamer:
90
- if new_text == " ":
91
- prepend_space = True
92
- continue
93
- if new_text.endswith(stop_str):
94
- new_text = new_text[:-len(stop_str)].strip()
95
- prepend_space = False
96
- elif prepend_space:
97
- new_text = " " + new_text
98
- prepend_space = False
99
- response += new_text
100
- if prepend_space:
101
- response += " "
102
-
103
- thread.join()
104
-
105
- return response
106
 
 
 
 
 
 
 
107
 
108
- # **创建 Gradio Web 界面**
109
- with gr.Blocks(title="LLaVA 1.5-13B Web UI") as demo:
110
- gr.Markdown("# 🌋 LLaVA-1.5-13B Web Interface")
111
- gr.Markdown("上传图片并输入文本,LLaVA 将返回回答")
112
 
113
  with gr.Row():
114
  with gr.Column(scale=3):
115
  image_input = gr.Image(type="pil", label="上传图片")
116
  text_input = gr.Textbox(placeholder="输入文本...", label="输入文本")
117
- temperature = gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Temperature")
118
- top_p = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Top P")
119
- max_tokens = gr.Slider(10, 1024, value=512, step=10, label="Max Tokens")
120
  submit_button = gr.Button("提交")
121
 
122
  with gr.Column(scale=7):
123
  chatbot_output = gr.Textbox(label="LLaVA 输出", interactive=False)
124
 
125
- submit_button.click(fn=llava_infer, inputs=[image_input, text_input, temperature, top_p, max_tokens], outputs=chatbot_output)
126
 
127
- # **启动 Gradio Web 界面**
128
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
1
  import gradio as gr
 
 
 
2
  import requests
3
+ import base64
4
+ from PIL import Image
5
+ import io
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ # **本地 GPU 服务器 API**
8
+ LOCAL_SERVER_URL = "http://169.233.7.2:5000/infer"
9
 
10
+ def image_to_base64(image):
11
+ """PIL Image -> Base64"""
12
+ buffer = io.BytesIO()
13
+ image.save(buffer, format="PNG")
14
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
15
 
16
+ def llava_infer(image, text):
17
+ """把用户输入的图片+文本发送到本地服务器"""
18
  if image is None or text.strip() == "":
19
  return "请提供图片和文本输入"
20
 
21
+ image_base64 = image_to_base64(image)
22
+ payload = {"image": image_base64, "text": text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ try:
25
+ response = requests.post(LOCAL_SERVER_URL, json=payload)
26
+ response_data = response.json()
27
+ return response_data["response"]
28
+ except Exception as e:
29
+ return f"服务器错误: {e}"
30
 
31
+ # **Gradio Web UI**
32
+ with gr.Blocks(title="LLaVA Remote Web UI") as demo:
33
+ gr.Markdown("# 🌋 LLaVA Web Interface (Remote Inference)")
34
+ gr.Markdown("上传图片并输入文本,LLaVA 将在远程 GPU 服务器推理")
35
 
36
  with gr.Row():
37
  with gr.Column(scale=3):
38
  image_input = gr.Image(type="pil", label="上传图片")
39
  text_input = gr.Textbox(placeholder="输入文本...", label="输入文本")
 
 
 
40
  submit_button = gr.Button("提交")
41
 
42
  with gr.Column(scale=7):
43
  chatbot_output = gr.Textbox(label="LLaVA 输出", interactive=False)
44
 
45
+ submit_button.click(fn=llava_infer, inputs=[image_input, text_input], outputs=chatbot_output)
46
 
47
+ # **启动 Hugging Face Web UI**
48
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True)