Spaces:
Running
Running
File size: 5,025 Bytes
1efea19 d8f007f 1efea19 d8f007f 1efea19 0c1dfa2 1efea19 8e1d1e9 1efea19 d8f007f 1efea19 9b7d4f4 1b9ef5d 9b7d4f4 8020e39 ea616ac f52ef7c 9b7d4f4 1efea19 8020e39 aeeecba 1efea19 8020e39 1efea19 d8f007f 1efea19 d8f007f 1efea19 d8f007f 1efea19 d8f007f 1efea19 d8f007f 1efea19 8e1d1e9 1efea19 d8f007f 1efea19 d8f007f 1efea19 8020e39 1efea19 8e1d1e9 1efea19 d8f007f 1efea19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import base64
st.set_page_config(page_title="LIA Demo", layout="wide")
# Model selection (STUBBED behavior)
# model_option = st.selectbox(
# "Choose a Gemma to reveal hidden truths:",
# ["gemma-2b-it (Instruct)", "gemma-2b", "gemma-7b", "gemma-7b-it"],
# index=0,
# help="Stubbed selection – only gemma-2b-it will load for now."
# )
st.markdown("<h1 style='text-align: center;'>Ask LeoNardo!</h1>", unsafe_allow_html=True)
# Load both GIFs in base64 format
def load_gif_base64(path):
with open(path, "rb") as f:
return base64.b64encode(f.read()).decode("utf-8")
# still_gem_b64 = load_gif_base64("assets/stillGem.gif")
# rotating_gem_b64 = load_gif_base64("assets/rotatingGem.gif")
# Placeholder for GIF HTML
gif_html = st.empty()
caption = st.empty()
# Initially show still gem
# gif_html.markdown(
# f"<div style='text-align:center;'><img src='data:image/gif;base64,{still_gem_b64}' width='300'></div>",
# unsafe_allow_html=True,
# )
gif_html.markdown(
f"<div style='text-align:center;'><img src='https://media0.giphy.com/media/v1.Y2lkPTc5MGI3NjExYTRxYzI2bXJmY3N2bXBtMHJtOGV3NW9vZ3l3M3czbGYybGpkeWQ1YSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9cw/3uPWb5EYVvxdfoREQm/giphy.gif' width='300'></div>",
unsafe_allow_html=True,
)
@st.cache_resource
def load_model():
# As Gemma is gated, we will show functionality of the demo using DeepSeek-R1-Distill-Qwen-1.5B model
# model_id = "google/gemma-2b-it"
# tokenizer = AutoTokenizer.from_pretrained(model_id, token=True)
# model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
# model_id = "deepseek-ai/deepseek-llm-7b-chat"
model_id = "deepseek-ai/DeepSeek-V3-0324"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
# device_map=None,
# torch_dtype=torch.float32
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code = True
)
# model.to("cpu")
return tokenizer, model
tokenizer, model = load_model()
prompt = st.text_area("Enter your prompt:", "What is Leonardo, the company with the red logo?")
# Example prompt selector
# examples = {
# "🧠 Summary": "Summarize the history of AI in 5 bullet points.",
# "💻 Code": "Write a Python function to sort a list using bubble sort.",
# "📜 Poem": "Write a haiku about large language models.",
# "🤖 Explain": "Explain what a transformer is in simple terms.",
# "🔍 Fact": "Who won the FIFA World Cup in 2022?"
# }
# selected_example = st.selectbox("Choose a Gemma to consult:", list(examples.keys()) + ["✍️ Custom input"])
# Add before generation
# col1, col2, col3 = st.columns(3)
# with col1:
# temperature = st.slider("Temperature", 0.1, 1.5, 1.0)
# with col2:
# max_tokens = st.slider("Max tokens", 50, 500, 100)
# with col3:
# top_p = st.slider("Top-p (nucleus sampling)", 0.1, 1.0, 0.95)
# if selected_example != "✍️ Custom input":
# prompt = examples[selected_example]
# else:
# prompt = st.text_area("Enter your prompt:")
if st.button("Generate"):
# Swap to rotating GIF
# gif_html.markdown(
# f"<div style='text-align:center;'><img src='data:image/gif;base64,{rotating_gem_b64}' width='300'></div>",
# unsafe_allow_html=True,
# )
gif_html.markdown(
f"<div style='text-align:center;'><img src='https://media2.giphy.com/media/v1.Y2lkPTc5MGI3NjExMXViMm02MnR6bGJ4c2h3ajYzdWNtNXNtYnNic3lnN2xyZzlzbm9seSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9cw/k32ddF9WVs44OUaZAm/giphy.gif' width='300'></div>",
unsafe_allow_html=True,
)
caption.markdown("<p style='text-align: center; margin-top: 20px;'>LeoNardo is thinking... 🌀</p>", unsafe_allow_html=True)
# Generate text
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate( **inputs,
# max_new_tokens=100,
max_new_tokens=200,
temperature=1.0,
top_p=0.95)
# Back to still
# gif_html.markdown(
# f"<div style='text-align:center;'><img src='data:image/gif;base64,{still_gem_b64}' width='300'></div>",
# unsafe_allow_html=True,
# )
gif_html.markdown(
f"<div style='text-align:center;'><img src='https://media0.giphy.com/media/v1.Y2lkPTc5MGI3NjExYTRxYzI2bXJmY3N2bXBtMHJtOGV3NW9vZ3l3M3czbGYybGpkeWQ1YSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9cw/3uPWb5EYVvxdfoREQm/giphy.gif' width='300'></div>",
unsafe_allow_html=True,
)
caption.empty()
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
st.markdown("### ✨ Output:")
st.write(result) |