Steph254 commited on
Commit
79ccf40
·
verified ·
1 Parent(s): dde7e39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -72
app.py CHANGED
@@ -5,47 +5,56 @@ import json
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
  from peft import PeftModel
7
 
8
- # Set Hugging Face Token for Authentication (ensure it's set in your environment)
9
- HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
10
 
11
- # Base model (needed for QLoRA adapter)
12
- BASE_MODEL = "meta-llama/Llama-3-1B-Instruct"
13
- QLORA_ADAPTER = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8"
 
14
 
15
  # Function to load Llama model
16
- def load_llama_model():
17
- print("Loading base model...")
18
- model = AutoModelForCausalLM.from_pretrained(
19
- BASE_MODEL,
20
- torch_dtype=torch.float32,
21
- device_map="cpu", # Ensure it runs on CPU
22
- token=HUGGINGFACE_TOKEN
23
- )
24
-
25
- print("Loading tokenizer...")
26
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False, token=HUGGINGFACE_TOKEN)
27
-
28
- print("Loading QLoRA adapter...")
29
- model = PeftModel.from_pretrained(
30
- model,
31
- QLORA_ADAPTER,
32
- token=HUGGINGFACE_TOKEN
33
- )
34
-
35
- print("Merging LoRA weights...")
36
- model = model.merge_and_unload() # Merge LoRA weights for inference
37
-
38
- return tokenizer, model
 
 
 
 
 
 
 
 
 
 
39
 
40
  # Load Llama 3.2 model
41
- MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8"
42
- tokenizer, model = load_llama_model()
43
 
44
  # Load Llama Guard for content moderation
45
- LLAMA_GUARD_NAME = "meta-llama/Llama-Guard-3-1B-INT4"
46
- guard_tokenizer, guard_model = load_llama_model(LLAMA_GUARD_NAME)
47
 
48
- # Define Prompt Templates
49
  PROMPTS = {
50
  "project_analysis": """<|begin_of_text|><|prompt|>Analyze this project description and generate:
51
  1. Project timeline with milestones
@@ -53,12 +62,10 @@ PROMPTS = {
53
  3. Potential risks
54
  4. Team composition
55
  5. Cost estimation
56
-
57
  Project: {project_description}<|completion|>""",
58
 
59
  "code_generation": """<|begin_of_text|><|prompt|>Generate implementation code for this feature:
60
  {feature_description}
61
-
62
  Considerations:
63
  - Use {programming_language}
64
  - Follow {coding_standards}
@@ -67,13 +74,11 @@ Considerations:
67
 
68
  "risk_analysis": """<|begin_of_text|><|prompt|>Predict potential risks for this project plan:
69
  {project_data}
70
-
71
  Format output as JSON with risk types, probabilities, and mitigation strategies<|completion|>"""
72
  }
73
 
74
- # Function: Content Moderation using Llama Guard
75
  def moderate_input(user_input):
76
- # Llama Guard specific prompt format
77
  prompt = f"""<|begin_of_text|><|user|>
78
  Input: {user_input}
79
  Please verify that this input doesn't violate any content policies.
@@ -81,7 +86,7 @@ Please verify that this input doesn't violate any content policies.
81
 
82
  inputs = guard_tokenizer(prompt, return_tensors="pt", truncation=True)
83
 
84
- with torch.no_grad(): # Disable gradient calculation for inference
85
  outputs = guard_model.generate(
86
  inputs.input_ids,
87
  max_length=256,
@@ -92,19 +97,19 @@ Please verify that this input doesn't violate any content policies.
92
 
93
  if "flagged" in response.lower() or "violated" in response.lower() or "policy violation" in response.lower():
94
  return "⚠️ Content flagged by Llama Guard. Please modify your input."
95
- return None # Safe input, proceed normally
96
 
97
- # Function: Generate AI responses
98
  def generate_response(prompt_type, **kwargs):
99
  prompt = PROMPTS[prompt_type].format(**kwargs)
100
 
101
  moderation_warning = moderate_input(prompt)
102
  if moderation_warning:
103
- return moderation_warning # Stop processing if flagged
104
 
105
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
106
 
107
- with torch.no_grad(): # Disable gradient calculation for inference
108
  outputs = model.generate(
109
  inputs.input_ids,
110
  max_length=1024,
@@ -115,40 +120,17 @@ def generate_response(prompt_type, **kwargs):
115
 
116
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
117
 
118
- # Function: Analyze project
119
- def analyze_project(project_desc):
120
- return generate_response("project_analysis", project_description=project_desc)
121
-
122
- # Function: Generate code
123
- def generate_code(feature_desc, lang="Python", standards="PEP8"):
124
- return generate_response("code_generation", feature_description=feature_desc, programming_language=lang, coding_standards=standards)
125
-
126
- # Function: Predict risks
127
- def predict_risks(project_data):
128
- risks = generate_response("risk_analysis", project_data=project_data)
129
- try:
130
- # Try to extract JSON part from the response
131
- import re
132
- json_match = re.search(r'\{.*\}', risks, re.DOTALL)
133
- if json_match:
134
- return json.loads(json_match.group(0))
135
- return {"error": "Could not parse JSON response"}
136
- except json.JSONDecodeError:
137
- return {"error": "Invalid JSON response. Please refine your input."}
138
-
139
- # Gradio UI
140
  def create_gradio_interface():
141
  with gr.Blocks(title="AI Project Manager", theme=gr.themes.Soft()) as demo:
142
  gr.Markdown("# 🚀 AI-Powered Project Manager & Code Assistant")
143
 
144
- # Project Analysis Tab
145
  with gr.Tab("Project Setup"):
146
  project_input = gr.Textbox(label="Project Description", lines=5, placeholder="Describe your project...")
147
- project_output = gr.Textbox(label="Project Analysis", lines=15) # Changed from JSON to Textbox for better formatting
148
  analyze_btn = gr.Button("Analyze Project")
149
  analyze_btn.click(analyze_project, inputs=project_input, outputs=project_output)
150
 
151
- # Code Generation Tab
152
  with gr.Tab("Code Assistant"):
153
  code_input = gr.Textbox(label="Feature Description", lines=3)
154
  lang_select = gr.Dropdown(["Python", "JavaScript", "Java", "C++"], label="Language", value="Python")
@@ -157,14 +139,12 @@ def create_gradio_interface():
157
  code_btn = gr.Button("Generate Code")
158
  code_btn.click(generate_code, inputs=[code_input, lang_select, standards_select], outputs=code_output)
159
 
160
- # Risk Analysis Tab
161
  with gr.Tab("Risk Analysis"):
162
  risk_input = gr.Textbox(label="Project Plan", lines=5)
163
  risk_output = gr.JSON(label="Risk Predictions")
164
  risk_btn = gr.Button("Predict Risks")
165
  risk_btn.click(predict_risks, inputs=risk_input, outputs=risk_output)
166
 
167
- # Real-time Chatbot for Collaboration
168
  with gr.Tab("Live Collaboration"):
169
  gr.Markdown("## Real-time Project Collaboration")
170
  chat = gr.Chatbot(height=400)
@@ -177,9 +157,8 @@ def create_gradio_interface():
177
  chat_history.append((message, moderation_warning))
178
  return "", chat_history
179
 
180
- # Format chat history for context
181
  history_text = ""
182
- for i, (usr, ai) in enumerate(chat_history[-3:]): # Use last 3 messages for context
183
  history_text += f"User: {usr}\nAI: {ai}\n"
184
 
185
  prompt = f"""<|begin_of_text|><|prompt|>Project Management Chat:
 
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
  from peft import PeftModel
7
 
8
+ # Set Hugging Face Token for Authentication
9
+ HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") # Ensure this is set in your environment
10
 
11
+ # Correct model paths (replace with your actual paths)
12
+ BASE_MODEL = "meta-llama/Llama-3-1B-Instruct" # Ensure this is the correct identifier
13
+ QLORA_ADAPTER = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8" # Ensure this is correct
14
+ LLAMA_GUARD_NAME = "meta-llama/Llama-Guard-3-1B-INT4" # Ensure this is correct
15
 
16
  # Function to load Llama model
17
+ def load_llama_model(model_name, is_guard=False):
18
+ print(f"Loading model: {model_name}")
19
+ try:
20
+ # Load tokenizer
21
+ tokenizer = AutoTokenizer.from_pretrained(
22
+ model_name,
23
+ use_fast=False,
24
+ token=HUGGINGFACE_TOKEN
25
+ )
26
+
27
+ # Load model
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ model_name,
30
+ torch_dtype=torch.float32,
31
+ device_map="cpu", # Ensure it runs on CPU
32
+ token=HUGGINGFACE_TOKEN
33
+ )
34
+
35
+ # Load QLoRA adapter if applicable
36
+ if not is_guard and "QLORA" in model_name:
37
+ print("Loading QLoRA adapter...")
38
+ model = PeftModel.from_pretrained(
39
+ model,
40
+ model_name,
41
+ token=HUGGINGFACE_TOKEN
42
+ )
43
+ print("Merging LoRA weights...")
44
+ model = model.merge_and_unload() # Merge LoRA weights for inference
45
+
46
+ return tokenizer, model
47
+ except Exception as e:
48
+ print(f"Error loading model {model_name}: {e}")
49
+ raise
50
 
51
  # Load Llama 3.2 model
52
+ tokenizer, model = load_llama_model(QLORA_ADAPTER)
 
53
 
54
  # Load Llama Guard for content moderation
55
+ guard_tokenizer, guard_model = load_llama_model(LLAMA_GUARD_NAME, is_guard=True)
 
56
 
57
+ # Define Prompt Templates (same as before)
58
  PROMPTS = {
59
  "project_analysis": """<|begin_of_text|><|prompt|>Analyze this project description and generate:
60
  1. Project timeline with milestones
 
62
  3. Potential risks
63
  4. Team composition
64
  5. Cost estimation
 
65
  Project: {project_description}<|completion|>""",
66
 
67
  "code_generation": """<|begin_of_text|><|prompt|>Generate implementation code for this feature:
68
  {feature_description}
 
69
  Considerations:
70
  - Use {programming_language}
71
  - Follow {coding_standards}
 
74
 
75
  "risk_analysis": """<|begin_of_text|><|prompt|>Predict potential risks for this project plan:
76
  {project_data}
 
77
  Format output as JSON with risk types, probabilities, and mitigation strategies<|completion|>"""
78
  }
79
 
80
+ # Function: Content Moderation using Llama Guard (same as before)
81
  def moderate_input(user_input):
 
82
  prompt = f"""<|begin_of_text|><|user|>
83
  Input: {user_input}
84
  Please verify that this input doesn't violate any content policies.
 
86
 
87
  inputs = guard_tokenizer(prompt, return_tensors="pt", truncation=True)
88
 
89
+ with torch.no_grad():
90
  outputs = guard_model.generate(
91
  inputs.input_ids,
92
  max_length=256,
 
97
 
98
  if "flagged" in response.lower() or "violated" in response.lower() or "policy violation" in response.lower():
99
  return "⚠️ Content flagged by Llama Guard. Please modify your input."
100
+ return None
101
 
102
+ # Function: Generate AI responses (same as before)
103
  def generate_response(prompt_type, **kwargs):
104
  prompt = PROMPTS[prompt_type].format(**kwargs)
105
 
106
  moderation_warning = moderate_input(prompt)
107
  if moderation_warning:
108
+ return moderation_warning
109
 
110
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
111
 
112
+ with torch.no_grad():
113
  outputs = model.generate(
114
  inputs.input_ids,
115
  max_length=1024,
 
120
 
121
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
122
 
123
+ # Gradio UI (same as before)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  def create_gradio_interface():
125
  with gr.Blocks(title="AI Project Manager", theme=gr.themes.Soft()) as demo:
126
  gr.Markdown("# 🚀 AI-Powered Project Manager & Code Assistant")
127
 
 
128
  with gr.Tab("Project Setup"):
129
  project_input = gr.Textbox(label="Project Description", lines=5, placeholder="Describe your project...")
130
+ project_output = gr.Textbox(label="Project Analysis", lines=15)
131
  analyze_btn = gr.Button("Analyze Project")
132
  analyze_btn.click(analyze_project, inputs=project_input, outputs=project_output)
133
 
 
134
  with gr.Tab("Code Assistant"):
135
  code_input = gr.Textbox(label="Feature Description", lines=3)
136
  lang_select = gr.Dropdown(["Python", "JavaScript", "Java", "C++"], label="Language", value="Python")
 
139
  code_btn = gr.Button("Generate Code")
140
  code_btn.click(generate_code, inputs=[code_input, lang_select, standards_select], outputs=code_output)
141
 
 
142
  with gr.Tab("Risk Analysis"):
143
  risk_input = gr.Textbox(label="Project Plan", lines=5)
144
  risk_output = gr.JSON(label="Risk Predictions")
145
  risk_btn = gr.Button("Predict Risks")
146
  risk_btn.click(predict_risks, inputs=risk_input, outputs=risk_output)
147
 
 
148
  with gr.Tab("Live Collaboration"):
149
  gr.Markdown("## Real-time Project Collaboration")
150
  chat = gr.Chatbot(height=400)
 
157
  chat_history.append((message, moderation_warning))
158
  return "", chat_history
159
 
 
160
  history_text = ""
161
+ for i, (usr, ai) in enumerate(chat_history[-3:]):
162
  history_text += f"User: {usr}\nAI: {ai}\n"
163
 
164
  prompt = f"""<|begin_of_text|><|prompt|>Project Management Chat: