File size: 3,458 Bytes
f995cde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import List, Dict
import time

class LlamaDemo:
    def __init__(self):
        self.model_name = "meta-llama/Llama-2-7b-chat-hf"
        # Initialize in lazy loading fashion
        self._model = None
        self._tokenizer = None
        
    @property
    def model(self):
        if self._model is None:
            self._model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                torch_dtype=torch.float16,
                device_map="auto"
            )
        return self._model
    
    @property
    def tokenizer(self):
        if self._tokenizer is None:
            self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        return self._tokenizer

    def generate_response(self, prompt: str, max_length: int = 512) -> str:
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        
        # Generate response
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_length=max_length,
                num_return_sequences=1,
                temperature=0.7,
                do_sample=True,
                pad_token_id=self.tokenizer.eos_token_id
            )
        
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return response.replace(prompt, "").strip()

def main():
    st.set_page_config(
        page_title="Llama 3.1 Demo",
        page_icon="🦙",
        layout="wide"
    )
    
    st.title("🦙 Llama 3.1 Demo")
    
    # Initialize session state
    if 'llama' not in st.session_state:
        st.session_state.llama = LlamaDemo()
    
    if 'chat_history' not in st.session_state:
        st.session_state.chat_history = []
        
    # Chat interface
    with st.container():
        # Display chat history
        for message in st.session_state.chat_history:
            role = message["role"]
            content = message["content"]
            
            with st.chat_message(role):
                st.write(content)
    
        # Input for new message
        if prompt := st.chat_input("What would you like to discuss?"):
            # Add user message to chat history
            st.session_state.chat_history.append({
                "role": "user",
                "content": prompt
            })
            
            with st.chat_message("user"):
                st.write(prompt)
            
            # Show assistant response
            with st.chat_message("assistant"):
                message_placeholder = st.empty()
                
                with st.spinner("Generating response..."):
                    response = st.session_state.llama.generate_response(prompt)
                    message_placeholder.write(response)
                    
                # Add assistant response to chat history
                st.session_state.chat_history.append({
                    "role": "assistant",
                    "content": response
                })
    
    # Sidebar with settings
    with st.sidebar:
        st.header("Settings")
        max_length = st.slider("Maximum response length", 64, 1024, 512)
        
        if st.button("Clear Chat History"):
            st.session_state.chat_history = []
            st.experimental_rerun()

if __name__ == "__main__":
    main()