Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from collections.abc import Iterator | |
from transformers import Gemma3ForCausalLM, AutoTokenizer, TextIteratorStreamer | |
import time | |
import spaces | |
from threading import Thread | |
import gradio as gr | |
MAX_MAX_NEW_TOKENS = 2048 | |
DEFAULT_MAX_NEW_TOKENS = 1024 | |
MAX_INPUT_TOKEN_LENGTH = 4096 | |
start_time = time.time() | |
model = Gemma3ForCausalLM.from_pretrained( | |
"google/gemma-3-4b-it", | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
).eval() | |
tokenizer = AutoTokenizer.from_pretrained( | |
"google/gemma-3-4b-it", | |
) | |
load_time = time.time() - start_time | |
print(f"Model loaded in {load_time:.2f} seconds") | |
def generate_text( | |
message: str, | |
chat_history: list[dict], | |
max_new_tokens: int = 1024, | |
temperature: float = 0.6, | |
top_p: float = 0.9, | |
top_k: int = 50, | |
repetition_penalty: float = 1.0, | |
) -> Iterator[str]: | |
conversation = [*chat_history, {"role": "user", "content": message}] | |
input_ids = tokenizer.apply_chat_template( | |
conversation, add_generation_prompt=True, return_tensors="pt" | |
) | |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
input_ids = input_ids.to(model.device) | |
streamer = TextIteratorStreamer( | |
tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True | |
) | |
generate_kwargs = dict( | |
{"input_ids": input_ids}, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
num_beams=1, | |
repetition_penalty=repetition_penalty, | |
) | |
thread = Thread(target=model.generate, kwargs=generate_kwargs) | |
thread.start() | |
output = [] | |
for text in streamer: | |
output.append(text) | |
yield " ".join(output) |