HenryShan commited on
Commit
e1bcd6a
·
verified ·
1 Parent(s): b5f5ffc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -31
app.py CHANGED
@@ -5,26 +5,28 @@ from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
5
  from deepseek_vl.utils.io import load_pil_images
6
  from io import BytesIO
7
  from PIL import Image
 
8
 
9
  # Load the model and processor
10
  model_path = "deepseek-ai/deepseek-vl-1.3b-chat"
11
  vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
12
  tokenizer = vl_chat_processor.tokenizer
13
 
14
- # Define the function for image description (CPU version)
15
- def describe_image(image, user_question="Describe this image in great detail."):
 
16
  try:
17
- # Convert the PIL Image to a BytesIO object
18
  image_byte_arr = BytesIO()
19
- image.save(image_byte_arr, format="PNG")
20
- image_byte_arr.seek(0)
21
 
22
- # Define the conversation
23
  conversation = [
24
  {
25
  "role": "User",
26
  "content": f"<image_placeholder>{user_question}",
27
- "images": [image_byte_arr]
28
  },
29
  {
30
  "role": "Assistant",
@@ -32,35 +34,27 @@ def describe_image(image, user_question="Describe this image in great detail."):
32
  }
33
  ]
34
 
35
- # Convert byte array back to PIL image
36
- pil_images = [Image.open(BytesIO(image_byte_arr.read()))]
37
- image_byte_arr.seek(0)
38
 
39
- # Prepare inputs
40
  prepare_inputs = vl_chat_processor(
41
  conversations=conversation,
42
  images=pil_images,
43
  force_batchify=True
44
- )
45
-
46
- # Convert all tensors in prepare_inputs to float32 for CPU compatibility
47
- for key in prepare_inputs:
48
- if isinstance(prepare_inputs[key], torch.Tensor):
49
- prepare_inputs[key] = prepare_inputs[key].to(dtype=torch.float32)
50
 
51
- # Load model with CPU and float32 weights
52
- vl_gpt = AutoModelForCausalLM.from_pretrained(
53
- model_path,
54
- trust_remote_code=True
55
- ).float().eval() # Convert all weights to float32
56
 
57
- # Generate embeddings with CPU
58
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
59
 
60
- # Generate response with CPU
61
  outputs = vl_gpt.language_model.generate(
62
  inputs_embeds=inputs_embeds,
63
- attention_mask=prepare_inputs["attention_mask"],
64
  pad_token_id=tokenizer.eos_token_id,
65
  bos_token_id=tokenizer.bos_token_id,
66
  eos_token_id=tokenizer.eos_token_id,
@@ -69,36 +63,38 @@ def describe_image(image, user_question="Describe this image in great detail."):
69
  use_cache=True
70
  )
71
 
72
- # Decode the response
73
- answer = tokenizer.decode(outputs[0].tolist(), skip_special_tokens=True)
74
  return answer
75
 
76
  except Exception as e:
 
77
  return f"Error: {str(e)}"
78
 
79
  # Gradio interface
80
  def gradio_app():
81
  with gr.Blocks() as demo:
82
- gr.Markdown("# Image Description with DeepSeek VL 1.3b 🐬 (CPU Version)")
83
 
84
  with gr.Row():
85
  image_input = gr.Image(type="pil", label="Upload an Image")
86
  question_input = gr.Textbox(
87
  label="Question (optional)",
88
- placeholder="Ask a question about the image",
89
  lines=2
90
  )
91
 
92
  output_text = gr.Textbox(label="Image Description", interactive=False)
 
93
  submit_btn = gr.Button("Generate Description")
94
 
95
  submit_btn.click(
96
  fn=describe_image,
97
- inputs=[image_input, question_input],
98
  outputs=output_text
99
  )
100
 
101
  demo.launch()
102
 
103
- # Launch the app
104
  gradio_app()
 
5
  from deepseek_vl.utils.io import load_pil_images
6
  from io import BytesIO
7
  from PIL import Image
8
+ import spaces # Import spaces for ZeroGPU support
9
 
10
  # Load the model and processor
11
  model_path = "deepseek-ai/deepseek-vl-1.3b-chat"
12
  vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
13
  tokenizer = vl_chat_processor.tokenizer
14
 
15
+ # Define the function for image description with ZeroGPU support
16
+ @spaces.GPU # Ensures GPU allocation for this function
17
+ def describe_image(image, user_question="Solve the problem in the image"):
18
  try:
19
+ # Convert the PIL Image to a BytesIO object for compatibility
20
  image_byte_arr = BytesIO()
21
+ image.save(image_byte_arr, format="PNG") # Save image in PNG format
22
+ image_byte_arr.seek(0) # Move pointer to the start
23
 
24
+ # Define the conversation, using the user's question
25
  conversation = [
26
  {
27
  "role": "User",
28
  "content": f"<image_placeholder>{user_question}",
29
+ "images": [image_byte_arr] # Pass the image byte array instead of an object
30
  },
31
  {
32
  "role": "Assistant",
 
34
  }
35
  ]
36
 
37
+ # Convert image byte array back to a PIL image for processing
38
+ pil_images = [Image.open(BytesIO(image_byte_arr.read()))] # Convert byte back to PIL Image
39
+ image_byte_arr.seek(0) # Reset the byte stream again for reuse
40
 
41
+ # Load images and prepare the inputs
42
  prepare_inputs = vl_chat_processor(
43
  conversations=conversation,
44
  images=pil_images,
45
  force_batchify=True
46
+ ).to('cuda')
 
 
 
 
 
47
 
48
+ # Load and prepare the model
49
+ vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(torch.bfloat16).cuda().eval()
 
 
 
50
 
51
+ # Generate embeddings from the image input
52
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
53
 
54
+ # Generate the model's response
55
  outputs = vl_gpt.language_model.generate(
56
  inputs_embeds=inputs_embeds,
57
+ attention_mask=prepare_inputs.attention_mask,
58
  pad_token_id=tokenizer.eos_token_id,
59
  bos_token_id=tokenizer.bos_token_id,
60
  eos_token_id=tokenizer.eos_token_id,
 
63
  use_cache=True
64
  )
65
 
66
+ # Decode the generated tokens into text
67
+ answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
68
  return answer
69
 
70
  except Exception as e:
71
+ # Provide detailed error information
72
  return f"Error: {str(e)}"
73
 
74
  # Gradio interface
75
  def gradio_app():
76
  with gr.Blocks() as demo:
77
+ gr.Markdown("# Image Description with DeepSeek VL 1.3b 🐬\n### Upload an image and ask a question about it.")
78
 
79
  with gr.Row():
80
  image_input = gr.Image(type="pil", label="Upload an Image")
81
  question_input = gr.Textbox(
82
  label="Question (optional)",
83
+ placeholder="Ask a question about the image (e.g., 'What is happening in this image?')",
84
  lines=2
85
  )
86
 
87
  output_text = gr.Textbox(label="Image Description", interactive=False)
88
+
89
  submit_btn = gr.Button("Generate Description")
90
 
91
  submit_btn.click(
92
  fn=describe_image,
93
+ inputs=[image_input, question_input], # Pass both image and question as inputs
94
  outputs=output_text
95
  )
96
 
97
  demo.launch()
98
 
99
+ # Launch the Gradio app
100
  gradio_app()