File size: 2,568 Bytes
1d569c6
d5be079
 
1d569c6
0a023ed
ccc5c3a
0a023ed
 
ccc5c3a
 
0a023ed
 
c766fcb
ccc5c3a
c766fcb
 
 
 
0a023ed
 
c766fcb
0a023ed
ccc5c3a
0a023ed
c766fcb
 
 
 
ccc5c3a
 
0a023ed
 
c766fcb
 
 
 
0a023ed
ccc5c3a
c766fcb
0a023ed
 
 
 
 
 
ccc5c3a
0a023ed
 
d5be079
0a023ed
 
 
 
 
 
d5be079
c766fcb
 
d5be079
0a023ed
d5be079
0a023ed
ccc5c3a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

class CodeGenerator:
    def __init__(self, model_name="Salesforce/codet5-base"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model.to(self.device)

    def generate_code(self, prompt, max_length=100):
        try:
            input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
            output = self.model.generate(input_ids, max_length=max_length, num_return_sequences=1)
            return self.tokenizer.decode(output[0], skip_special_tokens=True)
        except Exception as e:
            return f"Error generating code: {str(e)}"

class ChatHandler:
    def __init__(self, code_generator):
        self.history = []
        self.code_generator = code_generator

    def handle_message(self, message):
        if not message.strip():
            return "", self.history
        response = self.code_generator.generate_code(message)
        self.history.append({"role": "user", "content": message})
        self.history.append({"role": "assistant", "content": response})
        return "", self.history

    def clear_history(self):
        self.history = []
        return []

def create_gradio_interface():
    code_generator = CodeGenerator()
    chat_handler = ChatHandler(code_generator)

    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown("# S-Dreamer Salesforce/codet5-base Chat Interface")

        with gr.Row():
            with gr.Column(scale=3):
                chatbot = gr.Chatbot(type="messages", height=400)
                message_input = gr.Textbox(label="Enter your code-related query", placeholder="Type your message here...")
                submit_button = gr.Button("Submit")

            with gr.Column(scale=1):
                gr.Markdown("## Features")
                features = ["Code generation", "Code completion", "Code explanation", "Error correction"]
                for feature in features:
                    gr.Markdown(f"- {feature}")
                clear_button = gr.Button("Clear Chat")

        submit_button.click(chat_handler.handle_message, inputs=message_input, outputs=[message_input, chatbot])
        clear_button.click(lambda: (None, chat_handler.clear_history()), inputs=[], outputs=[message_input, chatbot])

    demo.launch()

if __name__ == "__main__":
    create_gradio_interface()