ai-assistant / app.py
Altayebhssab's picture
new update
81c8402
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import requests
# Load the model and tokenizer
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5-base")
return tokenizer, model
# Initialize model and tokenizer
tokenizer, model = load_model()
# Sidecar settings
SIDECAR_URL = "http://127.0.0.1:42424"
# Page Configurations
st.set_page_config(
page_title="AI Code Assistant",
page_icon="πŸ€–",
layout="wide",
initial_sidebar_state="expanded"
)
# Apply custom CSS for modern design
def local_css(file_name):
with open(file_name) as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
# Load custom CSS file (Add your own CSS styling in 'styles.css')
local_css("styles.css")
# Header Section
st.markdown(
"""
<div style="text-align: center; padding: 20px; background-color: #1E88E5; color: white; border-radius: 8px;">
<h1>πŸ€– AI Code Assistant</h1>
<p>Your assistant for generating and optimizing code with AI.</p>
</div>
""",
unsafe_allow_html=True
)
# Sidebar Section
st.sidebar.markdown(
"""
<div style="text-align: center; margin-bottom: 20px;">
<h2>βš™οΈ Options</h2>
</div>
""",
unsafe_allow_html=True
)
section = st.sidebar.radio(
"Choose a Section",
("Generate Code", "Train Model", "Prompt Engineer", "Optimize Model", "Sidecar Integration")
)
# Main Content Section
if section == "Generate Code":
st.markdown("<h2 style='text-align: center;'>πŸ“ Generate Code from Description</h2>", unsafe_allow_html=True)
st.write("Provide a description, and the AI will generate the corresponding Python code.")
prompt = st.text_area(
"Enter your description:",
"Write a Python function to reverse a string.",
placeholder="Enter a detailed code description...",
height=150
)
if st.button("πŸš€ Generate Code"):
with st.spinner("Generating code..."):
try:
response = requests.post(f"{SIDECAR_URL}/generate", json={"prompt": prompt})
if response.status_code == 200:
code = response.json().get("code", "No response from Sidecar.")
else:
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(inputs["input_ids"], max_length=100)
code = tokenizer.decode(outputs[0], skip_special_tokens=True)
st.code(code, language="python")
except Exception as e:
st.error(f"Error: {e}")
elif section == "Train Model":
st.markdown("<h2 style='text-align: center;'>πŸ“š Train the Model</h2>", unsafe_allow_html=True)
st.write("Upload your dataset to fine-tune the AI model.")
uploaded_file = st.file_uploader("Upload Dataset (JSON/CSV):")
if uploaded_file:
st.success("Dataset uploaded successfully!")
if st.button("Start Training"):
with st.spinner("Training in progress..."):
st.success("Model training completed successfully!")
elif section == "Prompt Engineer":
st.markdown("<h2 style='text-align: center;'>βš™οΈ Prompt Engineering</h2>", unsafe_allow_html=True)
st.write("Experiment with different prompts to improve code generation.")
prompt = st.text_area("Enter your prompt:", "Explain the following code: def add(a, b): return a + b")
if st.button("Test Prompt"):
with st.spinner("Testing prompt..."):
try:
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(inputs["input_ids"], max_length=100)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
st.write("**Model Output:**")
st.code(response, language="text")
except Exception as e:
st.error(f"Error: {e}")
elif section == "Optimize Model":
st.markdown("<h2 style='text-align: center;'>πŸš€ Optimize Model Performance</h2>", unsafe_allow_html=True)
st.write("Adjust model parameters for improved performance.")
lr = st.slider("Learning Rate:", 1e-5, 1e-3, value=1e-4, step=1e-5)
batch_size = st.slider("Batch Size:", 1, 64, value=16)
epochs = st.slider("Number of Epochs:", 1, 10, value=3)
if st.button("Apply Optimization Settings"):
st.success(f"Settings applied: LR={lr}, Batch Size={batch_size}, Epochs={epochs}")
elif section == "Sidecar Integration":
st.markdown("<h2 style='text-align: center;'>πŸ”— Sidecar Integration</h2>", unsafe_allow_html=True)
st.write("Test the Sidecar server connection.")
if st.button("Ping Sidecar"):
try:
response = requests.get(f"{SIDECAR_URL}/ping")
if response.status_code == 200:
st.success("Sidecar server is running!")
else:
st.error("Sidecar is not responding.")
except Exception as e:
st.error(f"Error: {e}")