Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import requests | |
# Load the model and tokenizer | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base") | |
model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5-base") | |
return tokenizer, model | |
# Initialize model and tokenizer | |
tokenizer, model = load_model() | |
# Sidecar settings | |
SIDECAR_URL = "http://127.0.0.1:42424" | |
# Page Configurations | |
st.set_page_config( | |
page_title="AI Code Assistant", | |
page_icon="π€", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Apply custom CSS for modern design | |
def local_css(file_name): | |
with open(file_name) as f: | |
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
# Load custom CSS file (Add your own CSS styling in 'styles.css') | |
local_css("styles.css") | |
# Header Section | |
st.markdown( | |
""" | |
<div style="text-align: center; padding: 20px; background-color: #1E88E5; color: white; border-radius: 8px;"> | |
<h1>π€ AI Code Assistant</h1> | |
<p>Your assistant for generating and optimizing code with AI.</p> | |
</div> | |
""", | |
unsafe_allow_html=True | |
) | |
# Sidebar Section | |
st.sidebar.markdown( | |
""" | |
<div style="text-align: center; margin-bottom: 20px;"> | |
<h2>βοΈ Options</h2> | |
</div> | |
""", | |
unsafe_allow_html=True | |
) | |
section = st.sidebar.radio( | |
"Choose a Section", | |
("Generate Code", "Train Model", "Prompt Engineer", "Optimize Model", "Sidecar Integration") | |
) | |
# Main Content Section | |
if section == "Generate Code": | |
st.markdown("<h2 style='text-align: center;'>π Generate Code from Description</h2>", unsafe_allow_html=True) | |
st.write("Provide a description, and the AI will generate the corresponding Python code.") | |
prompt = st.text_area( | |
"Enter your description:", | |
"Write a Python function to reverse a string.", | |
placeholder="Enter a detailed code description...", | |
height=150 | |
) | |
if st.button("π Generate Code"): | |
with st.spinner("Generating code..."): | |
try: | |
response = requests.post(f"{SIDECAR_URL}/generate", json={"prompt": prompt}) | |
if response.status_code == 200: | |
code = response.json().get("code", "No response from Sidecar.") | |
else: | |
inputs = tokenizer(prompt, return_tensors="pt") | |
outputs = model.generate(inputs["input_ids"], max_length=100) | |
code = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
st.code(code, language="python") | |
except Exception as e: | |
st.error(f"Error: {e}") | |
elif section == "Train Model": | |
st.markdown("<h2 style='text-align: center;'>π Train the Model</h2>", unsafe_allow_html=True) | |
st.write("Upload your dataset to fine-tune the AI model.") | |
uploaded_file = st.file_uploader("Upload Dataset (JSON/CSV):") | |
if uploaded_file: | |
st.success("Dataset uploaded successfully!") | |
if st.button("Start Training"): | |
with st.spinner("Training in progress..."): | |
st.success("Model training completed successfully!") | |
elif section == "Prompt Engineer": | |
st.markdown("<h2 style='text-align: center;'>βοΈ Prompt Engineering</h2>", unsafe_allow_html=True) | |
st.write("Experiment with different prompts to improve code generation.") | |
prompt = st.text_area("Enter your prompt:", "Explain the following code: def add(a, b): return a + b") | |
if st.button("Test Prompt"): | |
with st.spinner("Testing prompt..."): | |
try: | |
inputs = tokenizer(prompt, return_tensors="pt") | |
outputs = model.generate(inputs["input_ids"], max_length=100) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
st.write("**Model Output:**") | |
st.code(response, language="text") | |
except Exception as e: | |
st.error(f"Error: {e}") | |
elif section == "Optimize Model": | |
st.markdown("<h2 style='text-align: center;'>π Optimize Model Performance</h2>", unsafe_allow_html=True) | |
st.write("Adjust model parameters for improved performance.") | |
lr = st.slider("Learning Rate:", 1e-5, 1e-3, value=1e-4, step=1e-5) | |
batch_size = st.slider("Batch Size:", 1, 64, value=16) | |
epochs = st.slider("Number of Epochs:", 1, 10, value=3) | |
if st.button("Apply Optimization Settings"): | |
st.success(f"Settings applied: LR={lr}, Batch Size={batch_size}, Epochs={epochs}") | |
elif section == "Sidecar Integration": | |
st.markdown("<h2 style='text-align: center;'>π Sidecar Integration</h2>", unsafe_allow_html=True) | |
st.write("Test the Sidecar server connection.") | |
if st.button("Ping Sidecar"): | |
try: | |
response = requests.get(f"{SIDECAR_URL}/ping") | |
if response.status_code == 200: | |
st.success("Sidecar server is running!") | |
else: | |
st.error("Sidecar is not responding.") | |
except Exception as e: | |
st.error(f"Error: {e}") | |