cstr commited on
Commit
5611226
·
verified ·
1 Parent(s): e8e1e46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -31
app.py CHANGED
@@ -2653,19 +2653,19 @@ def create_app():
2653
  return update_model_info(provider, googleai_model)
2654
  return "<p>Model information not available</p>"
2655
 
2656
- def update_vision_indicator(provider, model_choice):
2657
  """Update the vision capability indicator"""
2658
- # Safety check for None model
2659
  if model_choice is None:
2660
- return False
 
 
2661
  return is_vision_model(provider, model_choice)
2662
 
2663
- def update_image_upload_visibility(provider, model_choice):
2664
  """Show/hide image upload based on model vision capabilities"""
2665
- # Safety check for None model
2666
- if model_choice is None:
2667
- return gr.update(visible=False)
2668
- is_vision = is_vision_model(provider, model_choice)
2669
  return gr.update(visible=is_vision)
2670
 
2671
  # Search model function
@@ -2710,19 +2710,26 @@ def create_app():
2710
  default_model = "mistralai/Mistral-7B-Instruct-v0.3" if "mistralai/Mistral-7B-Instruct-v0.3" in all_models else all_models[0] if all_models else None
2711
  return gr.update(choices=all_models, value=default_model)
2712
 
2713
- def search_poe_models(search_term):
2714
- """Filter Poe models based on search term"""
2715
- all_models = list(POE_MODELS.keys())
 
 
 
2716
  if not search_term:
2717
- return gr.update(choices=all_models, value="chinchilla" if "chinchilla" in all_models else all_models[0] if all_models else None)
2718
 
2719
  filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
2720
 
2721
  if filtered_models:
2722
  return gr.update(choices=filtered_models, value=filtered_models[0])
2723
  else:
2724
- return gr.update(choices=all_models, value="chinchilla" if "chinchilla" in all_models else all_models[0] if all_models else None)
2725
-
 
 
 
 
2726
  def search_groq_models(search_term):
2727
  """Filter Groq models based on search term"""
2728
  all_models = list(GROQ_MODELS.keys())
@@ -2839,29 +2846,48 @@ def create_app():
2839
  cohere_model,
2840
  together_model,
2841
  anthropic_model,
 
2842
  googleai_model
2843
  ]
2844
  ).then(
2845
  fn=update_context_for_provider,
2846
- inputs=[provider_choice, openrouter_model, openai_model, hf_model, groq_model, cohere_model, together_model, anthropic_model, googleai_model],
 
 
 
 
 
 
 
 
 
 
 
2847
  outputs=context_display
2848
  ).then(
2849
  fn=update_model_info_for_provider,
2850
- inputs=[provider_choice, openrouter_model, openai_model, hf_model, groq_model, cohere_model, together_model, anthropic_model, googleai_model],
 
 
 
 
 
 
 
 
 
 
 
2851
  outputs=model_info_display
2852
  ).then(
2853
- fn=lambda provider, model: update_vision_indicator(
2854
- provider,
2855
- get_current_model(provider, model, None, None, None, None, None, None, None, None)
2856
- ),
2857
- inputs=[provider_choice, openrouter_model],
2858
  outputs=is_vision_indicator
2859
  ).then(
2860
- fn=lambda provider, model: update_image_upload_visibility(
2861
- provider,
2862
- get_current_model(provider, model, None, None, None, None, None, None, None, None)
2863
- ),
2864
- inputs=[provider_choice, openrouter_model],
2865
  outputs=image_upload_container
2866
  )
2867
 
@@ -3136,17 +3162,27 @@ def create_app():
3136
  transforms=transforms,
3137
  api_key_override=api_key_override
3138
  )
 
 
 
 
 
 
 
 
3139
 
3140
  # Submit button click event
3141
  submit_btn.click(
3142
- fn=submit_message,
3143
  inputs=[
3144
  message, chatbot, provider_choice,
3145
- openrouter_model, openai_model, hf_model, groq_model, cohere_model, together_model, anthropic_model, poe_model, googleai_model,
 
3146
  temperature, max_tokens, top_p, frequency_penalty, presence_penalty, repetition_penalty,
3147
  top_k, min_p, seed, top_a, stream_output, response_format,
3148
  images, documents, reasoning_effort, system_message, transforms,
3149
- openrouter_api_key, openai_api_key, hf_api_key, groq_api_key, cohere_api_key, together_api_key, anthropic_api_key, poe_api_key, googleai_api_key
 
3150
  ],
3151
  outputs=chatbot,
3152
  show_progress="minimal",
@@ -3161,11 +3197,13 @@ def create_app():
3161
  fn=submit_message,
3162
  inputs=[
3163
  message, chatbot, provider_choice,
3164
- openrouter_model, openai_model, hf_model, groq_model, cohere_model, together_model, anthropic_model, poe_model, googleai_model,
 
3165
  temperature, max_tokens, top_p, frequency_penalty, presence_penalty, repetition_penalty,
3166
  top_k, min_p, seed, top_a, stream_output, response_format,
3167
  images, documents, reasoning_effort, system_message, transforms,
3168
- openrouter_api_key, openai_api_key, hf_api_key, groq_api_key, cohere_api_key, together_api_key, anthropic_api_key, poe_api_key, googleai_api_key
 
3169
  ],
3170
  outputs=chatbot,
3171
  show_progress="minimal",
 
2653
  return update_model_info(provider, googleai_model)
2654
  return "<p>Model information not available</p>"
2655
 
2656
+ def update_vision_indicator(provider, model_choice=None):
2657
  """Update the vision capability indicator"""
2658
+ # Simplified - don't call get_current_model since it causes issues
2659
  if model_choice is None:
2660
+ # Just check if the provider generally supports vision
2661
+ return provider in VISION_MODELS and len(VISION_MODELS[provider]) > 0
2662
+
2663
  return is_vision_model(provider, model_choice)
2664
 
2665
+ def update_image_upload_visibility(provider, model_choice=None):
2666
  """Show/hide image upload based on model vision capabilities"""
2667
+ # Simplified
2668
+ is_vision = update_vision_indicator(provider, model_choice)
 
 
2669
  return gr.update(visible=is_vision)
2670
 
2671
  # Search model function
 
2710
  default_model = "mistralai/Mistral-7B-Instruct-v0.3" if "mistralai/Mistral-7B-Instruct-v0.3" in all_models else all_models[0] if all_models else None
2711
  return gr.update(choices=all_models, value=default_model)
2712
 
2713
+ def search_models_generic(search_term, model_dict, default_model=None):
2714
+ """Generic model search function to reduce code duplication"""
2715
+ all_models = list(model_dict.keys())
2716
+ if not all_models:
2717
+ return gr.update(choices=[], value=None)
2718
+
2719
  if not search_term:
2720
+ return gr.update(choices=all_models, value=default_model if default_model in all_models else all_models[0])
2721
 
2722
  filtered_models = [model for model in all_models if search_term.lower() in model.lower()]
2723
 
2724
  if filtered_models:
2725
  return gr.update(choices=filtered_models, value=filtered_models[0])
2726
  else:
2727
+ return gr.update(choices=all_models, value=default_model if default_model in all_models else all_models[0])
2728
+
2729
+ def search_poe_models(search_term):
2730
+ """Filter Poe models based on search term"""
2731
+ return search_models_generic(search_term, POE_MODELS, "chinchilla")
2732
+
2733
  def search_groq_models(search_term):
2734
  """Filter Groq models based on search term"""
2735
  all_models = list(GROQ_MODELS.keys())
 
2846
  cohere_model,
2847
  together_model,
2848
  anthropic_model,
2849
+ poe_model,
2850
  googleai_model
2851
  ]
2852
  ).then(
2853
  fn=update_context_for_provider,
2854
+ inputs=[
2855
+ provider_choice,
2856
+ openrouter_model,
2857
+ openai_model,
2858
+ hf_model,
2859
+ groq_model,
2860
+ cohere_model,
2861
+ together_model,
2862
+ anthropic_model,
2863
+ poe_model,
2864
+ googleai_model
2865
+ ],
2866
  outputs=context_display
2867
  ).then(
2868
  fn=update_model_info_for_provider,
2869
+ inputs=[
2870
+ provider_choice,
2871
+ openrouter_model,
2872
+ openai_model,
2873
+ hf_model,
2874
+ groq_model,
2875
+ cohere_model,
2876
+ together_model,
2877
+ anthropic_model,
2878
+ poe_model,
2879
+ googleai_model
2880
+ ],
2881
  outputs=model_info_display
2882
  ).then(
2883
+ # Fix this with correct number of args using a simpler approach
2884
+ fn=lambda provider: update_vision_indicator(provider, None),
2885
+ inputs=provider_choice,
 
 
2886
  outputs=is_vision_indicator
2887
  ).then(
2888
+ # Same here
2889
+ fn=lambda provider: update_image_upload_visibility(provider, None),
2890
+ inputs=provider_choice,
 
 
2891
  outputs=image_upload_container
2892
  )
2893
 
 
3162
  transforms=transforms,
3163
  api_key_override=api_key_override
3164
  )
3165
+
3166
+ def clean_message(message):
3167
+ """Clean the message from style tags"""
3168
+ if isinstance(message, str):
3169
+ import re
3170
+ # Remove style tags
3171
+ message = re.sub(r'<userStyle>.*?</userStyle>', '', message)
3172
+ return message
3173
 
3174
  # Submit button click event
3175
  submit_btn.click(
3176
+ fn=lambda *args: submit_message(clean_message(args[0]), *args[1:]),
3177
  inputs=[
3178
  message, chatbot, provider_choice,
3179
+ openrouter_model, openai_model, hf_model, groq_model, cohere_model,
3180
+ together_model, anthropic_model, poe_model, googleai_model,
3181
  temperature, max_tokens, top_p, frequency_penalty, presence_penalty, repetition_penalty,
3182
  top_k, min_p, seed, top_a, stream_output, response_format,
3183
  images, documents, reasoning_effort, system_message, transforms,
3184
+ openrouter_api_key, openai_api_key, hf_api_key, groq_api_key, cohere_api_key,
3185
+ together_api_key, anthropic_api_key, poe_api_key, googleai_api_key
3186
  ],
3187
  outputs=chatbot,
3188
  show_progress="minimal",
 
3197
  fn=submit_message,
3198
  inputs=[
3199
  message, chatbot, provider_choice,
3200
+ openrouter_model, openai_model, hf_model, groq_model, cohere_model,
3201
+ together_model, anthropic_model, poe_model, googleai_model,
3202
  temperature, max_tokens, top_p, frequency_penalty, presence_penalty, repetition_penalty,
3203
  top_k, min_p, seed, top_a, stream_output, response_format,
3204
  images, documents, reasoning_effort, system_message, transforms,
3205
+ openrouter_api_key, openai_api_key, hf_api_key, groq_api_key, cohere_api_key,
3206
+ together_api_key, anthropic_api_key, poe_api_key, googleai_api_key
3207
  ],
3208
  outputs=chatbot,
3209
  show_progress="minimal",