Spaces:
Running
Running
File size: 1,434 Bytes
5492d7f 4392dee 5492d7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import gradio as gr
import torch
from transformers import pipeline
# Model ID for Llama 3 8B instruct (replace with the exact model you want)
MODEL_ID = "manycore-research/SpatialLM-Llama-1B"
# Load the text-generation pipeline with device_map="auto" to use GPU if available
generator = pipeline(
"text-generation",
model=MODEL_ID,
torch_dtype=torch.float16,
device_map="auto",
)
def generate_response(prompt, max_length=512, temperature=0.7):
# Format prompt for Llama 3 instruct style
formatted_prompt = f"<s>[INST] {prompt} [/INST]"
output = generator(
formatted_prompt,
max_length=max_length,
temperature=temperature,
do_sample=True,
top_p=0.95,
num_return_sequences=1,
)
generated_text = output[0]["generated_text"]
# Extract the response after the [/INST] token
response = generated_text.split("[/INST]")[-1].strip()
return response
with gr.Blocks() as demo:
gr.Markdown("# Chat with Llama 3 (8B Instruct)")
with gr.Row():
with gr.Column():
user_input = gr.Textbox(lines=3, placeholder="Type your message here...", label="Your Message")
submit_btn = gr.Button("Submit")
with gr.Column():
output = gr.Textbox(lines=10, label="Llama 3 Response")
submit_btn.click(fn=generate_response, inputs=user_input, outputs=output)
if __name__ == "__main__":
demo.launch()
|