Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -13,24 +13,24 @@ import gradio as gr
|
|
13 |
|
14 |
from rag_pipeline import RAGPipeline
|
15 |
from adversarial_framework import *
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
-
adv_pipeline = AdversarialAttackPipeline(answer_generator=
|
18 |
|
|
|
19 |
def gradio_wrapper(query, method, k):
|
20 |
stats_text, auc, fig, pert_q, pert_r, adv_r = adv_pipeline.evaluate_adversarial_robustness(
|
21 |
query=query,
|
22 |
method=method,
|
23 |
k=k
|
24 |
)
|
|
|
25 |
|
26 |
-
return (
|
27 |
-
stats_text,
|
28 |
-
f"{auc}",
|
29 |
-
fig,
|
30 |
-
f"🟠 Perturbed Query:\n\n{pert_q}",
|
31 |
-
f"🟢 Perturbed Response:\n\n{pert_r}",
|
32 |
-
f"🔴 Directly Perturbed Response of Normal Output:\n\n{adv_r}"
|
33 |
-
)
|
34 |
|
35 |
gr.Interface(
|
36 |
fn=gradio_wrapper,
|
|
|
13 |
|
14 |
from rag_pipeline import RAGPipeline
|
15 |
from adversarial_framework import *
|
16 |
+
# Load all models and retrievers ONCE
|
17 |
+
rag = RAGRetriever(
|
18 |
+
embedding_model="paraphrase-MiniLM-L3-v2",
|
19 |
+
cross_encoder_model="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
20 |
+
generator_model="google/flan-t5-small"
|
21 |
+
)
|
22 |
|
23 |
+
adv_pipeline = AdversarialAttackPipeline(answer_generator=rag.generate_answer)
|
24 |
|
25 |
+
# Define the Gradio wrapper
|
26 |
def gradio_wrapper(query, method, k):
|
27 |
stats_text, auc, fig, pert_q, pert_r, adv_r = adv_pipeline.evaluate_adversarial_robustness(
|
28 |
query=query,
|
29 |
method=method,
|
30 |
k=k
|
31 |
)
|
32 |
+
return stats_text, f"{auc}", fig, pert_q, pert_r, adv_r
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
gr.Interface(
|
36 |
fn=gradio_wrapper,
|