Spaces:
Running
Running
import gradio as gr | |
import json | |
import tempfile | |
import os | |
import re # For parsing conversation | |
from typing import Union, Optional, Dict, Tuple # Import Dict and Tuple | |
# Import the actual functions from synthgen | |
from synthgen import ( | |
generate_synthetic_text, | |
generate_prompts, | |
generate_synthetic_conversation, | |
generate_corpus_content # Import the new function | |
) | |
# We no longer need to import api_key here or check it directly in app.py | |
# --- Helper Functions for JSON Generation --- | |
# Use Union for Python < 3.10 compatibility | |
def create_json_file(data: object, base_filename: str) -> Union[str, None]: | |
"""Creates a temporary JSON file and returns its path.""" | |
try: | |
# Create a temporary file with a .json extension | |
with tempfile.NamedTemporaryFile(mode='w', suffix=".json", delete=False, encoding='utf-8') as temp_file: | |
json.dump(data, temp_file, indent=4, ensure_ascii=False) | |
return temp_file.name # Return the path to the temporary file | |
except Exception as e: | |
print(f"Error creating JSON file {base_filename}: {e}") | |
return None | |
# Add the missing function definition | |
def create_text_file(data: str, base_filename: str) -> Union[str, None]: | |
"""Creates a temporary text file and returns its path.""" | |
try: | |
# Ensure filename ends with .txt | |
if not base_filename.lower().endswith(".txt"): | |
base_filename += ".txt" # Append if missing for clarity, though suffix handles it | |
# Create a temporary file with a .txt extension | |
with tempfile.NamedTemporaryFile(mode='w', suffix=".txt", delete=False, encoding='utf-8') as temp_file: | |
temp_file.write(data) | |
return temp_file.name # Return the path to the temporary file | |
except Exception as e: | |
print(f"Error creating text file {base_filename}: {e}") | |
return None | |
def parse_conversation_string(text: str) -> list[dict]: | |
"""Parses a multi-line conversation string into a list of message dictionaries.""" | |
messages = [] | |
# Regex to capture "User:" or "Assistant:" at the start of a line, followed by content | |
pattern = re.compile(r"^(User|Assistant):\s*(.*)$", re.IGNORECASE | re.MULTILINE) | |
matches = pattern.finditer(text) | |
for match in matches: | |
role = match.group(1).lower() | |
content = match.group(2).strip() | |
messages.append({"role": role, "content": content}) | |
# If parsing fails or format is unexpected, return raw text in a single message? | |
# Or return empty list? Let's return what we found. | |
if not messages and text: # If regex found nothing but text exists | |
print(f"Warning: Could not parse conversation structure for: '{text[:100]}...'") | |
# Fallback: return the whole text as a single assistant message? Or user? | |
# Let's return a generic system message indicating the raw content | |
# return [{"role": "system", "content": f"Unparsed conversation text: {text}"}] | |
# Or maybe just return empty, TBD based on preference | |
pass # Return empty list if parsing fails for now | |
return messages | |
# Wrapper for text generation (remains largely the same, but error handling is improved in synthgen) | |
def run_generation(prompt: str, model: str, num_samples: int) -> str: | |
""" | |
Wrapper function for Gradio interface to generate multiple text samples. | |
Relies on generate_synthetic_text for API calls and error handling. | |
""" | |
if not prompt: | |
return "Error: Please enter a prompt." | |
if num_samples <= 0: | |
return "Error: Number of samples must be positive." | |
output = f"Generating {num_samples} samples using model '{model}'...\n" | |
output += "="*20 + "\n\n" | |
# generate_synthetic_text now handles API errors internally | |
for i in range(num_samples): | |
# The function returns the text or an error string starting with "Error:" | |
generated_text = generate_synthetic_text(prompt, model) | |
output += f"--- Sample {i+1} ---\n" | |
output += generated_text + "\n\n" # Append result directly | |
output += "="*20 + "\nGeneration complete (check results above for errors)." | |
return output | |
# Removed the placeholder backend functions (generate_prompts_backend, generate_single_conversation) | |
# Modified function to handle multiple conversation prompts using the real backend | |
def run_conversation_generation(system_prompts_text: str, model: str, num_turns: int) -> str: | |
""" | |
Wrapper function for Gradio interface to generate multiple conversations | |
based on a list of prompts, calling generate_synthetic_conversation. | |
""" | |
if not system_prompts_text: | |
return "Error: Please enter or generate at least one system prompt/topic." | |
if num_turns <= 0: | |
return "Error: Number of turns must be positive." | |
prompts = [p.strip() for p in system_prompts_text.strip().split('\n') if p.strip()] | |
if not prompts: | |
return "Error: No valid prompts found in the input." | |
output = f"Generating {len(prompts)} conversations ({num_turns} turns each) using model '{model}'...\n" | |
output += "="*40 + "\n\n" | |
for i, prompt in enumerate(prompts): | |
# Call the actual function from synthgen.py | |
# It handles API calls and returns the conversation or an error string. | |
conversation_text = generate_synthetic_conversation(prompt, model, num_turns) | |
# We don't need a try-except here because the function itself returns error strings | |
# The title is now included within the returned string from the function | |
output += f"--- Conversation {i+1}/{len(prompts)} ---\n" | |
output += conversation_text + "\n\n" # Append result directly | |
output += "="*40 + "\nGeneration complete (check results above for errors)." | |
return output | |
# Helper function for the Gradio UI to generate prompts using the real backend | |
def generate_prompts_ui( | |
num_prompts: int, | |
model: str, | |
temperature: float, # Add settings | |
top_p: float, | |
max_tokens: int | |
) -> str: | |
"""UI Wrapper to call the generate_prompts backend and format for Textbox.""" | |
# Handle optional settings | |
temp_val = temperature if temperature > 0 else None | |
top_p_val = top_p if 0 < top_p <= 1 else None | |
# Use a specific max_tokens for prompt generation or pass from UI? Let's pass from UI | |
max_tokens_val = max_tokens if max_tokens > 0 else 200 # Set a default if UI value is 0 | |
if not model: | |
return "Error: Please select a model for prompt generation." | |
if num_prompts <= 0: | |
return "Error: Number of prompts to generate must be positive." | |
if num_prompts > 50: | |
return "Error: Cannot generate more than 50 prompts at a time." | |
print(f"Generating prompts with settings: Temp={temp_val}, Top-P={top_p_val}, MaxTokens={max_tokens_val}") # Debug print | |
try: | |
# Call the actual function from synthgen.py, passing settings | |
prompts_list = generate_prompts( | |
num_prompts, | |
model, | |
temperature=temp_val, | |
top_p=top_p_val, | |
max_tokens=max_tokens_val | |
) | |
return "\n".join(prompts_list) | |
except ValueError as e: | |
# Catch errors raised by generate_prompts (e.g., API errors, parsing errors) | |
return f"Error generating prompts: {e}" | |
except Exception as e: | |
# Catch any other unexpected errors | |
print(f"Unexpected error in generate_prompts_ui: {e}") | |
return f"An unexpected error occurred: {e}" | |
# --- Modified Generation Wrappers --- | |
# Wrapper for text generation + JSON preparation - RETURNS TUPLE | |
def run_generation_and_prepare_json( | |
prompt: str, | |
model: str, | |
num_samples: int, | |
temperature: float, | |
top_p: float, | |
max_tokens: int | |
) -> Tuple[gr.update, gr.update]: # Return type hint (optional) | |
"""Generates text samples and prepares a JSON file for download.""" | |
# Handle optional settings | |
temp_val = temperature if temperature > 0 else None | |
top_p_val = top_p if 0 < top_p <= 1 else None | |
max_tokens_val = max_tokens if max_tokens > 0 else None | |
# Handle errors by returning updates for both outputs in a tuple | |
if not prompt: | |
return (gr.update(value="Error: Please enter a prompt."), gr.update(value=None)) | |
if num_samples <= 0: | |
return (gr.update(value="Error: Number of samples must be positive."), gr.update(value=None)) | |
output_str = f"Generating {num_samples} samples using model '{model}'...\n" | |
output_str += f"(Settings: Temp={temp_val}, Top-P={top_p_val}, MaxTokens={max_tokens_val})\n" | |
output_str += "="*20 + "\n\n" | |
results_list = [] | |
for i in range(num_samples): | |
generated_text = generate_synthetic_text( | |
prompt, model, temperature=temp_val, top_p=top_p_val, max_tokens=max_tokens_val | |
) | |
output_str += f"--- Sample {i+1} ---\n" | |
output_str += generated_text + "\n\n" | |
if not generated_text.startswith("Error:"): | |
results_list.append(generated_text) | |
output_str += "="*20 + "\nGeneration complete (check results above for errors)." | |
json_filepath = create_json_file(results_list, "text_samples.json") | |
# Return tuple of updates in the order of outputs list | |
return (gr.update(value=output_str), gr.update(value=json_filepath)) | |
# Wrapper for conversation generation + JSON preparation - RETURNS TUPLE | |
def run_conversation_generation_and_prepare_json( | |
system_prompts_text: str, | |
model: str, | |
num_turns: int, | |
temperature: float, | |
top_p: float, | |
max_tokens: int | |
) -> Tuple[gr.update, gr.update]: # Return type hint (optional) | |
"""Generates conversations and prepares a JSON file for download.""" | |
temp_val = temperature if temperature > 0 else None | |
top_p_val = top_p if 0 < top_p <= 1 else None | |
max_tokens_val = max_tokens if max_tokens > 0 else None | |
# Handle errors by returning updates for both outputs in a tuple | |
if not system_prompts_text: | |
return (gr.update(value="Error: Please enter or generate at least one system prompt/topic."), gr.update(value=None)) | |
if num_turns <= 0: | |
return (gr.update(value="Error: Number of turns must be positive."), gr.update(value=None)) | |
prompts = [p.strip() for p in system_prompts_text.strip().split('\n') if p.strip()] | |
if not prompts: | |
return (gr.update(value="Error: No valid prompts found in the input."), gr.update(value=None)) | |
output_str = f"Generating {len(prompts)} conversations ({num_turns} turns each) using model '{model}'...\n" | |
output_str += f"(Settings: Temp={temp_val}, Top-P={top_p_val}, MaxTokens={max_tokens_val})\n" | |
output_str += "="*40 + "\n\n" | |
results_list_structured = [] | |
for i, prompt in enumerate(prompts): | |
conversation_text = generate_synthetic_conversation( | |
prompt, model, num_turns, temperature=temp_val, top_p=top_p_val, max_tokens=max_tokens_val | |
) | |
output_str += f"--- Conversation {i+1}/{len(prompts)} ---\n" | |
output_str += conversation_text + "\n\n" | |
# --- Parsing Logic --- | |
core_conversation_text = conversation_text | |
if conversation_text.startswith("Error:"): core_conversation_text = None | |
elif "\n\n" in conversation_text: | |
parts = conversation_text.split("\n\n", 1) | |
core_conversation_text = parts[1] if len(parts) > 1 else conversation_text | |
if core_conversation_text: | |
messages = parse_conversation_string(core_conversation_text) | |
if messages: results_list_structured.append({"prompt": prompt, "messages": messages}) | |
else: results_list_structured.append({"prompt": prompt, "error": "Failed to parse structure.", "raw_text": core_conversation_text}) | |
elif conversation_text.startswith("Error:"): results_list_structured.append({"prompt": prompt, "error": conversation_text}) | |
else: results_list_structured.append({"prompt": prompt, "error": "Could not extract content.", "raw_text": conversation_text}) | |
# --- End Parsing Logic --- | |
output_str += "="*40 + "\nGeneration complete (check results above for errors)." | |
json_filepath = create_json_file(results_list_structured, "conversations.json") | |
# Return tuple of updates in the order of outputs list | |
return (gr.update(value=output_str), gr.update(value=json_filepath)) | |
# Define content_type_labels globally for use in UI and wrapper functions | |
content_type_labels = { | |
"Corpus Snippets": "# Snippets", | |
"Short Story": "Approx Words", | |
"Article": "Approx Words" | |
} | |
content_type_defaults = { | |
"Corpus Snippets": 5, | |
"Short Story": 1000, # Match new backend default | |
"Article": 1500 # Match new backend default | |
} | |
# Wrapper for Corpus/Content Generation | |
def run_corpus_generation_and_prepare_file( | |
topic: str, | |
content_type: str, | |
length_param: int, | |
model: str, | |
temperature: float, | |
top_p: float, | |
max_tokens: int | |
) -> Tuple[gr.update, gr.update]: | |
"""Generates corpus/story/article content and prepares a file for download.""" | |
temp_val = temperature if temperature > 0 else None | |
top_p_val = top_p if 0 < top_p <= 1 else None | |
max_tokens_val = max_tokens if max_tokens > 0 else None | |
# Use the global dictionary for error messages | |
label_for_error = content_type_labels.get(content_type, 'Length Param') | |
if not topic: return (gr.update(value="Error: Please enter a topic."), gr.update(value=None)) | |
if not content_type: return (gr.update(value="Error: Please select a content type."), gr.update(value=None)) | |
if length_param <= 0: return (gr.update(value=f"Error: Please enter a positive value for '{label_for_error}'."), gr.update(value=None)) | |
print(f"Generating {content_type} about '{topic}'...") | |
output_str = f"Generating {content_type} about '{topic}' using model '{model}'...\n" | |
output_str += f"(Settings: Temp={temp_val}, Top-P={top_p_val}, MaxTokens={max_tokens_val})\n" + "="*40 + "\n\n" | |
generated_content = generate_corpus_content( | |
topic=topic, content_type=content_type, length_param=length_param, model=model, | |
temperature=temp_val, top_p=top_p_val, max_tokens=max_tokens_val | |
) | |
output_str += generated_content | |
file_path = None | |
if not generated_content.startswith("Error:"): | |
core_content = generated_content | |
if "\n\n" in generated_content: parts = generated_content.split("\n\n", 1); core_content = parts[1] if len(parts) > 1 else generated_content | |
if content_type == "Corpus Snippets": | |
snippets = [s.strip() for s in core_content.split('---') if s.strip()] | |
if not snippets: snippets = [s.strip() for s in core_content.split('\n\n') if s.strip()] | |
corpus_data = {"topic": topic, "snippets": snippets} | |
file_path = create_json_file(corpus_data, f"{topic}_corpus.json") | |
else: | |
file_path = create_text_file(core_content, f"{topic}_{content_type.replace(' ','_')}.txt") | |
return (gr.update(value=output_str), gr.update(value=file_path)) | |
# NEW function to update the length parameter label and default value | |
def update_length_param_ui(content_type: str) -> gr.update: | |
"""Updates the label and default value of the length parameter input.""" | |
new_label = content_type_labels.get(content_type, "Length Param") | |
new_value = content_type_defaults.get(content_type, 5) # Default to 5 if type unknown | |
return gr.update(label=new_label, value=new_value) | |
# --- Generation Wrappers --- | |
# ... (generate_prompts_ui, run_generation_and_prepare_json, run_conversation_generation_and_prepare_json remain the same) ... | |
# NEW UI Wrapper for generating TOPICS | |
def generate_topics_ui( | |
num_topics: int, | |
model: str, | |
temperature: float, | |
top_p: float, | |
max_tokens: int | |
) -> str: | |
"""UI Wrapper to generate diverse topics using the AI.""" | |
temp_val = temperature if temperature > 0 else None | |
top_p_val = top_p if 0 < top_p <= 1 else None | |
max_tokens_val = max_tokens if max_tokens > 0 else 150 # Limit token for topic list | |
if not model: | |
return "Error: Please select a model for topic generation." | |
if num_topics <= 0: | |
return "Error: Number of topics to generate must be positive." | |
if num_topics > 50: # Keep limit reasonable | |
return "Error: Cannot generate more than 50 topics at a time." | |
print(f"Generating {num_topics} topics with settings: Temp={temp_val}, Top-P={top_p_val}, MaxTokens={max_tokens_val}") | |
# Instruction focused on generating topics | |
instruction = ( | |
f"Generate exactly {num_topics} diverse and interesting topics suitable for generating content like articles, stories, or corpus snippets. " | |
f"Each topic should be concise (a few words to a short phrase). " | |
f"Present each topic on a new line, with no other introductory or concluding text or numbering." | |
f"\n\nExamples:\n" | |
f"The future of renewable energy\n" | |
f"The history of the Silk Road\n" | |
f"The impact of social media on mental health" | |
) | |
system_msg = "You are an expert topic generator. Follow the user's instructions precisely." | |
try: | |
# Use the core text generation function | |
generated_text = generate_synthetic_text( | |
instruction, | |
model, | |
system_message=system_msg, | |
temperature=temp_val, | |
top_p=top_p_val, | |
max_tokens=max_tokens_val | |
) | |
if generated_text.startswith("Error:"): | |
raise ValueError(generated_text) # Propagate error | |
# Split into lines and clean up | |
topics_list = [t.strip() for t in generated_text.strip().split('\n') if t.strip()] | |
if not topics_list: | |
print(f"Warning: Failed to parse topics from generated text. Raw text:\n{generated_text}") | |
raise ValueError("AI failed to generate topics in the expected format.") | |
# Return newline-separated string for the Textbox | |
return "\n".join(topics_list[:num_topics]) # Truncate if needed | |
except ValueError as e: | |
return f"Error generating topics: {e}" | |
except Exception as e: | |
print(f"Unexpected error in generate_topics_ui: {e}") | |
return f"An unexpected error occurred: {e}" | |
# Modified Wrapper for Bulk Corpus/Content Generation | |
def run_bulk_content_generation_and_prepare_json( | |
topics_text: str, # Renamed from topic | |
content_type: str, | |
length_param: int, | |
model: str, | |
temperature: float, | |
top_p: float, | |
max_tokens: int | |
) -> Tuple[gr.update, gr.update]: | |
"""Generates content for multiple topics and prepares a JSON file.""" | |
temp_val = temperature if temperature > 0 else None | |
top_p_val = top_p if 0 < top_p <= 1 else None | |
max_tokens_val = max_tokens if max_tokens > 0 else None | |
# --- Input Validation --- | |
if not topics_text: | |
return (gr.update(value="Error: Please enter or generate at least one topic."), gr.update(value=None)) | |
if not content_type: | |
return (gr.update(value="Error: Please select a content type."), gr.update(value=None)) | |
topics = [t.strip() for t in topics_text.strip().split('\n') if t.strip()] | |
if not topics: | |
return (gr.update(value="Error: No valid topics found in the input."), gr.update(value=None)) | |
label_for_error = content_type_labels.get(content_type, 'Length Param') | |
if length_param <= 0: | |
return (gr.update(value=f"Error: Please enter a positive value for '{label_for_error}'."), gr.update(value=None)) | |
# --- End Validation --- | |
output_str = f"Generating {content_type} for {len(topics)} topics using model '{model}'...\n" | |
output_str += f"(Settings: Temp={temp_val}, Top-P={top_p_val}, MaxTokens={max_tokens_val})\n" + "="*40 + "\n\n" | |
bulk_results = [] # Store results for JSON | |
# --- Loop through topics --- | |
for i, topic in enumerate(topics): | |
print(f"Generating {content_type} for topic {i+1}/{len(topics)}: '{topic}'...") | |
output_str += f"--- Topic {i+1}/{len(topics)}: '{topic}' ---\n" | |
generated_content_full = generate_corpus_content( # Returns string including title/error | |
topic=topic, content_type=content_type, length_param=length_param, model=model, | |
temperature=temp_val, top_p=top_p_val, max_tokens=max_tokens_val | |
) | |
output_str += generated_content_full + "\n\n" # Add full result to textbox | |
# --- Prepare structured result for JSON --- | |
result_entry = {"topic": topic, "content_type": content_type} | |
if generated_content_full.startswith("Error:"): | |
result_entry["status"] = "error" | |
result_entry["error_message"] = generated_content_full | |
result_entry["content"] = None | |
else: | |
result_entry["status"] = "success" | |
result_entry["error_message"] = None | |
# Extract core content (remove potential title added by backend) | |
core_content = generated_content_full | |
if "\n\n" in generated_content_full: | |
parts = generated_content_full.split("\n\n", 1) | |
core_content = parts[1] if len(parts) > 1 else generated_content_full | |
if content_type == "Corpus Snippets": | |
snippets = [s.strip() for s in core_content.split('---') if s.strip()] | |
if not snippets: snippets = [s.strip() for s in core_content.split('\n\n') if s.strip()] | |
result_entry["content"] = snippets # Store list for corpus | |
else: | |
result_entry["content"] = core_content # Store string for story/article | |
bulk_results.append(result_entry) | |
# --- End JSON preparation --- | |
# --- Finalize --- | |
output_str += "="*40 + f"\nBulk generation complete for {len(topics)} topics." | |
json_filepath = create_json_file(bulk_results, f"{content_type.replace(' ','_')}_bulk_results.json") | |
return (gr.update(value=output_str), gr.update(value=json_filepath)) | |
# --- Gradio Interface Definition --- | |
with gr.Blocks() as demo: | |
gr.Markdown("# Synthetic Data Generator using OpenRouter") | |
gr.Markdown( | |
"Generate synthetic text samples, conversations, or other content using various models" | |
) | |
# Removed the api_key_loaded check and warning Markdown | |
# Define model choices (can be shared or specific per tab) | |
# Consider fetching these dynamically from OpenRouter if possible in the future | |
model_choices = [ | |
"deepseek/deepseek-chat-v3-0324:free", # Example free model | |
"meta-llama/llama-3.3-70b-instruct:free", | |
"deepseek/deepseek-r1:free", | |
"google/gemini-2.5-pro-exp-03-25:free", | |
"qwen/qwen-2.5-72b-instruct:free", | |
"featherless/qwerky-72b:free", | |
"google/gemma-3-27b-it:free", | |
"mistralai/mistral-small-24b-instruct-2501:free", | |
"deepseek/deepseek-r1-distill-llama-70b:free", | |
"sophosympatheia/rogue-rose-103b-v0.2:free", | |
"nvidia/llama-3.1-nemotron-70b-instruct:free", | |
"microsoft/phi-3-medium-128k-instruct:free", | |
"undi95/toppy-m-7b:free", | |
"huggingfaceh4/zephyr-7b-beta:free", | |
"openrouter/quasar-alpha" | |
# Add more model IDs as needed | |
] | |
default_model = model_choices[0] if model_choices else None | |
# --- Shared Model Settings --- | |
# Use an Accordion for less clutter | |
with gr.Accordion("Model Settings (Optional)", open=False): | |
# Set reasonable ranges and defaults. Use 0 for Max Tokens/Top-P to signify 'None'/API default. | |
temperature_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature", info="Controls randomness. Higher values are more creative, lower are more deterministic. 0 means use API default.") | |
top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.05, label="Top-P (Nucleus Sampling)", info="Considers only tokens with cumulative probability mass >= top_p. 0 means use API default.") | |
max_tokens_slider = gr.Number(value=0, minimum=0, maximum=8192, step=64, label="Max Tokens", info="Maximum number of tokens to generate in the completion. 0 means use API default.") | |
with gr.Tabs(): | |
with gr.TabItem("Text Generation"): | |
with gr.Row(): | |
prompt_input_text = gr.Textbox(label="Prompt", placeholder="Enter your prompt here (e.g., Generate a short product description for a sci-fi gadget)", lines=3) | |
with gr.Row(): | |
model_input_text = gr.Dropdown( | |
label="OpenRouter Model ID", | |
choices=model_choices, | |
value=default_model | |
) | |
num_samples_input_text = gr.Number(label="Number of Samples", value=3, minimum=1, maximum=20, step=1) | |
generate_button_text = gr.Button("Generate Text Samples") | |
output_text = gr.Textbox(label="Generated Samples", lines=15, show_copy_button=True) | |
# Add File component for download | |
download_file_text = gr.File(label="Download Samples as JSON") | |
generate_button_text.click( | |
fn=run_generation_and_prepare_json, | |
inputs=[ | |
prompt_input_text, model_input_text, num_samples_input_text, | |
temperature_slider, top_p_slider, max_tokens_slider # Add settings inputs | |
], | |
outputs=[output_text, download_file_text] | |
) | |
with gr.TabItem("Conversation Generation"): | |
gr.Markdown("Enter one system prompt/topic per line below, or use the 'Generate Prompts' button.") | |
with gr.Row(): | |
# Textbox for multiple prompts | |
prompt_input_conv = gr.Textbox( | |
label="Prompts (one per line)", | |
lines=5, # Make it multi-line | |
placeholder="Enter prompts here, one per line...\ne.g., Act as a pirate discussing treasure maps.\nDiscuss the future of space travel." | |
) | |
with gr.Row(): | |
# Input for number of prompts to generate | |
num_prompts_input_conv = gr.Number(label="Number of Prompts to Generate", value=5, minimum=1, maximum=20, step=1) # Keep max reasonable | |
# Button to trigger AI prompt generation | |
generate_prompts_button = gr.Button("Generate Prompts using AI") | |
with gr.Row(): | |
# Model selection for conversation generation AND prompt generation | |
model_input_conv = gr.Dropdown( | |
label="OpenRouter Model ID (for generation)", | |
choices=model_choices, | |
value=default_model | |
) | |
with gr.Row(): | |
# Input for number of turns per conversation | |
num_turns_input_conv = gr.Number(label="Number of Turns per Conversation (approx)", value=5, minimum=1, maximum=20, step=1) # Keep max reasonable | |
# Button to generate the conversations based on the prompts in the Textbox | |
generate_conversations_button = gr.Button("Generate Conversations") | |
output_conv = gr.Textbox(label="Generated Conversations", lines=15, show_copy_button=True) | |
# Add File component for download | |
download_file_conv = gr.File(label="Download Conversations as JSON") | |
# Connect the "Generate Prompts" button to the UI wrapper | |
generate_prompts_button.click( | |
fn=generate_prompts_ui, # Use the wrapper that calls the real function | |
inputs=[ | |
num_prompts_input_conv, model_input_conv, | |
temperature_slider, top_p_slider, max_tokens_slider # Add settings inputs | |
], | |
outputs=prompt_input_conv | |
) | |
# Connect the "Generate Conversations" button to the real function wrapper | |
generate_conversations_button.click( | |
fn=run_conversation_generation_and_prepare_json, # Use the wrapper that calls the real function | |
inputs=[ | |
prompt_input_conv, model_input_conv, num_turns_input_conv, | |
temperature_slider, top_p_slider, max_tokens_slider # Add settings inputs | |
], | |
outputs=[output_conv, download_file_conv] # Output to both Textbox and File | |
) | |
# --- Content Generation Tab (Modified for Bulk) --- | |
with gr.TabItem("Bulk Content Generation"): | |
output_content = gr.Textbox(label="Generated Content (Log)", lines=15, show_copy_button=True) | |
# Output is now always JSON | |
download_file_content = gr.File(label="Download Results as JSON") | |
gr.Markdown("Enter one topic per line below, or use the 'Generate Topics' button.") | |
with gr.Row(): | |
# Changed to multi-line Textbox | |
topic_input_content = gr.Textbox( | |
label="Topics (one per line)", | |
lines=5, | |
placeholder="Enter topics here, one per line...\ne.g., The future of renewable energy\nThe history of the Silk Road" | |
) | |
# --- Topic Generation --- | |
with gr.Accordion("Topic Generation Options", open=False): | |
with gr.Row(): | |
num_topics_input = gr.Number(label="# Topics to Generate", value=5, minimum=1, maximum=50, step=1) | |
# Use shared model selector below and settings | |
generate_topics_button = gr.Button("Generate Topics using AI") | |
# --- Generation Settings --- | |
with gr.Row(): | |
content_type_choices = list(content_type_labels.keys()) | |
content_type_input = gr.Dropdown( | |
label="Content Type", choices=content_type_choices, value=content_type_choices[0] | |
) | |
default_length_label = content_type_labels[content_type_choices[0]] | |
default_length_value = content_type_defaults[content_type_choices[0]] | |
length_param_input = gr.Number( | |
label=default_length_label, value=default_length_value, minimum=1, step=1 | |
) | |
with gr.Row(): | |
model_input_content = gr.Dropdown(label="Model", choices=model_choices, value=default_model) | |
# Button to trigger bulk generation | |
generate_content_button = gr.Button("Generate Bulk Content") | |
# --- Event Listeners --- | |
# Listener to update length param UI | |
content_type_input.change( | |
fn=update_length_param_ui, | |
inputs=content_type_input, | |
outputs=length_param_input | |
) | |
# Listener for topic generation button | |
generate_topics_button.click( | |
fn=generate_topics_ui, | |
inputs=[ # Pass necessary inputs for topic generation | |
num_topics_input, model_input_content, # Use this tab's model selector | |
temperature_slider, top_p_slider, max_tokens_slider | |
], | |
outputs=topic_input_content # Output generated topics to the textbox | |
) | |
# Listener for main generation button | |
generate_content_button.click( | |
fn=run_bulk_content_generation_and_prepare_json, # Use the new bulk wrapper | |
inputs=[ | |
topic_input_content, content_type_input, length_param_input, | |
model_input_content, | |
temperature_slider, top_p_slider, max_tokens_slider | |
], | |
outputs=[output_content, download_file_content] | |
) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
print("Launching Gradio App...") | |
print("Make sure the OPENROUTER_API_KEY environment variable is set.") | |
# Use share=True for temporary public link if running locally and need to test | |
demo.launch() # share=True | |