Spaces:
Runtime error
Runtime error
File size: 2,568 Bytes
1d569c6 d5be079 1d569c6 0a023ed ccc5c3a 0a023ed ccc5c3a 0a023ed c766fcb ccc5c3a c766fcb 0a023ed c766fcb 0a023ed ccc5c3a 0a023ed c766fcb ccc5c3a 0a023ed c766fcb 0a023ed ccc5c3a c766fcb 0a023ed ccc5c3a 0a023ed d5be079 0a023ed d5be079 c766fcb d5be079 0a023ed d5be079 0a023ed ccc5c3a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
class CodeGenerator:
def __init__(self, model_name="Salesforce/codet5-base"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def generate_code(self, prompt, max_length=100):
try:
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
output = self.model.generate(input_ids, max_length=max_length, num_return_sequences=1)
return self.tokenizer.decode(output[0], skip_special_tokens=True)
except Exception as e:
return f"Error generating code: {str(e)}"
class ChatHandler:
def __init__(self, code_generator):
self.history = []
self.code_generator = code_generator
def handle_message(self, message):
if not message.strip():
return "", self.history
response = self.code_generator.generate_code(message)
self.history.append({"role": "user", "content": message})
self.history.append({"role": "assistant", "content": response})
return "", self.history
def clear_history(self):
self.history = []
return []
def create_gradio_interface():
code_generator = CodeGenerator()
chat_handler = ChatHandler(code_generator)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# S-Dreamer Salesforce/codet5-base Chat Interface")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(type="messages", height=400)
message_input = gr.Textbox(label="Enter your code-related query", placeholder="Type your message here...")
submit_button = gr.Button("Submit")
with gr.Column(scale=1):
gr.Markdown("## Features")
features = ["Code generation", "Code completion", "Code explanation", "Error correction"]
for feature in features:
gr.Markdown(f"- {feature}")
clear_button = gr.Button("Clear Chat")
submit_button.click(chat_handler.handle_message, inputs=message_input, outputs=[message_input, chatbot])
clear_button.click(lambda: (None, chat_handler.clear_history()), inputs=[], outputs=[message_input, chatbot])
demo.launch()
if __name__ == "__main__":
create_gradio_interface() |