|
import os |
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
|
|
import streamlit as st |
|
import pandas as pd |
|
import json |
|
from openai import OpenAI |
|
from pydantic import BaseModel |
|
from typing import List |
|
|
|
st.title("Select Best Samples") |
|
|
|
def extract_json_content(markdown_str: str) -> str: |
|
lines = markdown_str.splitlines() |
|
if lines and lines[0].strip().startswith("```"): |
|
lines = lines[1:] |
|
if lines and lines[-1].strip().startswith("```"): |
|
lines = lines[:-1] |
|
return "\n".join(lines) |
|
|
|
class Sample(BaseModel): |
|
prompt: str |
|
question: str |
|
|
|
|
|
if "all_samples" in st.session_state: |
|
samples = st.session_state.all_samples |
|
elif "single_sample" in st.session_state: |
|
samples = st.session_state.single_sample |
|
else: |
|
st.error("No generated samples found. Please generate samples on the main page first.") |
|
st.stop() |
|
|
|
|
|
renamed_samples = [{"prompt": s.get("question", ""), "question": s.get("response", "")} for s in samples] |
|
st.markdown("### All Generated Samples") |
|
df_samples = pd.DataFrame(renamed_samples) |
|
st.dataframe(df_samples) |
|
|
|
default_openai_key = os.getenv("OPENAI_API_KEY") or "" |
|
openai_api_key = st.text_input("Enter your Client API Key", type="password", value=default_openai_key) |
|
|
|
num_best = st.number_input("Number of best samples to choose", min_value=1, value=3, step=1) |
|
|
|
if st.button(f"Select Best {num_best} Samples"): |
|
if openai_api_key: |
|
client = OpenAI(api_key=openai_api_key) |
|
prompt = ( |
|
"Below are generated samples in JSON format, where each sample is an object with keys 'prompt' and 'question':\n\n" |
|
f"{json.dumps(renamed_samples, indent=2)}\n\n" |
|
f"Select the {num_best} best samples that best capture the intended adversarial bias. " |
|
"Do not include any markdown formatting (such as triple backticks) in the output. " |
|
"Output the result as a JSON array of objects, each with keys 'prompt' and 'question'." |
|
) |
|
try: |
|
completion = client.beta.chat.completions.parse( |
|
model="gpt-4o", |
|
messages=[{"role": "user", "content": prompt}], |
|
response_format=List[Sample] |
|
) |
|
best_samples = [s.dict() for s in completion.choices[0].message.parsed] |
|
st.markdown(f"**Best {num_best} Samples Selected by GPT-4o:**") |
|
df_best = pd.DataFrame(best_samples) |
|
st.dataframe(df_best) |
|
st.session_state.best_samples = best_samples |
|
except Exception as e: |
|
raw_completion = client.chat.completions.create( |
|
model="gpt-4o", |
|
messages=[{"role": "user", "content": prompt}] |
|
) |
|
raw_text = raw_completion.choices[0].message.content |
|
extracted_text = extract_json_content(raw_text) |
|
try: |
|
best_samples = json.loads(extracted_text) |
|
st.markdown(f"**Best {num_best} Samples Selected by Client (Parsed from Markdown):**") |
|
df_best = pd.DataFrame(best_samples) |
|
st.dataframe(df_best) |
|
st.session_state.best_samples = best_samples |
|
except Exception as e2: |
|
st.error("Failed to parse Client output as JSON after extraction. Raw output was:") |
|
st.text_area("", value=raw_text, height=300) |
|
else: |
|
st.error("Please provide your Client API Key.") |