File size: 5,916 Bytes
e4bde4c
22507c4
e4bde4c
4cb26b4
22507c4
b7f0e75
22507c4
4cb26b4
e4bde4c
22507c4
 
 
e4bde4c
 
 
 
 
 
 
4cb26b4
22507c4
 
 
 
 
 
 
 
e4bde4c
 
22507c4
 
 
 
 
 
 
 
 
 
4cb26b4
e4bde4c
 
 
22507c4
 
 
 
 
 
 
 
 
 
 
 
e4bde4c
22507c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4bde4c
22507c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4bde4c
22507c4
e4bde4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cb26b4
 
e4bde4c
4cb26b4
e4bde4c
 
 
 
22507c4
 
e4bde4c
 
 
22507c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4bde4c
 
 
22507c4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import time
import streamlit as st
from app.utils import generate_answer, load_llm
from core.types import ThoughtStepsDisplay, BigMessage 
from .app_config import InputConfig, ENV_FILE_PATH, CONFIG_FILE_PATH
from core.prompts.cot import SYSTEM_PROMPT




def config_sidebar(config:InputConfig) -> InputConfig:
    st.sidebar.header('Configuration')
    model_name =    st.sidebar.text_input('Model Name: e.g. provider/model-name',value=config.model_name, placeholder='openai/gpt-3.5-turbo')
    model_api_key = st.sidebar.text_input('API Key: ',type='password',value=config.model_api_key, placeholder='sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
    max_tokens =    st.sidebar.number_input('Max Tokens per Thought: ',value=config.max_tokens, min_value=1)
    max_steps =     st.sidebar.number_input('Max Thinking Steps: ',value=config.max_steps, min_value=1, step=1, ) 
    temperature =   st.sidebar.number_input('Temperature: ',value=config.temperature, min_value=0.0, step=0.1, max_value=10.0)
    timeout =       st.sidebar.number_input('Timeout(seconds): ',value=config.timeout, min_value=0.0,step = 1.0)
    sleeptime =     st.sidebar.number_input('Sleep Time(seconds)',value=config.sleeptime, min_value=0.0, step = 1.0, help='Time between requests to avoid hitting rate limit')  
    force_max_steps = st.sidebar.checkbox('Force Max Steps', value=config.force_max_steps, help="If checked, will generate given number of max steps. If not checked, assistant can stop at few step thinking it has the right answer.") 
    
    config.model_name = model_name
    config.model_api_key = model_api_key
    config.max_tokens = max_tokens
    config.max_steps = max_steps
    config.temperature = temperature
    config.timeout = timeout
    config.sleeptime = sleeptime
    config.force_max_steps = force_max_steps
    

    if st.sidebar.button('Save config'):
        config.save(env_file=ENV_FILE_PATH, config_file=CONFIG_FILE_PATH)
        st.sidebar.success('Config saved!')
        
    return config


    
def main():
    st.set_page_config(page_title="Open-o1", page_icon="🧠", layout="wide")
    st.title('Open-O1')
    st.write('Welcome to Open-O1!')

    
    config = InputConfig.load(env_file=ENV_FILE_PATH, config_file=CONFIG_FILE_PATH)    
    config = config_sidebar(config=config)
    llm = load_llm(config)

    
    current_tab='o1_tab'
    big_message_attr_name = f"{current_tab}_big_messages"
    

    clear_chat_bt = st.sidebar.button('Clear Chat')
    if clear_chat_bt:
        delattr(st.session_state, big_message_attr_name)


    big_message_attr = set_and_get_state_attr(big_message_attr_name, default_value=[])
    
    # this prints the older messages
    for message in big_message_attr:
        with st.chat_message(message.role):
            
            for thought in message.thoughts:
                print_thought(thought.to_thought_steps_display(), is_final=False)

            if message.content:
                if message.role == 'user':
                    st.markdown(message.content)
                else:
                    print_thought(message.content.to_thought_steps_display(), is_final=True)

            
    
    if prompt := st.chat_input("What is up bro?"):
        big_message_attr.append(BigMessage(role="user", content=prompt, thoughts=[])) 
        
        with st.chat_message("user"):
            st.markdown(prompt)

        with st.chat_message("assistant"):
            messages = [{
                "role": "system",
                "content": SYSTEM_PROMPT
            }]
            
            messages += [m.to_message() for m in big_message_attr]
            
            thoughts = []
            
            #add json keyword in user message , helps in json output
            for message in messages:
                if message["role"] == "user":
                    message["content"] = f"{message['content']}, json format"
            
            start_time = time.time()
            
            with st.status("Thinking...", expanded=True) as status:

                for step in generate_answer(
                    messages=messages, 
                    max_steps=config.max_steps, 
                    stream=False, 
                    max_tokens=config.max_tokens, 
                    temperature=config.temperature, 
                    sleeptime=config.sleeptime,
                    timeout=config.timeout, 
                    llm=llm,
                    force_max_steps=config.force_max_steps,
                    response_format={ "type": "json_object" }
                    
                    ):

                    thoughts.append(step)

                    st.write(step.to_thought_steps_display().md())
                    # add breakline after each step
                    st.markdown('---')
                    status.update(label=step.step_title, state="running", expanded=False)
                    

                status.update(
                    label=f"Thought for {time.time()-start_time:.2f} seconds", state="complete", expanded=False
                )

            last_step = thoughts.pop()
            print_thought(last_step.to_thought_steps_display(), is_final=True)

            big_message_attr.append(BigMessage(
                role="assistant", 
                content=last_step, 
                thoughts=thoughts
            ))


    
def set_and_get_state_attr(attr_name:str, default_value:any=None) -> any:
    if attr_name not in st.session_state:
        setattr(st.session_state, attr_name, default_value)
    return getattr(st.session_state, attr_name)


def print_thought(thought:ThoughtStepsDisplay, is_final:bool=False):
    if is_final:
       st.markdown(thought.md())
    else:
        # st.markdown(f'\n```json\n{thought.model_dump_json()}\n```\n', unsafe_allow_html=True) 
        with st.expander(f'{thought.step_title}'):
            st.markdown(thought.md())