Spaces:
Runtime error
Runtime error
""" | |
Demo is based on https://scikit-learn.org/stable/auto_examples/applications/plot_stock_market.html | |
""" | |
import sys | |
import numpy as np | |
import pandas as pd | |
symbol_dict = { | |
"TOT": "Total", | |
"XOM": "Exxon", | |
"CVX": "Chevron", | |
"COP": "ConocoPhillips", | |
"VLO": "Valero Energy", | |
"MSFT": "Microsoft", | |
"IBM": "IBM", | |
"TWX": "Time Warner", | |
"CMCSA": "Comcast", | |
"CVC": "Cablevision", | |
"YHOO": "Yahoo", | |
"DELL": "Dell", | |
"HPQ": "HP", | |
"AMZN": "Amazon", | |
"TM": "Toyota", | |
"CAJ": "Canon", | |
"SNE": "Sony", | |
"F": "Ford", | |
"HMC": "Honda", | |
"NAV": "Navistar", | |
"NOC": "Northrop Grumman", | |
"BA": "Boeing", | |
"KO": "Coca Cola", | |
"MMM": "3M", | |
"MCD": "McDonald's", | |
"PEP": "Pepsi", | |
"K": "Kellogg", | |
"UN": "Unilever", | |
"MAR": "Marriott", | |
"PG": "Procter Gamble", | |
"CL": "Colgate-Palmolive", | |
"GE": "General Electrics", | |
"WFC": "Wells Fargo", | |
"JPM": "JPMorgan Chase", | |
"AIG": "AIG", | |
"AXP": "American express", | |
"BAC": "Bank of America", | |
"GS": "Goldman Sachs", | |
"AAPL": "Apple", | |
"SAP": "SAP", | |
"CSCO": "Cisco", | |
"TXN": "Texas Instruments", | |
"XRX": "Xerox", | |
"WMT": "Wal-Mart", | |
"HD": "Home Depot", | |
"GSK": "GlaxoSmithKline", | |
"PFE": "Pfizer", | |
"SNY": "Sanofi-Aventis", | |
"NVS": "Novartis", | |
"KMB": "Kimberly-Clark", | |
"R": "Ryder", | |
"GD": "General Dynamics", | |
"RTN": "Raytheon", | |
"CVS": "CVS", | |
"CAT": "Caterpillar", | |
"DD": "DuPont de Nemours", | |
} | |
symbols, names = np.array(sorted(symbol_dict.items())).T | |
quotes = [] | |
for symbol in symbols: | |
print("Fetching quote history for %r" % symbol, file=sys.stderr) | |
url = ( | |
"https://raw.githubusercontent.com/scikit-learn/examples-data/" | |
"master/financial-data/{}.csv" | |
) | |
quotes.append(pd.read_csv(url.format(symbol))) | |
close_prices = np.vstack([q["close"] for q in quotes]) | |
open_prices = np.vstack([q["open"] for q in quotes]) | |
# The daily variations of the quotes are what carry the most information | |
variation = close_prices - open_prices | |
from sklearn import covariance | |
alphas = np.logspace(-1.5, 1, num=10) | |
edge_model = covariance.GraphicalLassoCV(alphas=alphas) | |
# standardize the time series: using correlations rather than covariance | |
# former is more efficient for structurerelations rather than covariance | |
# former is more efficient for structure recovery | |
X = variation.copy().T | |
X /= X.std(axis=0) | |
edge_model.fit(X) | |
from sklearn import cluster | |
_, labels = cluster.affinity_propagation(edge_model.covariance_, random_state=0) | |
n_labels = labels.max() | |
# Finding a low-dimension embedding for visualization: find the best position of | |
# the nodes (the stocks) on a 2D plane | |
from sklearn import manifold | |
node_position_model = manifold.LocallyLinearEmbedding( | |
n_components=2, eigen_solver="dense", n_neighbors=6 | |
) | |
embedding = node_position_model.fit_transform(X.T).T | |
import matplotlib.pyplot as plt | |
from matplotlib.collections import LineCollection | |
import plotly.graph_objs as go | |
def visualize_stocks(): | |
# Plot the graph of partial correlations | |
partial_correlations = edge_model.precision_.copy() | |
d = 1 / np.sqrt(np.diag(partial_correlations)) | |
partial_correlations *= d | |
partial_correlations *= d[:, np.newaxis] | |
non_zero = np.abs(np.triu(partial_correlations, k=1)) > 0.02 | |
# Plot the nodes using the coordinates of our embedding | |
scatter = go.Scatter3d( | |
x=embedding[0], | |
y=embedding[1], | |
z=embedding[2], | |
mode="markers", | |
marker=dict(size=35 * d**2, color=labels, colorscale="Viridis"), | |
hovertext=names, | |
hovertemplate="%{hovertext}<br>", | |
) | |
# # Plot the edges | |
start_idx, end_idx = np.where(non_zero) | |
# print(non_zero, non_zero.shape) | |
# print(start_idx, start_idx.shape) | |
segments = [ | |
dict( | |
x=[embedding[0][start], embedding[0][stop]], | |
y=[embedding[1][start], embedding[1][stop]], | |
z=[embedding[2][start], embedding[2][stop]], | |
colorscale="Hot", | |
color=np.abs(partial_correlations[start, stop]), | |
line=dict(width=10 * np.abs(partial_correlations[start, stop])), | |
) | |
for start, stop in zip(start_idx, end_idx) | |
] | |
fig = go.Figure(data=[scatter]) | |
for idx, segment in enumerate(segments, 1): | |
fig.add_trace( | |
go.Scatter3d( | |
x=segment["x"], # x-coordinates of the line segment | |
y=segment["y"], # y-coordinates of the line segment | |
z=segment["z"], # z-coordinates of the line segment | |
mode="lines", # type of the plot (line) | |
line=dict( | |
color=segment["color"], # color of the line | |
colorscale=segment["colorscale"], # color scale of the line | |
width=segment["line"]["width"] * 2.5, # width of the line | |
), | |
hoverinfo="none", # disable hover for the line segments | |
), | |
) | |
fig.data[idx].showlegend = False | |
return fig | |
import gradio as gr | |
title = " π Visualizing the stock market structure π" | |
with gr.Blocks(title=title) as demo: | |
gr.Markdown(f"# {title}") | |
gr.Markdown(" Data is of 56 stocks between the period of 2003 - 2008 <br>") | |
gr.Markdown( | |
" Stocks the move in together with each other are grouped together in a cluster <br>" | |
) | |
gr.Markdown( | |
" **[Demo is based on sklearn docs](https://scikit-learn.org/stable/auto_examples/applications/plot_stock_market.html)**" | |
) | |
for i in range(n_labels + 1): | |
gr.Markdown(f"Cluster {i + 1}: {', '.join(names[labels == i])}") | |
btn = gr.Button(value="Visualize") | |
btn.click( | |
visualize_stocks, outputs=gr.Plot(label="Visualizing stock into clusters") | |
) | |
gr.Markdown(f"## In progress") | |
demo.launch() | |