gemma-examples / app.py
randomUser69696's picture
fixed the input message format
64ee729
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
import gradio as gr
# from PIL import Image
# import requests
# import torch
# import os
# from transformers import Gemma3ForConditionalGeneration, AutoProcessor
# print("hey")
# # Set the cache directory
# cache_dir = "F:\\huggingface_cache"
# # Set environment variables for good measure
# # os.environ["TRANSFORMERS_CACHE"] = cache_dir
# # os.environ["HF_HOME"] = cache_dir
# # os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
# # Model ID
# model_id = "gemma3:latest"
# from ollama import chat
# from ollama import ChatResponse
# def _get_response(message):
# messages = [
# {
# 'role': 'user',
# 'content': message,
# },
# ]
# response: ChatResponse = chat(model=model_id, messages=messages)
# return response.message.content
# import requests
# import base64
# # Function to encode image to Base64
# def encode_image_to_base64(image_path):
# with open(image_path, "rb") as image_file:
# return base64.b64encode(image_file.read()).decode("utf-8")
# def image_process():
# image_path = r"F:\HF\gemma-examples\WhatsApp Image 2025-03-21 at 10.05.06 PM.jpeg" # Replace with your image path
# # Encode the image
# image_base64 = encode_image_to_base64(image_path)
# # Ollama API endpoint
# OLLAMA_URL = "http://localhost:11434/api/generate"
# # Payload for the API request
# payload = {
# "model": model_id, # Specify the model version
# "prompt": "Given image is a handwritten text in english language, read it carefully and extract all the text mentioned in it.",
# "images": [image_base64], # List of Base64-encoded images
# "stream": False
# }
# # Headers for the request
# headers = {
# "Content-Type": "application/json"
# }
# # Send the POST request
# response = requests.post(OLLAMA_URL, json=payload, headers=headers)
# # Check the response
# if response.status_code == 200:
# data = response.json()
# print("Response from Gemma 3:")
# print(data.get("response", "No response field in the API response."))
# else:
# print(f"Error: {response.status_code}")
# print(response.text)
# return response.text
# # Path to your image
# def _hit_endpoint(name):
# import requests
# import json
# # Define the URL of the Ollama server
# OLLAMA_URL = "http://localhost:11434/api/generate"
# # Define the request payload
# payload = {
# "model": model_id, # Change this to your desired model
# "prompt": name,
# "stream": False
# }
# # Make the request
# response = requests.post(OLLAMA_URL, json=payload)
# # Parse and print the response
# if response.status_code == 200:
# data = response.json()
# print(data["response"]) # Extracting the generated text
# return data["response"]
# else:
# print(f"Error: {response.status_code} - {response.text}")
# return "An error occurred!"
import os
import torch
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
import os
from huggingface_hub import login
import os
login(token=os.getenv("hf_token") )
model_id = os.getenv("MODEL_ID", "google/gemma-3-12b-it")
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager",cache_dir = "F:\\huggingface_cache"
)
def run_fn(message):
messages_list = []
'''
conversation = [
{
"role": "user",
"content": [
{"type": "image", "image": "https://www.ilankelman.org/stopsigns/australia.jpg"},
{"type": "text", "text": "Please describe this image in detail."},
],
},
]
'''
messages_list.append({"role": "user", "content":[{ "type":"text","text": message}] })
inputs = processor.apply_chat_template(
messages_list,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(device=model.device, dtype=torch.bfloat16)
streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
max_new_tokens = 100
generate_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
)
outputs = model.generate(**generate_kwargs)
return outputs
# return None
def greet(name):
return run_fn(name)
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
demo.launch()