amiguel commited on
Commit
ac19c17
Β·
verified Β·
1 Parent(s): 026c97a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -7
app.py CHANGED
@@ -36,11 +36,24 @@ with st.sidebar:
36
  if "messages" not in st.session_state:
37
  st.session_state.messages = []
38
 
 
39
  @st.cache_data
40
  def process_file(uploaded_file):
41
- # Existing file processing logic
42
- pass
 
 
 
 
 
 
 
 
 
 
 
43
 
 
44
  @st.cache_resource
45
  def load_model(hf_token):
46
  try:
@@ -48,16 +61,13 @@ def load_model(hf_token):
48
  st.error("πŸ” Authentication required! Please provide a Hugging Face token.")
49
  return None
50
 
51
- # Login to Hugging Face Hub
52
  login(token=hf_token)
53
 
54
- # Load tokenizer
55
  tokenizer = AutoTokenizer.from_pretrained(
56
  MODEL_NAME,
57
  token=hf_token
58
  )
59
 
60
- # Load model with KV caching support
61
  model = AutoModelForCausalLM.from_pretrained(
62
  MODEL_NAME,
63
  device_map="auto",
@@ -71,7 +81,43 @@ def load_model(hf_token):
71
  st.error(f"πŸ€– Model loading failed: {str(e)}")
72
  return None
73
 
74
- # In the main chat handling section:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  if prompt := st.chat_input("Ask your inspection question..."):
76
  if not hf_token:
77
  st.error("πŸ”‘ Authentication required!")
@@ -89,7 +135,6 @@ if prompt := st.chat_input("Ask your inspection question..."):
89
  model = st.session_state.model
90
  tokenizer = st.session_state.tokenizer
91
 
92
-
93
  # Add user message
94
  with st.chat_message("user", avatar="πŸ‘€"):
95
  st.markdown(prompt)
 
36
  if "messages" not in st.session_state:
37
  st.session_state.messages = []
38
 
39
+ # File processing function
40
  @st.cache_data
41
  def process_file(uploaded_file):
42
+ if uploaded_file is None:
43
+ return ""
44
+
45
+ try:
46
+ if uploaded_file.type == "application/pdf":
47
+ pdf_reader = PyPDF2.PdfReader(uploaded_file)
48
+ return "\n".join([page.extract_text() for page in pdf_reader.pages])
49
+ elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
50
+ df = pd.read_excel(uploaded_file)
51
+ return df.to_markdown()
52
+ except Exception as e:
53
+ st.error(f"πŸ“„ Error processing file: {str(e)}")
54
+ return ""
55
 
56
+ # Model loading function
57
  @st.cache_resource
58
  def load_model(hf_token):
59
  try:
 
61
  st.error("πŸ” Authentication required! Please provide a Hugging Face token.")
62
  return None
63
 
 
64
  login(token=hf_token)
65
 
 
66
  tokenizer = AutoTokenizer.from_pretrained(
67
  MODEL_NAME,
68
  token=hf_token
69
  )
70
 
 
71
  model = AutoModelForCausalLM.from_pretrained(
72
  MODEL_NAME,
73
  device_map="auto",
 
81
  st.error(f"πŸ€– Model loading failed: {str(e)}")
82
  return None
83
 
84
+ # Generation function with KV caching
85
+ def generate_with_kv_cache(prompt, file_context, use_cache=True):
86
+ full_prompt = f"Analyze this context:\n{file_context}\n\nQuestion: {prompt}\nAnswer:"
87
+
88
+ streamer = TextIteratorStreamer(
89
+ tokenizer,
90
+ skip_prompt=True,
91
+ skip_special_tokens=True
92
+ )
93
+
94
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
95
+
96
+ generation_kwargs = {
97
+ **inputs,
98
+ "max_new_tokens": 1024,
99
+ "temperature": 0.7,
100
+ "top_p": 0.9,
101
+ "repetition_penalty": 1.1,
102
+ "do_sample": True,
103
+ "use_cache": use_cache,
104
+ "streamer": streamer
105
+ }
106
+
107
+ Thread(target=model.generate, kwargs=generation_kwargs).start()
108
+ return streamer
109
+
110
+ # Display chat messages
111
+ for message in st.session_state.messages:
112
+ try:
113
+ avatar = "πŸ‘€" if message["role"] == "user" else "πŸ€–"
114
+ with st.chat_message(message["role"], avatar=avatar):
115
+ st.markdown(message["content"])
116
+ except:
117
+ with st.chat_message(message["role"]):
118
+ st.markdown(message["content"])
119
+
120
+ # Chat input handling
121
  if prompt := st.chat_input("Ask your inspection question..."):
122
  if not hf_token:
123
  st.error("πŸ”‘ Authentication required!")
 
135
  model = st.session_state.model
136
  tokenizer = st.session_state.tokenizer
137
 
 
138
  # Add user message
139
  with st.chat_message("user", avatar="πŸ‘€"):
140
  st.markdown(prompt)