adilkh26 commited on
Commit
2272006
·
verified ·
1 Parent(s): bb1f5be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -14
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import gradio as gr
2
  import torch
3
- import deepspeed
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
  # Model name
@@ -9,27 +8,23 @@ model_name = "OpenGVLab/InternVideo2_5_Chat_8B"
9
  # Load tokenizer
10
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
11
 
12
- # Enable DeepSpeed Inference (ZeRO-3)
13
- ds_engine = deepspeed.init_inference(
14
- dtype=torch.float16, # Use float16 for efficiency
15
- replace_method="auto", # Automatically replace ops for inference
16
- replace_with_kernel_inject=True
17
- )
18
 
19
- # Load model with DeepSpeed
20
  model = AutoModelForCausalLM.from_pretrained(
21
- model_name,
22
  trust_remote_code=True,
23
- torch_dtype=torch.float16,
24
- device_map="auto" # Auto place on GPU
25
  )
26
 
27
- # Apply DeepSpeed to model
28
- model = ds_engine.module(model)
29
 
30
  # Define inference function
31
  def chat_with_model(prompt):
32
- inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
33
  output = model.generate(**inputs, max_length=200)
34
  return tokenizer.decode(output[0], skip_special_tokens=True)
35
 
 
1
  import gradio as gr
2
  import torch
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
  # Model name
 
8
  # Load tokenizer
9
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
10
 
11
+ # Detect device
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
13
 
14
+ # Load model
15
  model = AutoModelForCausalLM.from_pretrained(
16
+ model_name,
17
  trust_remote_code=True,
18
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32, # Use float16 on GPU, float32 on CPU
19
+ device_map="auto" if device == "cuda" else None # Use GPU if available
20
  )
21
 
22
+ # Move model to device
23
+ model.to(device)
24
 
25
  # Define inference function
26
  def chat_with_model(prompt):
27
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
28
  output = model.generate(**inputs, max_length=200)
29
  return tokenizer.decode(output[0], skip_special_tokens=True)
30