Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -8,6 +8,27 @@ from sklearn.preprocessing import StandardScaler
|
|
8 |
from sklearn.model_selection import train_test_split
|
9 |
import gradio as gr
|
10 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
# Define the Dataset class
|
13 |
class BankNiftyDataset(Dataset):
|
@@ -43,7 +64,8 @@ class LSTMModel(nn.Module):
|
|
43 |
return out
|
44 |
|
45 |
# Function to train the model
|
46 |
-
def train_model(
|
|
|
47 |
criterion = nn.MSELoss()
|
48 |
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
49 |
|
@@ -97,6 +119,8 @@ def generate_report(predictions, actual_values, signals):
|
|
97 |
|
98 |
# Function to process data and make predictions
|
99 |
def predict():
|
|
|
|
|
100 |
# Load the pre-existing CSV file
|
101 |
csv_path = 'BANKNIFTY_OPTION_CHAIN_data.csv'
|
102 |
if not os.path.exists(csv_path):
|
@@ -104,8 +128,11 @@ def predict():
|
|
104 |
|
105 |
# Load and preprocess data
|
106 |
data = pd.read_csv(csv_path)
|
107 |
-
scaler
|
108 |
-
|
|
|
|
|
|
|
109 |
data[['open', 'high', 'low', 'close', 'volume', 'oi']] = scaled_data
|
110 |
|
111 |
# Split data
|
@@ -120,11 +147,13 @@ def predict():
|
|
120 |
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
|
121 |
|
122 |
# Initialize and train the model
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
|
|
128 |
|
129 |
# Make predictions
|
130 |
model.eval()
|
@@ -138,18 +167,40 @@ def predict():
|
|
138 |
|
139 |
# Generate signals and report
|
140 |
signals = generate_signals(predictions, actual_values)
|
141 |
-
|
142 |
|
143 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
-
# Set up the Gradio interface
|
146 |
iface = gr.Interface(
|
147 |
-
fn=
|
148 |
inputs=None,
|
149 |
-
outputs=gr.Textbox(label="Prediction Report"),
|
150 |
title="BankNifty Option Chain Predictor",
|
151 |
-
description="
|
152 |
)
|
153 |
|
154 |
-
#
|
155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from sklearn.model_selection import train_test_split
|
9 |
import gradio as gr
|
10 |
import os
|
11 |
+
import time
|
12 |
+
from fastapi import FastAPI, BackgroundTasks
|
13 |
+
from fastapi.middleware.cors import CORSMiddleware
|
14 |
+
import asyncio
|
15 |
+
|
16 |
+
# FastAPI app
|
17 |
+
app = FastAPI()
|
18 |
+
|
19 |
+
# Add CORS middleware
|
20 |
+
app.add_middleware(
|
21 |
+
CORSMiddleware,
|
22 |
+
allow_origins=["*"],
|
23 |
+
allow_credentials=True,
|
24 |
+
allow_methods=["*"],
|
25 |
+
allow_headers=["*"],
|
26 |
+
)
|
27 |
+
|
28 |
+
# Global variables
|
29 |
+
model = None
|
30 |
+
scaler = None
|
31 |
+
latest_report = "Initializing..."
|
32 |
|
33 |
# Define the Dataset class
|
34 |
class BankNiftyDataset(Dataset):
|
|
|
64 |
return out
|
65 |
|
66 |
# Function to train the model
|
67 |
+
def train_model(train_loader, val_loader, num_epochs=10):
|
68 |
+
global model
|
69 |
criterion = nn.MSELoss()
|
70 |
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
71 |
|
|
|
119 |
|
120 |
# Function to process data and make predictions
|
121 |
def predict():
|
122 |
+
global model, scaler, latest_report
|
123 |
+
|
124 |
# Load the pre-existing CSV file
|
125 |
csv_path = 'BANKNIFTY_OPTION_CHAIN_data.csv'
|
126 |
if not os.path.exists(csv_path):
|
|
|
128 |
|
129 |
# Load and preprocess data
|
130 |
data = pd.read_csv(csv_path)
|
131 |
+
if scaler is None:
|
132 |
+
scaler = StandardScaler()
|
133 |
+
scaled_data = scaler.fit_transform(data[['open', 'high', 'low', 'close', 'volume', 'oi']])
|
134 |
+
else:
|
135 |
+
scaled_data = scaler.transform(data[['open', 'high', 'low', 'close', 'volume', 'oi']])
|
136 |
data[['open', 'high', 'low', 'close', 'volume', 'oi']] = scaled_data
|
137 |
|
138 |
# Split data
|
|
|
147 |
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
|
148 |
|
149 |
# Initialize and train the model
|
150 |
+
if model is None:
|
151 |
+
input_dim = 6
|
152 |
+
hidden_dim = 64
|
153 |
+
output_dim = len(target_cols)
|
154 |
+
model = LSTMModel(input_dim, hidden_dim, output_dim)
|
155 |
+
|
156 |
+
train_model(train_loader, val_loader)
|
157 |
|
158 |
# Make predictions
|
159 |
model.eval()
|
|
|
167 |
|
168 |
# Generate signals and report
|
169 |
signals = generate_signals(predictions, actual_values)
|
170 |
+
latest_report = generate_report(predictions, actual_values, signals)
|
171 |
|
172 |
+
return latest_report
|
173 |
+
|
174 |
+
# Background task to update the model and report
|
175 |
+
async def update_model_and_report():
|
176 |
+
global latest_report
|
177 |
+
while True:
|
178 |
+
latest_report = predict()
|
179 |
+
await asyncio.sleep(3600) # Update every hour
|
180 |
+
|
181 |
+
# Startup event to begin the background task
|
182 |
+
@app.on_event("startup")
|
183 |
+
async def startup_event():
|
184 |
+
background_tasks = BackgroundTasks()
|
185 |
+
background_tasks.add_task(update_model_and_report)
|
186 |
+
await background_tasks()
|
187 |
+
|
188 |
+
# Gradio interface
|
189 |
+
def gradio_interface():
|
190 |
+
return latest_report
|
191 |
|
|
|
192 |
iface = gr.Interface(
|
193 |
+
fn=gradio_interface,
|
194 |
inputs=None,
|
195 |
+
outputs=gr.Textbox(label="Latest Prediction Report"),
|
196 |
title="BankNifty Option Chain Predictor",
|
197 |
+
description="This app automatically generates and updates predictions and trading signals based on the latest BankNifty option chain data."
|
198 |
)
|
199 |
|
200 |
+
# Combine FastAPI and Gradio
|
201 |
+
app = gr.mount_gradio_app(app, iface, path="/")
|
202 |
+
|
203 |
+
# Run the FastAPI app
|
204 |
+
if __name__ == "__main__":
|
205 |
+
import uvicorn
|
206 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|