Spaces:
Runtime error
Runtime error
import re | |
import gradio as gr | |
from PIL import Image | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
import spacy | |
from spacy.matcher import Matcher | |
device='cpu' | |
processor = AutoProcessor.from_pretrained("microsoft/git-base") | |
model = AutoModelForCausalLM.from_pretrained("nkasmanoff/sky-scribe").to(device) | |
nlp = spacy.load('en_core_web_sm') | |
def predict(image,max_length=50,device='cpu'): | |
pixel_values = processor(images=image, return_tensors="pt").to(device).pixel_values | |
generated_ids = model.generate(pixel_values=pixel_values, max_length=max_length) | |
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
# relation = get_relation(generated_caption) | |
# entity_pair = get_entities(generated_caption) | |
# knowlege_triplet = f"'{entity_pair[0]}'---{relation}--->'{entity_pair[1]}'" | |
return generated_caption | |
def get_entities(sent): | |
## chunk 1 | |
ent1 = "" | |
ent2 = "" | |
prv_tok_dep = "" # dependency tag of previous token in the sentence | |
prv_tok_text = "" # previous token in the sentence | |
prefix = "" | |
modifier = "" | |
############################################################# | |
for tok in nlp(sent): | |
## chunk 2 | |
# if token is a punctuation mark then move on to the next token | |
if tok.dep_ != "punct": | |
# check: token is a compound word or not | |
if tok.dep_ == "compound": | |
prefix = tok.text | |
# if the previous word was also a 'compound' then add the current word to it | |
if prv_tok_dep == "compound": | |
prefix = prv_tok_text + " " + tok.text | |
# check: token is a modifier or not | |
if tok.dep_.endswith("mod") == True: | |
modifier = tok.text | |
# if the previous word was also a 'compound' then add the current word to it | |
if prv_tok_dep == "compound": | |
modifier = prv_tok_text + " " + tok.text | |
## chunk 3 | |
if tok.dep_.find("subj") == True: | |
ent1 = modifier + " " + prefix + " " + tok.text | |
prefix = "" | |
modifier = "" | |
prv_tok_dep = "" | |
prv_tok_text = "" | |
## chunk 4 | |
if tok.dep_.find("obj") == True: | |
ent2 = modifier + " " + prefix + " " + tok.text | |
## chunk 5 | |
# update variables | |
prv_tok_dep = tok.dep_ | |
prv_tok_text = tok.text | |
############################################################# | |
return [ent1.strip(), ent2.strip()] | |
def get_relation(sent): | |
doc = nlp(sent) | |
# Matcher class object | |
matcher = Matcher(nlp.vocab) | |
#define the pattern | |
pattern = [{'DEP':'ROOT'}, | |
{'DEP':'prep','OP':"?"}, | |
{'DEP':'agent','OP':"?"}, | |
{'POS':'ADJ','OP':"?"}] | |
matcher.add('matching_pattern', patterns=[pattern]) | |
matches = matcher(doc) | |
k = len(matches) - 1 | |
span = doc[matches[k][1]:matches[k][2]] | |
return(span.text) | |
input = gr.inputs.Image(label="Please upload an image", type = 'pil', optional=True) | |
output = gr.outputs.Textbox(type="text",label="Captions") | |
title = "Satellite Image Knowledge Extraction" | |
description = "Provide an image, receive back a triplet that can be used to form a knowledge graph." | |
interface = gr.Interface( | |
fn=predict, | |
inputs = input, | |
theme="grass", | |
outputs=output, | |
title=title, | |
description=description | |
) | |
interface.launch(debug=True) |