Spaces:
Sleeping
Sleeping
import streamlit as st | |
import datetime | |
import tempfile | |
import black | |
from streamlit_ace import st_ace | |
from streamlit_extras.colored_header import colored_header | |
from streamlit_extras.add_vertical_space import add_vertical_space | |
import re | |
from typing import Optional, Dict, List | |
import ast | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
# Initialize model and tokenizer globally | |
def load_model_and_tokenizer(): | |
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct") | |
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct") | |
if torch.cuda.is_available(): | |
model = model.to("cuda") | |
return model, tokenizer | |
def clear_chat(): | |
"""Clear the chat history""" | |
if 'messages' in st.session_state: | |
st.session_state.messages = [] | |
if 'current_session' in st.session_state: | |
st.session_state.current_session = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
def handle_file_upload(): | |
"""Handle file upload functionality""" | |
uploaded_file = st.file_uploader( | |
"Upload a file", | |
type=["txt", "pdf", "py", "json", "csv"], | |
help="Upload a file to discuss with the AI" | |
) | |
if uploaded_file is not None: | |
file_contents = uploaded_file.read() | |
if uploaded_file.type == "application/pdf": | |
return f"Uploaded PDF: {uploaded_file.name}" | |
else: | |
try: | |
return file_contents.decode() | |
except UnicodeDecodeError: | |
return "Binary file uploaded" | |
return None | |
def generate_response(prompt: str, temperature: float, max_tokens: int, system_prompt: str) -> str: | |
"""Generate response using the Qwen model""" | |
model, tokenizer = load_model_and_tokenizer() | |
# Format the input with system prompt | |
full_prompt = f"System: {system_prompt}\n\nUser: {prompt}\n\nAssistant:" | |
try: | |
inputs = tokenizer(full_prompt, return_tensors="pt", padding=True) | |
if torch.cuda.is_available(): | |
inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
# Generate response | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
do_sample=True, | |
pad_token_id=tokenizer.pad_token_id | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract assistant's response | |
response = response.split("Assistant:")[-1].strip() | |
return response | |
except Exception as e: | |
st.error(f"Error generating response: {str(e)}") | |
return f"Error: {str(e)}" | |
class CodeAnalyzer: | |
def extract_code_blocks(text: str) -> List[str]: | |
"""Extract code blocks from markdown text""" | |
code_blocks = re.findall(r'```(?:python)?\n(.*?)\n```', text, re.DOTALL) | |
return code_blocks | |
def is_code_complete(code: str) -> bool: | |
"""Check if the code block is syntactically complete""" | |
try: | |
ast.parse(code) | |
return True | |
except SyntaxError: | |
return False | |
def get_context(code: str) -> Dict: | |
"""Analyze code to extract context (variables, functions, classes)""" | |
context = { | |
'variables': [], | |
'functions': [], | |
'classes': [] | |
} | |
try: | |
tree = ast.parse(code) | |
for node in ast.walk(tree): | |
if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Store): | |
context['variables'].append(node.id) | |
elif isinstance(node, ast.FunctionDef): | |
context['functions'].append(node.name) | |
elif isinstance(node, ast.ClassDef): | |
context['classes'].append(node.name) | |
except: | |
pass | |
return context | |
class CodeCompletion: | |
def __init__(self): | |
pass | |
def get_completion_suggestions(self, code: str, context: Dict) -> str: | |
"""Generate code completion suggestions based on context""" | |
prompt = f"""Given the following code context: | |
Code: | |
{code} | |
Context: | |
Variables: {', '.join(context['variables'])} | |
Functions: {', '.join(context['functions'])} | |
Classes: {', '.join(context['classes'])} | |
Please complete or continue this code in a natural way.""" | |
return generate_response(prompt, 0.3, 500, "You are a Python coding assistant. Provide only code completion, no explanations.") | |
def handle_code_continuation(incomplete_code: str) -> str: | |
"""Handle continuation of incomplete code""" | |
prompt = f"""Complete the following Python code: | |
{incomplete_code} | |
Provide only the completion part that would make this code syntactically complete and logical.""" | |
return generate_response(prompt, 0.3, 500, "You are a Python coding assistant. Complete the code naturally.") | |
def format_code(code: str) -> str: | |
"""Format Python code using black""" | |
try: | |
return black.format_str(code, mode=black.FileMode()) | |
except: | |
return code | |
def init_session_state(): | |
"""Initialize session state variables""" | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "sessions" not in st.session_state: | |
st.session_state.sessions = {} | |
if "current_session" not in st.session_state: | |
st.session_state.current_session = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
if "system_prompt" not in st.session_state: | |
st.session_state.system_prompt = "You are a helpful AI assistant." | |
if "saved_code_snippets" not in st.session_state: | |
st.session_state.saved_code_snippets = [] | |
if "code_context" not in st.session_state: | |
st.session_state.code_context = {} | |
if "current_code_block" not in st.session_state: | |
st.session_state.current_code_block = None | |
if "code_history" not in st.session_state: | |
st.session_state.code_history = [] | |
if "last_code_state" not in st.session_state: | |
st.session_state.last_code_state = None | |
def setup_page_config(): | |
"""Setup page configuration and styling""" | |
st.set_page_config( | |
page_title="Qwen Coder Chat", | |
page_icon="🤖", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
st.markdown(""" | |
<style> | |
.main { | |
max-width: 1200px; | |
margin: 0 auto; | |
padding: 2rem; | |
} | |
.stChatMessage { | |
background-color: #ffffff; | |
border-radius: 8px; | |
padding: 1rem; | |
margin: 0.5rem 0; | |
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); | |
} | |
.stChatInputContainer { | |
border-radius: 8px; | |
border: 1px solid #e0e0e0; | |
padding: 0.5rem; | |
background-color: #ffffff; | |
} | |
.code-editor { | |
border-radius: 8px; | |
margin: 1rem 0; | |
border: 1px solid #e0e0e0; | |
} | |
.code-snippet { | |
background-color: #f8fafc; | |
padding: 1rem; | |
border-radius: 8px; | |
margin: 0.5rem 0; | |
} | |
.completion-suggestion { | |
background-color: #f1f5f9; | |
padding: 0.5rem; | |
border-left: 3px solid #0284c7; | |
margin: 0.25rem 0; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
def code_editor_section(): | |
"""Render the code editor section""" | |
st.subheader("📝 Code Editor") | |
code_content = st_ace( | |
value=st.session_state.current_code_block or "", | |
language="python", | |
theme="monokai", | |
key="code_editor", | |
height=300, | |
show_gutter=True, | |
wrap=True, | |
auto_update=True | |
) | |
col1, col2 = st.columns(2) | |
with col1: | |
if st.button("Format Code"): | |
st.session_state.current_code_block = format_code(code_content) | |
with col2: | |
if st.button("Get Completion Suggestions"): | |
if code_content: | |
code_analyzer = CodeAnalyzer() | |
context = code_analyzer.get_context(code_content) | |
completion = CodeCompletion() | |
suggestions = completion.get_completion_suggestions(code_content, context) | |
st.code(suggestions, language="python") | |
def main(): | |
"""Main application logic""" | |
setup_page_config() | |
init_session_state() | |
# Initialize model | |
with st.spinner("Loading Qwen2.5-Coder model..."): | |
load_model_and_tokenizer() | |
# Sidebar configuration | |
with st.sidebar: | |
colored_header(label="Model Settings", description="Configure your chat parameters", color_name="blue-70") | |
with st.expander("Advanced Settings", expanded=False): | |
temperature = st.slider("Temperature", 0.0, 2.0, 0.7, 0.1) | |
max_tokens = st.number_input("Max Tokens", 50, 4096, 2048) | |
system_prompt = st.text_area("System Prompt", st.session_state.system_prompt) | |
if st.button("Clear Chat"): | |
clear_chat() | |
st.title("🤖 Qwen2.5-Coder Chat") | |
st.caption("Powered by Qwen2.5-Coder-32B-Instruct") | |
# Main interface tabs | |
tab1, tab2 = st.tabs(["Chat", "Code Editor"]) | |
with tab1: | |
# File upload section | |
uploaded_content = handle_file_upload() | |
if uploaded_content: | |
st.session_state.messages.append({ | |
"role": "user", | |
"content": f"I've uploaded the following content:\n\n{uploaded_content}" | |
}) | |
# Display chat messages | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Check for code blocks in the message | |
code_blocks = CodeAnalyzer.extract_code_blocks(message["content"]) | |
if code_blocks and message["role"] == "assistant": | |
for code in code_blocks: | |
if not CodeAnalyzer.is_code_complete(code): | |
st.info("This code block appears to be incomplete. Would you like to complete it?") | |
if st.button("Complete Code", key=f"complete_{len(code)}"): | |
completion = handle_code_continuation(code) | |
st.code(completion, language="python") | |
# Chat input | |
if prompt := st.chat_input("Message (use @ to attach a file)"): | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("assistant"): | |
with st.spinner("Thinking..."): | |
response = generate_response(prompt, temperature, max_tokens, system_prompt) | |
st.markdown(response) | |
# Store code blocks in context | |
code_blocks = CodeAnalyzer.extract_code_blocks(response) | |
if code_blocks: | |
st.session_state.last_code_state = code_blocks[-1] | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
with tab2: | |
code_editor_section() | |
# Footer | |
add_vertical_space(2) | |
st.markdown("---") | |
st.markdown("Made with ❤️ using Streamlit and Qwen2.5-Coder") | |
if __name__ == "__main__": | |
main() |