File size: 4,440 Bytes
c936033
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import streamlit as st
import os
from datetime import datetime
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Set up the Streamlit app layout
st.title("HuggingFace Model Chat App")
st.write("Enter a HuggingFace model name, set parameters, and chat with the model!")

# Sidebar for model selection and parameters
with st.sidebar:
    st.header("Model and Parameters")
    model_name = st.text_input("HuggingFace Model Name (e.g., meta-llama/Llama-2-7b-chat-hf)", value="meta-llama/Llama-2-7b-chat-hf")
    system_prompt = st.text_area("System Prompt", value="You are a helpful AI assistant.", height=100)
    temperature = st.slider("Temperature (Randomness)", min_value=0.1, max_value=2.0, value=0.7, step=0.1)
    top_p = st.slider("Top-p (Nucleus Sampling)", min_value=0.1, max_value=1.0, value=0.9, step=0.05)
    max_length = st.slider("Max Response Length", min_value=50, max_value=500, value=200, step=10)
    load_model_button = st.button("Load Model")

# Initialize session state for model, tokenizer, and chat history
if 'model' not in st.session_state:
    st.session_state.model = None
    st.session_state.tokenizer = None
    st.session_state.chat_history = []
    st.session_state.model_loaded = False

# Function to save query and response to markdown file
def save_to_md(query, response):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"chat_{timestamp}.md"
    os.makedirs("chat_history", exist_ok=True)
    with open(os.path.join("chat_history", filename), "w", encoding="utf-8") as f:
        f.write(f"# Chat Log\n\n**Query:** {query}\n\n**Response:** {response}\n")
    return filename

# Load model and tokenizer when button is clicked
if load_model_button:
    with st.spinner("Loading model... This may take a while."):
        try:
            st.session_state.tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=os.getenv("HUGGINGFACE_TOKEN"))
            st.session_state.model = AutoModelForCausalLM.from_pretrained(
                model_name, 
                torch_dtype=torch.float16, 
                device_map="auto", 
                use_auth_token=os.getenv("HUGGINGFACE_TOKEN")
            )
            st.session_state.model_loaded = True
            st.success("Model loaded successfully!")
        except Exception as e:
            st.error(f"Error loading model: {str(e)}")
            st.session_state.model_loaded = False

# Chat interface
if st.session_state.model_loaded:
    st.header("Chat with the Model")
    user_input = st.text_area("Your Message", height=100)
    send_button = st.button("Send")

    # Display chat history
    st.subheader("Chat History")
    for chat in st.session_state.chat_history:
        st.markdown(f"**You:** {chat['query']}")
        st.markdown(f"**Model:** {chat['response']}")
        st.markdown(f"**Saved as:** {chat['filename']}")
        st.markdown("---")

    # Process user input and generate response
    if send_button and user_input:
        try:
            # Prepare input with system prompt
            full_input = f"{system_prompt}\n\nUser: {user_input}\nAssistant: "
            inputs = st.session_state.tokenizer(full_input, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")

            # Generate response
            with st.spinner("Generating response..."):
                outputs = st.session_state.model.generate(
                    **inputs,
                    max_length=max_length,
                    temperature=temperature,
                    top_p=top_p,
                    do_sample=True,
                    pad_token_id=st.session_state.tokenizer.eos_token_id
                )
                response = st.session_state.tokenizer.decode(outputs[0], skip_special_tokens=True)
                # Extract only the assistant's response
                response = response.split("Assistant: ")[-1].strip()

            # Save to markdown
            filename = save_to_md(user_input, response)

            # Update chat history
            st.session_state.chat_history.append({
                "query": user_input,
                "response": response,
                "filename": filename
            })

            # Rerun to update the display
            st.rerun()
        except Exception as e:
            st.error(f"Error generating response: {str(e)}")
else:
    st.info("Please load a model to start chatting.")