amiguel commited on
Commit
026c97a
Β·
verified Β·
1 Parent(s): 0eb710b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -36
app.py CHANGED
@@ -43,52 +43,53 @@ def process_file(uploaded_file):
43
 
44
  @st.cache_resource
45
  def load_model(hf_token):
46
- # Existing model loading logic
47
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- def generate_with_kv_cache(prompt, file_context, use_cache=True):
50
- full_prompt = f"Analyze this context:\n{file_context}\n\nQuestion: {prompt}\nAnswer:"
51
-
52
- streamer = TextIteratorStreamer(
53
- tokenizer,
54
- skip_prompt=True,
55
- skip_special_tokens=True
56
- )
57
-
58
- inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
59
-
60
- # KV Caching parameters
61
- generation_kwargs = {
62
- **inputs,
63
- "max_new_tokens": 1024,
64
- "temperature": 0.7,
65
- "top_p": 0.9,
66
- "repetition_penalty": 1.1,
67
- "do_sample": True,
68
- "use_cache": use_cache, # KV Cache control
69
- "streamer": streamer
70
- }
71
-
72
- Thread(target=model.generate, kwargs=generation_kwargs).start()
73
- return streamer
74
-
75
- # Display chat messages
76
- for message in st.session_state.messages:
77
- # Existing message display logic
78
- pass
79
-
80
- # Chat input handling
81
  if prompt := st.chat_input("Ask your inspection question..."):
82
  if not hf_token:
83
  st.error("πŸ”‘ Authentication required!")
84
  st.stop()
85
 
86
- # Load model
87
  if "model" not in st.session_state:
88
- st.session_state.model, st.session_state.tokenizer = load_model(hf_token)
 
 
 
 
 
 
89
  model = st.session_state.model
90
  tokenizer = st.session_state.tokenizer
91
 
 
92
  # Add user message
93
  with st.chat_message("user", avatar="πŸ‘€"):
94
  st.markdown(prompt)
 
43
 
44
  @st.cache_resource
45
  def load_model(hf_token):
46
+ try:
47
+ if not hf_token:
48
+ st.error("πŸ” Authentication required! Please provide a Hugging Face token.")
49
+ return None
50
+
51
+ # Login to Hugging Face Hub
52
+ login(token=hf_token)
53
+
54
+ # Load tokenizer
55
+ tokenizer = AutoTokenizer.from_pretrained(
56
+ MODEL_NAME,
57
+ token=hf_token
58
+ )
59
+
60
+ # Load model with KV caching support
61
+ model = AutoModelForCausalLM.from_pretrained(
62
+ MODEL_NAME,
63
+ device_map="auto",
64
+ torch_dtype=torch.float16,
65
+ token=hf_token
66
+ )
67
+
68
+ return model, tokenizer
69
+
70
+ except Exception as e:
71
+ st.error(f"πŸ€– Model loading failed: {str(e)}")
72
+ return None
73
 
74
+ # In the main chat handling section:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  if prompt := st.chat_input("Ask your inspection question..."):
76
  if not hf_token:
77
  st.error("πŸ”‘ Authentication required!")
78
  st.stop()
79
 
80
+ # Load model if not already loaded
81
  if "model" not in st.session_state:
82
+ model_data = load_model(hf_token)
83
+ if model_data is None:
84
+ st.error("Failed to load model. Please check your token and try again.")
85
+ st.stop()
86
+
87
+ st.session_state.model, st.session_state.tokenizer = model_data
88
+
89
  model = st.session_state.model
90
  tokenizer = st.session_state.tokenizer
91
 
92
+
93
  # Add user message
94
  with st.chat_message("user", avatar="πŸ‘€"):
95
  st.markdown(prompt)