Steph254 commited on
Commit
b94b847
·
verified ·
1 Parent(s): 5e6f6cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -42
app.py CHANGED
@@ -1,63 +1,84 @@
1
  import os
2
  import gradio as gr
3
- from transformers import LlamaTokenizer, AutoModelForCausalLM
4
  import torch
5
  import json
 
6
 
7
  # Set Hugging Face Token for Authentication (ensure it's set in your environment)
8
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
9
 
10
- # Load Llama 3.2 (QLoRA) Model on CPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8"
12
- tokenizer = LlamaTokenizer.from_pretrained(MODEL_NAME, token=HUGGINGFACE_TOKEN)
13
- model = AutoModelForCausalLM.from_pretrained(
14
- MODEL_NAME,
15
- token=HUGGINGFACE_TOKEN,
16
- device_map="cpu"
17
- )
18
-
19
- # Load Llama Guard for content moderation on CPU
20
  LLAMA_GUARD_NAME = "meta-llama/Llama-Guard-3-1B-INT4"
21
- guard_tokenizer = LlamaTokenizer.from_pretrained(LLAMA_GUARD_NAME, token=HUGGINGFACE_TOKEN)
22
- guard_model = AutoModelForCausalLM.from_pretrained(
23
- LLAMA_GUARD_NAME,
24
- token=HUGGINGFACE_TOKEN,
25
- device_map="cpu"
26
- )
27
 
28
  # Define Prompt Templates
29
  PROMPTS = {
30
- "project_analysis": """Analyze this project description and generate:
31
  1. Project timeline with milestones
32
  2. Required technology stack
33
  3. Potential risks
34
  4. Team composition
35
  5. Cost estimation
36
 
37
- Project: {project_description}""",
38
 
39
- "code_generation": """Generate implementation code for this feature:
40
  {feature_description}
41
 
42
  Considerations:
43
  - Use {programming_language}
44
  - Follow {coding_standards}
45
  - Include error handling
46
- - Add documentation""",
47
 
48
- "risk_analysis": """Predict potential risks for this project plan:
49
  {project_data}
50
 
51
- Format output as JSON with risk types, probabilities, and mitigation strategies"""
52
  }
53
 
54
  # Function: Content Moderation using Llama Guard
55
  def moderate_input(user_input):
56
- inputs = guard_tokenizer(user_input, return_tensors="pt", max_length=512, truncation=True)
57
- outputs = guard_model.generate(inputs.input_ids, max_length=512)
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  response = guard_tokenizer.decode(outputs[0], skip_special_tokens=True)
59
 
60
- if "flagged" in response.lower():
61
  return "⚠️ Content flagged by Llama Guard. Please modify your input."
62
  return None # Safe input, proceed normally
63
 
@@ -69,14 +90,16 @@ def generate_response(prompt_type, **kwargs):
69
  if moderation_warning:
70
  return moderation_warning # Stop processing if flagged
71
 
72
- inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
73
 
74
- outputs = model.generate(
75
- inputs.input_ids,
76
- max_length=1024,
77
- temperature=0.7 if prompt_type == "project_analysis" else 0.5,
78
- top_p=0.9
79
- )
 
 
80
 
81
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
82
 
@@ -92,7 +115,12 @@ def generate_code(feature_desc, lang="Python", standards="PEP8"):
92
  def predict_risks(project_data):
93
  risks = generate_response("risk_analysis", project_data=project_data)
94
  try:
95
- return json.loads(risks) # Convert to structured JSON if valid
 
 
 
 
 
96
  except json.JSONDecodeError:
97
  return {"error": "Invalid JSON response. Please refine your input."}
98
 
@@ -104,7 +132,7 @@ def create_gradio_interface():
104
  # Project Analysis Tab
105
  with gr.Tab("Project Setup"):
106
  project_input = gr.Textbox(label="Project Description", lines=5, placeholder="Describe your project...")
107
- project_output = gr.JSON(label="Project Analysis")
108
  analyze_btn = gr.Button("Analyze Project")
109
  analyze_btn.click(analyze_project, inputs=project_input, outputs=project_output)
110
 
@@ -137,14 +165,27 @@ def create_gradio_interface():
137
  chat_history.append((message, moderation_warning))
138
  return "", chat_history
139
 
140
- prompt = f"""Project Management Chat:
141
- Context: {message}
142
- Chat History: {chat_history}
143
- User: {message}
144
- AI:"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
- inputs = tokenizer(prompt, return_tensors="pt")
147
- outputs = model.generate(inputs.input_ids, max_length=1024)
148
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
149
  chat_history.append((message, response))
150
  return "", chat_history
@@ -157,4 +198,4 @@ def create_gradio_interface():
157
  # Run Gradio App
158
  if __name__ == "__main__":
159
  interface = create_gradio_interface()
160
- interface.launch(share=True)
 
1
  import os
2
  import gradio as gr
 
3
  import torch
4
  import json
5
+ from transformers import AutoTokenizer
6
 
7
  # Set Hugging Face Token for Authentication (ensure it's set in your environment)
8
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
9
 
10
+ # Function to load Llama model
11
+ def load_llama_model(model_name):
12
+ from transformers import LlamaForCausalLM, LlamaTokenizer
13
+
14
+ # Use AutoTokenizer which will handle various tokenizer types
15
+ tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_TOKEN, use_fast=False)
16
+
17
+ # Use the LlamaForCausalLM class which can properly load the consolidated.00.pth format
18
+ model = LlamaForCausalLM.from_pretrained(
19
+ model_name,
20
+ token=HUGGINGFACE_TOKEN,
21
+ torch_dtype=torch.float16, # Use float16 to reduce memory usage on CPU
22
+ low_cpu_mem_usage=True, # Optimize for low memory usage
23
+ device_map="cpu"
24
+ )
25
+
26
+ return tokenizer, model
27
+
28
+ # Load Llama 3.2 model
29
  MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8"
30
+ tokenizer, model = load_llama_model(MODEL_NAME)
31
+
32
+ # Load Llama Guard for content moderation
 
 
 
 
 
33
  LLAMA_GUARD_NAME = "meta-llama/Llama-Guard-3-1B-INT4"
34
+ guard_tokenizer, guard_model = load_llama_model(LLAMA_GUARD_NAME)
 
 
 
 
 
35
 
36
  # Define Prompt Templates
37
  PROMPTS = {
38
+ "project_analysis": """<|begin_of_text|><|prompt|>Analyze this project description and generate:
39
  1. Project timeline with milestones
40
  2. Required technology stack
41
  3. Potential risks
42
  4. Team composition
43
  5. Cost estimation
44
 
45
+ Project: {project_description}<|completion|>""",
46
 
47
+ "code_generation": """<|begin_of_text|><|prompt|>Generate implementation code for this feature:
48
  {feature_description}
49
 
50
  Considerations:
51
  - Use {programming_language}
52
  - Follow {coding_standards}
53
  - Include error handling
54
+ - Add documentation<|completion|>""",
55
 
56
+ "risk_analysis": """<|begin_of_text|><|prompt|>Predict potential risks for this project plan:
57
  {project_data}
58
 
59
+ Format output as JSON with risk types, probabilities, and mitigation strategies<|completion|>"""
60
  }
61
 
62
  # Function: Content Moderation using Llama Guard
63
  def moderate_input(user_input):
64
+ # Llama Guard specific prompt format
65
+ prompt = f"""<|begin_of_text|><|user|>
66
+ Input: {user_input}
67
+ Please verify that this input doesn't violate any content policies.
68
+ <|assistant|>"""
69
+
70
+ inputs = guard_tokenizer(prompt, return_tensors="pt", truncation=True)
71
+
72
+ with torch.no_grad(): # Disable gradient calculation for inference
73
+ outputs = guard_model.generate(
74
+ inputs.input_ids,
75
+ max_length=256,
76
+ temperature=0.1
77
+ )
78
+
79
  response = guard_tokenizer.decode(outputs[0], skip_special_tokens=True)
80
 
81
+ if "flagged" in response.lower() or "violated" in response.lower() or "policy violation" in response.lower():
82
  return "⚠️ Content flagged by Llama Guard. Please modify your input."
83
  return None # Safe input, proceed normally
84
 
 
90
  if moderation_warning:
91
  return moderation_warning # Stop processing if flagged
92
 
93
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
94
 
95
+ with torch.no_grad(): # Disable gradient calculation for inference
96
+ outputs = model.generate(
97
+ inputs.input_ids,
98
+ max_length=1024,
99
+ temperature=0.7 if prompt_type == "project_analysis" else 0.5,
100
+ top_p=0.9,
101
+ do_sample=True
102
+ )
103
 
104
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
105
 
 
115
  def predict_risks(project_data):
116
  risks = generate_response("risk_analysis", project_data=project_data)
117
  try:
118
+ # Try to extract JSON part from the response
119
+ import re
120
+ json_match = re.search(r'\{.*\}', risks, re.DOTALL)
121
+ if json_match:
122
+ return json.loads(json_match.group(0))
123
+ return {"error": "Could not parse JSON response"}
124
  except json.JSONDecodeError:
125
  return {"error": "Invalid JSON response. Please refine your input."}
126
 
 
132
  # Project Analysis Tab
133
  with gr.Tab("Project Setup"):
134
  project_input = gr.Textbox(label="Project Description", lines=5, placeholder="Describe your project...")
135
+ project_output = gr.Textbox(label="Project Analysis", lines=15) # Changed from JSON to Textbox for better formatting
136
  analyze_btn = gr.Button("Analyze Project")
137
  analyze_btn.click(analyze_project, inputs=project_input, outputs=project_output)
138
 
 
165
  chat_history.append((message, moderation_warning))
166
  return "", chat_history
167
 
168
+ # Format chat history for context
169
+ history_text = ""
170
+ for i, (usr, ai) in enumerate(chat_history[-3:]): # Use last 3 messages for context
171
+ history_text += f"User: {usr}\nAI: {ai}\n"
172
+
173
+ prompt = f"""<|begin_of_text|><|prompt|>Project Management Chat:
174
+ Context: {message}
175
+ Chat History: {history_text}
176
+ User: {message}<|completion|>"""
177
+
178
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
179
+
180
+ with torch.no_grad():
181
+ outputs = model.generate(
182
+ inputs.input_ids,
183
+ max_length=1024,
184
+ temperature=0.7,
185
+ top_p=0.9,
186
+ do_sample=True
187
+ )
188
 
 
 
189
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
190
  chat_history.append((message, response))
191
  return "", chat_history
 
198
  # Run Gradio App
199
  if __name__ == "__main__":
200
  interface = create_gradio_interface()
201
+ interface.launch(share=True)