|
import streamlit as st |
|
from paddleocr import PaddleOCR |
|
from PIL import ImageDraw, ImageFont |
|
import torch |
|
from transformers import AutoProcessor,LayoutLMv3ForTokenClassification |
|
import numpy as np |
|
|
|
model_Hugging_path = "Noureddinesa/Output_LayoutLMv3_v5" |
|
|
|
|
|
|
|
|
|
def Labels(): |
|
labels = ['InvNum', 'InvDate', 'Fourni', 'TTC', 'TVA', 'TT', 'Autre'] |
|
id2label = {v: k for v, k in enumerate(labels)} |
|
label2id = {k: v for v, k in enumerate(labels)} |
|
return id2label, label2id |
|
|
|
|
|
|
|
def Paddle(): |
|
ocr = PaddleOCR(use_angle_cls=False,lang='fr',rec=False) |
|
return ocr |
|
|
|
def processbbox(BBOX, width, height): |
|
bbox = [] |
|
bbox.append(BBOX[0][0]) |
|
bbox.append(BBOX[0][1]) |
|
bbox.append(BBOX[2][0]) |
|
bbox.append(BBOX[2][1]) |
|
|
|
bbox[0]= 1000*bbox[0]/width |
|
bbox[1]= 1000*bbox[1]/height |
|
bbox[2]= 1000*bbox[2]/width |
|
bbox[3]= 1000*bbox[3]/height |
|
for i in range(4): |
|
bbox[i] = int(bbox[i]) |
|
return bbox |
|
|
|
|
|
def Preprocess(image): |
|
image_array = np.array(image) |
|
ocr = Paddle() |
|
width, height = image.size |
|
results = ocr.ocr(image_array, cls=True) |
|
results = results[0] |
|
test_dict = {'image': image ,'tokens':[], "bboxes":[]} |
|
for item in results : |
|
bbox = processbbox(item[0], width, height) |
|
test_dict['tokens'].append(item[1][0]) |
|
test_dict['bboxes'].append(bbox) |
|
|
|
print(test_dict['bboxes']) |
|
print(test_dict['tokens']) |
|
return test_dict |
|
|
|
|
|
|
|
def Encode(image): |
|
example = Preprocess(image) |
|
image = example["image"] |
|
words = example["tokens"] |
|
boxes = example["bboxes"] |
|
processor = AutoProcessor.from_pretrained(model_Hugging_path, apply_ocr=False) |
|
encoding = processor(image, words, boxes=boxes,return_offsets_mapping=True,truncation=True, max_length=512, padding="max_length", return_tensors="pt") |
|
offset_mapping = encoding.pop('offset_mapping') |
|
return encoding, offset_mapping,words |
|
|
|
|
|
def unnormalize_box(bbox, width, height): |
|
return [ |
|
width * (bbox[0] / 1000), |
|
height * (bbox[1] / 1000), |
|
width * (bbox[2] / 1000), |
|
height * (bbox[3] / 1000), |
|
] |
|
|
|
|
|
def Run_model(image): |
|
encoding,offset_mapping,words = Encode(image) |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
model = LayoutLMv3ForTokenClassification.from_pretrained(model_Hugging_path) |
|
model.to(device) |
|
|
|
outputs = model(**encoding) |
|
|
|
predictions = outputs.logits.argmax(-1).squeeze().tolist() |
|
token_boxes = encoding.bbox.squeeze().tolist() |
|
|
|
width, height = image.size |
|
|
|
id2label, _ = Labels() |
|
is_subword = np.array(offset_mapping.squeeze().tolist())[:,0] != 0 |
|
true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]] |
|
true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]] |
|
return true_predictions,true_boxes,words |
|
|
|
|
|
|
|
def Get_Json(true_predictions,words): |
|
Results = {} |
|
i = 0 |
|
for prd in true_predictions: |
|
if prd in ['InvNum','Fourni', 'InvDate','TT','TTC','TVA']: |
|
|
|
Results[prd] = words[i-1] |
|
i+=1 |
|
key_mapping = {'InvNum':'Numéro de facture','Fourni':'Fournisseur', 'InvDate':'Date Facture','TT':'Total HT','TTC':'Total TTC','TVA':'TVA'} |
|
Results = {key_mapping.get(key, key): value for key, value in Results.items()} |
|
return Results |
|
|
|
|
|
|
|
def Draw(image): |
|
true_predictions, true_boxes,words = Run_model(image) |
|
draw = ImageDraw.Draw(image) |
|
|
|
label2color = { |
|
'InvNum': 'blue', |
|
'InvDate': 'green', |
|
'Fourni': 'orange', |
|
'TTC':'purple', |
|
'TVA': 'magenta', |
|
'TT': 'red', |
|
'Autre': 'black' |
|
} |
|
|
|
|
|
rectangle_thickness = 4 |
|
label_x_offset = 20 |
|
label_y_offset = -30 |
|
|
|
|
|
custom_font_size = 25 |
|
|
|
|
|
font_path = "arial.ttf" |
|
custom_font = ImageFont.truetype(font_path, custom_font_size) |
|
|
|
for prediction, box in zip(true_predictions, true_boxes): |
|
predicted_label = prediction |
|
|
|
if predicted_label in label2color: |
|
color = label2color[predicted_label] |
|
else: |
|
color = 'black' |
|
if predicted_label != "Autre": |
|
draw.rectangle(box, outline=color, width=rectangle_thickness) |
|
|
|
draw.rectangle((box[0], box[1]+ label_y_offset,box[2],box[3]+ label_y_offset), fill=color) |
|
draw.text((box[0] + label_x_offset, box[1] + label_y_offset), text=predicted_label, fill='white', font=custom_font) |
|
|
|
|
|
Results = Get_Json(true_predictions,words) |
|
|
|
return image,Results |
|
|
|
|
|
|
|
|
|
def Add_Results(data): |
|
|
|
for key, value in data.items(): |
|
data[key] = st.sidebar.text_input(key, value) |
|
|
|
|
|
|
|
def check_if_changed(original_values, updated_values): |
|
for key, value in original_values.items(): |
|
if updated_values[key] != value: |
|
return True |
|
return False |
|
|
|
|
|
|
|
def Update(Results): |
|
New_results = {} |
|
|
|
if "Fournisseur" in Results.keys(): |
|
text_fourni = st.sidebar.text_input("Fournisseur", value=Results["Fournisseur"]) |
|
New_results["Fournisseur"] = text_fourni |
|
|
|
if "Date Facture" in Results.keys(): |
|
text_InvDate = st.sidebar.text_input("Date Facture", value=Results["Date Facture"]) |
|
New_results["Date Facture"] = text_InvDate |
|
|
|
if "Numéro de facture" in Results.keys(): |
|
text_InvNum = st.sidebar.text_input("Numéro de facture", value=Results["Numéro de facture"]) |
|
New_results["Numéro de facture"] = text_InvNum |
|
|
|
if "Total HT" in Results.keys(): |
|
text_TT = st.sidebar.text_input("Total HT", value=Results["Total HT"]) |
|
New_results["Total HT"] = text_TT |
|
|
|
if "TVA" in Results.keys(): |
|
text_TVA = st.sidebar.text_input("TVA", value=Results["TVA"]) |
|
New_results["TVA"] = text_TVA |
|
|
|
if "Total TTC" in Results.keys(): |
|
text_TTC = st.sidebar.text_input("TTC", value=Results["Total TTC"]) |
|
New_results["Total TTC"] = text_TTC |
|
return New_results |
|
|
|
|
|
|
|
def Change_Image(image1,image2): |
|
|
|
if 'current_image' not in st.session_state: |
|
st.session_state.current_image = 'image1' |
|
|
|
|
|
if st.sidebar.button('Remove'): |
|
if st.session_state.current_image == 'image1': |
|
st.session_state.current_image = 'image2' |
|
else: |
|
st.session_state.current_image = 'image1' |
|
|
|
if st.session_state.current_image == 'image1': |
|
st.image(image1, caption='Output', use_column_width=True) |
|
else: |
|
st.image(image2, caption='Image initiale', use_column_width=True) |