nyasukun's picture
.
fd96719
raw
history blame
4.6 kB
import gradio as gr
import spaces
from transformers import pipeline
import torch
import logging
# Configure logging/logger
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Predefined list of models to compare (can be expanded)
model_options = {
"Foundation-Sec-8B": pipeline("text-generation", model="fdtn-ai/Foundation-Sec-8B"),
}
# Define the response function
@spaces.GPU
def generate_text_local(model_pipeline, prompt):
"""Local text generation"""
try:
logger.info(f"Running local text generation with {model_pipeline.path}")
# Move model to GPU (entire pipeline)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_pipeline.model = model_pipeline.model.to(device)
# Set other pipeline components to use GPU
if hasattr(model_pipeline, "device"):
model_pipeline.device = device
# Record device information
device_info = next(model_pipeline.model.parameters()).device
logger.info(f"Model {model_pipeline.path} is running on device: {device_info}")
outputs = model_pipeline(
prompt,
max_new_tokens=3, # = model.generate(max_new_tokens=3, …)
do_sample=True,
temperature=0.1,
top_p=0.9,
clean_up_tokenization_spaces=True, # echo 部分を整形
)
# Move model back to CPU
model_pipeline.model = model_pipeline.model.to("cpu")
if hasattr(model_pipeline, "device"):
model_pipeline.device = torch.device("cpu")
return outputs[0]["generated_text"].replace(prompt, "").strip()
except Exception as e:
logger.error(f"Error in local text generation with {model_pipeline.path}: {str(e)}")
return f"Error: {str(e)}"
# Build Gradio app
def create_demo():
with gr.Blocks() as demo:
gr.Markdown("# AI Model Comparison Tool 🌟")
gr.Markdown(
"""
Compare responses from two AI models side-by-side.
Select two models, ask a question, and compare their responses in real time!
"""
)
# Input Section
with gr.Row():
system_message = gr.Textbox(
value="You are a helpful assistant providing answers for technical and customer support queries.",
label="System message"
)
user_message = gr.Textbox(label="Your question", placeholder="Type your question here...")
with gr.Row():
max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens")
temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(
minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
)
# Model Selection Section
selected_models = gr.CheckboxGroup(
choices=list(model_options.keys()),
label="Select exactly two model to compare",
value=["Foundation-Sec-8B"], # Default models
)
# Dynamic Response Section
response_box1 = gr.Textbox(label="Response from Model 1", interactive=False)
#response_box2 = gr.Textbox(label="Response from Model 2", interactive=False)
# Function to generate responses
def generate_responses(
message, system_message, max_tokens, temperature, top_p, selected_models
):
#if len(selected_models) != 2:
# return "Error: Please select exactly two models to compare.", ""
responses = generate_text_local(
#message, [], system_message, max_tokens, temperature, top_p, selected_models
model_options[selected_models[0]],
message
)
#return responses.get(selected_models[0], ""), responses.get(selected_models[1], "")
return responses
# Add a button for generating responses
submit_button = gr.Button("Generate Responses")
submit_button.click(
generate_responses,
inputs=[user_message, system_message, max_tokens, temperature, top_p, selected_models],
#outputs=[response_box1, response_box2], # Link to response boxes
outputs=[response_box1]
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch()