import io import json import fitz import streamlit as st import torch from PIL import Image, ImageGrab from transformers import pipeline # --- Configuration and Setup --- DEVICE = 0 if torch.cuda.is_available() else -1 st.set_page_config( page_title="Invoice AI | by Arif Dogan", page_icon="๐Ÿงพ", layout="wide", initial_sidebar_state="collapsed", ) # --- Styling --- st.markdown( """ """, unsafe_allow_html=True, ) # --- Functions --- @st.cache_resource def load_model(): return pipeline( "document-question-answering", model="faisalraza/layoutlm-invoices", device=DEVICE, ) def process_pdf(pdf_file): pdf_content = pdf_file.read() pdf_stream = io.BytesIO(pdf_content) try: with fitz.open(stream=pdf_stream, filetype="pdf") as pdf_document: if pdf_document.page_count > 0: page = pdf_document[0] pix = page.get_pixmap(matrix=fitz.Matrix(300 / 72, 300 / 72)) img_data = pix.tobytes("png") return Image.open(io.BytesIO(img_data)), pdf_document.page_count else: raise ValueError("PDF has no pages") except Exception as e: raise e finally: pdf_stream.close() def process_image(uploaded_file): uploaded_file.seek(0) if uploaded_file.type == "application/pdf": return process_pdf(uploaded_file) return Image.open(uploaded_file), 1 def get_clipboard_image(): try: img = ImageGrab.grabclipboard() return (img, 1) if isinstance(img, Image.Image) else (None, 0) except Exception: return None, 0 def prepare_export_data(extracted_info, format_type): if format_type == "JSON": return json.dumps( {field: data["value"] for field, data in extracted_info.items()}, indent=2 ) elif format_type == "CSV": header = ",".join(extracted_info.keys()) values = ",".join(f'"{data["value"]}"' for data in extracted_info.values()) return f"{header}\n{values}" else: # TXT return "\n".join( f"{field}: {data['value']}" for field, data in extracted_info.items() ) def extract_information(model, image, questions, progress_bar, status_text): extracted_info = {} for idx, question in enumerate(questions): try: # Update progress bar and status text progress_bar.progress((idx + 1) / len(questions)) status_text.text(f"Processing: {question} ({idx + 1}/{len(questions)})") response = model(image=image, question=question) if ( response and response[0].get("answer", "").strip() ): # Check for non-empty answer answer = response[0]["answer"] confidence = response[0]["score"] if confidence > 0.1: field = ( question.replace("What is the ", "").replace("?", "").title() ) extracted_info[field] = {"value": answer, "confidence": confidence} except Exception: continue # Handle potential errors during model processing return extracted_info # --- Initialization --- if "processed_image" not in st.session_state: st.session_state.processed_image = None if "extracted_info" not in st.session_state: st.session_state.extracted_info = {} # --- UI Layout --- st.markdown( """

๐Ÿงพ Invoice AI Extractor

Powered by LayoutLM

""", unsafe_allow_html=True, ) model = load_model() col1, col2 = st.columns([2, 1]) with col1: uploaded_file = st.file_uploader( "Drop invoice (PDF, JPG, PNG)", type=["pdf", "jpg", "jpeg", "png"] ) with col2: st.write("Or paste from clipboard (Ctrl/Cmd + V)") check_clipboard = st.button("๐Ÿ“Ž Check Clipboard") # --- Image Processing Logic --- if uploaded_file: try: image, _ = process_image(uploaded_file) st.session_state.processed_image = image st.session_state.extracted_info = {} # Reset on new upload except Exception as e: st.error(f"Error processing file: {e}") elif check_clipboard: clipboard_image, _ = get_clipboard_image() if clipboard_image: st.session_state.processed_image = clipboard_image st.session_state.extracted_info = {} st.success("Image loaded from clipboard") else: st.warning("No image found in clipboard") # --- Display and Information Extraction --- if st.session_state.processed_image: try: image = st.session_state.processed_image.convert("RGB") col1, col2 = st.columns([1, 1]) with col1: st.image(image, caption="Document", use_container_width=True) with col2: st.markdown("### ๐Ÿ“Š Extracted Information") if not st.session_state.extracted_info: questions = [ "What is the invoice number?", "What is the invoice date?", "What is the total amount?", "What is the company name?", "What is the due date?", "What is the tax amount?", ] # Create progress bar and status text elements progress_bar = st.progress(0) status_text = st.empty() st.session_state.extracted_info = extract_information( model, image, questions, progress_bar, status_text ) # Clear status text after completion status_text.empty() if st.session_state.extracted_info: for field, data in st.session_state.extracted_info.items(): conf_col, val_col = st.columns([1, 4]) with val_col: st.text_input( field, data["value"], disabled=True, key=f"input_{field}" ) # added key with conf_col: confidence = data["confidence"] css_class = ( "high" if confidence > 0.7 else "medium" if confidence > 0.4 else "low" ) st.markdown( f"

{confidence:.1%}

", unsafe_allow_html=True, ) st.markdown("### ๐Ÿ“ฅ Export") export_format = st.selectbox("Format", ["JSON", "CSV", "TXT"]) export_data = prepare_export_data( st.session_state.extracted_info, export_format ) file_extension = export_format.lower() st.download_button( "Download", export_data, file_name=f"invoice_data.{file_extension}", mime=f"text/{file_extension}", ) else: st.warning( "Could not extract information. Please ensure the document is clear." ) except Exception as e: st.error(f"Error during processing: {e}") # --- Footer --- st.markdown("---") st.markdown( """

Created by Arif Dogan | ๐Ÿค— Hugging Face

""", unsafe_allow_html=True, )