Spaces:
Running
Running
import os | |
import gradio as gr | |
from azure.ai.inference import ChatCompletionsClient | |
from azure.core.credentials import AzureKeyCredential | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
model_path = "microsoft/Phi-4-mini-instruct" | |
hf_model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
device_map="auto", | |
torch_dtype="auto", | |
trust_remote_code=True, | |
) | |
# Azure Inference setup | |
url = os.getenv("Azure_Endpoint") | |
api_key = AzureKeyCredential(os.getenv("Azure_API_KEY")) | |
# Initialize the ChatCompletionsClient | |
client = ChatCompletionsClient( | |
endpoint=url, | |
credential=api_key, | |
stream=True | |
) | |
# Get and print model information (optional) | |
try: | |
model_info = client.get_model_info() | |
print("Model name:", model_info.model_name) | |
print("Model type:", model_info.model_type) | |
print("Model provider name:", model_info.model_provider_name) | |
except Exception as e: | |
print("Could not get model info:", str(e)) | |
# Configuration parameters | |
default_temperature = 0.7 | |
default_max_tokens = 4096 | |
default_top_p = 0.1 | |
# Example prompts that users can try | |
example_prompts = [ | |
"Explain internet to a medieval knight.", | |
"Share some ideas about the best vegetables to start growing in February and March. I'd love to know which ones thrive when planted early in the season!", | |
"I'd like to buy a new car. Start by asking me about my budget and which features I care most about, then provide a recommendation.", | |
"I'm thinking about moving to a new city. Can you help me plan the move?", | |
"I have $20,000 in my savings account, where I receive a 4% profit per year and payments twice a year. Can you please tell me how long it will take for me to become a millionaire?", | |
] | |
def get_azure_response(message, chat_history, temperature, max_tokens, top_p): | |
""" | |
Function to get a response from the Azure Phi-4 model | |
""" | |
# Prepare conversation history in the format expected by Azure | |
messages = [{"role": "system", "content": "You are a helpful AI assistant specialized in financial advice and planning."}] | |
# Add conversation history | |
for human, assistant in chat_history: | |
messages.append({"role": "user", "content": human}) | |
if assistant: # Only add non-empty assistant messages | |
messages.append({"role": "assistant", "content": assistant}) | |
# Add the current message | |
messages.append({"role": "user", "content": message}) | |
# Prepare the payload | |
payload = { | |
"messages": messages, | |
"max_tokens": max_tokens, | |
"temperature": temperature, | |
"top_p": top_p, | |
"presence_penalty": 0, | |
"frequency_penalty": 0, | |
"stream": True | |
} | |
# Get response | |
try: | |
print("Sending request to Azure...") | |
response = client.complete(payload) | |
return response | |
except Exception as e: | |
print(f"Error getting response: {str(e)}") | |
return f"Error: {str(e)}" | |
# CSS for custom styling | |
custom_css = """ | |
.container { | |
max-width: 1200px !important; | |
margin-left: auto !important; | |
margin-right: auto !important; | |
padding-top: 0rem !important; | |
} | |
.header { | |
text-align: center; | |
margin-bottom: 0rem; | |
} | |
.header h1 { | |
font-size: 2.5rem !important; | |
font-weight: 700 !important; | |
color: #1a5276 !important; | |
margin-bottom: 0.5rem !important; | |
} | |
.header p { | |
font-size: 1.2rem !important; | |
color: #34495e !important; | |
} | |
.chatbot-container { | |
border-radius: 10px !important; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1) !important; | |
overflow: hidden !important; | |
} | |
.emoji-button { | |
background: none !important; | |
border: none !important; | |
padding: 0.2rem 0.5rem !important; | |
font-size: 1.5rem !important; | |
cursor: pointer !important; | |
transition: transform 0.2s !important; | |
} | |
.emoji-button:hover { | |
transform: scale(1.2) !important; | |
} | |
.message-input { | |
margin-top: 1rem !important; | |
display: flex !important; | |
align-items: center !important; | |
} | |
.footer { | |
margin-top: 2rem; | |
text-align: center; | |
font-size: 0.9rem; | |
color: #7f8c8d; | |
} | |
.parameters-section { | |
background-color: #f8f9fa !important; | |
padding: 1rem !important; | |
border-radius: 8px !important; | |
margin-bottom: 1rem !important; | |
} | |
.examples-section { | |
background-color: #e8f4f8 !important; | |
padding: 1rem !important; | |
border-radius: 8px !important; | |
margin-bottom: 1rem !important; | |
} | |
.right-panel { | |
padding-left: 1rem !important; | |
} | |
""" | |
# Create the Gradio interface with a modern, professional design | |
with gr.Blocks(css=custom_css, title="Phi-4-mini Playground") as demo: | |
with gr.Column(elem_classes="container"): | |
# Header section | |
with gr.Column(elem_classes="header"): | |
gr.Markdown("# Phi-4-mini Playground") | |
# Main content with side-by-side layout | |
with gr.Row(): | |
# Left column for chat | |
with gr.Column(scale=7): | |
# Main chat interface | |
with gr.Column(elem_classes="chatbot-container"): | |
chatbot = gr.Chatbot( | |
height=600, # Increased height to match right panel | |
bubble_full_width=False, | |
show_label=False, | |
avatar_images=(None, "https://upload.wikimedia.org/wikipedia/commons/d/d3/Phi-integrated-information-symbol.png") | |
) | |
with gr.Row(elem_classes="message-input"): | |
msg = gr.Textbox( | |
label="Your message", | |
placeholder="Start type here ...", | |
lines=2, | |
show_label=False, | |
container=False, | |
scale=8 | |
) | |
send_btn = gr.Button("๐คSend", variant="primary", scale=1) | |
with gr.Row(): | |
clear = gr.Button("๐๏ธ Clear", variant="secondary", scale=1) | |
regenerate = gr.Button("๐ Regenerate", variant="secondary", scale=1) | |
# Right column for examples and settings | |
with gr.Column(scale=3, elem_classes="right-panel"): | |
# Examples section | |
with gr.Column(elem_classes="examples-section"): | |
examples = gr.Examples( | |
examples=example_prompts, | |
inputs=msg, | |
examples_per_page=4 | |
) | |
# Model parameters section | |
with gr.Column(elem_classes="parameters-section"): | |
gr.Markdown("### Advanced Settings") | |
temp_slider = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=default_temperature, | |
step=0.1, | |
label="Temperature", | |
info="Higher = more creative, lower = more focused" | |
) | |
top_p_slider = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=default_top_p, | |
step=0.1, | |
label="Top P", | |
info="Controls diversity of responses" | |
) | |
max_tokens_slider = gr.Slider( | |
minimum=100, | |
maximum=32000, | |
value=default_max_tokens, | |
step=100, | |
label="Max Tokens", | |
info="Maximum length of response" | |
) | |
# Footer | |
with gr.Column(elem_classes="footer"): | |
gr.Markdown("Powered by Microsoft [Phi-4 mini model](https://aka.ms/phi-4-mini/azure) on Azure AI. ยฉ 2025") | |
# Simplified chat function that handles both sending and receiving messages | |
def chat(message, history, temperature, max_tokens, top_p): | |
if not message.strip(): | |
return "", history | |
# Get response from Azure | |
response = get_azure_response(message, history, temperature, max_tokens, top_p) | |
# Add the exchange to history | |
history.append((message, "")) | |
response_index = len(history) - 1 # Create a blank index for the newest response | |
full_response = "" # Stream the response | |
try: | |
print("Streaming response from Azure...") | |
for chunk in response: | |
if chunk.choices: | |
content = chunk.choices[0].delta.content | |
if content: | |
full_response += content | |
# Update the response in place | |
history[response_index] = (message, full_response) | |
# Yield the updated history | |
yield "", history | |
# Print usage statistics at the end | |
print("Streaming completed") | |
# Return the final state | |
return "", history | |
except Exception as e: | |
error_message = f"Error: {str(e)}" | |
print(error_message) | |
# Update history with error message | |
history[response_index] = (message, error_message) | |
return "", history | |
# Function to clear the conversation | |
def clear_conversation(): | |
return [], default_temperature, default_max_tokens, default_top_p | |
# Function to regenerate the last response | |
def regenerate_response(history, temperature, max_tokens, top_p): | |
if not history: | |
return history | |
last_user_message = history[-1][0] | |
# Remove the last exchange | |
history = history[:-1] | |
# Get new response | |
response = get_azure_response(last_user_message, history, temperature, max_tokens, top_p) | |
# Add the exchange to history | |
history.append((last_user_message, "")) | |
response_index = len(history) - 1 | |
full_response = "" | |
try: | |
for chunk in response: | |
if chunk.choices: | |
content = chunk.choices[0].delta.content | |
if content: | |
full_response += content | |
history[response_index] = (last_user_message, full_response) | |
yield history | |
return history | |
except Exception as e: | |
error_message = f"Error: {str(e)}" | |
history[response_index] = (last_user_message, error_message) | |
return history | |
# Set up event handlers | |
msg.submit(chat, [msg, chatbot, temp_slider, max_tokens_slider, top_p_slider], [msg, chatbot]) | |
send_btn.click(chat, [msg, chatbot, temp_slider, max_tokens_slider, top_p_slider], [msg, chatbot]) | |
clear.click(clear_conversation, None, [chatbot, temp_slider, max_tokens_slider, top_p_slider]) | |
regenerate.click(regenerate_response, [chatbot, temp_slider, max_tokens_slider, top_p_slider], [chatbot]) | |
# Launch the app | |
demo.launch() # Set share=True to generate a public URL for testing# Launch the app |