VPCSinfo commited on
Commit
8895970
·
1 Parent(s): 56ac5db

[FIX] fixes max_len and min_len dynamic.

Browse files
Files changed (1) hide show
  1. tool.py +23 -5
tool.py CHANGED
@@ -17,7 +17,7 @@ class TranscriptSummarizer(Tool):
17
 
18
  def __init__(self, *args, **kwargs):
19
  super().__init__(*args, **kwargs)
20
- self.summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
21
  self.api_url = "https://api-inference.huggingface.co/models/ZB-Tech/Text-to-Image"
22
  self.headers = {"Authorization": f"Bearer {os.getenv('HF_API_KEY')}"}
23
 
@@ -27,12 +27,30 @@ class TranscriptSummarizer(Tool):
27
 
28
  def forward(self, transcript: str) -> str:
29
  try:
30
- summary = self.summarizer(transcript, max_length=2000, min_length=750, do_sample=False)[0]['summary_text']
31
- key_entities = summary.split()[:100] # Extract the first 100 words
32
- image_prompt = f"Generate an image related to: {' '.join(key_entities)}, professional style"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  image_bytes = self.query({"inputs": image_prompt})
34
  image = Image.open(io.BytesIO(image_bytes))
35
- image_url = "Images/image.jpg" # Placeholder, as we can't directly pass PIL Image to Gradio
 
 
 
36
  image.save(image_url) # Save the image to a file
37
  return f"{summary}\n\nImage URL: {image_url}" # Return the file path
38
  except Exception as e:
 
17
 
18
  def __init__(self, *args, **kwargs):
19
  super().__init__(*args, **kwargs)
20
+ self.summarizer = pipeline("summarization", model="facebook/usin")
21
  self.api_url = "https://api-inference.huggingface.co/models/ZB-Tech/Text-to-Image"
22
  self.headers = {"Authorization": f"Bearer {os.getenv('HF_API_KEY')}"}
23
 
 
27
 
28
  def forward(self, transcript: str) -> str:
29
  try:
30
+ transcript_length = len(transcript)
31
+
32
+ def get_summary_lengths(length):
33
+ if length <= 1000:
34
+ max_length = 300
35
+ min_length = 100
36
+ elif length <= 3000:
37
+ max_length = 750
38
+ min_length = 250
39
+ else:
40
+ max_length = 1500
41
+ min_length = 500
42
+ return max_length, min_length
43
+
44
+ max_length, min_length = get_summary_lengths(transcript_length)
45
+ summary = self.summarizer(transcript, max_length=max_length, min_length=min_length, do_sample=False)[0]['summary_text']
46
+ key_entities = summary.split()[:3] # Extract first 3 words as key entities
47
+ image_prompt = f"Generate an image related to: {' '.join(key_entities)}, cartoon style"
48
  image_bytes = self.query({"inputs": image_prompt})
49
  image = Image.open(io.BytesIO(image_bytes))
50
+ image_folder = "Image"
51
+ if not os.path.exists(image_folder):
52
+ os.makedirs(image_folder)
53
+ image_url = os.path.join(image_folder, "image.jpg") # Specify the folder path
54
  image.save(image_url) # Save the image to a file
55
  return f"{summary}\n\nImage URL: {image_url}" # Return the file path
56
  except Exception as e: