Daryl Lim
commited on
Commit
·
f096bc8
1
Parent(s):
5750f60
Update app.py
Browse files
app.py
CHANGED
@@ -132,18 +132,35 @@ def generate_summary(chunks: List[Document], length_type="sentences", length_cou
|
|
132 |
length_type: Either "sentences" or "paragraphs"
|
133 |
length_count: Number of sentences (1-10) or paragraphs (1-3)
|
134 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
# Concatenate the retrieved chunks
|
136 |
combined_text = " ".join([chunk.page_content for chunk in chunks])
|
137 |
|
138 |
-
#
|
139 |
if length_type == "sentences":
|
140 |
-
length_instruction = f"
|
141 |
else: # paragraphs
|
142 |
-
length_instruction = f"
|
143 |
|
144 |
-
# Construct the prompt
|
145 |
prompt = f"""<instruction>
|
146 |
-
Knowledge Cutoff Date: April 2024. You are Granite, developed by IBM. You are a helpful AI assistant. {length_instruction} Your response should only include the
|
147 |
</instruction>
|
148 |
|
149 |
<text>
|
@@ -154,16 +171,18 @@ Knowledge Cutoff Date: April 2024. You are Granite, developed by IBM. You are a
|
|
154 |
# Calculate appropriate max_new_tokens based on length requirements
|
155 |
# Approximate tokens: ~15 tokens per sentence, ~75 tokens per paragraph
|
156 |
if length_type == "sentences":
|
157 |
-
max_tokens = length_count *
|
158 |
else: # paragraphs
|
159 |
-
max_tokens = length_count *
|
160 |
|
161 |
# Ensure minimum tokens and add buffer
|
162 |
-
max_tokens = max(100, min(
|
163 |
|
164 |
# Generate the summary using the IBM Granite model
|
165 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
166 |
|
|
|
|
|
167 |
with torch.no_grad():
|
168 |
output = model.generate(
|
169 |
**inputs,
|
@@ -385,14 +404,35 @@ def create_gradio_interface():
|
|
385 |
# Convert capitalized length_type to lowercase for processing
|
386 |
length_type_lower = length_type.lower()
|
387 |
|
|
|
|
|
|
|
388 |
if length_type_lower == "sentences":
|
389 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
390 |
else:
|
391 |
-
#
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
396 |
|
397 |
return process_document(file, length_type_lower, count)
|
398 |
|
|
|
132 |
length_type: Either "sentences" or "paragraphs"
|
133 |
length_count: Number of sentences (1-10) or paragraphs (1-3)
|
134 |
"""
|
135 |
+
# Print debug information to track what parameters are being used
|
136 |
+
print(f"Generating summary with length_type={length_type}, length_count={length_count}")
|
137 |
+
|
138 |
+
# Ensure length_count is an integer
|
139 |
+
try:
|
140 |
+
length_count = int(length_count)
|
141 |
+
except (ValueError, TypeError):
|
142 |
+
# Default to 3 if conversion fails
|
143 |
+
print(f"Failed to convert length_count to int: {length_count}, using default 3")
|
144 |
+
length_count = 3
|
145 |
+
|
146 |
+
# Apply limits based on type
|
147 |
+
if length_type == "sentences":
|
148 |
+
length_count = max(1, min(10, length_count)) # Limit to 1-10 sentences
|
149 |
+
else: # paragraphs
|
150 |
+
length_count = max(1, min(3, length_count)) # Limit to 1-3 paragraphs
|
151 |
+
|
152 |
# Concatenate the retrieved chunks
|
153 |
combined_text = " ".join([chunk.page_content for chunk in chunks])
|
154 |
|
155 |
+
# Use a more direct instruction to enforce the length constraint
|
156 |
if length_type == "sentences":
|
157 |
+
length_instruction = f"Your summary must be EXACTLY {length_count} sentence{'s' if length_count > 1 else ''}. Not more, not less."
|
158 |
else: # paragraphs
|
159 |
+
length_instruction = f"Your summary must be EXACTLY {length_count} paragraph{'s' if length_count > 1 else ''}. Not more, not less."
|
160 |
|
161 |
+
# Construct the prompt with clearer instructions
|
162 |
prompt = f"""<instruction>
|
163 |
+
Knowledge Cutoff Date: April 2024. You are Granite, developed by IBM. You are a helpful AI assistant. Summarize the following text. {length_instruction} Your response should only include the summary. Do not provide any further explanation.
|
164 |
</instruction>
|
165 |
|
166 |
<text>
|
|
|
171 |
# Calculate appropriate max_new_tokens based on length requirements
|
172 |
# Approximate tokens: ~15 tokens per sentence, ~75 tokens per paragraph
|
173 |
if length_type == "sentences":
|
174 |
+
max_tokens = length_count * 30 # Increased slightly for flexibility
|
175 |
else: # paragraphs
|
176 |
+
max_tokens = length_count * 120 # Increased slightly for flexibility
|
177 |
|
178 |
# Ensure minimum tokens and add buffer
|
179 |
+
max_tokens = max(100, min(1500, max_tokens + 50))
|
180 |
|
181 |
# Generate the summary using the IBM Granite model
|
182 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
183 |
|
184 |
+
print(f"Using max_new_tokens={max_tokens}")
|
185 |
+
|
186 |
with torch.no_grad():
|
187 |
output = model.generate(
|
188 |
**inputs,
|
|
|
404 |
# Convert capitalized length_type to lowercase for processing
|
405 |
length_type_lower = length_type.lower()
|
406 |
|
407 |
+
print(f"Processing with length_type={length_type}, sentence_count={sentence_count}, paragraph_count={paragraph_count}")
|
408 |
+
|
409 |
+
# Determine count based on the selected length type
|
410 |
if length_type_lower == "sentences":
|
411 |
+
# For sentences, use the slider value directly
|
412 |
+
try:
|
413 |
+
count = int(sentence_count)
|
414 |
+
count = max(1, min(10, count)) # Ensure within range 1-10
|
415 |
+
print(f"Using sentence count: {count}")
|
416 |
+
except (ValueError, TypeError):
|
417 |
+
print(f"Invalid sentence count: {sentence_count}, using default 3")
|
418 |
+
count = 3
|
419 |
else:
|
420 |
+
# For paragraphs, convert from string to int if needed
|
421 |
+
try:
|
422 |
+
# Check if paragraph_count is a string (from radio button)
|
423 |
+
if isinstance(paragraph_count, str):
|
424 |
+
count = int(paragraph_count)
|
425 |
+
# Check if it's a boolean (from visibility toggle)
|
426 |
+
elif isinstance(paragraph_count, bool):
|
427 |
+
count = 1 # Default if boolean
|
428 |
+
else:
|
429 |
+
count = int(paragraph_count)
|
430 |
+
|
431 |
+
count = max(1, min(3, count)) # Ensure within range 1-3
|
432 |
+
print(f"Using paragraph count: {count}")
|
433 |
+
except (ValueError, TypeError):
|
434 |
+
print(f"Invalid paragraph count: {paragraph_count}, using default 1")
|
435 |
+
count = 1
|
436 |
|
437 |
return process_document(file, length_type_lower, count)
|
438 |
|