Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import pipeline | |
# Model ID for Llama 3 8B instruct (replace with the exact model you want) | |
MODEL_ID = "tiiuae/falcon-40b" | |
# Load the text-generation pipeline with device_map="auto" to use GPU if available | |
generator = pipeline( | |
"text-generation", | |
model=MODEL_ID, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
) | |
def generate_response(prompt, max_length=512, temperature=0.7): | |
# Format prompt for Llama 3 instruct style | |
formatted_prompt = f"<s>[INST] {prompt} [/INST]" | |
output = generator( | |
formatted_prompt, | |
max_length=max_length, | |
temperature=temperature, | |
do_sample=True, | |
top_p=0.95, | |
num_return_sequences=1, | |
) | |
generated_text = output[0]["generated_text"] | |
# Extract the response after the [/INST] token | |
response = generated_text.split("[/INST]")[-1].strip() | |
return response | |
with gr.Blocks() as demo: | |
gr.Markdown("# Chat with Llama 3 (8B Instruct)") | |
with gr.Row(): | |
with gr.Column(): | |
user_input = gr.Textbox(lines=3, placeholder="Type your message here...", label="Your Message") | |
submit_btn = gr.Button("Submit") | |
with gr.Column(): | |
output = gr.Textbox(lines=10, label="Llama 3 Response") | |
submit_btn.click(fn=generate_response, inputs=user_input, outputs=output) | |
if __name__ == "__main__": | |
demo.launch() | |