[FIX] fixes max_len and min_len dynamic.
Browse files
tool.py
CHANGED
@@ -17,7 +17,7 @@ class TranscriptSummarizer(Tool):
|
|
17 |
|
18 |
def __init__(self, *args, **kwargs):
|
19 |
super().__init__(*args, **kwargs)
|
20 |
-
self.summarizer = pipeline("summarization", model="facebook/
|
21 |
self.api_url = "https://api-inference.huggingface.co/models/ZB-Tech/Text-to-Image"
|
22 |
self.headers = {"Authorization": f"Bearer {os.getenv('HF_API_KEY')}"}
|
23 |
|
@@ -27,12 +27,30 @@ class TranscriptSummarizer(Tool):
|
|
27 |
|
28 |
def forward(self, transcript: str) -> str:
|
29 |
try:
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
image_bytes = self.query({"inputs": image_prompt})
|
34 |
image = Image.open(io.BytesIO(image_bytes))
|
35 |
-
|
|
|
|
|
|
|
36 |
image.save(image_url) # Save the image to a file
|
37 |
return f"{summary}\n\nImage URL: {image_url}" # Return the file path
|
38 |
except Exception as e:
|
|
|
17 |
|
18 |
def __init__(self, *args, **kwargs):
|
19 |
super().__init__(*args, **kwargs)
|
20 |
+
self.summarizer = pipeline("summarization", model="facebook/usin")
|
21 |
self.api_url = "https://api-inference.huggingface.co/models/ZB-Tech/Text-to-Image"
|
22 |
self.headers = {"Authorization": f"Bearer {os.getenv('HF_API_KEY')}"}
|
23 |
|
|
|
27 |
|
28 |
def forward(self, transcript: str) -> str:
|
29 |
try:
|
30 |
+
transcript_length = len(transcript)
|
31 |
+
|
32 |
+
def get_summary_lengths(length):
|
33 |
+
if length <= 1000:
|
34 |
+
max_length = 300
|
35 |
+
min_length = 100
|
36 |
+
elif length <= 3000:
|
37 |
+
max_length = 750
|
38 |
+
min_length = 250
|
39 |
+
else:
|
40 |
+
max_length = 1500
|
41 |
+
min_length = 500
|
42 |
+
return max_length, min_length
|
43 |
+
|
44 |
+
max_length, min_length = get_summary_lengths(transcript_length)
|
45 |
+
summary = self.summarizer(transcript, max_length=max_length, min_length=min_length, do_sample=False)[0]['summary_text']
|
46 |
+
key_entities = summary.split()[:3] # Extract first 3 words as key entities
|
47 |
+
image_prompt = f"Generate an image related to: {' '.join(key_entities)}, cartoon style"
|
48 |
image_bytes = self.query({"inputs": image_prompt})
|
49 |
image = Image.open(io.BytesIO(image_bytes))
|
50 |
+
image_folder = "Image"
|
51 |
+
if not os.path.exists(image_folder):
|
52 |
+
os.makedirs(image_folder)
|
53 |
+
image_url = os.path.join(image_folder, "image.jpg") # Specify the folder path
|
54 |
image.save(image_url) # Save the image to a file
|
55 |
return f"{summary}\n\nImage URL: {image_url}" # Return the file path
|
56 |
except Exception as e:
|