space-turtle / pages /Select_Best.py
Akash190104's picture
initial demo commit
05b5eca
raw
history blame
3.49 kB
import os
from dotenv import load_dotenv
load_dotenv()
import streamlit as st
import pandas as pd
import json
from openai import OpenAI
from pydantic import BaseModel
from typing import List
st.title("Select Best Samples")
def extract_json_content(markdown_str: str) -> str:
lines = markdown_str.splitlines()
if lines and lines[0].strip().startswith("```"):
lines = lines[1:]
if lines and lines[-1].strip().startswith("```"):
lines = lines[:-1]
return "\n".join(lines)
class Sample(BaseModel):
prompt: str
question: str
# Use samples from either interactive or random generation.
if "all_samples" in st.session_state:
samples = st.session_state.all_samples
elif "single_sample" in st.session_state:
samples = st.session_state.single_sample
else:
st.error("No generated samples found. Please generate samples on the main page first.")
st.stop()
# Rename keys for consistency.
renamed_samples = [{"prompt": s.get("question", ""), "question": s.get("response", "")} for s in samples]
st.markdown("### All Generated Samples")
df_samples = pd.DataFrame(renamed_samples)
st.dataframe(df_samples)
default_openai_key = os.getenv("OPENAI_API_KEY") or ""
openai_api_key = st.text_input("Enter your Client API Key", type="password", value=default_openai_key)
num_best = st.number_input("Number of best samples to choose", min_value=1, value=3, step=1)
if st.button(f"Select Best {num_best} Samples"):
if openai_api_key:
client = OpenAI(api_key=openai_api_key)
prompt = (
"Below are generated samples in JSON format, where each sample is an object with keys 'prompt' and 'question':\n\n"
f"{json.dumps(renamed_samples, indent=2)}\n\n"
f"Select the {num_best} best samples that best capture the intended adversarial bias. "
"Do not include any markdown formatting (such as triple backticks) in the output. "
"Output the result as a JSON array of objects, each with keys 'prompt' and 'question'."
)
try:
completion = client.beta.chat.completions.parse(
model="gpt-4o",
messages=[{"role": "user", "content": prompt}],
response_format=List[Sample]
)
best_samples = [s.dict() for s in completion.choices[0].message.parsed]
st.markdown(f"**Best {num_best} Samples Selected by GPT-4o:**")
df_best = pd.DataFrame(best_samples)
st.dataframe(df_best)
st.session_state.best_samples = best_samples
except Exception as e:
raw_completion = client.chat.completions.create(
model="gpt-4o",
messages=[{"role": "user", "content": prompt}]
)
raw_text = raw_completion.choices[0].message.content
extracted_text = extract_json_content(raw_text)
try:
best_samples = json.loads(extracted_text)
st.markdown(f"**Best {num_best} Samples Selected by Client (Parsed from Markdown):**")
df_best = pd.DataFrame(best_samples)
st.dataframe(df_best)
st.session_state.best_samples = best_samples
except Exception as e2:
st.error("Failed to parse Client output as JSON after extraction. Raw output was:")
st.text_area("", value=raw_text, height=300)
else:
st.error("Please provide your Client API Key.")