pradeepsengarr commited on
Commit
a75db23
·
verified ·
1 Parent(s): 16e2c10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -32,7 +32,6 @@
32
  # else:
33
  # st.warning("Please enter a query!")
34
 
35
-
36
  import os
37
  import streamlit as st
38
  import torch
@@ -48,11 +47,14 @@ if not hf_token:
48
  # Check if CUDA is available
49
  device = "cuda" if torch.cuda.is_available() else "cpu"
50
 
51
- # Load the model and tokenizer with token authentication
52
  MODEL_NAME = "google/gemma-2b-it"
53
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=hf_token)
54
  model = AutoModelForCausalLM.from_pretrained(
55
- MODEL_NAME, token=hf_token, torch_dtype=torch.float16 if device == "cuda" else torch.float32, device_map="auto"
 
 
 
56
  )
57
 
58
  # Streamlit UI
@@ -63,10 +65,10 @@ if st.button("Generate Code"):
63
  if user_input:
64
  with st.spinner("⏳ Generating response... Please wait!"):
65
  inputs = tokenizer(user_input, return_tensors="pt").to(device)
66
- output = model.generate(**inputs, max_new_tokens=100)
67
  response = tokenizer.decode(output[0], skip_special_tokens=True)
68
 
69
  st.subheader("📝 Generated Code:")
70
- st.code(response, language="python") # Display code properly
71
  else:
72
  st.warning("⚠️ Please enter a query!")
 
32
  # else:
33
  # st.warning("Please enter a query!")
34
 
 
35
  import os
36
  import streamlit as st
37
  import torch
 
47
  # Check if CUDA is available
48
  device = "cuda" if torch.cuda.is_available() else "cpu"
49
 
50
+ # Load model in 8-bit mode (less RAM usage)
51
  MODEL_NAME = "google/gemma-2b-it"
52
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=hf_token)
53
  model = AutoModelForCausalLM.from_pretrained(
54
+ MODEL_NAME,
55
+ token=hf_token,
56
+ device_map="auto",
57
+ load_in_8bit=True # 👈 This reduces RAM usage!
58
  )
59
 
60
  # Streamlit UI
 
65
  if user_input:
66
  with st.spinner("⏳ Generating response... Please wait!"):
67
  inputs = tokenizer(user_input, return_tensors="pt").to(device)
68
+ output = model.generate(**inputs, max_new_tokens=50)
69
  response = tokenizer.decode(output[0], skip_special_tokens=True)
70
 
71
  st.subheader("📝 Generated Code:")
72
+ st.code(response, language="python")
73
  else:
74
  st.warning("⚠️ Please enter a query!")