abdullahalioo commited on
Commit
d6be5f7
·
verified ·
1 Parent(s): 8d16e4e

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +78 -0
main.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ from pydantic import BaseModel
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from fastapi.responses import StreamingResponse
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ import torch
7
+ import asyncio
8
+
9
+ # FastAPI app
10
+ app = FastAPI()
11
+
12
+ # CORS Middleware (so JS from browser can access it)
13
+ app.add_middleware(
14
+ CORSMiddleware,
15
+ allow_origins=["*"], # Change "*" to your frontend URL for better security
16
+ allow_credentials=True,
17
+ allow_methods=["*"],
18
+ allow_headers=["*"],
19
+ )
20
+
21
+ # Request body model
22
+ class Question(BaseModel):
23
+ question: str
24
+
25
+ # Load the model and tokenizer
26
+ model_name = "Qwen/Qwen2.5-7B-Instruct" # Use Qwen2.5-7B-Instruct (adjust for VL if needed)
27
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ model_name,
30
+ torch_dtype=torch.float16, # Use float16 for GPU memory efficiency
31
+ device_map="auto", # Automatically map to GPU/CPU
32
+ trust_remote_code=True
33
+ )
34
+
35
+ async def generate_response_chunks(prompt: str):
36
+ try:
37
+ # Prepare the input prompt
38
+ messages = [
39
+ {"role": "system", "content": "You are Orion AI assistant created by Abdullah Ali, who is very intelligent, 13 years old, and lives in Lahore."},
40
+ {"role": "user", "content": prompt}
41
+ ]
42
+ inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
43
+
44
+ # Asynchronous generator to yield tokens as they are generated
45
+ async def stream_tokens():
46
+ # Generate tokens one by one
47
+ for output in model.generate(
48
+ inputs,
49
+ max_new_tokens=512,
50
+ temperature=0.7,
51
+ top_p=0.9,
52
+ do_sample=True,
53
+ pad_token_id=tokenizer.eos_token_id,
54
+ return_dict_in_generate=True,
55
+ output_scores=False,
56
+ streaming=True # Enable streaming in model.generate (if supported)
57
+ ):
58
+ # Decode the latest token
59
+ token_id = output.sequences[0][-1] # Get the last generated token
60
+ token_text = tokenizer.decode([token_id], skip_special_tokens=True)
61
+ if token_text:
62
+ yield token_text
63
+ await asyncio.sleep(0.01) # Small delay to control streaming speed
64
+ else:
65
+ # Handle special tokens or empty outputs
66
+ continue
67
+
68
+ return stream_tokens()
69
+
70
+ except Exception as e:
71
+ yield f"Error occurred: {e}"
72
+
73
+ @app.post("/ask")
74
+ async def ask(question: Question):
75
+ return StreamingResponse(
76
+ generate_response_chunks(question.question),
77
+ media_type="text/plain"
78
+ )