sky-scribe / app.py
nkasmanoff's picture
Update app.py
9e79a2c
raw
history blame
3.59 kB
import re
import gradio as gr
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
import spacy
from spacy.matcher import Matcher
device='cpu'
processor = AutoProcessor.from_pretrained("microsoft/git-base")
model = AutoModelForCausalLM.from_pretrained("nkasmanoff/sky-scribe").to(device)
nlp = spacy.load('en_core_web_sm')
def predict(image,max_length=50,device='cpu'):
pixel_values = processor(images=image, return_tensors="pt").to(device).pixel_values
generated_ids = model.generate(pixel_values=pixel_values, max_length=max_length)
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# relation = get_relation(generated_caption)
# entity_pair = get_entities(generated_caption)
# knowlege_triplet = f"'{entity_pair[0]}'---{relation}--->'{entity_pair[1]}'"
return generated_caption
def get_entities(sent):
## chunk 1
ent1 = ""
ent2 = ""
prv_tok_dep = "" # dependency tag of previous token in the sentence
prv_tok_text = "" # previous token in the sentence
prefix = ""
modifier = ""
#############################################################
for tok in nlp(sent):
## chunk 2
# if token is a punctuation mark then move on to the next token
if tok.dep_ != "punct":
# check: token is a compound word or not
if tok.dep_ == "compound":
prefix = tok.text
# if the previous word was also a 'compound' then add the current word to it
if prv_tok_dep == "compound":
prefix = prv_tok_text + " " + tok.text
# check: token is a modifier or not
if tok.dep_.endswith("mod") == True:
modifier = tok.text
# if the previous word was also a 'compound' then add the current word to it
if prv_tok_dep == "compound":
modifier = prv_tok_text + " " + tok.text
## chunk 3
if tok.dep_.find("subj") == True:
ent1 = modifier + " " + prefix + " " + tok.text
prefix = ""
modifier = ""
prv_tok_dep = ""
prv_tok_text = ""
## chunk 4
if tok.dep_.find("obj") == True:
ent2 = modifier + " " + prefix + " " + tok.text
## chunk 5
# update variables
prv_tok_dep = tok.dep_
prv_tok_text = tok.text
#############################################################
return [ent1.strip(), ent2.strip()]
def get_relation(sent):
doc = nlp(sent)
# Matcher class object
matcher = Matcher(nlp.vocab)
#define the pattern
pattern = [{'DEP':'ROOT'},
{'DEP':'prep','OP':"?"},
{'DEP':'agent','OP':"?"},
{'POS':'ADJ','OP':"?"}]
matcher.add('matching_pattern', patterns=[pattern])
matches = matcher(doc)
k = len(matches) - 1
span = doc[matches[k][1]:matches[k][2]]
return(span.text)
input = gr.inputs.Image(label="Please upload an image", type = 'pil', optional=True)
output = gr.outputs.Textbox(type="text",label="Captions")
title = "Satellite Image Knowledge Extraction"
description = "Provide an image, receive back a triplet that can be used to form a knowledge graph."
interface = gr.Interface(
fn=predict,
inputs = input,
theme="grass",
outputs=output,
title=title,
description=description
)
interface.launch(debug=True)