File size: 4,456 Bytes
280ee80 ac6b0ec 280ee80 cab69db 280ee80 cab69db 280ee80 55c951d 280ee80 cab69db 280ee80 cab69db 280ee80 ac6b0ec 280ee80 40c3a36 280ee80 cab69db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import gradio as gr
import json
import os
import sys
from fastapi import FastAPI
from pydantic import BaseModel
from hamilton import driver
from pandas import DataFrame
from fastapi.middleware.cors import CORSMiddleware
# Add the src directory to the Python path
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
from src.data_module import data_pipeline, embedding_pipeline, vectorstore
from src.classification_module import semantic_similarity, dio_support_detector
from src.enforcement_module import policy_enforcement_decider
from decouple import config
app = FastAPI()
# Enable CORS for Gradio to communicate with FastAPI
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
config = {
"loader": "pd",
"embedding_service": "openai",
"api_key": config("OPENAI_API_KEY"),
"model_name": "text-embedding-ada-002",
"mistral_public_url": config("MISTRAL_PUBLIC_URL"),
"ner_public_url": config("NER_PUBLIC_URL"),
}
dr = (
driver.Builder()
.with_config(config)
.with_modules(data_pipeline, embedding_pipeline, vectorstore, semantic_similarity, dio_support_detector)
.build()
)
dr_enforcement = (
driver.Builder()
.with_config(config)
.with_modules(policy_enforcement_decider)
.build()
)
class RadicalizationDetectionRequest(BaseModel):
user_text: str
class PolicyEnforcementRequest(BaseModel):
user_text: str
violation_context: dict
class RadicalizationDetectionResponse(BaseModel):
values: dict
class PolicyEnforcementResponse(BaseModel):
values: dict
@app.post("/detect_radicalization")
def detect_radicalization(
request: RadicalizationDetectionRequest
) -> RadicalizationDetectionResponse:
results = dr.execute(
final_vars=["detect_glorification"],
inputs={"project_root": ".", "user_input": request.user_text}
)
if isinstance(results, DataFrame):
results = results.to_dict(orient="dict")
return RadicalizationDetectionResponse(values=results)
@app.post("/generate_policy_enforcement")
def generate_policy_enforcement(
request: PolicyEnforcementRequest
) -> PolicyEnforcementResponse:
results = dr_enforcement.execute(
final_vars=["get_enforcement_decision"],
inputs={"project_root": ".", "user_input": request.user_text, "violation_context": request.violation_context}
)
if isinstance(results, DataFrame):
results = results.to_dict(orient="dict")
return PolicyEnforcementResponse(values=results)
# Gradio Interface Functions
def gradio_detect_radicalization(user_text: str):
request = RadicalizationDetectionRequest(user_text=user_text)
response = detect_radicalization(request)
return response.values
def gradio_generate_policy_enforcement(user_text: str, violation_context: str):
# violation_context needs to be provided in a valid JSON format
try:
context_dict = json.loads(violation_context) # Parse violation_context as JSON
except json.JSONDecodeError:
return {"error": "Invalid JSON format for violation_context"}
request = PolicyEnforcementRequest(user_text=user_text, violation_context=context_dict)
response = generate_policy_enforcement(request)
return response.values
# Define the Gradio interface
iface = gr.Interface(
fn=gradio_detect_radicalization, # Function to detect radicalization
inputs="text", # Single text input
outputs="json", # Return JSON output
title="Radicalization Detection",
description="Enter text to detect glorification or radicalization."
)
# Second interface for policy enforcement
iface2 = gr.Interface(
fn=gradio_generate_policy_enforcement, # Function to generate policy enforcement
inputs=["text", gr.Textbox(lines=5, placeholder="Enter JSON-formatted violation context")], # Two text inputs, one for user text, one for violation context
outputs="json", # Return JSON output
title="Policy Enforcement Decision",
description="Enter user text and context to generate a policy enforcement decision."
)
# Combine the interfaces in a Tabbed interface
iface_combined = gr.TabbedInterface([iface, iface2], ["Detect Radicalization", "Policy Enforcement"])
if __name__ == "__main__":
# Launch Gradio interface (no need to launch Uvicorn separately)
iface_combined.launch(server_name="0.0.0.0", server_port=7860)
|