roshnn24 commited on
Commit
270c639
·
verified ·
1 Parent(s): 167782d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -24
app.py CHANGED
@@ -15,11 +15,16 @@ import re
15
  from werkzeug.utils import secure_filename
16
  import torch
17
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
18
 
19
  app = Flask(__name__)
20
 
21
  PORT = int(os.environ.get("PORT", 7860))
22
 
 
 
 
 
23
  UPLOAD_FOLDER = '/tmp/uploads' # Change to tmp directory for Spaces
24
  ALLOWED_EXTENSIONS = {'py'}
25
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
@@ -28,33 +33,38 @@ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
28
  DATABASE_PATH = '/tmp/chat_database.db'
29
 
30
  # Initialize LangChain with Ollama LLM
31
- from transformers import AutoTokenizer, AutoModelForCausalLM
32
- import torch
33
-
34
- # Load model and tokenizer
35
- model_name = "mistralai/Mistral-7B-Instruct-v0.1"
36
- tokenizer = AutoTokenizer.from_pretrained(model_name)
37
- model = AutoModelForCausalLM.from_pretrained(
38
- model_name,
39
- torch_dtype=torch.float16,
40
- device_map="auto",
41
- load_in_8bit=True
42
- )
43
-
44
- # Create pipeline
45
- pipe = pipeline(
46
- "text-generation",
47
- model=model,
48
- tokenizer=tokenizer,
49
- max_new_tokens=512,
50
- temperature=0.7,
51
- top_p=0.95,
52
- repetition_penalty=1.15
53
- )
54
 
 
 
 
 
 
 
 
 
 
 
55
 
 
 
56
 
57
- llm = HuggingFacePipeline(pipeline=pipe)
 
 
58
 
59
  @contextmanager
60
  def get_db_connection():
 
15
  from werkzeug.utils import secure_filename
16
  import torch
17
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
18
+ from huggingface_hub import login
19
 
20
  app = Flask(__name__)
21
 
22
  PORT = int(os.environ.get("PORT", 7860))
23
 
24
+ hf_token = os.environ.get("HF_TOKEN")
25
+ if hf_token:
26
+ login(HF_TOKEN)
27
+
28
  UPLOAD_FOLDER = '/tmp/uploads' # Change to tmp directory for Spaces
29
  ALLOWED_EXTENSIONS = {'py'}
30
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
 
33
  DATABASE_PATH = '/tmp/chat_database.db'
34
 
35
  # Initialize LangChain with Ollama LLM
36
+ if hf_token:
37
+ model_name = "mistralai/Mistral-7B-Instruct-v0.1"
38
+ else:
39
+ # Fallback to a free, smaller model
40
+ model_name = "microsoft/phi-4"
41
+
42
+ try:
43
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
44
+ model = AutoModelForCausalLM.from_pretrained(
45
+ model_name,
46
+ torch_dtype=torch.float16,
47
+ device_map="auto",
48
+ load_in_8bit=True
49
+ )
 
 
 
 
 
 
 
 
 
50
 
51
+ # Create pipeline
52
+ pipe = pipeline(
53
+ "text-generation",
54
+ model=model,
55
+ tokenizer=tokenizer,
56
+ max_new_tokens=512,
57
+ temperature=0.7,
58
+ top_p=0.95,
59
+ repetition_penalty=1.15
60
+ )
61
 
62
+ # Initialize LangChain with HuggingFacePipeline
63
+ llm = HuggingFacePipeline(pipeline=pipe)
64
 
65
+ except Exception as e:
66
+ print(f"Error loading model: {e}")
67
+ raise
68
 
69
  @contextmanager
70
  def get_db_connection():