tiny-model / app.py
warhawkmonk's picture
Update app.py
4392dee verified
raw
history blame
1.43 kB
import gradio as gr
import torch
from transformers import pipeline
# Model ID for Llama 3 8B instruct (replace with the exact model you want)
MODEL_ID = "manycore-research/SpatialLM-Llama-1B"
# Load the text-generation pipeline with device_map="auto" to use GPU if available
generator = pipeline(
"text-generation",
model=MODEL_ID,
torch_dtype=torch.float16,
device_map="auto",
)
def generate_response(prompt, max_length=512, temperature=0.7):
# Format prompt for Llama 3 instruct style
formatted_prompt = f"<s>[INST] {prompt} [/INST]"
output = generator(
formatted_prompt,
max_length=max_length,
temperature=temperature,
do_sample=True,
top_p=0.95,
num_return_sequences=1,
)
generated_text = output[0]["generated_text"]
# Extract the response after the [/INST] token
response = generated_text.split("[/INST]")[-1].strip()
return response
with gr.Blocks() as demo:
gr.Markdown("# Chat with Llama 3 (8B Instruct)")
with gr.Row():
with gr.Column():
user_input = gr.Textbox(lines=3, placeholder="Type your message here...", label="Your Message")
submit_btn = gr.Button("Submit")
with gr.Column():
output = gr.Textbox(lines=10, label="Llama 3 Response")
submit_btn.click(fn=generate_response, inputs=user_input, outputs=output)
if __name__ == "__main__":
demo.launch()