qhuang0805123 commited on
Commit
c587c68
·
verified ·
1 Parent(s): 2502356

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +163 -0
main.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uvicorn
3
+ from fastapi import FastAPI, HTTPException, Depends
4
+ from typing import List, Tuple, Optional, Union, Dict, Any
5
+ import torch
6
+
7
+ from config import ServerConfig
8
+ from guard_chat import SafeModelPipeline
9
+ from pydantic import BaseModel
10
+ from dataclasses import dataclass, field
11
+ import logging
12
+
13
+ # Initialize FastAPI app
14
+ app = FastAPI(title="Safe Chat Model API")
15
+
16
+
17
+ # Global variables for model and config
18
+ pipeline: Optional[SafeModelPipeline] = None
19
+ config: Optional[ServerConfig] = None
20
+
21
+ class ChatRequest(BaseModel):
22
+ messages: List[Dict[str, str]]
23
+ conversation_id: Optional[str] = None # To track different conversations
24
+ max_new_tokens: Optional[int] = 80
25
+ temperature: Optional[float] = 0.7
26
+ do_sample: Optional[bool] = True
27
+ top_p: Optional[float] = 0.9
28
+ top_k: Optional[int] = 50
29
+
30
+ class ChatResponse(BaseModel):
31
+ response: str
32
+ input_safety: str
33
+ output_safety: str
34
+ filtered: bool
35
+ conversation_id: str
36
+
37
+
38
+
39
+
40
+
41
+ def get_config():
42
+ """Dependency to get config"""
43
+ if config is None:
44
+ raise HTTPException(status_code=500, detail="Server not initialized")
45
+ return config
46
+
47
+
48
+
49
+ @app.post("/update_config")
50
+ async def update_config(
51
+ temperature: Optional[float] = None,
52
+ top_k: Optional[int] = None,
53
+ top_p: Optional[float] = None,
54
+ max_new_tokens: Optional[int] = None,
55
+ config: ServerConfig = Depends(get_config)
56
+ ):
57
+ """Update generation parameters"""
58
+ try:
59
+ if pipeline is None:
60
+ raise HTTPException(status_code=500, detail="Model not initialized")
61
+
62
+ # Update the configuration
63
+ if temperature is not None:
64
+ pipeline.args.temperature = temperature
65
+ if top_k is not None:
66
+ pipeline.args.top_k = top_k
67
+ if top_p is not None:
68
+ pipeline.args.top_p = top_p
69
+ if max_new_tokens is not None:
70
+ pipeline.args.max_new_tokens = max_new_tokens
71
+
72
+ return {"message": "Configuration updated successfully"}
73
+
74
+
75
+ except Exception as e:
76
+ raise HTTPException(status_code=500, detail=str(e))
77
+
78
+ @app.get("/health")
79
+ async def health_check():
80
+ """Health check endpoint"""
81
+ if pipeline is None or config is None:
82
+ raise HTTPException(status_code=503, detail="Server not fully initialized")
83
+ return {"status": "healthy"}
84
+
85
+
86
+ @app.on_event("startup")
87
+ async def startup_event():
88
+ """Initialize the server on startup"""
89
+ global pipeline, config
90
+
91
+ try:
92
+ # Load config from YAML
93
+ config_path = os.getenv("CONFIG_PATH", "config.yaml")
94
+ config = ServerConfig.from_yaml(config_path)
95
+
96
+ # Initialize pipeline with config
97
+ pipeline = SafeModelPipeline(
98
+ model_args=config.to_chat_arguments(),
99
+ system_prompt=config.system_prompt,
100
+ max_history_tokens=config.max_history_tokens
101
+ )
102
+
103
+ except Exception as e:
104
+ print(f"Error initializing server: {str(e)}")
105
+ raise
106
+
107
+ @app.post("/chat", response_model=ChatResponse)
108
+ async def chat_endpoint(request: ChatRequest):
109
+ """Chat endpoint that handles conversation and calls the safe model pipeline"""
110
+ logging.info(f"Received request: {request}") # Log the incoming request
111
+
112
+ if pipeline is None:
113
+ raise HTTPException(status_code=500, detail="Model not initialized")
114
+
115
+ try:
116
+ # Validate request
117
+ if not request.messages:
118
+ raise HTTPException(status_code=400, detail="No messages provided")
119
+
120
+ logging.info(f"Processing messages: {request.messages}") # Log the messages
121
+
122
+ # Call the safe model pipeline's generate_response method
123
+ try:
124
+ response = await pipeline.generate_response(request)
125
+ logging.info(f"Generated response: {response}") # Log the response
126
+ except Exception as e:
127
+ logging.error(f"Pipeline error: {str(e)}", exc_info=True)
128
+ raise HTTPException(status_code=500, detail=f"Pipeline error: {str(e)}")
129
+
130
+ # If response is None or invalid, raise an error
131
+ if not response:
132
+ raise HTTPException(status_code=500, detail="Failed to generate response")
133
+
134
+ return ChatResponse(
135
+ response=response.response,
136
+ conversation_id=response.conversation_id,
137
+ input_safety=response.input_safety,
138
+ output_safety=response.output_safety,
139
+ filtered=response.filtered
140
+ )
141
+
142
+ except Exception as e:
143
+ logging.error(f"Error in chat endpoint: {str(e)}", exc_info=True)
144
+ raise HTTPException(status_code=500, detail=str(e))
145
+
146
+ if __name__ == "__main__":
147
+ import argparse
148
+
149
+ parser = argparse.ArgumentParser()
150
+ parser.add_argument("--host", type=str, default="0.0.0.0",
151
+ help="Host to run the server on")
152
+ parser.add_argument("--port", type=int, default=8000,
153
+ help="Port to run the server on")
154
+ parser.add_argument("--config", type=str, default="config.yaml",
155
+ help="Path to config file")
156
+
157
+ args = parser.parse_args()
158
+
159
+ # Set config path environment variable
160
+ os.environ["CONFIG_PATH"] = args.config
161
+
162
+ # Run the server
163
+ uvicorn.run(app, host=args.host, port=args.port)