Spaces:
Runtime error
Runtime error
VarshithaChennamsetti
commited on
Create patent_app.py
Browse files- patent_app.py +81 -0
patent_app.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Import statements
|
2 |
+
import streamlit as st
|
3 |
+
from transformers import pipeline
|
4 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from datasets import load_dataset
|
8 |
+
|
9 |
+
# Torch and torch dataloader
|
10 |
+
import torch
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
|
13 |
+
st.title('Patentability Decision App')
|
14 |
+
|
15 |
+
# Input all validation patent files
|
16 |
+
dataset_dict = load_dataset('HUPD/hupd',
|
17 |
+
name='sample',
|
18 |
+
data_files="https://huggingface.co/datasets/HUPD/hupd/blob/main/hupd_metadata_2022-02-22.feather",
|
19 |
+
icpr_label=None,
|
20 |
+
train_filing_start_date='2016-01-01',
|
21 |
+
train_filing_end_date='2016-01-21',
|
22 |
+
val_filing_start_date='2016-01-22',
|
23 |
+
val_filing_end_date='2016-01-31',
|
24 |
+
)
|
25 |
+
|
26 |
+
# Remove all untrained decisions
|
27 |
+
# Label-to-index mapping for the decision status field
|
28 |
+
decision_to_str = {'REJECTED': 0, 'ACCEPTED': 1, 'PENDING': 2, 'CONT-REJECTED': 3, 'CONT-ACCEPTED': 4, 'CONT-PENDING': 5}
|
29 |
+
|
30 |
+
# Helper function
|
31 |
+
def map_decision_to_string(example):
|
32 |
+
return {'decision': decision_to_str[example['decision']]}
|
33 |
+
|
34 |
+
# Re-labeling/mapping in validation set
|
35 |
+
val_set = dataset_dict['validation'].map(map_decision_to_string)
|
36 |
+
# Filtering only those patents that have decisions as accepted/rejected
|
37 |
+
val_set = val_set.filter(lambda e: e['decision'] <= 1)
|
38 |
+
|
39 |
+
# Display all patent numbers to select a file
|
40 |
+
patent_num = st.selectbox("Select a patent based on its number", val_set['patent_number'])
|
41 |
+
|
42 |
+
# Get the abstract and claims data to predict
|
43 |
+
if patent_num and st.button('Get Data to predict!'):
|
44 |
+
# Display the abstract and claims
|
45 |
+
val_set = val_set.filter(lambda e: e['patent_number'] == patent_num)
|
46 |
+
|
47 |
+
abstract_text = st.text_area('Abstract', val_set['abstract'])
|
48 |
+
claims_text = st.text_area('Abstract', val_set['claims'])
|
49 |
+
|
50 |
+
# Predict on those texts
|
51 |
+
if abstract_text and claims_text and st.button('Predict!'):
|
52 |
+
# Model/tokenizer name or path to finetuned model
|
53 |
+
model_name_or_path = './models/'
|
54 |
+
model_name = 'distilbert-base-uncased'
|
55 |
+
# Tokenizer
|
56 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
57 |
+
# Model
|
58 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path)
|
59 |
+
|
60 |
+
# Tokenize the validation dataset and pass it to the model for prediction
|
61 |
+
_SECTION_ = 'claims'
|
62 |
+
val_set = val_set.map(lambda e: tokenizer((e[_SECTION_]), truncation=True, padding='max_length'),batched=True)
|
63 |
+
val_set.set_format(type='torch', columns=['input_ids', 'attention_mask', 'decision'])
|
64 |
+
# Creating a dataloader and only passing one row
|
65 |
+
val_dataloader = DataLoader(val_set, batch_size=16)
|
66 |
+
batch = next(iter(val_dataloader))
|
67 |
+
inputs = (batch['input_ids'][0])
|
68 |
+
decisions = (batch['decision'][0])
|
69 |
+
|
70 |
+
# Predict
|
71 |
+
with torch.no_grad():
|
72 |
+
outputs = model(input_ids=inputs, labels=decisions).logits
|
73 |
+
|
74 |
+
# Display prediction
|
75 |
+
prediction = np.argmax(outputs, axis=-1).stride()[0] # prediction
|
76 |
+
value = {i for i in decision_to_str if decision_to_str[i]==prediction}
|
77 |
+
st.text('This is the predicted decision: ' + str(value))
|
78 |
+
|
79 |
+
# Patentability score
|
80 |
+
st.text('Probability that it will be rejected : ' + str(outputs[0][0].item() * 100))
|
81 |
+
st.text('Probability that it will be accepted : ' + str(outputs[0][1].item() * 100))
|