Lhumpal commited on
Commit
7f349bb
·
verified ·
1 Parent(s): ffd9ec7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -45
app.py CHANGED
@@ -24,58 +24,44 @@
24
  from fastapi import FastAPI, HTTPException
25
  from pydantic import BaseModel
26
  from huggingface_hub import InferenceClient
27
- from typing import List, Tuple
28
 
29
- # Initialize the Inference client
30
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
31
-
32
- # Initialize FastAPI
33
  app = FastAPI()
34
 
35
- class Message(BaseModel):
36
- message: str
37
- history: List[Tuple[str, str]]
38
- system_message: str
39
- max_tokens: int
40
- temperature: float
41
- top_p: float
42
-
43
- def generate_response(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p):
44
- messages = [{"role": "system", "content": system_message}]
45
 
46
- for val in history:
47
- if val[0]:
48
- messages.append({"role": "user", "content": val[0]})
49
- if val[1]:
50
- messages.append({"role": "assistant", "content": val[1]})
 
 
51
 
52
- messages.append({"role": "user", "content": message})
 
53
 
54
- response = ""
 
 
 
 
 
 
 
 
 
55
 
56
- for message in client.chat_completion(
57
- messages,
58
- max_tokens=max_tokens,
59
- stream=True,
60
- temperature=temperature,
61
- top_p=top_p,
62
- ):
63
- token = message.choices[0].delta.content
64
- response += token
65
- return response
66
 
67
- @app.post("/chat")
68
- async def chat_response(msg: Message):
69
- try:
70
- response = generate_response(
71
- msg.message,
72
- msg.history,
73
- msg.system_message,
74
- msg.max_tokens,
75
- msg.temperature,
76
- msg.top_p,
77
- )
78
  return {"response": response}
79
  except Exception as e:
80
- raise HTTPException(status_code=500, detail=str(e))
81
-
 
24
  from fastapi import FastAPI, HTTPException
25
  from pydantic import BaseModel
26
  from huggingface_hub import InferenceClient
 
27
 
 
 
 
 
28
  app = FastAPI()
29
 
30
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
 
 
 
 
 
31
 
32
+ class ChatRequest(BaseModel):
33
+ message: str
34
+ history: list[tuple[str, str]] = []
35
+ system_message: str = "You are a friendly Chatbot."
36
+ max_tokens: int = 512
37
+ temperature: float = 0.7
38
+ top_p: float = 0.95
39
 
40
+ class ChatResponse(BaseModel):
41
+ response: str
42
 
43
+ @app.post("/chat", response_model=ChatResponse)
44
+ async def chat(request: ChatRequest):
45
+ try:
46
+ messages = [{"role": "system", "content": request.system_message}]
47
+ for val in request.history:
48
+ if val[0]:
49
+ messages.append({"role": "user", "content": val[0]})
50
+ if val[1]:
51
+ messages.append({"role": "assistant", "content": val[1]})
52
+ messages.append({"role": "user", "content": request.message})
53
 
54
+ response = ""
55
+ for message in client.chat_completion(
56
+ messages,
57
+ max_tokens=request.max_tokens,
58
+ stream=True,
59
+ temperature=request.temperature,
60
+ top_p=request.top_p,
61
+ ):
62
+ token = message.choices[0].delta.content
63
+ response += token
64
 
 
 
 
 
 
 
 
 
 
 
 
65
  return {"response": response}
66
  except Exception as e:
67
+ raise HTTPException(status_code=500, detail=str(e))