Sirawitch commited on
Commit
e65e766
·
verified ·
1 Parent(s): fd410a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -12
app.py CHANGED
@@ -3,23 +3,34 @@ from pydantic import BaseModel
3
  from typing import Optional
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
 
 
 
 
 
 
6
 
7
  app = FastAPI()
8
 
9
- model_name = "scb10x/llama-3-typhoon-v1.5-8b-instruct"
10
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
11
 
12
- # ใช้ BitsAndBytes สำหรับ quantization
13
- config = AutoConfig.from_pretrained(model_name)
14
- config.quantization_config = BitsAndBytesConfig(load_in_8bit=True)
15
 
16
- # โหลดโมเดลด้วย 8-bit quantization
17
- model = AutoModelForCausalLM.from_pretrained(
18
- model_name,
19
- config=config,
20
- device_map="auto",
21
- torch_dtype=torch.float16,
22
- )
 
 
 
 
23
 
24
  class Query(BaseModel):
25
  queryResult: Optional[dict] = None
@@ -44,6 +55,7 @@ async def webhook(query: Query):
44
 
45
  return {"fulfillmentText": ai_response}
46
  except Exception as e:
 
47
  raise HTTPException(status_code=500, detail=str(e))
48
 
49
  if __name__ == "__main__":
 
3
  from typing import Optional
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
6
+ from transformers import BitsAndBytesConfig # เพิ่มการ import นี้
7
+ import logging
8
+
9
+ # ตั้งค่า logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
 
13
  app = FastAPI()
14
 
15
+ try:
16
+ model_name = "scb10x/llama-3-typhoon-v1.5-8b-instruct"
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
 
19
+ # ใช้ BitsAndBytes สำหรับ quantization
20
+ config = AutoConfig.from_pretrained(model_name)
21
+ config.quantization_config = BitsAndBytesConfig(load_in_8bit=True)
22
 
23
+ # โหลดโมเดลด้วย 8-bit quantization
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ model_name,
26
+ config=config,
27
+ device_map="auto",
28
+ torch_dtype=torch.float16,
29
+ )
30
+ logger.info("Model loaded successfully")
31
+ except Exception as e:
32
+ logger.error(f"Error loading model: {str(e)}")
33
+ raise
34
 
35
  class Query(BaseModel):
36
  queryResult: Optional[dict] = None
 
55
 
56
  return {"fulfillmentText": ai_response}
57
  except Exception as e:
58
+ logger.error(f"Error in webhook: {str(e)}")
59
  raise HTTPException(status_code=500, detail=str(e))
60
 
61
  if __name__ == "__main__":