Lhumpal commited on
Commit
ffd9ec7
·
verified ·
1 Parent(s): 134c2d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -18
app.py CHANGED
@@ -1,23 +1,81 @@
1
- from fastapi import FastAPI
2
- from fastapi.responses import JSONResponse
3
- from fastapi import Request
4
- from huggingface_hub import InferenceClient
5
 
6
- app = FastAPI()
7
 
8
- @app.post("/")
9
- async def greet_json(request: Request):
10
- input_data = await request.json()
11
- # number = input_data.get("number")
12
 
13
- # tripled_number = number * 2
14
- # return {"message": f"Your input number is: {number}, your doubled number is: {tripled_number}"}
15
- user_input = input_data.get("user_input")
16
 
17
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
18
- # Get the response from the model
19
- response = client(user_input)
20
 
21
- # assistant_response = client.text_generation(user_input)
22
- assistant_response = "I am assistant."
23
- return {"assistant_message": f"Your input message is: {user_input}, assistant_response is: {response}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from fastapi import FastAPI
2
+ # from fastapi.responses import JSONResponse
3
+ # from fastapi import Request
4
+ # from huggingface_hub import InferenceClient
5
 
6
+ # app = FastAPI()
7
 
8
+ # @app.post("/")
9
+ # async def greet_json(request: Request):
10
+ # input_data = await request.json()
11
+ # # number = input_data.get("number")
12
 
13
+ # # tripled_number = number * 2
14
+ # # return {"message": f"Your input number is: {number}, your doubled number is: {tripled_number}"}
15
+ # user_input = input_data.get("user_input")
16
 
17
+ # client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
18
+ # # Get the response from the model
19
+ # response = client(user_input)
20
 
21
+ # # assistant_response = client.text_generation(user_input)
22
+ # assistant_response = "I am assistant."
23
+ # return {"assistant_message": f"Your input message is: {user_input}, assistant_response is: {response}"}
24
+ from fastapi import FastAPI, HTTPException
25
+ from pydantic import BaseModel
26
+ from huggingface_hub import InferenceClient
27
+ from typing import List, Tuple
28
+
29
+ # Initialize the Inference client
30
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
31
+
32
+ # Initialize FastAPI
33
+ app = FastAPI()
34
+
35
+ class Message(BaseModel):
36
+ message: str
37
+ history: List[Tuple[str, str]]
38
+ system_message: str
39
+ max_tokens: int
40
+ temperature: float
41
+ top_p: float
42
+
43
+ def generate_response(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p):
44
+ messages = [{"role": "system", "content": system_message}]
45
+
46
+ for val in history:
47
+ if val[0]:
48
+ messages.append({"role": "user", "content": val[0]})
49
+ if val[1]:
50
+ messages.append({"role": "assistant", "content": val[1]})
51
+
52
+ messages.append({"role": "user", "content": message})
53
+
54
+ response = ""
55
+
56
+ for message in client.chat_completion(
57
+ messages,
58
+ max_tokens=max_tokens,
59
+ stream=True,
60
+ temperature=temperature,
61
+ top_p=top_p,
62
+ ):
63
+ token = message.choices[0].delta.content
64
+ response += token
65
+ return response
66
+
67
+ @app.post("/chat")
68
+ async def chat_response(msg: Message):
69
+ try:
70
+ response = generate_response(
71
+ msg.message,
72
+ msg.history,
73
+ msg.system_message,
74
+ msg.max_tokens,
75
+ msg.temperature,
76
+ msg.top_p,
77
+ )
78
+ return {"response": response}
79
+ except Exception as e:
80
+ raise HTTPException(status_code=500, detail=str(e))
81
+