mjavaid commited on
Commit
5253b6b
·
1 Parent(s): ee5632d

first commit

Browse files
Files changed (1) hide show
  1. app.py +23 -33
app.py CHANGED
@@ -1,13 +1,12 @@
1
- import spaces
2
  import gradio as gr
3
  from transformers import pipeline
4
  import torch
5
  import os
 
6
 
7
  hf_token = os.environ["HF_TOKEN"]
8
 
9
  # Load the Gemma 3 pipeline.
10
- # Gemma 3 is a multimodal model that accepts text and image inputs.
11
  pipe = pipeline(
12
  "image-text-to-text",
13
  model="google/gemma-3-4b-it",
@@ -16,53 +15,44 @@ pipe = pipeline(
16
  use_auth_token=hf_token
17
  )
18
  @spaces.GPU
19
- def generate_response(user_text, user_image, history):
 
 
 
 
 
20
  messages = [
21
  {
22
  "role": "system",
23
  "content": [{"type": "text", "text": "You are a helpful assistant."}]
24
  }
25
  ]
26
- user_content = []
27
- if user_image is not None:
28
- user_content.append({"type": "image", "image": user_image})
29
  if user_text:
30
  user_content.append({"type": "text", "text": user_text})
31
  messages.append({"role": "user", "content": user_content})
32
 
33
- # Call the pipeline with the provided messages.
34
  output = pipe(text=messages, max_new_tokens=200)
35
-
36
- print(output)
37
- print(output[0]["generated_text"][-1]["content"])
38
 
39
- # Attempt to extract the generated content using the expected structure.
40
  try:
41
  response = output[0]["generated_text"][-1]["content"]
42
- history.append((user_text, response))
43
-
44
  except (KeyError, IndexError, TypeError):
45
- # Fallback: return the raw output as a string.
46
- #print(response)
47
- pass
48
- #response = str(output)
49
 
50
- return history, history
51
-
52
- with gr.Blocks() as demo:
53
- gr.Markdown("# Gemma 3 Chat Interface")
54
- gr.Markdown(
55
- "This interface lets you chat with the Gemma 3 model. "
56
- "You can type a message and optionally attach an image."
57
- )
58
- # Specify type="messages" to avoid deprecation warnings.
59
- chatbot = gr.Chatbot(type="messages")
60
- with gr.Row():
61
- txt = gr.Textbox(show_label=False, placeholder="Type your message here...", container=False)
62
- img = gr.Image(type="pil", label="Attach an image (optional)")
63
- state = gr.State([])
64
 
65
- txt.submit(generate_response, inputs=[txt, img, state], outputs=[chatbot, state])
 
 
 
 
 
 
 
 
 
66
 
67
  if __name__ == "__main__":
68
- demo.launch()
 
 
1
  import gradio as gr
2
  from transformers import pipeline
3
  import torch
4
  import os
5
+ import spaces
6
 
7
  hf_token = os.environ["HF_TOKEN"]
8
 
9
  # Load the Gemma 3 pipeline.
 
10
  pipe = pipeline(
11
  "image-text-to-text",
12
  model="google/gemma-3-4b-it",
 
15
  use_auth_token=hf_token
16
  )
17
  @spaces.GPU
18
+ def generate_response(user_text, user_image):
19
+ # Check if an image was uploaded.
20
+ if user_image is None:
21
+ return "Error: An image upload is mandatory."
22
+
23
+ # Prepare messages with the system prompt and user inputs.
24
  messages = [
25
  {
26
  "role": "system",
27
  "content": [{"type": "text", "text": "You are a helpful assistant."}]
28
  }
29
  ]
30
+ user_content = [{"type": "image", "image": user_image}]
 
 
31
  if user_text:
32
  user_content.append({"type": "text", "text": user_text})
33
  messages.append({"role": "user", "content": user_content})
34
 
35
+ # Call the pipeline.
36
  output = pipe(text=messages, max_new_tokens=200)
 
 
 
37
 
38
+ # Try to extract the generated content.
39
  try:
40
  response = output[0]["generated_text"][-1]["content"]
 
 
41
  except (KeyError, IndexError, TypeError):
42
+ response = str(output)
 
 
 
43
 
44
+ return response
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ iface = gr.Interface(
47
+ fn=generate_response,
48
+ inputs=[
49
+ gr.Textbox(label="Message", placeholder="Type your message here..."),
50
+ gr.Image(type="pil", label="Upload an Image", source="upload")
51
+ ],
52
+ outputs=gr.Textbox(label="Response"),
53
+ title="Gemma 3 Simple Interface",
54
+ description="Enter your message and upload an image (image upload is mandatory) to get a response."
55
+ )
56
 
57
  if __name__ == "__main__":
58
+ iface.launch()