Miquel Farré commited on
Commit
c51a7f7
·
1 Parent(s): e6b8427

adding api calls with fallback

Browse files
Files changed (3) hide show
  1. app.py +4 -2
  2. e2bqwen.py +186 -23
  3. requirements.txt +2 -1
app.py CHANGED
@@ -22,8 +22,10 @@ TMP_DIR = './tmp/'
22
  if not os.path.exists(TMP_DIR):
23
  os.makedirs(TMP_DIR)
24
 
25
- model = QwenVLAPIModel()
26
- login(token=os.getenv("HUGGINGFACE_API_KEY"))
 
 
27
 
28
  custom_css = """
29
  /* Your existing CSS */
 
22
  if not os.path.exists(TMP_DIR):
23
  os.makedirs(TMP_DIR)
24
 
25
+ hf_token = os.getenv("HUGGINGFACE_API_KEY")
26
+ login(token=hf_token)
27
+ model = QwenVLAPIModel(hf_token = hf_token)
28
+
29
 
30
  custom_css = """
31
  /* Your existing CSS */
e2bqwen.py CHANGED
@@ -435,48 +435,175 @@ REMEMBER TO ALWAYS CLICK IN THE MIDDLE OF THE TEXT, NOT ON THE SIDE, NOT UNDER.
435
 
436
 
437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
  class QwenVLAPIModel(Model):
439
- """Model wrapper for Qwen2.5VL API"""
440
 
441
  def __init__(
442
  self,
443
- model_path: str = "Qwen/Qwen2.5-VL-72B-Instruct",
444
- provider: str = "hyperbolic"
 
 
445
  ):
446
  super().__init__()
447
  self.model_path = model_path
448
  self.model_id = model_path
449
  self.provider = provider
 
 
450
 
451
- self.client = InferenceClient(
 
452
  provider=self.provider,
453
  )
454
 
 
 
 
 
 
 
 
 
 
455
  def __call__(
456
  self,
457
  messages: List[Dict[str, Any]],
458
  stop_sequences: Optional[List[str]] = None,
459
  **kwargs
460
  ) -> ChatMessage:
461
- """Convert a list of messages to an API request and return the response"""
462
- # # Count images in messages - debug
463
- # image_count = 0
464
- # for msg in messages:
465
- # if isinstance(msg.get("content"), list):
466
- # for item in msg["content"]:
467
- # if isinstance(item, dict) and item.get("type") == "image":
468
- # image_count += 1
469
 
470
- # print(f"QwenVLAPIModel received {len(messages)} messages with {image_count} images")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
 
472
- # Format the messages for the API
473
-
474
  formatted_messages = []
475
 
476
  for msg in messages:
477
  role = msg["role"]
 
 
478
  if isinstance(msg["content"], list):
479
- content = []
480
  for item in msg["content"]:
481
  if item["type"] == "text":
482
  content.append({"type": "text", "text": item["text"]})
@@ -499,14 +626,48 @@ class QwenVLAPIModel(Model):
499
  }
500
  })
501
  else:
 
502
  content = [{"type": "text", "text": msg["content"]}]
503
-
504
  formatted_messages.append({"role": role, "content": content})
505
 
506
- # Make the API request
507
- completion = self.client.chat.completions.create(
508
- model=self.model_path,
509
- messages=formatted_messages,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
  max_tokens=kwargs.get("max_new_tokens", 512),
511
  temperature=kwargs.get("temperature", 0.7),
512
  top_p=kwargs.get("top_p", 0.9),
@@ -523,7 +684,8 @@ class QwenVLAPIModel(Model):
523
  "class": self.__class__.__name__,
524
  "model_path": self.model_path,
525
  "provider": self.provider,
526
- # We don't save the API key for security reasons
 
527
  }
528
 
529
  @classmethod
@@ -532,4 +694,5 @@ class QwenVLAPIModel(Model):
532
  return cls(
533
  model_path=data.get("model_path", "Qwen/Qwen2.5-VL-72B-Instruct"),
534
  provider=data.get("provider", "hyperbolic"),
535
- )
 
 
435
 
436
 
437
 
438
+ # class QwenVLAPIModel(Model):
439
+ # """Model wrapper for Qwen2.5VL API"""
440
+
441
+ # def __init__(
442
+ # self,
443
+ # model_path: str = "Qwen/Qwen2.5-VL-72B-Instruct",
444
+ # provider: str = "hyperbolic"
445
+ # ):
446
+ # super().__init__()
447
+ # self.model_path = model_path
448
+ # self.model_id = model_path
449
+ # self.provider = provider
450
+
451
+ # self.client = InferenceClient(
452
+ # provider=self.provider,
453
+ # )
454
+
455
+ # def __call__(
456
+ # self,
457
+ # messages: List[Dict[str, Any]],
458
+ # stop_sequences: Optional[List[str]] = None,
459
+ # **kwargs
460
+ # ) -> ChatMessage:
461
+ # """Convert a list of messages to an API request and return the response"""
462
+ # # # Count images in messages - debug
463
+ # # image_count = 0
464
+ # # for msg in messages:
465
+ # # if isinstance(msg.get("content"), list):
466
+ # # for item in msg["content"]:
467
+ # # if isinstance(item, dict) and item.get("type") == "image":
468
+ # # image_count += 1
469
+
470
+ # # print(f"QwenVLAPIModel received {len(messages)} messages with {image_count} images")
471
+
472
+ # # Format the messages for the API
473
+
474
+ # formatted_messages = []
475
+
476
+ # for msg in messages:
477
+ # role = msg["role"]
478
+ # if isinstance(msg["content"], list):
479
+ # content = []
480
+ # for item in msg["content"]:
481
+ # if item["type"] == "text":
482
+ # content.append({"type": "text", "text": item["text"]})
483
+ # elif item["type"] == "image":
484
+ # # Handle image path or direct image object
485
+ # if isinstance(item["image"], str):
486
+ # # Image is a path
487
+ # with open(item["image"], "rb") as image_file:
488
+ # base64_image = base64.b64encode(image_file.read()).decode("utf-8")
489
+ # else:
490
+ # # Image is a PIL image or similar object
491
+ # img_byte_arr = io.BytesIO()
492
+ # item["image"].save(img_byte_arr, format="PNG")
493
+ # base64_image = base64.b64encode(img_byte_arr.getvalue()).decode("utf-8")
494
+
495
+ # content.append({
496
+ # "type": "image_url",
497
+ # "image_url": {
498
+ # "url": f"data:image/png;base64,{base64_image}"
499
+ # }
500
+ # })
501
+ # else:
502
+ # content = [{"type": "text", "text": msg["content"]}]
503
+
504
+ # formatted_messages.append({"role": role, "content": content})
505
+
506
+ # # Make the API request
507
+ # completion = self.client.chat.completions.create(
508
+ # model=self.model_path,
509
+ # messages=formatted_messages,
510
+ # max_tokens=kwargs.get("max_new_tokens", 512),
511
+ # temperature=kwargs.get("temperature", 0.7),
512
+ # top_p=kwargs.get("top_p", 0.9),
513
+ # )
514
+
515
+ # # Extract the response text
516
+ # output_text = completion.choices[0].message.content
517
+
518
+ # return ChatMessage(role=MessageRole.ASSISTANT, content=output_text)
519
+
520
+ # def to_dict(self) -> Dict[str, Any]:
521
+ # """Convert the model to a dictionary"""
522
+ # return {
523
+ # "class": self.__class__.__name__,
524
+ # "model_path": self.model_path,
525
+ # "provider": self.provider,
526
+ # # We don't save the API key for security reasons
527
+ # }
528
+
529
+ # @classmethod
530
+ # def from_dict(cls, data: Dict[str, Any]) -> "QwenVLAPIModel":
531
+ # """Create a model from a dictionary"""
532
+ # return cls(
533
+ # model_path=data.get("model_path", "Qwen/Qwen2.5-VL-72B-Instruct"),
534
+ # provider=data.get("provider", "hyperbolic"),
535
+ # )
536
  class QwenVLAPIModel(Model):
537
+ """Model wrapper for Qwen2.5VL API with fallback mechanism"""
538
 
539
  def __init__(
540
  self,
541
+ model_path: str = "Qwen/Qwen2.5-VL-72B-Instruct",
542
+ provider: str = "hyperbolic",
543
+ hf_token: str = None,
544
+ hf_base_url: str = "https://n5wr7lfx6wp94tvl.us-east-1.aws.endpoints.huggingface.cloud/v1/"
545
  ):
546
  super().__init__()
547
  self.model_path = model_path
548
  self.model_id = model_path
549
  self.provider = provider
550
+ self.hf_token = hf_token
551
+ self.hf_base_url = hf_base_url
552
 
553
+ # Initialize hyperbolic client
554
+ self.hyperbolic_client = InferenceClient(
555
  provider=self.provider,
556
  )
557
 
558
+ # Initialize HF OpenAI-compatible client if token is provided
559
+ self.hf_client = None
560
+ if hf_token:
561
+ from openai import OpenAI
562
+ self.hf_client = OpenAI(
563
+ base_url=self.hf_base_url,
564
+ api_key=self.hf_token
565
+ )
566
+
567
  def __call__(
568
  self,
569
  messages: List[Dict[str, Any]],
570
  stop_sequences: Optional[List[str]] = None,
571
  **kwargs
572
  ) -> ChatMessage:
573
+ """Convert a list of messages to an API request with fallback mechanism"""
 
 
 
 
 
 
 
574
 
575
+ # Format messages once for both APIs
576
+ formatted_messages = self._format_messages(messages)
577
+
578
+ # First try the HF endpoint if available
579
+ if self.hf_client:
580
+ try:
581
+ completion = self._call_hf_endpoint(
582
+ formatted_messages,
583
+ stop_sequences,
584
+ **kwargs
585
+ )
586
+ return ChatMessage(role=MessageRole.ASSISTANT, content=completion)
587
+ except Exception as e:
588
+ print(f"HF endpoint failed with error: {e}. Falling back to hyperbolic.")
589
+ # Continue to fallback
590
+
591
+ # Fallback to hyperbolic
592
+ try:
593
+ return self._call_hyperbolic(formatted_messages, stop_sequences, **kwargs)
594
+ except Exception as e:
595
+ raise Exception(f"Both endpoints failed. Last error: {e}")
596
+
597
+ def _format_messages(self, messages: List[Dict[str, Any]]):
598
+ """Format messages for API requests - works for both endpoints"""
599
 
 
 
600
  formatted_messages = []
601
 
602
  for msg in messages:
603
  role = msg["role"]
604
+ content = []
605
+
606
  if isinstance(msg["content"], list):
 
607
  for item in msg["content"]:
608
  if item["type"] == "text":
609
  content.append({"type": "text", "text": item["text"]})
 
626
  }
627
  })
628
  else:
629
+ # Plain text message
630
  content = [{"type": "text", "text": msg["content"]}]
631
+
632
  formatted_messages.append({"role": role, "content": content})
633
 
634
+ return formatted_messages
635
+
636
+ def _call_hf_endpoint(self, formatted_messages, stop_sequences=None, **kwargs):
637
+ """Call the Hugging Face OpenAI-compatible endpoint"""
638
+
639
+ # Extract parameters with defaults
640
+ max_tokens = kwargs.get("max_new_tokens", 512)
641
+ temperature = kwargs.get("temperature", 0.7)
642
+ top_p = kwargs.get("top_p", 0.9)
643
+ stream = kwargs.get("stream", False)
644
+
645
+ completion = self.hf_client.chat.completions.create(
646
+ model="tgi", # Model name for the endpoint
647
+ messages=formatted_messages,
648
+ max_tokens=max_tokens,
649
+ temperature=temperature,
650
+ top_p=top_p,
651
+ stream=stream,
652
+ stop=stop_sequences
653
+ )
654
+
655
+ if stream:
656
+ # For streaming responses, return a generator
657
+ def stream_generator():
658
+ for chunk in completion:
659
+ yield chunk.choices[0].delta.content or ""
660
+ return stream_generator()
661
+ else:
662
+ # For non-streaming, return the full text
663
+ return completion.choices[0].message.content
664
+
665
+ def _call_hyperbolic(self, formatted_messages, stop_sequences=None, **kwargs):
666
+ """Call the hyperbolic API"""
667
+
668
+ completion = self.hyperbolic_client.chat.completions.create(
669
+ model=self.model_path,
670
+ messages=formatted_messages,
671
  max_tokens=kwargs.get("max_new_tokens", 512),
672
  temperature=kwargs.get("temperature", 0.7),
673
  top_p=kwargs.get("top_p", 0.9),
 
684
  "class": self.__class__.__name__,
685
  "model_path": self.model_path,
686
  "provider": self.provider,
687
+ "hf_base_url": self.hf_base_url,
688
+ # We don't save the API keys for security reasons
689
  }
690
 
691
  @classmethod
 
694
  return cls(
695
  model_path=data.get("model_path", "Qwen/Qwen2.5-VL-72B-Instruct"),
696
  provider=data.get("provider", "hyperbolic"),
697
+ hf_base_url=data.get("hf_base_url", "https://n5wr7lfx6wp94tvl.us-east-1.aws.endpoints.huggingface.cloud/v1/"),
698
+ )
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  e2b_desktop
2
  smolagents
3
  Pillow
4
- huggingface_hub
 
 
1
  e2b_desktop
2
  smolagents
3
  Pillow
4
+ huggingface_hub
5
+ openai