|
import streamlit as st |
|
import duckdb |
|
import pandas as pd |
|
from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode |
|
from st_link_analysis import st_link_analysis, NodeStyle, EdgeStyle |
|
from graph_builder import StLinkBuilder |
|
|
|
|
|
NODE_STYLES = [ |
|
NodeStyle("EVENT", "#FF7F3E", "name", "description"), |
|
NodeStyle("PERSON", "#4CAF50", "name", "person"), |
|
NodeStyle("NAME", "#2A629A", "name", "badge"), |
|
NodeStyle("ORGANIZATION", "#9C27B0", "name", "business"), |
|
NodeStyle("LOCATION", "#2196F3", "name", "place"), |
|
NodeStyle("THEME", "#FFC107", "name", "sell"), |
|
NodeStyle("COUNT", "#795548", "name", "inventory"), |
|
NodeStyle("AMOUNT", "#607D8B", "name", "wallet"), |
|
] |
|
|
|
|
|
EDGE_STYLES = [ |
|
EdgeStyle("MENTIONED_IN", caption="label", directed=True), |
|
EdgeStyle("LOCATED_IN", caption="label", directed=True), |
|
EdgeStyle("CATEGORIZED_AS", caption="label", directed=True) |
|
] |
|
|
|
|
|
GDELT_CATEGORIES = { |
|
"Metadata": ["GKGRECORDID", "DATE", "SourceCommonName", "DocumentIdentifier", "V2.1Quotations", "tone"], |
|
"Persons": ["V2EnhancedPersons", "V1Persons"], |
|
"Organizations": ["V2EnhancedOrganizations", "V1Organizations"], |
|
"Locations": ["V2EnhancedLocations", "V1Locations"], |
|
"Themes": ["V2EnhancedThemes", "V1Themes"], |
|
"Names": ["V2.1AllNames"], |
|
"Counts": ["V2.1Counts", "V1Counts"], |
|
"Amounts": ["V2.1Amounts"], |
|
"V2GCAM": ["V2GCAM"], |
|
"V2.1EnhancedDates": ["V2.1EnhancedDates"], |
|
} |
|
|
|
def initialize_db(): |
|
"""Initialize database connection and create dataset view""" |
|
con = duckdb.connect() |
|
con.execute(""" |
|
CREATE VIEW negative_tone AS ( |
|
SELECT * |
|
FROM read_parquet('hf://datasets/dwb2023/gdelt-gkg-march2020-v2@~parquet/default/negative_tone/*.parquet') |
|
); |
|
""") |
|
return con |
|
|
|
def fetch_data(con, source_filter=None, |
|
start_date=None, end_date=None, limit=50, include_all_columns=False): |
|
"""Fetch filtered data from the database""" |
|
if include_all_columns: |
|
columns = "*" |
|
else: |
|
columns = "GKGRECORDID, DATE, SourceCommonName, tone, DocumentIdentifier, 'V2.1Quotations', SourceCollectionIdentifier" |
|
|
|
query = f""" |
|
SELECT {columns} |
|
FROM negative_tone |
|
WHERE TRUE |
|
""" |
|
params = [] |
|
|
|
if source_filter: |
|
query += " AND SourceCommonName ILIKE ?" |
|
params.append(f"%{source_filter}%") |
|
if start_date: |
|
query += " AND DATE >= ?" |
|
params.append(start_date) |
|
if end_date: |
|
query += " AND DATE <= ?" |
|
params.append(end_date) |
|
if limit: |
|
query += f" LIMIT {limit}" |
|
|
|
try: |
|
result = con.execute(query, params) |
|
return result.fetchdf() |
|
except Exception as e: |
|
st.error(f"Query execution failed: {str(e)}") |
|
return pd.DataFrame() |
|
|
|
def render_data_grid(df): |
|
""" |
|
Render an interactive data grid (with builtโin filtering) and return the selected row. |
|
The grid is configured to show only the desired columns (ID, Date, Source, Tone) |
|
and allow filtering/search on each. |
|
""" |
|
st.subheader("Search and Filter Records") |
|
|
|
|
|
gb = GridOptionsBuilder.from_dataframe(df) |
|
gb.configure_default_column(filter=True, sortable=True, resizable=True) |
|
|
|
gb.configure_selection('single', use_checkbox=False) |
|
grid_options = gb.build() |
|
|
|
|
|
grid_response = AgGrid( |
|
df, |
|
gridOptions=grid_options, |
|
update_mode=GridUpdateMode.SELECTION_CHANGED, |
|
height=400, |
|
fit_columns_on_grid_load=True |
|
) |
|
|
|
selected = grid_response.get('selected_rows') |
|
if selected is not None: |
|
|
|
if isinstance(selected, pd.DataFrame): |
|
if not selected.empty: |
|
return selected.iloc[0].to_dict() |
|
|
|
elif isinstance(selected, list) and len(selected) > 0: |
|
return selected[0] |
|
return None |
|
|
|
def render_graph(record): |
|
""" |
|
Render a graph visualization for the selected record. |
|
Uses StLinkBuilder to convert the record into graph format and then |
|
displays the graph using st_link_analysis. |
|
""" |
|
st.subheader(f"Event Graph: {record.get('GKGRECORDID', 'Unknown')}") |
|
stlink_builder = StLinkBuilder() |
|
|
|
record_df = pd.DataFrame([record]) |
|
graph_data = stlink_builder.build_graph(record_df) |
|
return st_link_analysis( |
|
elements=graph_data, |
|
layout="fcose", |
|
node_styles=NODE_STYLES, |
|
edge_styles=EDGE_STYLES |
|
) |
|
|
|
def render_raw_data(record): |
|
"""Render raw GDELT data in expandable sections.""" |
|
st.header("Full Record Details") |
|
for category, fields in GDELT_CATEGORIES.items(): |
|
with st.expander(f"{category}"): |
|
for field in fields: |
|
if field in record: |
|
st.markdown(f"**{field}:**") |
|
st.text(record[field]) |
|
st.divider() |
|
|
|
def main(): |
|
st.title("๐ COVID Event Graph Explorer") |
|
st.markdown(""" |
|
**Interactive Event Graph Viewer** |
|
|
|
Filter and select individual COVID-19 event records to display their detailed graph representations. Analyze relationships between events and associated entities using the interactive graph below. |
|
""") |
|
|
|
|
|
with initialize_db() as con: |
|
if con is not None: |
|
|
|
|
|
|
|
with st.sidebar: |
|
st.header("Search Filters") |
|
source = st.text_input("Filter by source name") |
|
start_date = st.text_input("Start date (YYYYMMDD)", "20200314") |
|
end_date = st.text_input("End date (YYYYMMDD)", "20200315") |
|
limit = st.slider("Number of results to display", 10, 500, 100) |
|
|
|
|
|
df_initial = fetch_data( |
|
con=con, |
|
source_filter=source, |
|
start_date=start_date, |
|
end_date=end_date, |
|
limit=limit, |
|
include_all_columns=False |
|
) |
|
|
|
|
|
df_full = fetch_data( |
|
con=con, |
|
source_filter=source, |
|
start_date=start_date, |
|
end_date=end_date, |
|
limit=limit, |
|
include_all_columns=True |
|
) |
|
|
|
|
|
grid_df = df_initial[['GKGRECORDID', 'DATE', 'SourceCommonName', 'tone', 'DocumentIdentifier', 'SourceCollectionIdentifier']].copy() |
|
grid_df.columns = ['ID', 'Date', 'Source', 'Tone', 'Doc ID', 'Source Collection ID'] |
|
|
|
|
|
selected_row = render_data_grid(grid_df) |
|
|
|
if selected_row: |
|
|
|
selected_id = selected_row['ID'] |
|
full_record = df_full[df_full['GKGRECORDID'] == selected_id].iloc[0] |
|
|
|
|
|
render_graph(full_record) |
|
|
|
render_raw_data(full_record) |
|
else: |
|
st.info("Use the grid filters above to search and select a record.") |
|
|
|
else: |
|
st.warning("No matching records found.") |
|
|
|
|
|
con.close() |
|
|
|
main() |
|
|