Update main.py
Browse files
main.py
CHANGED
@@ -1,85 +1,76 @@
|
|
1 |
-
import
|
2 |
-
from
|
3 |
-
from fastapi import
|
4 |
-
from
|
5 |
-
|
6 |
-
from text_generation import Client
|
7 |
|
8 |
-
|
9 |
-
HF_TOKEN = os.environ.get("HF_TOKEN")
|
10 |
-
if HF_TOKEN is None:
|
11 |
-
raise ValueError("Please set the HF_TOKEN environment variable.")
|
12 |
|
13 |
-
|
14 |
-
model_id = 'codellama/CodeLlama-34b-Instruct-hf'
|
15 |
-
API_URL = "https://api-inference.huggingface.co/models/" + model_id
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
)
|
|
|
|
|
|
|
21 |
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
-
# Allow CORS for your frontend application
|
28 |
app.add_middleware(
|
29 |
CORSMiddleware,
|
30 |
-
allow_origins=["
|
31 |
allow_credentials=True,
|
32 |
-
allow_methods=["*"],
|
33 |
-
allow_headers=["*"],
|
34 |
)
|
35 |
|
36 |
-
# Pydantic model for request body
|
37 |
-
class ChatRequest(BaseModel):
|
38 |
-
prompt: str
|
39 |
-
history: List[Tuple[str, str]]
|
40 |
-
|
41 |
-
DEFAULT_SYSTEM_PROMPT = """\
|
42 |
-
You are a helpful, respectful and honest assistant with a deep knowledge of code and software design. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
|
43 |
-
"""
|
44 |
-
|
45 |
-
def get_prompt(message: str, chat_history: List[Tuple[str, str]],
|
46 |
-
system_prompt: str) -> str:
|
47 |
-
texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
|
48 |
-
do_strip = False
|
49 |
-
for user_input, response in chat_history:
|
50 |
-
user_input = user_input.strip() if do_strip else user_input
|
51 |
-
do_strip = True
|
52 |
-
texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
|
53 |
-
message = message.strip() if do_strip else message
|
54 |
-
texts.append(f'{message} [/INST]')
|
55 |
-
return ''.join(texts)
|
56 |
-
|
57 |
@app.post("/generate/")
|
58 |
-
async def
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
63 |
|
64 |
-
|
|
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
top_k=50,
|
71 |
-
temperature=0.1,
|
72 |
-
)
|
73 |
-
|
74 |
-
stream = client.generate_stream(prompt_text, **generate_kwargs)
|
75 |
-
output = ""
|
76 |
-
for response in stream:
|
77 |
-
if any([end_token in response.token.text for end_token in [EOS_STRING, EOT_STRING]]):
|
78 |
-
break
|
79 |
-
else:
|
80 |
-
output += response.token.text
|
81 |
|
82 |
-
|
83 |
|
84 |
-
except Exception as e:
|
85 |
-
raise HTTPException(status_code=500, detail=str(e))
|
|
|
1 |
+
from fastapi import FastAPI, Request, HTTPException
|
2 |
+
from fastapi.responses import JSONResponse, FileResponse
|
3 |
+
from fastapi.staticfiles import StaticFiles
|
4 |
+
from huggingface_hub import InferenceClient
|
5 |
+
import json
|
|
|
6 |
|
7 |
+
app = FastAPI()
|
|
|
|
|
|
|
8 |
|
9 |
+
client = InferenceClient("NousResearch/Hermes-3-Llama-3.1-8B")
|
|
|
|
|
10 |
|
11 |
+
SYSTEM_MESSAGE = (
|
12 |
+
"You are a helpful, respectful and honest assistant. Always answer as helpfully "
|
13 |
+
"as possible, while being safe. Your answers should not include any harmful, "
|
14 |
+
"unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure "
|
15 |
+
"that your responses are socially unbiased and positive in nature.\n\nIf a question "
|
16 |
+
"does not make any sense, or is not factually coherent, explain why instead of "
|
17 |
+
"answering something not correct. If you don't know the answer to a question, please "
|
18 |
+
"don't share false information."
|
19 |
+
"Always respond in the language of user prompt for each prompt ."
|
20 |
)
|
21 |
+
MAX_TOKENS = 2000
|
22 |
+
TEMPERATURE = 0.7
|
23 |
+
TOP_P = 0.95
|
24 |
|
25 |
+
def respond(message, history: list[tuple[str, str]]):
|
26 |
+
messages = [{"role": "system", "content": SYSTEM_MESSAGE}]
|
27 |
+
|
28 |
+
for val in history:
|
29 |
+
if val[0]:
|
30 |
+
messages.append({"role": "user", "content": val[0]})
|
31 |
+
if val[1]:
|
32 |
+
messages.append({"role": "assistant", "content": val[1]})
|
33 |
|
34 |
+
messages.append({"role": "user", "content": message})
|
35 |
+
|
36 |
+
response = client.chat_completion(
|
37 |
+
messages,
|
38 |
+
max_tokens=MAX_TOKENS,
|
39 |
+
stream=True,
|
40 |
+
temperature=TEMPERATURE,
|
41 |
+
top_p=TOP_P,
|
42 |
+
)
|
43 |
+
|
44 |
+
for message in response: # Handle regular iteration
|
45 |
+
yield message.choices[0].delta.content
|
46 |
+
|
47 |
+
from fastapi.middleware.cors import CORSMiddleware
|
48 |
|
|
|
49 |
app.add_middleware(
|
50 |
CORSMiddleware,
|
51 |
+
allow_origins=["https://artixiban-ll3.static.hf.space"], # Allow only this origin
|
52 |
allow_credentials=True,
|
53 |
+
allow_methods=["*"], # Allow all methods (GET, POST, etc.)
|
54 |
+
allow_headers=["*"], # Allow all headers
|
55 |
)
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
@app.post("/generate/")
|
58 |
+
async def generate(request: Request):
|
59 |
+
allowed_origin = "https://artixiban-ll3.static.hf.space"
|
60 |
+
origin = request.headers.get("origin")
|
61 |
+
if origin != allowed_origin:
|
62 |
+
raise HTTPException(status_code=403, detail="Origin not allowed")
|
63 |
+
form = await request.form()
|
64 |
+
prompt = form.get("prompt")
|
65 |
+
history = json.loads(form.get("history", "[]")) # Default to empty history
|
66 |
|
67 |
+
if not prompt:
|
68 |
+
raise HTTPException(status_code=400, detail="Prompt is required")
|
69 |
|
70 |
+
response_generator = respond(prompt, history)
|
71 |
+
final_response = ""
|
72 |
+
for part in response_generator:
|
73 |
+
final_response += part
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
+
return JSONResponse(content={"response": final_response})
|
76 |
|
|
|
|