Steph254 commited on
Commit
102f341
·
verified ·
1 Parent(s): 883f158

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -10
app.py CHANGED
@@ -1,18 +1,24 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
  import torch
4
  import json
5
  from datetime import datetime
6
 
7
- # Load Llama 3 model (quantized for CPU hosting)
8
- MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map="auto")
 
 
 
11
 
12
- # Load Llama Guard for content moderation
13
  LLAMA_GUARD_NAME = "meta-llama/Llama-Guard-3-1B-INT4"
14
  guard_tokenizer = AutoTokenizer.from_pretrained(LLAMA_GUARD_NAME)
15
- guard_model = AutoModelForCausalLM.from_pretrained(LLAMA_GUARD_NAME, torch_dtype=torch.float16, device_map="auto")
 
 
 
16
 
17
  # Define Prompt Templates
18
  PROMPTS = {
@@ -50,7 +56,7 @@ def moderate_input(user_input):
50
  return "⚠️ Content flagged by Llama Guard. Please modify your input."
51
  return None # Safe input, proceed normally
52
 
53
- # Function: Generate AI responses (Project Analysis, Code, or Risks)
54
  def generate_response(prompt_type, **kwargs):
55
  prompt = PROMPTS[prompt_type].format(**kwargs)
56
 
@@ -58,11 +64,11 @@ def generate_response(prompt_type, **kwargs):
58
  if moderation_warning:
59
  return moderation_warning # Stop processing if flagged
60
 
61
- inputs = tokenizer(prompt, return_tensors="pt", max_length=2048, truncation=True)
62
 
63
  outputs = model.generate(
64
  inputs.input_ids,
65
- max_length=2048,
66
  temperature=0.7 if prompt_type == "project_analysis" else 0.5,
67
  top_p=0.9
68
  )
@@ -133,7 +139,7 @@ def create_gradio_interface():
133
  AI:"""
134
 
135
  inputs = tokenizer(prompt, return_tensors="pt")
136
- outputs = model.generate(inputs.input_ids, max_length=2048)
137
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
138
  chat_history.append((message, response))
139
  return "", chat_history
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import json
5
  from datetime import datetime
6
 
7
+ # Load Llama 3.2 (QLoRA) Model on CPU
8
+ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8"
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ MODEL_NAME,
12
+ device_map="cpu" # Force CPU usage
13
+ )
14
 
15
+ # Load Llama Guard for content moderation on CPU
16
  LLAMA_GUARD_NAME = "meta-llama/Llama-Guard-3-1B-INT4"
17
  guard_tokenizer = AutoTokenizer.from_pretrained(LLAMA_GUARD_NAME)
18
+ guard_model = AutoModelForCausalLM.from_pretrained(
19
+ LLAMA_GUARD_NAME,
20
+ device_map="cpu"
21
+ )
22
 
23
  # Define Prompt Templates
24
  PROMPTS = {
 
56
  return "⚠️ Content flagged by Llama Guard. Please modify your input."
57
  return None # Safe input, proceed normally
58
 
59
+ # Function: Generate AI responses
60
  def generate_response(prompt_type, **kwargs):
61
  prompt = PROMPTS[prompt_type].format(**kwargs)
62
 
 
64
  if moderation_warning:
65
  return moderation_warning # Stop processing if flagged
66
 
67
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
68
 
69
  outputs = model.generate(
70
  inputs.input_ids,
71
+ max_length=1024,
72
  temperature=0.7 if prompt_type == "project_analysis" else 0.5,
73
  top_p=0.9
74
  )
 
139
  AI:"""
140
 
141
  inputs = tokenizer(prompt, return_tensors="pt")
142
+ outputs = model.generate(inputs.input_ids, max_length=1024)
143
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
144
  chat_history.append((message, response))
145
  return "", chat_history