giulio98 commited on
Commit
163c806
·
verified ·
1 Parent(s): b5ae53a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -17,7 +17,7 @@ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
17
  from langchain_docling import DoclingLoader
18
  from langchain_docling.loader import ExportType
19
  from langchain_text_splitters import RecursiveCharacterTextSplitter
20
- from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, TextIteratorStreamer
21
  from transformers.models.llama.modeling_llama import rotate_half
22
  import threading
23
  import shutil
@@ -30,9 +30,12 @@ from utils import (
30
 
31
  # Initialize the model and tokenizer.
32
  api_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
33
- model_name = "meta-llama/Llama-3.1-8B-Instruct"
 
34
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_token)
35
- model = AutoModelForCausalLM.from_pretrained(model_name, token=api_token, torch_dtype=torch.float16)
 
 
36
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
  model = model.eval()
38
  model.to(device)
 
17
  from langchain_docling import DoclingLoader
18
  from langchain_docling.loader import ExportType
19
  from langchain_text_splitters import RecursiveCharacterTextSplitter
20
+ from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, TextIteratorStreamer, BitsAndBytesConfig
21
  from transformers.models.llama.modeling_llama import rotate_half
22
  import threading
23
  import shutil
 
30
 
31
  # Initialize the model and tokenizer.
32
  api_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
33
+ # model_name = "meta-llama/Llama-3.1-8B-Instruct"
34
+ model_name = "google/gemma-3-27b-it"
35
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_token)
36
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
37
+ # model = AutoModelForCausalLM.from_pretrained(model_name, token=api_token, torch_dtype=torch.float16)
38
+ model = AutoModelForCausalLM.from_pretrained(model_name, token=api_token, quantization_config=quantization_config, torch_dtype="auto")
39
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
  model = model.eval()
41
  model.to(device)