abdullahalioo commited on
Commit
a6a8da7
·
verified ·
1 Parent(s): d6be5f7

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +34 -20
main.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from fastapi import FastAPI, Request
2
  from pydantic import BaseModel
3
  from fastapi.middleware.cors import CORSMiddleware
@@ -5,14 +6,19 @@ from fastapi.responses import StreamingResponse
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  import torch
7
  import asyncio
 
 
 
 
 
8
 
9
  # FastAPI app
10
  app = FastAPI()
11
 
12
- # CORS Middleware (so JS from browser can access it)
13
  app.add_middleware(
14
  CORSMiddleware,
15
- allow_origins=["*"], # Change "*" to your frontend URL for better security
16
  allow_credentials=True,
17
  allow_methods=["*"],
18
  allow_headers=["*"],
@@ -23,14 +29,20 @@ class Question(BaseModel):
23
  question: str
24
 
25
  # Load the model and tokenizer
26
- model_name = "Qwen/Qwen2.5-7B-Instruct" # Use Qwen2.5-7B-Instruct (adjust for VL if needed)
27
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
28
- model = AutoModelForCausalLM.from_pretrained(
29
- model_name,
30
- torch_dtype=torch.float16, # Use float16 for GPU memory efficiency
31
- device_map="auto", # Automatically map to GPU/CPU
32
- trust_remote_code=True
33
- )
 
 
 
 
 
 
34
 
35
  async def generate_response_chunks(prompt: str):
36
  try:
@@ -41,9 +53,8 @@ async def generate_response_chunks(prompt: str):
41
  ]
42
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
43
 
44
- # Asynchronous generator to yield tokens as they are generated
45
  async def stream_tokens():
46
- # Generate tokens one by one
47
  for output in model.generate(
48
  inputs,
49
  max_new_tokens=512,
@@ -53,26 +64,29 @@ async def generate_response_chunks(prompt: str):
53
  pad_token_id=tokenizer.eos_token_id,
54
  return_dict_in_generate=True,
55
  output_scores=False,
56
- streaming=True # Enable streaming in model.generate (if supported)
57
  ):
58
- # Decode the latest token
59
- token_id = output.sequences[0][-1] # Get the last generated token
60
  token_text = tokenizer.decode([token_id], skip_special_tokens=True)
61
  if token_text:
62
  yield token_text
63
- await asyncio.sleep(0.01) # Small delay to control streaming speed
64
- else:
65
- # Handle special tokens or empty outputs
66
- continue
67
 
68
  return stream_tokens()
69
 
70
  except Exception as e:
 
71
  yield f"Error occurred: {e}"
72
 
73
  @app.post("/ask")
74
  async def ask(question: Question):
 
75
  return StreamingResponse(
76
  generate_response_chunks(question.question),
77
  media_type="text/plain"
78
- )
 
 
 
 
 
1
+
2
  from fastapi import FastAPI, Request
3
  from pydantic import BaseModel
4
  from fastapi.middleware.cors import CORSMiddleware
 
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  import torch
8
  import asyncio
9
+ import logging
10
+
11
+ # Set up logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
 
15
  # FastAPI app
16
  app = FastAPI()
17
 
18
+ # CORS Middleware (for frontend access)
19
  app.add_middleware(
20
  CORSMiddleware,
21
+ allow_origins=["*"], # Update to specific frontend URL in production
22
  allow_credentials=True,
23
  allow_methods=["*"],
24
  allow_headers=["*"],
 
29
  question: str
30
 
31
  # Load the model and tokenizer
32
+ model_name = "Qwen/Qwen2.5-7B-Instruct"
33
+ try:
34
+ logger.info(f"Loading model {model_name}...")
35
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
36
+ model = AutoModelForCausalLM.from_pretrained(
37
+ model_name,
38
+ torch_dtype=torch.float16,
39
+ device_map="auto",
40
+ trust_remote_code=True
41
+ )
42
+ logger.info("Model loaded successfully.")
43
+ except Exception as e:
44
+ logger.error(f"Failed to load model: {e}")
45
+ raise
46
 
47
  async def generate_response_chunks(prompt: str):
48
  try:
 
53
  ]
54
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
55
 
56
+ # Asynchronous generator to yield tokens
57
  async def stream_tokens():
 
58
  for output in model.generate(
59
  inputs,
60
  max_new_tokens=512,
 
64
  pad_token_id=tokenizer.eos_token_id,
65
  return_dict_in_generate=True,
66
  output_scores=False,
67
+ streaming=True
68
  ):
69
+ token_id = output.sequences[0][-1]
 
70
  token_text = tokenizer.decode([token_id], skip_special_tokens=True)
71
  if token_text:
72
  yield token_text
73
+ await asyncio.sleep(0.01) # Control streaming speed
74
+ logger.info("Streaming completed.")
 
 
75
 
76
  return stream_tokens()
77
 
78
  except Exception as e:
79
+ logger.error(f"Error during generation: {e}")
80
  yield f"Error occurred: {e}"
81
 
82
  @app.post("/ask")
83
  async def ask(question: Question):
84
+ logger.info(f"Received question: {question.question}")
85
  return StreamingResponse(
86
  generate_response_chunks(question.question),
87
  media_type="text/plain"
88
+ )
89
+
90
+ @app.get("/")
91
+ async def root():
92
+ return {"message": "Orion AI Chat API is running!"}