File size: 4,598 Bytes
0a8cafa
297f353
74f37a5
 
 
7841db2
74f37a5
 
 
 
 
 
7841db2
10dfcb1
 
74f37a5
10dfcb1
 
7841db2
602d4aa
74f37a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7841db2
74f37a5
 
 
 
 
 
 
 
 
0a8cafa
4fc9e70
7841db2
f3d87e2
 
10dfcb1
 
9038518
 
10dfcb1
f3d87e2
10dfcb1
ae21d92
10dfcb1
 
 
 
 
 
 
 
 
 
 
 
 
 
4b78c6c
 
 
74f37a5
 
4b78c6c
10dfcb1
ae21d92
4b78c6c
74f37a5
10dfcb1
4b78c6c
10dfcb1
4b78c6c
10dfcb1
74f37a5
 
 
 
fd96719
74f37a5
10dfcb1
74f37a5
 
5c30376
 
10dfcb1
 
4b78c6c
74f37a5
 
10dfcb1
 
f3d87e2
0a8cafa
 
7841db2
5c30376
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr
import spaces
from transformers import pipeline
import torch
import logging

# Configure logging/logger
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Predefined list of models to compare (can be expanded)
model_options = {
    "Foundation-Sec-8B": pipeline("text-generation", model="fdtn-ai/Foundation-Sec-8B"),
}

# Define the response function
@spaces.GPU
def generate_text_local(model_pipeline, prompt):
    """Local text generation"""
    try:
        logger.info(f"Running local text generation with {model_pipeline.path}")
        
        # Move model to GPU (entire pipeline)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model_pipeline.model = model_pipeline.model.to(device)
        
        # Set other pipeline components to use GPU
        if hasattr(model_pipeline, "device"):
            model_pipeline.device = device
        
        # Record device information
        device_info = next(model_pipeline.model.parameters()).device
        logger.info(f"Model {model_pipeline.path} is running on device: {device_info}")
        
        outputs = model_pipeline(
            prompt,
            max_new_tokens=3,      # = model.generate(max_new_tokens=3, …)
            do_sample=True,
            temperature=0.1,
            top_p=0.9,
            clean_up_tokenization_spaces=True,  # echo 部分を整形
        )

        # Move model back to CPU
        model_pipeline.model = model_pipeline.model.to("cpu")
        if hasattr(model_pipeline, "device"):
            model_pipeline.device = torch.device("cpu")
        
        return outputs[0]["generated_text"].replace(prompt, "").strip()
    except Exception as e:
        logger.error(f"Error in local text generation with {model_pipeline.path}: {str(e)}")
        return f"Error: {str(e)}"

# Build Gradio app
def create_demo():
    with gr.Blocks() as demo:
        gr.Markdown("# AI Model Comparison Tool 🌟")
        gr.Markdown(
            """
            Compare responses from two AI models side-by-side.  
            Select two models, ask a question, and compare their responses in real time!
            """
        )

        # Input Section
        with gr.Row():
            system_message = gr.Textbox(
                value="You are a helpful assistant providing answers for technical and customer support queries.",
                label="System message"
            )
            user_message = gr.Textbox(label="Your question", placeholder="Type your question here...")

        with gr.Row():
            max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens")
            temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
            top_p = gr.Slider(
                minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
            )

        # Model Selection Section
        selected_models = gr.CheckboxGroup(
            choices=list(model_options.keys()),
            label="Select exactly two model to compare",
            value=["Foundation-Sec-8B"],  # Default models
        )

        # Dynamic Response Section
        response_box1 = gr.Textbox(label="Response from Model 1", interactive=False)
        #response_box2 = gr.Textbox(label="Response from Model 2", interactive=False)

        # Function to generate responses
        def generate_responses(
            message, system_message, max_tokens, temperature, top_p, selected_models
        ):
            #if len(selected_models) != 2:
            #    return "Error: Please select exactly two models to compare.", ""
            responses = generate_text_local(
                #message, [], system_message, max_tokens, temperature, top_p, selected_models
                model_options[selected_models[0]],
                message
            )
            #return responses.get(selected_models[0], ""), responses.get(selected_models[1], "")
            return responses
        # Add a button for generating responses
        submit_button = gr.Button("Generate Responses")
        submit_button.click(
            generate_responses,
            inputs=[user_message, system_message, max_tokens, temperature, top_p, selected_models],
            #outputs=[response_box1, response_box2],  # Link to response boxes
            outputs=[response_box1]
        )

    return demo

if __name__ == "__main__":
    demo = create_demo()
    demo.launch()