File size: 3,442 Bytes
27df4ce
c1f44fc
27df4ce
 
50526d9
 
27df4ce
 
62446de
48f3abc
27df4ce
897869f
27df4ce
 
 
 
 
897869f
529bd38
 
 
 
 
 
27df4ce
 
50526d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddf0c20
a653ebc
27df4ce
 
04baa0a
27df4ce
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import re 
import gradio as gr
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
import spacy
from spacy.matcher import Matcher 
device='cpu'

processor = AutoProcessor.from_pretrained("microsoft/git-base")
model = AutoModelForCausalLM.from_pretrained("nkasmanoff/git-planet").to(device)

nlp = spacy.load('en_core_web_sm')

def predict(image,max_length=64,device='cpu'):
    pixel_values = processor(images=image, return_tensors="pt").to(device).pixel_values
    generated_ids = model.generate(pixel_values=pixel_values, max_length=max_length)
    generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    relation = get_relation(generated_caption)

    entity_pair = get_entities(generated_caption)

    knowlege_triplet = f"{entity_pair[0]}-{relation}->{entity_pair[1]}"

    return knowlege_triplet 


def get_entities(sent):
    ## chunk 1
    ent1 = ""
    ent2 = ""

    prv_tok_dep = ""  # dependency tag of previous token in the sentence
    prv_tok_text = ""  # previous token in the sentence

    prefix = ""
    modifier = ""

    #############################################################

    for tok in nlp(sent):
        ## chunk 2
        # if token is a punctuation mark then move on to the next token
        if tok.dep_ != "punct":
            # check: token is a compound word or not
            if tok.dep_ == "compound":
                prefix = tok.text
                # if the previous word was also a 'compound' then add the current word to it
                if prv_tok_dep == "compound":
                    prefix = prv_tok_text + " " + tok.text

            # check: token is a modifier or not
            if tok.dep_.endswith("mod") == True:
                modifier = tok.text
                # if the previous word was also a 'compound' then add the current word to it
                if prv_tok_dep == "compound":
                    modifier = prv_tok_text + " " + tok.text

            ## chunk 3
            if tok.dep_.find("subj") == True:
                ent1 = modifier + " " + prefix + " " + tok.text
                prefix = ""
                modifier = ""
                prv_tok_dep = ""
                prv_tok_text = ""

                ## chunk 4
            if tok.dep_.find("obj") == True:
                ent2 = modifier + " " + prefix + " " + tok.text

            ## chunk 5  
            # update variables
            prv_tok_dep = tok.dep_
            prv_tok_text = tok.text
    #############################################################

    return [ent1.strip(), ent2.strip()]




def get_relation(sent):

    doc = nlp(sent)

    # Matcher class object 
    matcher = Matcher(nlp.vocab)

    #define the pattern 
    pattern = [{'DEP':'ROOT'},
            {'DEP':'prep','OP':"?"},
            {'DEP':'agent','OP':"?"},  
            {'POS':'ADJ','OP':"?"}] 

    matcher.add('matching_pattern', patterns=[pattern])
    matches = matcher(doc)
    k = len(matches) - 1

    span = doc[matches[k][1]:matches[k][2]] 

    return(span.text)



input = gr.inputs.Image(label="Please upload a remote sensing image", type = 'pil', optional=True)
output = gr.outputs.Textbox(type="text",label="Captions")


title = "Satellite Image Captioning"

interface = gr.Interface(
        fn=predict,
        inputs = input,
        theme="grass",
        outputs=output,
        title=title,
    )
interface.launch(debug=True)