Rishi Desai commited on
Commit
69f7712
·
1 Parent(s): fb6c5e9

first patch at fixing manual entry

Browse files
Files changed (2) hide show
  1. demo.py +170 -22
  2. prompt.py +1 -3
demo.py CHANGED
@@ -8,6 +8,7 @@ from pathlib import Path
8
  import shutil
9
 
10
  from main import process_images
 
11
 
12
  # Maximum number of images
13
  MAX_IMAGES = 30
@@ -304,6 +305,30 @@ def get_css_styles():
304
  margin-bottom: 10px;
305
  border-left: 3px solid #4CAF50;
306
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  </style>
308
  """
309
 
@@ -398,7 +423,7 @@ def create_captioning_area():
398
  def setup_event_handlers(
399
  image_upload, stored_image_paths, captioning_area, caption_btn, download_btn,
400
  download_output, status_text, image_rows, image_components, caption_components,
401
- batch_category_checkbox, batch_by_category
402
  ):
403
  """Set up all event handlers for the UI"""
404
  # Combined outputs for the upload function
@@ -428,8 +453,8 @@ def setup_event_handlers(
428
  outputs=[batch_by_category]
429
  )
430
 
431
- # Set up captioning button
432
- caption_btn.click(
433
  on_captioning_start,
434
  inputs=None,
435
  outputs=[status_text, caption_btn]
@@ -443,6 +468,17 @@ def setup_event_handlers(
443
  outputs=[status_text, caption_btn, download_btn]
444
  )
445
 
 
 
 
 
 
 
 
 
 
 
 
446
  # Set up download button
447
  download_btn.click(
448
  create_zip_from_ui,
@@ -465,30 +501,142 @@ def build_ui():
465
  with gr.Blocks() as demo:
466
  gr.Markdown("# Image Auto-captioner for LoRA Training")
467
 
468
- # Store uploaded images
469
- stored_image_paths = gr.State([])
470
- batch_by_category = gr.State(False) # State to track if batch by category is enabled
471
 
472
- # Create a two-column layout for the entire interface
473
- with gr.Row():
474
- # Create upload area in left column
475
- _, image_upload = create_upload_area()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
 
477
- # Create config area in right column
478
- _, batch_category_checkbox, caption_btn, download_btn, download_output, status_text = create_config_area()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
 
480
  # Add CSS styling
481
  gr.HTML(get_css_styles())
482
-
483
- # Create captioning area (initially hidden)
484
- captioning_area, image_rows, image_components, caption_components = create_captioning_area()
485
-
486
- # Set up event handlers
487
- setup_event_handlers(
488
- image_upload, stored_image_paths, captioning_area, caption_btn, download_btn,
489
- download_output, status_text, image_rows, image_components, caption_components,
490
- batch_category_checkbox, batch_by_category
491
- )
492
 
493
  return demo
494
 
 
8
  import shutil
9
 
10
  from main import process_images
11
+ from prompt import optimize_prompt
12
 
13
  # Maximum number of images
14
  MAX_IMAGES = 30
 
305
  margin-bottom: 10px;
306
  border-left: 3px solid #4CAF50;
307
  }
308
+
309
+ /* Tab styling */
310
+ .tabs {
311
+ margin-top: 20px;
312
+ }
313
+
314
+ /* Prompt optimization tab styling */
315
+ .optimization-status {
316
+ margin-top: 10px;
317
+ padding: 8px;
318
+ border-radius: 4px;
319
+ background-color: #f8f9fa;
320
+ }
321
+
322
+ /* Input/output boxes for prompt optimization */
323
+ .prompt-box {
324
+ margin-bottom: 15px;
325
+ }
326
+
327
+ /* Make optimize button stand out */
328
+ .optimize-btn {
329
+ margin-top: 10px;
330
+ margin-bottom: 15px;
331
+ }
332
  </style>
333
  """
334
 
 
423
  def setup_event_handlers(
424
  image_upload, stored_image_paths, captioning_area, caption_btn, download_btn,
425
  download_output, status_text, image_rows, image_components, caption_components,
426
+ batch_category_checkbox, batch_by_category, shared_captions=None
427
  ):
428
  """Set up all event handlers for the UI"""
429
  # Combined outputs for the upload function
 
453
  outputs=[batch_by_category]
454
  )
455
 
456
+ # Set up captioning button chain
457
+ caption_chain = caption_btn.click(
458
  on_captioning_start,
459
  inputs=None,
460
  outputs=[status_text, caption_btn]
 
468
  outputs=[status_text, caption_btn, download_btn]
469
  )
470
 
471
+ # If shared_captions is provided, add an additional handler to update it
472
+ if shared_captions is not None:
473
+ def extract_valid_captions(*caption_values):
474
+ return [c for c in caption_values if c and c.strip()]
475
+
476
+ caption_chain.success(
477
+ extract_valid_captions,
478
+ inputs=caption_components,
479
+ outputs=[shared_captions]
480
+ )
481
+
482
  # Set up download button
483
  download_btn.click(
484
  create_zip_from_ui,
 
501
  with gr.Blocks() as demo:
502
  gr.Markdown("# Image Auto-captioner for LoRA Training")
503
 
504
+ # Store generated captions for sharing between tabs
505
+ shared_captions = gr.State([])
 
506
 
507
+ # Create tabs for different functionality
508
+ with gr.Tabs() as tabs:
509
+ with gr.TabItem("Image Captioning") as captioning_tab:
510
+ # Store uploaded images
511
+ stored_image_paths = gr.State([])
512
+ batch_by_category = gr.State(False) # State to track if batch by category is enabled
513
+
514
+ # Create a two-column layout for the entire interface
515
+ with gr.Row():
516
+ # Create upload area in left column
517
+ _, image_upload = create_upload_area()
518
+
519
+ # Create config area in right column
520
+ _, batch_category_checkbox, caption_btn, download_btn, download_output, status_text = create_config_area()
521
+
522
+ # Create captioning area (initially hidden)
523
+ captioning_area, image_rows, image_components, caption_components = create_captioning_area()
524
+
525
+ # Set up event handlers with shared captions
526
+ setup_event_handlers(
527
+ image_upload, stored_image_paths, captioning_area, caption_btn, download_btn,
528
+ download_output, status_text, image_rows, image_components, caption_components,
529
+ batch_category_checkbox, batch_by_category, shared_captions
530
+ )
531
 
532
+ with gr.TabItem("Prompt Optimization") as prompt_tab:
533
+ with gr.Row():
534
+ with gr.Column(scale=1):
535
+ # Left side for caption input
536
+ gr.Markdown("### Upload Captions")
537
+ gr.Markdown("Upload caption files (.txt) or enter captions manually", elem_classes="file-types-info")
538
+
539
+ captions_upload = gr.File(
540
+ file_count="multiple",
541
+ label="Upload caption files",
542
+ file_types=[".txt"],
543
+ type="filepath",
544
+ elem_classes="file-upload-container"
545
+ )
546
+
547
+ manual_captions = gr.Textbox(
548
+ label="Or enter captions manually (one per line)",
549
+ lines=5,
550
+ placeholder="Enter captions here, one per line",
551
+ elem_classes="prompt-box"
552
+ )
553
+
554
+ # Add button to use captions from image captioning tab
555
+ use_generated_captions = gr.Button("Use Captions from Manual Entry", variant="secondary")
556
+
557
+ # Function to update manual captions with shared ones
558
+ def fill_with_shared_captions(captions_list):
559
+ if not captions_list or len(captions_list) == 0:
560
+ return "No captions available. Generate captions in the Image Captioning tab first."
561
+ return "\n".join(captions_list)
562
+
563
+ # Connect button to fill manual captions area
564
+ use_generated_captions.click(
565
+ fill_with_shared_captions,
566
+ inputs=[shared_captions],
567
+ outputs=[manual_captions]
568
+ )
569
+
570
+ with gr.Column(scale=1):
571
+ # Right side for prompt input and output
572
+ gr.Markdown("### Optimize Prompt")
573
+
574
+ user_prompt = gr.Textbox(
575
+ label="Enter your prompt",
576
+ lines=3,
577
+ placeholder="Enter the prompt you want to optimize",
578
+ elem_classes="prompt-box"
579
+ )
580
+
581
+ optimize_btn = gr.Button("Optimize Prompt", variant="primary", elem_classes="optimize-btn")
582
+
583
+ optimized_prompt = gr.Textbox(
584
+ label="Optimized Prompt",
585
+ lines=5,
586
+ placeholder="Optimized prompt will appear here",
587
+ elem_classes="prompt-box"
588
+ )
589
+
590
+ optimization_status = gr.Markdown("Enter a prompt and upload captions to begin", elem_classes="optimization-status")
591
+
592
+ # Function to handle optimization
593
+ def run_optimization(prompt, caption_files, manual_caption_text):
594
+ if not prompt or prompt.strip() == "":
595
+ return "", "Please enter a prompt to optimize"
596
+
597
+ # Handle different input sources for captions
598
+ caption_list = []
599
+
600
+ if manual_caption_text and manual_caption_text.strip():
601
+ # Use manually entered captions
602
+ caption_list = [line.strip() for line in manual_caption_text.split("\n") if line.strip()]
603
+
604
+ elif caption_files and len(caption_files) > 0:
605
+ # Read captions from uploaded files
606
+ for file_path in caption_files:
607
+ if os.path.exists(file_path) and file_path.lower().endswith('.txt'):
608
+ with open(file_path, 'r', encoding='utf-8') as f:
609
+ content = f.read().strip()
610
+ if content:
611
+ caption_list.append(content)
612
+
613
+ if not caption_list:
614
+ return "", "Please upload caption files or enter captions manually"
615
+
616
+ try:
617
+ # Call the optimize_prompt function from prompt.py
618
+ result = optimize_prompt(prompt, captions_list=caption_list)
619
+ return result, "✅ Prompt optimization complete"
620
+ except Exception as e:
621
+ return "", f"❌ Error optimizing prompt: {str(e)}"
622
+
623
+ # Add info about prompt optimization
624
+ gr.Markdown("""
625
+ **About Prompt Optimization:**
626
+ - This feature helps you craft prompts that match the style of your training captions
627
+ - Upload caption files, enter captions manually, or use captions from the Image Captioning tab
628
+ - Enter a simple prompt and the system will optimize it to match your training style
629
+ """, elem_classes=["category-info"])
630
+
631
+ # Connect the optimize button to the optimization function
632
+ optimize_btn.click(
633
+ run_optimization,
634
+ inputs=[user_prompt, captions_upload, manual_captions],
635
+ outputs=[optimized_prompt, optimization_status]
636
+ )
637
 
638
  # Add CSS styling
639
  gr.HTML(get_css_styles())
 
 
 
 
 
 
 
 
 
 
640
 
641
  return demo
642
 
prompt.py CHANGED
@@ -1,9 +1,7 @@
1
  import os
2
  import argparse
3
  from pathlib import Path
4
- from caption import get_system_prompt, get_together_client, extract_captions
5
-
6
- MODEL_ID = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
7
 
8
  def optimize_prompt(user_prompt, captions_dir=None, captions_list=None):
9
  """Optimize a user prompt to follow the same format as training captions.
 
1
  import os
2
  import argparse
3
  from pathlib import Path
4
+ from caption import get_system_prompt, get_together_client, extract_captions, MODEL_ID
 
 
5
 
6
  def optimize_prompt(user_prompt, captions_dir=None, captions_list=None):
7
  """Optimize a user prompt to follow the same format as training captions.