File size: 2,673 Bytes
854f61d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from typing import Annotated

from fastapi import FastAPI, Form, UploadFile
from pydantic import BaseModel
from hamilton import driver
from pandas import DataFrame

from data_module import data_pipeline, embedding_pipeline, vectorstore
from classification_module import semantic_similarity, dio_support_detector
from enforcement_module import policy_enforcement_decider

from decouple import config

app = FastAPI()

config = {"loader": "pd", 
          "embedding_service": "openai", 
          "api_key": config("OPENAI_API_KEY"), 
          "model_name": "text-embedding-ada-002",
          "mistral_public_url": config("MISTRAL_PUBLIC_URL"),
          "ner_public_url": config("NER_PUBLIC_URL")
          }  # or "pd"

dr = (
    driver.Builder()
    .with_config(config)
    .with_modules(data_pipeline, embedding_pipeline, vectorstore, semantic_similarity, dio_support_detector)
    .build()
)

dr_enforcement = (
    driver.Builder()
    .with_config(config)
    .with_modules(policy_enforcement_decider)
    .build()
)

class RadicalizationDetectionRequest(BaseModel):
    user_text: str
    
class PolicyEnforcementRequest(BaseModel):
    user_text: str
    violation_context: dict

class RadicalizationDetectionResponse(BaseModel):
    """Response to the /detect endpoint"""
    values: dict
    
class PolicyEnforcementResponse(BaseModel):
    """Response to the /generate_policy_enforcement endpoint"""
    values: dict
    
@app.post("/detect_radicalization")
def detect_radicalization(

    request: RadicalizationDetectionRequest

) -> RadicalizationDetectionResponse:
    
    results = dr.execute(
        final_vars=["detect_glorification"],
        inputs={"project_root": ".", "user_input": request.user_text}
    )
    print(results)
    print(type(results))
    if isinstance(results, DataFrame):
        results = results.to_dict(orient="dict")
    return RadicalizationDetectionResponse(values=results)

@app.post("/generate_policy_enforcement")
def generate_policy_enforcement(

    request: PolicyEnforcementRequest

) -> PolicyEnforcementResponse:
    results = dr_enforcement.execute(
        final_vars=["get_enforcement_decision"],
        inputs={"project_root": ".", "user_input": request.user_text, "violation_context": request.violation_context}
    )
    print(results)
    print(type(results))
    if isinstance(results, DataFrame):
        results = results.to_dict(orient="dict")
    return PolicyEnforcementResponse(values=results)



if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)  # specify host and port