Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import InferenceClient | |
from transformers import pipeline | |
import edge_tts | |
import tempfile | |
import asyncio | |
# Initialize the inference client with your Hugging Face token | |
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.1") | |
# Initialize the ASR pipeline | |
asr = pipeline("automatic-speech-recognition", "facebook/wav2vec2-base-960h") | |
INITIAL_MESSAGE = "Hi! I'm your music buddy—tell me about your mood and the type of tunes you're in the mood for today!" | |
def speech_to_text(speech): | |
"""Converts speech to text using the ASR pipeline.""" | |
return asr(speech)["text"] | |
def classify_mood(input_string): | |
"""Classifies the mood based on keywords in the input string.""" | |
input_string = input_string.lower() | |
mood_words = {"happy", "sad", "instrumental", "party"} | |
for word in mood_words: | |
if word in input_string: | |
return word, True | |
return None, False | |
def generate(prompt, history, temperature=0.1, max_new_tokens=2048, top_p=0.8, repetition_penalty=1.0): | |
temperature = float(temperature) | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
seed=42, | |
) | |
formatted_prompt = format_prompt(prompt, history) | |
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) | |
output = "" | |
for response in stream: | |
output += response.token.text | |
# Check if the output is a single mood word (confirmed by user) | |
if output.strip().lower() in ["happy", "sad", "instrumental", "party"]: | |
return f"Playing {output.strip().capitalize()} playlist for you!" | |
elif output.strip().lower() == "unclear": | |
return "I'm having trouble determining your mood. Could you tell me more explicitly how you're feeling?" | |
else: | |
return output.strip() | |
def format_prompt(message, history): | |
"""Formats the prompt including fixed instructions and conversation history.""" | |
fixed_prompt = """ | |
You are a smart mood analyzer for a music recommendation system. Your goal is to determine the user's current mood and suggest an appropriate music playlist. Follow these instructions carefully: | |
1. Engage in a conversation to understand the user's mood. Don't assume their mood based on activities or preferences. | |
2. Classify the mood into one of four categories: Happy, Sad, Instrumental, or Party. | |
3. If the mood is unclear, ask relevant follow-up questions. Do not classify prematurely. | |
4. Before suggesting a playlist, always ask for confirmation. For example: "It sounds like you might be in a [mood] mood. Would you like me to play a [mood] playlist for you?" | |
5. Only respond with a single mood word (Happy, Sad, Instrumental, or Party) if the user explicitly confirms they want that type of playlist. | |
6. If you can't determine the mood after 5 exchanges, respond with "Unclear". | |
7. Stay on topic and focus on understanding the user's current emotional state. | |
Remember: Your primary goal is accurate mood classification and appropriate music suggestion. Always get confirmation before playing a playlist. | |
""" | |
prompt = f"{fixed_prompt}\n" | |
# Add conversation history | |
for i, (user_prompt, bot_response) in enumerate(history): | |
prompt += f"User: {user_prompt}\nAssistant: {bot_response}\n" | |
if i == 3: # This is the 4th exchange (0-indexed) | |
prompt += "Note: This is the last exchange. If the mood is still unclear, respond with 'Unclear'.\n" | |
prompt += f"User: {message}\nAssistant:" | |
return prompt | |
async def text_to_speech(text): | |
communicate = edge_tts.Communicate(text) | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
tmp_path = tmp_file.name | |
await communicate.save(tmp_path) | |
return tmp_path | |
def process_input(input_text, history): | |
if not input_text: | |
return history, history, "", None | |
response = generate(input_text, history) | |
history.append((input_text, response)) | |
return history, history, "", None | |
async def generate_audio(history): | |
if history and len(history) > 0: | |
last_response = history[-1][1] | |
audio_path = await text_to_speech(last_response) | |
return audio_path | |
return None | |
async def init_chat(): | |
history = [("", INITIAL_MESSAGE)] | |
audio_path = await text_to_speech(INITIAL_MESSAGE) | |
return history, history, audio_path | |
# Gradio interface setup | |
with gr.Blocks() as demo: | |
gr.Markdown("# Mood-Based Music Recommender with Continuous Voice Chat") | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox(placeholder="Type your message here or use the microphone to speak...") | |
audio_output = gr.Audio(label="AI Response", autoplay=True) | |
state = gr.State([]) | |
with gr.Row(): | |
submit = gr.Button("Send") | |
voice_input = gr.Audio(sources="microphone", type="filepath", label="Voice Input") | |
# Initialize chat with greeting | |
demo.load(init_chat, outputs=[state, chatbot, audio_output]) | |
# Handle text input | |
msg.submit(process_input, inputs=[msg, state], outputs=[state, chatbot, msg, voice_input]).then( | |
generate_audio, inputs=[state], outputs=[audio_output] | |
) | |
submit.click(process_input, inputs=[msg, state], outputs=[state, chatbot, msg, voice_input]).then( | |
generate_audio, inputs=[state], outputs=[audio_output] | |
) | |
# Handle voice input | |
voice_input.stop_recording( | |
lambda x: speech_to_text(x) if x else "", | |
inputs=[voice_input], | |
outputs=[msg] | |
).then( | |
process_input, inputs=[msg, state], outputs=[state, chatbot, msg, voice_input] | |
).then( | |
generate_audio, inputs=[state], outputs=[audio_output] | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) |