abdullahalioo commited on
Commit
2c97dd8
·
verified ·
1 Parent(s): 09dee7d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +30 -65
main.py CHANGED
@@ -1,24 +1,18 @@
1
-
2
- from fastapi import FastAPI, Request
3
  from pydantic import BaseModel
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from fastapi.responses import StreamingResponse
6
- from transformers import AutoModelForCausalLM, AutoTokenizer
7
- import torch
8
  import asyncio
9
- import logging
10
-
11
- # Set up logging
12
- logging.basicConfig(level=logging.INFO)
13
- logger = logging.getLogger(__name__)
14
 
15
  # FastAPI app
16
  app = FastAPI()
17
 
18
- # CORS Middleware (for frontend access)
19
  app.add_middleware(
20
  CORSMiddleware,
21
- allow_origins=["*"], # Update to specific frontend URL in production
22
  allow_credentials=True,
23
  allow_methods=["*"],
24
  allow_headers=["*"],
@@ -28,67 +22,38 @@ app.add_middleware(
28
  class Question(BaseModel):
29
  question: str
30
 
31
- # Load the model and tokenizer
32
- model_name = "Qwen/Qwen2.5-7B-Instruct"
33
- try:
34
- logger.info(f"Loading model {model_name}...")
35
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
36
- model = AutoModelForCausalLM.from_pretrained(
37
- model_name,
38
- torch_dtype=torch.float16,
39
- device_map="auto",
40
- trust_remote_code=True
41
- )
42
- logger.info("Model loaded successfully.")
43
- except Exception as e:
44
- logger.error(f"Failed to load model: {e}")
45
- raise
46
 
47
  async def generate_response_chunks(prompt: str):
48
- try:
49
- # Prepare the input prompt
50
- messages = [
51
- {"role": "system", "content": "You are Orion AI assistant created by Abdullah Ali, who is very intelligent, 13 years old, and lives in Lahore."},
52
  {"role": "user", "content": prompt}
53
- ]
54
- inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
55
-
56
- # Asynchronous generator to yield tokens
57
- async def stream_tokens():
58
- for output in model.generate(
59
- inputs,
60
- max_new_tokens=512,
61
- temperature=0.7,
62
- top_p=0.9,
63
- do_sample=True,
64
- pad_token_id=tokenizer.eos_token_id,
65
- return_dict_in_generate=True,
66
- output_scores=False,
67
- streaming=True
68
- ):
69
- token_id = output.sequences[0][-1]
70
- token_text = tokenizer.decode([token_id], skip_special_tokens=True)
71
- if token_text:
72
- yield token_text
73
- await asyncio.sleep(0.01) # Control streaming speed
74
- logger.info("Streaming completed.")
75
-
76
- # Yield tokens from stream_tokens
77
- async for token in stream_tokens():
78
- yield token
79
-
80
- except Exception as e:
81
- logger.error(f"Error during generation: {e}")
82
- yield f"Error occurred: {e}"
83
 
84
  @app.post("/ask")
85
  async def ask(question: Question):
86
- logger.info(f"Received question: {question.question}")
87
  return StreamingResponse(
88
  generate_response_chunks(question.question),
89
  media_type="text/plain"
90
  )
91
-
92
- @app.get("/")
93
- async def root():
94
- return {"message": "Orion AI Chat API is running!"}
 
1
+ from fastapi import FastAPI
 
2
  from pydantic import BaseModel
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.responses import StreamingResponse
5
+ import httpx
 
6
  import asyncio
7
+ import json
 
 
 
 
8
 
9
  # FastAPI app
10
  app = FastAPI()
11
 
12
+ # CORS Middleware
13
  app.add_middleware(
14
  CORSMiddleware,
15
+ allow_origins=["*"],
16
  allow_credentials=True,
17
  allow_methods=["*"],
18
  allow_headers=["*"],
 
22
  class Question(BaseModel):
23
  question: str
24
 
25
+ # Your OWN Hosted HuggingFace Space URL
26
+ YOUR_SPACE_URL = "https://your-space-name-username.hf.space" # 🔥 change this!
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  async def generate_response_chunks(prompt: str):
29
+ payload = {
30
+ "messages": [
31
+ {"role": "system", "content": "You are a Orion AI assistant created by abdullah ali who is very intelligent and he is 13 years old and lives in Lahore."},
 
32
  {"role": "user", "content": prompt}
33
+ ],
34
+ "temperature": 0.7,
35
+ "max_tokens": 512,
36
+ "stream": True # Tell your server to stream output
37
+ }
38
+
39
+ async with httpx.AsyncClient(timeout=None) as client:
40
+ async with client.stream("POST", f"{YOUR_SPACE_URL}/v1/chat/completions", json=payload) as response:
41
+ async for line in response.aiter_lines():
42
+ if line.strip():
43
+ try:
44
+ # The server sends stream chunks, decode them
45
+ data = json.loads(line)
46
+ content = data['choices'][0]['delta']['content']
47
+ if content:
48
+ for letter in content:
49
+ yield letter
50
+ await asyncio.sleep(0.01) # simulate typing
51
+ except Exception as e:
52
+ yield f"Error decoding stream: {e}"
 
 
 
 
 
 
 
 
 
 
53
 
54
  @app.post("/ask")
55
  async def ask(question: Question):
 
56
  return StreamingResponse(
57
  generate_response_chunks(question.question),
58
  media_type="text/plain"
59
  )