chibuzordev commited on
Commit
5c0610c
·
verified ·
1 Parent(s): a4689b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -13,24 +13,24 @@ import gradio as gr
13
 
14
  from rag_pipeline import RAGPipeline
15
  from adversarial_framework import *
 
 
 
 
 
 
16
 
17
- adv_pipeline = AdversarialAttackPipeline(answer_generator=RAGPipeline().generate_answer)
18
 
 
19
  def gradio_wrapper(query, method, k):
20
  stats_text, auc, fig, pert_q, pert_r, adv_r = adv_pipeline.evaluate_adversarial_robustness(
21
  query=query,
22
  method=method,
23
  k=k
24
  )
 
25
 
26
- return (
27
- stats_text,
28
- f"{auc}",
29
- fig,
30
- f"🟠 Perturbed Query:\n\n{pert_q}",
31
- f"🟢 Perturbed Response:\n\n{pert_r}",
32
- f"🔴 Directly Perturbed Response of Normal Output:\n\n{adv_r}"
33
- )
34
 
35
  gr.Interface(
36
  fn=gradio_wrapper,
 
13
 
14
  from rag_pipeline import RAGPipeline
15
  from adversarial_framework import *
16
+ # Load all models and retrievers ONCE
17
+ rag = RAGRetriever(
18
+ embedding_model="paraphrase-MiniLM-L3-v2",
19
+ cross_encoder_model="cross-encoder/ms-marco-MiniLM-L-6-v2",
20
+ generator_model="google/flan-t5-small"
21
+ )
22
 
23
+ adv_pipeline = AdversarialAttackPipeline(answer_generator=rag.generate_answer)
24
 
25
+ # Define the Gradio wrapper
26
  def gradio_wrapper(query, method, k):
27
  stats_text, auc, fig, pert_q, pert_r, adv_r = adv_pipeline.evaluate_adversarial_robustness(
28
  query=query,
29
  method=method,
30
  k=k
31
  )
32
+ return stats_text, f"{auc}", fig, pert_q, pert_r, adv_r
33
 
 
 
 
 
 
 
 
 
34
 
35
  gr.Interface(
36
  fn=gradio_wrapper,