Alejadro Sanchez-Giraldo commited on
Commit
b930686
·
1 Parent(s): 1afabcf
Files changed (2) hide show
  1. app.py +14 -5
  2. requirements.txt +1 -0
app.py CHANGED
@@ -38,16 +38,25 @@ print("MPS available: ", torch.backends.mps.is_available())
38
  tokenizer = AutoTokenizer.from_pretrained(
39
  "deepseek-ai/deepseek-coder-1.3b-instruct", trust_remote_code=True)
40
  model = AutoModelForCausalLM.from_pretrained(
41
- "deepseek-ai/deepseek-coder-1.3b-instruct", trust_remote_code=True, torch_dtype=torch.bfloat16)
 
 
 
 
42
 
43
  # Disable tokenizers parallelism warning
44
  os.environ["TOKENIZERS_PARALLELISM"] = "True"
45
 
46
 
47
- # Check if MPS (Metal Performance Shaders) is available
48
- device = torch.device(
49
- "mps") if torch.backends.mps.is_available() else torch.device("cpu")
50
- model = model.to(device)
 
 
 
 
 
51
 
52
 
53
  # Function to handle user input and generate a response
 
38
  tokenizer = AutoTokenizer.from_pretrained(
39
  "deepseek-ai/deepseek-coder-1.3b-instruct", trust_remote_code=True)
40
  model = AutoModelForCausalLM.from_pretrained(
41
+ "deepseek-ai/deepseek-coder-1.3b-instruct",
42
+ trust_remote_code=True,
43
+ torch_dtype=torch.float16, # Use float16 for better GPU memory efficiency
44
+ device_map="auto" # Automatically handle device placement
45
+ )
46
 
47
  # Disable tokenizers parallelism warning
48
  os.environ["TOKENIZERS_PARALLELISM"] = "True"
49
 
50
 
51
+ # Configure device
52
+ if torch.cuda.is_available():
53
+ device = torch.device("cuda")
54
+ # Print GPU information
55
+ print(f"Using GPU: {torch.cuda.get_device_name(0)}")
56
+ print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
57
+ else:
58
+ device = torch.device("cpu")
59
+ print("No GPU available, using CPU")
60
 
61
 
62
  # Function to handle user input and generate a response
requirements.txt CHANGED
@@ -6,3 +6,4 @@ minijinja
6
  torch --extra-index-url https://download.pytorch.org/whl/cu118
7
  torchvision
8
  torchaudio
 
 
6
  torch --extra-index-url https://download.pytorch.org/whl/cu118
7
  torchvision
8
  torchaudio
9
+ accelerate==0.26.0