|
import gradio as gr |
|
import pixeltable as pxt |
|
from pixeltable.functions.mistralai import chat_completions |
|
from datetime import datetime |
|
from textblob import TextBlob |
|
import nltk |
|
from nltk.tokenize import word_tokenize |
|
from nltk.corpus import stopwords |
|
import os |
|
import getpass |
|
import re |
|
|
|
|
|
nltk.download('punkt', quiet=True) |
|
nltk.download('stopwords', quiet=True) |
|
|
|
|
|
if 'MISTRAL_API_KEY' not in os.environ: |
|
os.environ['MISTRAL_API_KEY'] = getpass.getpass('Mistral AI API Key:') |
|
|
|
|
|
@pxt.udf |
|
def get_sentiment_score(text: str) -> float: |
|
return TextBlob(text).sentiment.polarity |
|
|
|
@pxt.udf |
|
def extract_keywords(text: str, num_keywords: int = 5) -> list: |
|
stop_words = set(stopwords.words('english')) |
|
words = word_tokenize(text.lower()) |
|
keywords = [word for word in words if word.isalnum() and word not in stop_words] |
|
return sorted(set(keywords), key=keywords.count, reverse=True)[:num_keywords] |
|
|
|
@pxt.udf |
|
def calculate_readability(text: str) -> float: |
|
words = len(re.findall(r'\w+', text)) |
|
sentences = len(re.findall(r'\w+[.!?]', text)) or 1 |
|
average_words_per_sentence = words / sentences |
|
return 206.835 - 1.015 * average_words_per_sentence |
|
|
|
def run_inference_and_analysis(task, system_prompt, input_text, temperature, top_p, max_tokens, min_tokens, stop, random_seed, safe_prompt): |
|
|
|
pxt.drop_table('mistral_prompts', ignore_errors=True) |
|
t = pxt.create_table('mistral_prompts', { |
|
'task': pxt.StringType(), |
|
'system': pxt.StringType(), |
|
'input_text': pxt.StringType(), |
|
'timestamp': pxt.TimestampType(), |
|
'temperature': pxt.FloatType(), |
|
'top_p': pxt.FloatType(), |
|
'max_tokens': pxt.IntType(), |
|
'min_tokens': pxt.IntType(), |
|
'stop': pxt.StringType(), |
|
'random_seed': pxt.IntType(), |
|
'safe_prompt': pxt.BoolType() |
|
}) |
|
|
|
|
|
t.insert([{ |
|
'task': task, |
|
'system': system_prompt, |
|
'input_text': input_text, |
|
'timestamp': datetime.now(), |
|
'temperature': temperature, |
|
'top_p': top_p, |
|
'max_tokens': max_tokens, |
|
'min_tokens': min_tokens, |
|
'stop': stop, |
|
'random_seed': random_seed, |
|
'safe_prompt': safe_prompt |
|
}]) |
|
|
|
|
|
msgs = [ |
|
{'role': 'system', 'content': t.system}, |
|
{'role': 'user', 'content': t.input_text} |
|
] |
|
|
|
common_params = { |
|
'messages': msgs, |
|
'temperature': temperature, |
|
'top_p': top_p, |
|
'max_tokens': max_tokens if max_tokens is not None else 300, |
|
'min_tokens': min_tokens, |
|
'stop': stop.split(',') if stop else None, |
|
'random_seed': random_seed, |
|
'safe_prompt': safe_prompt |
|
} |
|
|
|
|
|
t['open_mistral_nemo'] = chat_completions(model='open-mistral-nemo', **common_params) |
|
t['mistral_medium'] = chat_completions(model='mistral-medium', **common_params) |
|
|
|
|
|
t['omn_response'] = t.open_mistral_nemo.choices[0].message.content |
|
t['ml_response'] = t.mistral_medium.choices[0].message.content |
|
|
|
|
|
t['large_sentiment_score'] = get_sentiment_score(t.ml_response) |
|
t['large_keywords'] = extract_keywords(t.ml_response) |
|
t['large_readability_score'] = calculate_readability(t.ml_response) |
|
t['open_sentiment_score'] = get_sentiment_score(t.omn_response) |
|
t['open_keywords'] = extract_keywords(t.omn_response) |
|
t['open_readability_score'] = calculate_readability(t.omn_response) |
|
|
|
|
|
results = t.select( |
|
t.omn_response, t.ml_response, |
|
t.large_sentiment_score, t.open_sentiment_score, |
|
t.large_keywords, t.open_keywords, |
|
t.large_readability_score, t.open_readability_score |
|
).tail(1) |
|
|
|
return ( |
|
results['omn_response'][0], |
|
results['ml_response'][0], |
|
results['large_sentiment_score'][0], |
|
results['open_sentiment_score'][0], |
|
results['large_keywords'][0], |
|
results['open_keywords'][0], |
|
results['large_readability_score'][0], |
|
results['open_readability_score'][0] |
|
) |
|
|
|
def gradio_interface(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown("# LLM Prompt Studio") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
task = gr.Textbox(label="Task") |
|
system_prompt = gr.Textbox(label="System Prompt", lines=3) |
|
input_text = gr.Textbox(label="Input Text", lines=3) |
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
temperature = gr.Slider(minimum=0, maximum=1, value=0.7, step=0.1, label="Temperature") |
|
top_p = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.1, label="Top P") |
|
max_tokens = gr.Number(label="Max Tokens", value=300) |
|
min_tokens = gr.Number(label="Min Tokens", value=None) |
|
stop = gr.Textbox(label="Stop Sequences (comma-separated)") |
|
random_seed = gr.Number(label="Random Seed", value=None) |
|
safe_prompt = gr.Checkbox(label="Safe Prompt", value=False) |
|
|
|
|
|
examples = [ |
|
["Sentiment Analysis", |
|
"You are an AI trained to analyze the sentiment of text. Provide a detailed analysis of the emotional tone, highlighting key phrases that indicate sentiment.", |
|
"The new restaurant downtown exceeded all my expectations. The food was exquisite, the service impeccable, and the ambiance was perfect for a romantic evening. I can't wait to go back!", |
|
0.3, 0.95, 200, None, "", None, False], |
|
|
|
["Story Generation", |
|
"You are a creative writer. Generate a short, engaging story based on the given prompt. Include vivid descriptions and an unexpected twist.", |
|
"In a world where dreams are shared, a young girl discovers she can manipulate other people's dreams.", |
|
0.9, 0.8, 500, 300, "The end", None, False] |
|
] |
|
|
|
gr.Examples( |
|
examples=examples, |
|
inputs=[ |
|
task, system_prompt, input_text, |
|
temperature, top_p, max_tokens, |
|
min_tokens, stop, random_seed, |
|
safe_prompt |
|
], |
|
outputs=[ |
|
omn_response, ml_response, |
|
large_sentiment, open_sentiment, |
|
large_keywords, open_keywords, |
|
large_readability, open_readability |
|
], |
|
fn=run_inference_and_analysis |
|
) |
|
|
|
submit_btn = gr.Button("Run Analysis") |
|
|
|
with gr.Column(): |
|
|
|
omn_response = gr.Textbox(label="Open-Mistral-Nemo Response") |
|
ml_response = gr.Textbox(label="Mistral-Medium Response") |
|
|
|
with gr.Row(): |
|
large_sentiment = gr.Number(label="Mistral-Medium Sentiment") |
|
open_sentiment = gr.Number(label="Open-Mistral-Nemo Sentiment") |
|
|
|
with gr.Row(): |
|
large_keywords = gr.Textbox(label="Mistral-Medium Keywords") |
|
open_keywords = gr.Textbox(label="Open-Mistral-Nemo Keywords") |
|
|
|
with gr.Row(): |
|
large_readability = gr.Number(label="Mistral-Medium Readability") |
|
open_readability = gr.Number(label="Open-Mistral-Nemo Readability") |
|
|
|
submit_btn.click( |
|
run_inference_and_analysis, |
|
inputs=[ |
|
task, system_prompt, input_text, |
|
temperature, top_p, max_tokens, |
|
min_tokens, stop, random_seed, |
|
safe_prompt |
|
], |
|
outputs=[ |
|
omn_response, ml_response, |
|
large_sentiment, open_sentiment, |
|
large_keywords, open_keywords, |
|
large_readability, open_readability |
|
] |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
gradio_interface().launch() |