File size: 4,456 Bytes
280ee80
ac6b0ec
280ee80
 
cab69db
280ee80
 
 
cab69db
280ee80
 
 
 
55c951d
 
 
280ee80
 
 
 
 
cab69db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280ee80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cab69db
280ee80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac6b0ec
 
 
 
280ee80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40c3a36
280ee80
 
 
 
 
 
 
cab69db
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr
import json
import os
import sys
from fastapi import FastAPI
from pydantic import BaseModel
from hamilton import driver
from pandas import DataFrame
from fastapi.middleware.cors import CORSMiddleware

# Add the src directory to the Python path
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))

from src.data_module import data_pipeline, embedding_pipeline, vectorstore
from src.classification_module import semantic_similarity, dio_support_detector
from src.enforcement_module import policy_enforcement_decider

from decouple import config

app = FastAPI()

# Enable CORS for Gradio to communicate with FastAPI
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

config = {
    "loader": "pd",
    "embedding_service": "openai",
    "api_key": config("OPENAI_API_KEY"),
    "model_name": "text-embedding-ada-002",
    "mistral_public_url": config("MISTRAL_PUBLIC_URL"),
    "ner_public_url": config("NER_PUBLIC_URL"),
}

dr = (
    driver.Builder()
    .with_config(config)
    .with_modules(data_pipeline, embedding_pipeline, vectorstore, semantic_similarity, dio_support_detector)
    .build()
)

dr_enforcement = (
    driver.Builder()
    .with_config(config)
    .with_modules(policy_enforcement_decider)
    .build()
)

class RadicalizationDetectionRequest(BaseModel):
    user_text: str
    
class PolicyEnforcementRequest(BaseModel):
    user_text: str
    violation_context: dict

class RadicalizationDetectionResponse(BaseModel):
    values: dict
    
class PolicyEnforcementResponse(BaseModel):
    values: dict

@app.post("/detect_radicalization")
def detect_radicalization(
    request: RadicalizationDetectionRequest
) -> RadicalizationDetectionResponse:
    
    results = dr.execute(
        final_vars=["detect_glorification"],
        inputs={"project_root": ".", "user_input": request.user_text}
    )
    if isinstance(results, DataFrame):
        results = results.to_dict(orient="dict")
    return RadicalizationDetectionResponse(values=results)

@app.post("/generate_policy_enforcement")
def generate_policy_enforcement(
    request: PolicyEnforcementRequest
) -> PolicyEnforcementResponse:
    results = dr_enforcement.execute(
        final_vars=["get_enforcement_decision"],
        inputs={"project_root": ".", "user_input": request.user_text, "violation_context": request.violation_context}
    )
    if isinstance(results, DataFrame):
        results = results.to_dict(orient="dict")
    return PolicyEnforcementResponse(values=results)

# Gradio Interface Functions
def gradio_detect_radicalization(user_text: str):
    request = RadicalizationDetectionRequest(user_text=user_text)
    response = detect_radicalization(request)
    return response.values

def gradio_generate_policy_enforcement(user_text: str, violation_context: str):
    # violation_context needs to be provided in a valid JSON format
    try:
        context_dict = json.loads(violation_context)  # Parse violation_context as JSON
    except json.JSONDecodeError:
        return {"error": "Invalid JSON format for violation_context"}
    request = PolicyEnforcementRequest(user_text=user_text, violation_context=context_dict)
    response = generate_policy_enforcement(request)
    return response.values

# Define the Gradio interface
iface = gr.Interface(
    fn=gradio_detect_radicalization,  # Function to detect radicalization
    inputs="text",  # Single text input
    outputs="json",  # Return JSON output
    title="Radicalization Detection",
    description="Enter text to detect glorification or radicalization."
)

# Second interface for policy enforcement
iface2 = gr.Interface(
    fn=gradio_generate_policy_enforcement,  # Function to generate policy enforcement
    inputs=["text", gr.Textbox(lines=5, placeholder="Enter JSON-formatted violation context")],  # Two text inputs, one for user text, one for violation context
    outputs="json",  # Return JSON output
    title="Policy Enforcement Decision",
    description="Enter user text and context to generate a policy enforcement decision."
)

# Combine the interfaces in a Tabbed interface
iface_combined = gr.TabbedInterface([iface, iface2], ["Detect Radicalization", "Policy Enforcement"])

if __name__ == "__main__":
    # Launch Gradio interface (no need to launch Uvicorn separately)
    iface_combined.launch(server_name="0.0.0.0", server_port=7860)