Phoenix21 commited on
Commit
e857f76
·
verified ·
1 Parent(s): f7b276d

removed gemini llm

Browse files
Files changed (1) hide show
  1. pipeline.py +19 -13
pipeline.py CHANGED
@@ -164,14 +164,14 @@ RATE_LIMIT_REQUESTS = 60
164
  CACHE_SIZE_LIMIT = 1000
165
 
166
  # Google Gemini (primary)
167
- GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
168
- gemini_llm = ChatGoogleGenerativeAI(
169
- model="gemini-2.0-flash",
170
- temperature=0.5,
171
- max_tokens=None,
172
- timeout=None,
173
- max_retries=2,
174
- )
175
 
176
  # Fallback
177
  fallback_groq_api_key = os.environ.get("GROQ_API_KEY_FALLBACK", "GROQ_API_KEY")
@@ -450,15 +450,21 @@ class PipelineState:
450
 
451
  # Specialized chain for self-harm
452
  from prompts import selfharm_prompt
453
- self.self_harm_chain = LLMChain(llm=gemini_llm, prompt=selfharm_prompt, verbose=False)
 
 
 
454
 
455
  # NEW: chain for frustration/harsh queries
456
  from prompts import frustration_prompt
457
- self.frustration_chain = LLMChain(llm=gemini_llm, prompt=frustration_prompt, verbose=False)
 
 
458
 
459
  # NEW: chain for ethical conflict queries
460
  from prompts import ethical_conflict_prompt
461
- self.ethical_conflict_chain = LLMChain(llm=gemini_llm, prompt=ethical_conflict_prompt, verbose=False)
 
462
 
463
  # Build brand & wellness vectorstores
464
  brand_csv = "BrandAI.csv"
@@ -473,8 +479,8 @@ class PipelineState:
473
  self.gemini_llm = gemini_llm
474
  self.groq_fallback_llm = groq_fallback_llm
475
 
476
- self.brand_rag_chain = build_rag_chain2(brand_vs, self.gemini_llm)
477
- self.wellness_rag_chain = build_rag_chain(wellness_vs, self.gemini_llm)
478
 
479
  self.brand_rag_chain_fallback = build_rag_chain2(brand_vs, self.groq_fallback_llm)
480
  self.wellness_rag_chain_fallback = build_rag_chain(wellness_vs, self.groq_fallback_llm)
 
164
  CACHE_SIZE_LIMIT = 1000
165
 
166
  # Google Gemini (primary)
167
+ # GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
168
+ # gemini_llm = ChatGoogleGenerativeAI(
169
+ # model="gemini-2.0-flash",
170
+ # temperature=0.5,
171
+ # max_tokens=None,
172
+ # timeout=None,
173
+ # max_retries=2,
174
+ # )
175
 
176
  # Fallback
177
  fallback_groq_api_key = os.environ.get("GROQ_API_KEY_FALLBACK", "GROQ_API_KEY")
 
450
 
451
  # Specialized chain for self-harm
452
  from prompts import selfharm_prompt
453
+ # self.self_harm_chain = LLMChain(llm=gemini_llm, prompt=selfharm_prompt, verbose=False)
454
+
455
+ self.self_harm_chain = LLMChain(llm=groq_fallback_llm, prompt=selfharm_prompt, verbose=False)
456
+
457
 
458
  # NEW: chain for frustration/harsh queries
459
  from prompts import frustration_prompt
460
+ # self.frustration_chain = LLMChain(llm=gemini_llm, prompt=frustration_prompt, verbose=False)
461
+ self.frustration_chain = LLMChain(llm=groq_fallback_llm, prompt=frustration_prompt, verbose=False)
462
+
463
 
464
  # NEW: chain for ethical conflict queries
465
  from prompts import ethical_conflict_prompt
466
+ # self.ethical_conflict_chain = LLMChain(llm=gemini_llm, prompt=ethical_conflict_prompt, verbose=False)
467
+ self.ethical_conflict_chain = LLMChain(llm=groq_fallback_llm, prompt=ethical_conflict_prompt, verbose=False)
468
 
469
  # Build brand & wellness vectorstores
470
  brand_csv = "BrandAI.csv"
 
479
  self.gemini_llm = gemini_llm
480
  self.groq_fallback_llm = groq_fallback_llm
481
 
482
+ # self.brand_rag_chain = build_rag_chain2(brand_vs, self.gemini_llm)
483
+ # self.wellness_rag_chain = build_rag_chain(wellness_vs, self.gemini_llm)
484
 
485
  self.brand_rag_chain_fallback = build_rag_chain2(brand_vs, self.groq_fallback_llm)
486
  self.wellness_rag_chain_fallback = build_rag_chain(wellness_vs, self.groq_fallback_llm)