Daemontatox commited on
Commit
15245b5
·
verified ·
1 Parent(s): 543e080

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -67
app.py CHANGED
@@ -1,6 +1,5 @@
1
- from transformers import MllamaForConditionalGeneration, AutoProcessor, TextIteratorStreamer
2
  from PIL import Image
3
- import requests
4
  import torch
5
  from threading import Thread
6
  import gradio as gr
@@ -8,100 +7,137 @@ from gradio import FileData
8
  import time
9
  import spaces
10
  from unsloth import FastVisionModel
11
- ckpt ="Daemontatox/DocumentLlama"
12
- #model = MllamaForConditionalGeneration.from_pretrained(ckpt,
13
- #torch_dtype=torch.bfloat16).to("cuda")
14
- #processor = AutoProcessor.from_pretrained(ckpt)
15
  model, tokenizer = FastVisionModel.from_pretrained(
16
  ckpt,
17
- load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.
18
- use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
19
  )
20
 
21
-
 
22
 
23
  @spaces.GPU()
24
  def bot_streaming(message, history, max_new_tokens=2048):
25
-
26
  txt = message["text"]
27
- ext_buffer = f"{txt}"
28
-
29
- messages= []
30
  images = []
31
 
32
-
33
- for i, msg in enumerate(history):
34
  if isinstance(msg[0], tuple):
35
- messages.append({"role": "user", "content": [{"type": "text", "text": history[i+1][0]}, {"type": "image"}]})
36
- messages.append({"role": "assistant", "content": [{"type": "text", "text": history[i+1][1]}]})
 
 
 
 
 
 
 
 
 
37
  images.append(Image.open(msg[0][0]).convert("RGB"))
38
  elif isinstance(history[i-1], tuple) and isinstance(msg[0], str):
39
- # messages are already handled
40
  pass
41
- elif isinstance(history[i-1][0], str) and isinstance(msg[0], str): # text only turn
42
- messages.append({"role": "user", "content": [{"type": "text", "text": msg[0]}]})
43
- messages.append({"role": "assistant", "content": [{"type": "text", "text": msg[1]}]})
 
 
 
 
 
 
44
 
45
- # add current message
46
  if len(message["files"]) == 1:
47
-
48
- if isinstance(message["files"][0], str): # examples
49
  image = Image.open(message["files"][0]).convert("RGB")
50
- else: # regular input
51
  image = Image.open(message["files"][0]["path"]).convert("RGB")
52
  images.append(image)
53
- messages.append({"role": "user", "content": [{"type": "text", "text": txt}, {"type": "image"}]})
 
 
 
 
 
 
54
  else:
55
- messages.append({"role": "user", "content": [{"type": "text", "text": txt}]})
56
-
 
 
57
 
58
- texts = processor.apply_chat_template(messages, add_generation_prompt=True)
59
-
60
- if images == []:
61
- inputs = processor(text=texts, return_tensors="pt").to("cuda")
 
 
 
 
 
 
62
  else:
63
- inputs = processor(text=texts, images=images, return_tensors="pt").to("cuda")
64
- streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
 
 
 
65
 
66
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
67
- generated_text = ""
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
70
  thread.start()
71
- buffer = ""
72
 
73
- for new_text in streamer:
74
  buffer += new_text
75
- generated_text_without_prompt = buffer
76
  time.sleep(0.01)
77
  yield buffer
78
 
79
-
80
- demo = gr.ChatInterface(fn=bot_streaming, title="Document Analyzer", examples=[
81
- [{"text": "Which era does this piece belong to? Give details about the era.", "files":["./examples/rococo.jpg"]},
82
- 200],
83
- [{"text": "Where do the droughts happen according to this diagram?", "files":["./examples/weather_events.png"]},
84
- 250],
85
- [{"text": "What happens when you take out white cat from this chain?", "files":["./examples/ai2d_test.jpg"]},
86
- 250],
87
- [{"text": "How long does it take from invoice date to due date? Be short and concise.", "files":["./examples/invoice.png"]},
88
- 250],
89
- [{"text": "Where to find this monument? Can you give me other recommendations around the area?", "files":["./examples/wat_arun.jpg"]},
90
- 250],
91
  ],
92
- textbox=gr.MultimodalTextbox(),
93
- additional_inputs = [gr.Slider(
94
- minimum=10,
95
- maximum=500,
96
- value=2048,
97
- step=10,
98
- label="Maximum number of new tokens to generate",
99
- )
100
- ],
101
- cache_examples=False,
102
- description="MllM ",
103
- stop_btn="Stop Generation",
104
- fill_height=True,
105
- multimodal=True)
106
-
 
 
107
  demo.launch(debug=True)
 
1
+ from transformers import AutoTokenizer, TextStreamer
2
  from PIL import Image
 
3
  import torch
4
  from threading import Thread
5
  import gradio as gr
 
7
  import time
8
  import spaces
9
  from unsloth import FastVisionModel
10
+
11
+ # Load model and tokenizer
12
+ ckpt = "Daemontatox/DocumentLlama"
 
13
  model, tokenizer = FastVisionModel.from_pretrained(
14
  ckpt,
15
+ load_in_4bit=True,
16
+ use_gradient_checkpointing="unsloth",
17
  )
18
 
19
+ # Enable inference mode
20
+ FastVisionModel.for_inference(model)
21
 
22
  @spaces.GPU()
23
  def bot_streaming(message, history, max_new_tokens=2048):
 
24
  txt = message["text"]
25
+ messages = []
 
 
26
  images = []
27
 
28
+ # Process history
29
+ for i, msg in enumerate(history):
30
  if isinstance(msg[0], tuple):
31
+ messages.append({
32
+ "role": "user",
33
+ "content": [
34
+ {"type": "text", "text": history[i+1][0]},
35
+ {"type": "image"}
36
+ ]
37
+ })
38
+ messages.append({
39
+ "role": "assistant",
40
+ "content": [{"type": "text", "text": history[i+1][1]}]
41
+ })
42
  images.append(Image.open(msg[0][0]).convert("RGB"))
43
  elif isinstance(history[i-1], tuple) and isinstance(msg[0], str):
 
44
  pass
45
+ elif isinstance(history[i-1][0], str) and isinstance(msg[0], str):
46
+ messages.append({
47
+ "role": "user",
48
+ "content": [{"type": "text", "text": msg[0]}]
49
+ })
50
+ messages.append({
51
+ "role": "assistant",
52
+ "content": [{"type": "text", "text": msg[1]}]
53
+ })
54
 
55
+ # Handle current message
56
  if len(message["files"]) == 1:
57
+ if isinstance(message["files"][0], str): # examples
 
58
  image = Image.open(message["files"][0]).convert("RGB")
59
+ else: # regular input
60
  image = Image.open(message["files"][0]["path"]).convert("RGB")
61
  images.append(image)
62
+ messages.append({
63
+ "role": "user",
64
+ "content": [
65
+ {"type": "image"},
66
+ {"type": "text", "text": txt}
67
+ ]
68
+ })
69
  else:
70
+ messages.append({
71
+ "role": "user",
72
+ "content": [{"type": "text", "text": txt}]
73
+ })
74
 
75
+ # Prepare inputs
76
+ input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
77
+
78
+ if images:
79
+ inputs = tokenizer(
80
+ images[-1], # Use the last image
81
+ input_text,
82
+ add_special_tokens=False,
83
+ return_tensors="pt"
84
+ ).to("cuda")
85
  else:
86
+ inputs = tokenizer(
87
+ input_text,
88
+ add_special_tokens=False,
89
+ return_tensors="pt"
90
+ ).to("cuda")
91
 
92
+ # Setup streaming
93
+ text_streamer = TextStreamer(tokenizer, skip_prompt=True)
94
+ buffer = ""
95
+
96
+ def generate():
97
+ nonlocal buffer
98
+ output_ids = model.generate(
99
+ **inputs,
100
+ streamer=text_streamer,
101
+ max_new_tokens=max_new_tokens,
102
+ use_cache=True,
103
+ temperature=1.5,
104
+ min_p=0.1
105
+ )
106
 
107
+ thread = Thread(target=generate)
108
  thread.start()
 
109
 
110
+ for new_text in text_streamer:
111
  buffer += new_text
 
112
  time.sleep(0.01)
113
  yield buffer
114
 
115
+ # Setup Gradio interface
116
+ demo = gr.ChatInterface(
117
+ fn=bot_streaming,
118
+ title="Document Analyzer",
119
+ examples=[
120
+ [{"text": "Which era does this piece belong to? Give details about the era.", "files":["./examples/rococo.jpg"]}, 200],
121
+ [{"text": "Where do the droughts happen according to this diagram?", "files":["./examples/weather_events.png"]}, 250],
122
+ [{"text": "What happens when you take out white cat from this chain?", "files":["./examples/ai2d_test.jpg"]}, 250],
123
+ [{"text": "How long does it take from invoice date to due date? Be short and concise.", "files":["./examples/invoice.png"]}, 250],
124
+ [{"text": "Where to find this monument? Can you give me other recommendations around the area?", "files":["./examples/wat_arun.jpg"]}, 250],
 
 
125
  ],
126
+ textbox=gr.MultimodalTextbox(),
127
+ additional_inputs=[
128
+ gr.Slider(
129
+ minimum=10,
130
+ maximum=500,
131
+ value=2048,
132
+ step=10,
133
+ label="Maximum number of new tokens to generate",
134
+ )
135
+ ],
136
+ cache_examples=False,
137
+ description="MllM",
138
+ stop_btn="Stop Generation",
139
+ fill_height=True,
140
+ multimodal=True
141
+ )
142
+
143
  demo.launch(debug=True)