iisadia's picture
Update app.py
5e13a93 verified
raw
history blame
8.19 kB
import streamlit as st
import time
import requests
from streamlit.components.v1 import html
import os
from dotenv import load_dotenv
# Voice input dependencies
import torchaudio
import numpy as np
import torch
from io import BytesIO
import hashlib
from audio_recorder_streamlit import audio_recorder
from transformers import pipeline
######################################
# Voice Input Helper Functions
######################################
@st.cache_resource
def load_voice_model():
return pipeline("automatic-speech-recognition", model="openai/whisper-base")
def process_audio(audio_bytes):
waveform, sample_rate = torchaudio.load(BytesIO(audio_bytes))
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
return {"raw": waveform.numpy().squeeze(), "sampling_rate": 16000}
def get_voice_transcription(state_key):
if state_key not in st.session_state:
st.session_state[state_key] = ""
audio_bytes = audio_recorder(
key=state_key + "_audio",
pause_threshold=0.8,
text="๐ŸŽ™๏ธ Speak your message",
recording_color="#e8b62c",
neutral_color="#6aa36f"
)
if audio_bytes:
current_hash = hashlib.md5(audio_bytes).hexdigest()
last_hash_key = state_key + "_last_hash"
if st.session_state.get(last_hash_key, "") != current_hash:
st.session_state[last_hash_key] = current_hash
try:
audio_input = process_audio(audio_bytes)
whisper = load_voice_model()
transcribed_text = whisper(audio_input)["text"]
st.info(f"๐Ÿ“ Transcribed: {transcribed_text}")
st.session_state[state_key] += (" " + transcribed_text).strip()
st.experimental_rerun()
except Exception as e:
st.error(f"Voice input error: {str(e)}")
return st.session_state[state_key]
######################################
# Game Functions & Styling
######################################
@st.cache_resource
def get_help_agent():
return pipeline("conversational", model="facebook/blenderbot-400M-distill")
def inject_custom_css():
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
* { font-family: 'Inter', sans-serif; }
.title { font-size: 2.8rem !important; font-weight: 800 !important;
background: linear-gradient(45deg, #6C63FF, #3B82F6);
-webkit-background-clip: text; -webkit-text-fill-color: transparent;
text-align: center; margin: 1rem 0; }
.subtitle { font-size: 1.1rem; text-align: center; color: #64748B; margin-bottom: 2.5rem; }
.question-box { background: white; border-radius: 20px; padding: 2rem; margin: 1.5rem 0;
box-shadow: 0 10px 25px rgba(0,0,0,0.08); border: 1px solid #e2e8f0; color: black; }
.input-box { background: white; border-radius: 12px; padding: 1.5rem; margin: 1rem 0;
box-shadow: 0 4px 6px rgba(0,0,0,0.05); }
.stTextInput input { border: 2px solid #e2e8f0 !important; border-radius: 10px !important;
padding: 12px 16px !important; }
button { background: linear-gradient(45deg, #6C63FF, #3B82F6) !important;
color: white !important; border-radius: 10px !important;
padding: 12px 24px !important; font-weight: 600; }
.final-reveal { font-size: 2.8rem;
background: linear-gradient(45deg, #6C63FF, #3B82F6);
-webkit-background-clip: text; -webkit-text-fill-color: transparent;
text-align: center; margin: 2rem 0; font-weight: 800; }
</style>
""", unsafe_allow_html=True)
def show_confetti():
html("""
<canvas id="confetti-canvas" class="confetti"></canvas>
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/confetti.browser.min.js"></script>
<script>
const count = 200;
const defaults = { origin: { y: 0.7 }, zIndex: 1050 };
function fire(particleRatio, opts) {
confetti(Object.assign({}, defaults, opts, {
particleCount: Math.floor(count * particleRatio)
}));
}
fire(0.25, { spread: 26, startVelocity: 55 });
fire(0.2, { spread: 60 });
fire(0.35, { spread: 100, decay: 0.91, scalar: 0.8 });
fire(0.1, { spread: 120, startVelocity: 25, decay: 0.92, scalar: 1.2 });
fire(0.1, { spread: 120, startVelocity: 45 });
</script>
""")
def ask_llama(conversation_history, category, is_final_guess=False):
api_url = "https://api.groq.com/openai/v1/chat/completions"
headers = {
"Authorization": f"Bearer {os.getenv('GROQ_API_KEY')}",
"Content-Type": "application/json"
}
system_prompt = f"""You're playing 20 questions to guess a {category}. Rules:
1. Ask strategic, non-repeating yes/no questions to narrow down.
2. Use all previous answers smartly.
3. If you're 80%+ sure, say: Final Guess: [your guess]
4. For places: ask about continent, country, landmarks, etc.
5. For people: ask if real, profession, gender, etc.
6. For objects: ask about use, size, material, etc."""
prompt = f"""Based on these answers about a {category}, provide ONLY your final guess with no extra text:
{conversation_history}""" if is_final_guess else "Ask your next smart yes/no question."
messages = [{"role": "system", "content": system_prompt}]
messages += conversation_history
messages.append({"role": "user", "content": prompt})
data = {
"model": "llama-3-70b-8192",
"messages": messages,
"temperature": 0.8,
"max_tokens": 100
}
try:
res = requests.post(api_url, headers=headers, json=data)
res.raise_for_status()
return res.json()["choices"][0]["message"]["content"]
except Exception as e:
st.error(f"โŒ LLaMA API error: {e}")
return "..."
######################################
# Main App Logic Here (UI, Game Loop)
######################################
def main():
load_dotenv()
inject_custom_css()
st.title("๐ŸŽฎ Guess It! - 20 Questions Game")
st.markdown("<div class='subtitle'>Think of a person, place, or object. LLaMA will try to guess it!</div>", unsafe_allow_html=True)
category = st.selectbox("Category of your secret:", ["Person", "Place", "Object"])
if "conversation" not in st.session_state:
st.session_state.conversation = []
st.session_state.last_bot_msg = ""
if st.button("๐Ÿ”„ Restart Game"):
st.session_state.conversation = []
st.session_state.last_bot_msg = ""
st.rerun()
if not st.session_state.conversation:
st.session_state.last_bot_msg = ask_llama([], category)
st.session_state.conversation.append({"role": "assistant", "content": st.session_state.last_bot_msg})
st.markdown(f"<div class='question-box'><strong>LLaMA:</strong> {st.session_state.last_bot_msg}</div>", unsafe_allow_html=True)
user_input = get_voice_transcription("voice_input") or st.text_input("๐Ÿ’ฌ Your answer (yes/no/sometimes):")
if st.button("Submit Answer") and user_input:
st.session_state.conversation.append({"role": "user", "content": user_input})
with st.spinner("Thinking..."):
response = ask_llama(st.session_state.conversation, category)
st.session_state.last_bot_msg = response
st.session_state.conversation.append({"role": "assistant", "content": response})
st.rerun()
if st.button("๐Ÿค” Make Final Guess"):
with st.spinner("Making final guess..."):
final_guess = ask_llama(st.session_state.conversation, category, is_final_guess=True)
st.markdown(f"<div class='final-reveal'>๐Ÿคฏ Final Guess: <strong>{final_guess}</strong></div>", unsafe_allow_html=True)
show_confetti()
if __name__ == "__main__":
main()