Makhinur commited on
Commit
d6b44d2
·
verified ·
1 Parent(s): 2096e32

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +59 -68
main.py CHANGED
@@ -1,85 +1,76 @@
1
- import os
2
- from typing import List, Tuple
3
- from fastapi import FastAPI, Form, HTTPException
4
- from fastapi.middleware.cors import CORSMiddleware
5
- from pydantic import BaseModel
6
- from text_generation import Client
7
 
8
- # Ensure the HF_TOKEN environment variable is set
9
- HF_TOKEN = os.environ.get("HF_TOKEN")
10
- if HF_TOKEN is None:
11
- raise ValueError("Please set the HF_TOKEN environment variable.")
12
 
13
- # Model and API setup
14
- model_id = 'codellama/CodeLlama-34b-Instruct-hf'
15
- API_URL = "https://api-inference.huggingface.co/models/" + model_id
16
 
17
- client = Client(
18
- API_URL,
19
- headers={"Authorization": f"Bearer {HF_TOKEN}"},
 
 
 
 
 
 
20
  )
 
 
 
21
 
22
- EOS_STRING = "</s>"
23
- EOT_STRING = "<EOT>"
 
 
 
 
 
 
24
 
25
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- # Allow CORS for your frontend application
28
  app.add_middleware(
29
  CORSMiddleware,
30
- allow_origins=["*"], # Change this to your frontend's URL in production
31
  allow_credentials=True,
32
- allow_methods=["*"],
33
- allow_headers=["*"],
34
  )
35
 
36
- # Pydantic model for request body
37
- class ChatRequest(BaseModel):
38
- prompt: str
39
- history: List[Tuple[str, str]]
40
-
41
- DEFAULT_SYSTEM_PROMPT = """\
42
- You are a helpful, respectful and honest assistant with a deep knowledge of code and software design. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
43
- """
44
-
45
- def get_prompt(message: str, chat_history: List[Tuple[str, str]],
46
- system_prompt: str) -> str:
47
- texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
48
- do_strip = False
49
- for user_input, response in chat_history:
50
- user_input = user_input.strip() if do_strip else user_input
51
- do_strip = True
52
- texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
53
- message = message.strip() if do_strip else message
54
- texts.append(f'{message} [/INST]')
55
- return ''.join(texts)
56
-
57
  @app.post("/generate/")
58
- async def generate_response(prompt: str = Form(...), history: str = Form(...)):
59
- try:
60
- chat_history = eval(history) # Convert history string back to list
61
- system_prompt = DEFAULT_SYSTEM_PROMPT
62
- message = prompt
 
 
 
63
 
64
- prompt_text = get_prompt(message, chat_history, system_prompt)
 
65
 
66
- generate_kwargs = dict(
67
- max_new_tokens=1024,
68
- do_sample=True,
69
- top_p=0.9,
70
- top_k=50,
71
- temperature=0.1,
72
- )
73
-
74
- stream = client.generate_stream(prompt_text, **generate_kwargs)
75
- output = ""
76
- for response in stream:
77
- if any([end_token in response.token.text for end_token in [EOS_STRING, EOT_STRING]]):
78
- break
79
- else:
80
- output += response.token.text
81
 
82
- return {"response": output}
83
 
84
- except Exception as e:
85
- raise HTTPException(status_code=500, detail=str(e))
 
1
+ from fastapi import FastAPI, Request, HTTPException
2
+ from fastapi.responses import JSONResponse, FileResponse
3
+ from fastapi.staticfiles import StaticFiles
4
+ from huggingface_hub import InferenceClient
5
+ import json
 
6
 
7
+ app = FastAPI()
 
 
 
8
 
9
+ client = InferenceClient("NousResearch/Hermes-3-Llama-3.1-8B")
 
 
10
 
11
+ SYSTEM_MESSAGE = (
12
+ "You are a helpful, respectful and honest assistant. Always answer as helpfully "
13
+ "as possible, while being safe. Your answers should not include any harmful, "
14
+ "unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure "
15
+ "that your responses are socially unbiased and positive in nature.\n\nIf a question "
16
+ "does not make any sense, or is not factually coherent, explain why instead of "
17
+ "answering something not correct. If you don't know the answer to a question, please "
18
+ "don't share false information."
19
+ "Always respond in the language of user prompt for each prompt ."
20
  )
21
+ MAX_TOKENS = 2000
22
+ TEMPERATURE = 0.7
23
+ TOP_P = 0.95
24
 
25
+ def respond(message, history: list[tuple[str, str]]):
26
+ messages = [{"role": "system", "content": SYSTEM_MESSAGE}]
27
+
28
+ for val in history:
29
+ if val[0]:
30
+ messages.append({"role": "user", "content": val[0]})
31
+ if val[1]:
32
+ messages.append({"role": "assistant", "content": val[1]})
33
 
34
+ messages.append({"role": "user", "content": message})
35
+
36
+ response = client.chat_completion(
37
+ messages,
38
+ max_tokens=MAX_TOKENS,
39
+ stream=True,
40
+ temperature=TEMPERATURE,
41
+ top_p=TOP_P,
42
+ )
43
+
44
+ for message in response: # Handle regular iteration
45
+ yield message.choices[0].delta.content
46
+
47
+ from fastapi.middleware.cors import CORSMiddleware
48
 
 
49
  app.add_middleware(
50
  CORSMiddleware,
51
+ allow_origins=["https://artixiban-ll3.static.hf.space"], # Allow only this origin
52
  allow_credentials=True,
53
+ allow_methods=["*"], # Allow all methods (GET, POST, etc.)
54
+ allow_headers=["*"], # Allow all headers
55
  )
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  @app.post("/generate/")
58
+ async def generate(request: Request):
59
+ allowed_origin = "https://artixiban-ll3.static.hf.space"
60
+ origin = request.headers.get("origin")
61
+ if origin != allowed_origin:
62
+ raise HTTPException(status_code=403, detail="Origin not allowed")
63
+ form = await request.form()
64
+ prompt = form.get("prompt")
65
+ history = json.loads(form.get("history", "[]")) # Default to empty history
66
 
67
+ if not prompt:
68
+ raise HTTPException(status_code=400, detail="Prompt is required")
69
 
70
+ response_generator = respond(prompt, history)
71
+ final_response = ""
72
+ for part in response_generator:
73
+ final_response += part
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ return JSONResponse(content={"response": final_response})
76