xuandin commited on
Commit
d7db16b
·
verified ·
1 Parent(s): a486265

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -0
app.py CHANGED
@@ -18,6 +18,18 @@ os.environ["MKL_NUM_THREADS"] = str(psutil.cpu_count(logical=False))
18
  # Set device globally
19
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
20
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  @st.cache_data
22
  def preprocess_text(text):
23
  # Add any text cleaning or normalization here
@@ -309,6 +321,10 @@ with st.sidebar:
309
  # Main content
310
  tabs = st.tabs(["🔍 Kiểm chứng", "📊 Lịch sử", "ℹ️ Thông tin"])
311
 
 
 
 
 
312
  # --- Tab Verify ---
313
  with tabs[0]:
314
  col1, col2 = st.columns([2, 1])
 
18
  # Set device globally
19
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
+ @st.cache_resource()
22
+ def load_model(model_name, model_class, is_bc=False, device=None):
23
+ if device is None:
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+
26
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
27
+ model = model_class.from_pretrained(model_name, num_labels=3 if not is_bc else 2)
28
+ model.eval()
29
+
30
+ model.to(device)
31
+ return tokenizer, model
32
+
33
  @st.cache_data
34
  def preprocess_text(text):
35
  # Add any text cleaning or normalization here
 
321
  # Main content
322
  tabs = st.tabs(["🔍 Kiểm chứng", "📊 Lịch sử", "ℹ️ Thông tin"])
323
 
324
+ tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering, device=DEVICE)
325
+ tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True, device=DEVICE)
326
+ tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification, device=DEVICE)
327
+
328
  # --- Tab Verify ---
329
  with tabs[0]:
330
  col1, col2 = st.columns([2, 1])