Daryl Lim commited on
Commit
d700fcc
·
1 Parent(s): 55a0a7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -11
app.py CHANGED
@@ -22,7 +22,7 @@ from langchain.schema import Document
22
 
23
  # Transformers imports for IBM Granite model
24
  import spaces
25
- from transformers import AutoTokenizer, AutoModelForCausalLM
26
 
27
  # Initialize IBM Granite model and tokenizer
28
  print("Loading Granite model and tokenizer...")
@@ -31,12 +31,17 @@ model_name = "ibm-granite/granite-3.3-8b-instruct"
31
  # Load tokenizer
32
  tokenizer = AutoTokenizer.from_pretrained(model_name)
33
 
 
 
 
 
 
 
34
  # Load model with optimization for GPU
35
  model = AutoModelForCausalLM.from_pretrained(
36
  model_name,
37
  device_map="auto",
38
- torch_dtype=torch.bfloat16,
39
- load_in_8bit=True # Use 8-bit quantization for memory efficiency
40
  )
41
  print("Model loaded successfully!")
42
 
@@ -198,6 +203,10 @@ def process_document_chunks(texts, batch_size=8):
198
  except Exception as e:
199
  print(f"Error in document processing: {str(e)}")
200
  # Fallback to basic processing if optimization fails
 
 
 
 
201
  return FAISS.from_documents(texts, embeddings)
202
 
203
  # Main function to process document and generate summary
@@ -359,10 +368,11 @@ def create_gradio_interface():
359
 
360
  # Add interactivity to show/hide appropriate count selector
361
  def update_count_visibility(length_type):
362
- return {
363
- sentence_count: length_type == "Sentences",
364
- paragraph_count: length_type == "Paragraphs"
365
- }
 
366
 
367
  length_type.change(
368
  fn=update_count_visibility,
@@ -370,15 +380,21 @@ def create_gradio_interface():
370
  outputs=[sentence_count, paragraph_count]
371
  )
372
 
373
- # Function to convert paragraph count from string to int and handle capitalized length types
374
  def process_document_wrapper(file, length_type, sentence_count, paragraph_count):
375
  # Convert capitalized length_type to lowercase for processing
376
  length_type_lower = length_type.lower()
377
 
378
  if length_type_lower == "sentences":
379
- return process_document(file, length_type_lower, int(sentence_count))
380
  else:
381
- return process_document(file, length_type_lower, int(paragraph_count))
 
 
 
 
 
 
382
 
383
  submit_btn.click(
384
  fn=process_document_wrapper,
@@ -402,4 +418,4 @@ def create_gradio_interface():
402
  # Launch the application
403
  if __name__ == "__main__":
404
  app = create_gradio_interface()
405
- app.launch()
 
22
 
23
  # Transformers imports for IBM Granite model
24
  import spaces
25
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
26
 
27
  # Initialize IBM Granite model and tokenizer
28
  print("Loading Granite model and tokenizer...")
 
31
  # Load tokenizer
32
  tokenizer = AutoTokenizer.from_pretrained(model_name)
33
 
34
+ # Create quantization config
35
+ quantization_config = BitsAndBytesConfig(
36
+ load_in_4bit=True, # Use 4-bit quantization for better memory efficiency
37
+ bnb_4bit_compute_dtype=torch.bfloat16 # Use bfloat16 for computation with 4-bit quantization
38
+ )
39
+
40
  # Load model with optimization for GPU
41
  model = AutoModelForCausalLM.from_pretrained(
42
  model_name,
43
  device_map="auto",
44
+ quantization_config=quantization_config
 
45
  )
46
  print("Model loaded successfully!")
47
 
 
203
  except Exception as e:
204
  print(f"Error in document processing: {str(e)}")
205
  # Fallback to basic processing if optimization fails
206
+ embeddings = HuggingFaceEmbeddings(
207
+ model_name="nomic-ai/nomic-embed-text-v1",
208
+ model_kwargs={'trust_remote_code': True}
209
+ )
210
  return FAISS.from_documents(texts, embeddings)
211
 
212
  # Main function to process document and generate summary
 
368
 
369
  # Add interactivity to show/hide appropriate count selector
370
  def update_count_visibility(length_type):
371
+ is_sentences = length_type == "Sentences"
372
+ return [
373
+ gr.update(visible=is_sentences), # For sentence_count
374
+ gr.update(visible=not is_sentences) # For paragraph_count
375
+ ]
376
 
377
  length_type.change(
378
  fn=update_count_visibility,
 
380
  outputs=[sentence_count, paragraph_count]
381
  )
382
 
383
+ # Function to handle form submission properly
384
  def process_document_wrapper(file, length_type, sentence_count, paragraph_count):
385
  # Convert capitalized length_type to lowercase for processing
386
  length_type_lower = length_type.lower()
387
 
388
  if length_type_lower == "sentences":
389
+ count = int(sentence_count)
390
  else:
391
+ # Handle potential type issues with paragraph_count
392
+ if isinstance(paragraph_count, bool):
393
+ count = 1 # Default if boolean
394
+ else:
395
+ count = int(paragraph_count)
396
+
397
+ return process_document(file, length_type_lower, count)
398
 
399
  submit_btn.click(
400
  fn=process_document_wrapper,
 
418
  # Launch the application
419
  if __name__ == "__main__":
420
  app = create_gradio_interface()
421
+ app.launch()