Spaces:
Build error
Build error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,1281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Install TA-Lib (see instructions above) then: pip install TA-Lib
|
2 |
+
import ccxt
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
import time
|
6 |
+
from sklearn.neighbors import KNeighborsClassifier
|
7 |
+
from scipy.linalg import svd
|
8 |
+
import gradio as gr
|
9 |
+
import concurrent.futures
|
10 |
+
import traceback
|
11 |
+
from datetime import datetime, timezone, timedelta
|
12 |
+
import logging
|
13 |
+
import sys
|
14 |
+
import talib # Import TA-Lib
|
15 |
+
import threading
|
16 |
+
|
17 |
+
# --- Setup Logging ---
|
18 |
+
logging.basicConfig(
|
19 |
+
level=logging.INFO,
|
20 |
+
format='%(asctime)s - %(levelname)s - [%(threadName)s:%(funcName)s] - %(message)s',
|
21 |
+
stream=sys.stdout
|
22 |
+
)
|
23 |
+
logging.getLogger().handlers[0].flush = sys.stdout.flush
|
24 |
+
|
25 |
+
# --- Parameters ---
|
26 |
+
L = 10
|
27 |
+
LAG = 11
|
28 |
+
MINUTES_PER_HOUR = 60
|
29 |
+
PREDICTION_WINDOW_HOURS = 2
|
30 |
+
TRAINING_WINDOW_HOURS = 12
|
31 |
+
TOTAL_WINDOW_HOURS = TRAINING_WINDOW_HOURS + PREDICTION_WINDOW_HOURS
|
32 |
+
K = TRAINING_WINDOW_HOURS * MINUTES_PER_HOUR # 720
|
33 |
+
WINDOW = TOTAL_WINDOW_HOURS * MINUTES_PER_HOUR # 840
|
34 |
+
FEATURES = ['open', 'high', 'low', 'close', 'volume']
|
35 |
+
D = 5
|
36 |
+
OVERLAP_STEP = 60
|
37 |
+
MIN_TRAINING_EXAMPLES = 20
|
38 |
+
MAX_COINS_TO_DISPLAY = 10
|
39 |
+
USE_SYNTHETIC_DATA_FOR_LOW_VOLUME = False
|
40 |
+
NUM_WORKERS_TRAINING = 4
|
41 |
+
NUM_WORKERS_PREDICTION = 10
|
42 |
+
|
43 |
+
# --- TA & Risk Parameters ---
|
44 |
+
TA_DATA_POINTS = 200 # Candles needed for TA calculation
|
45 |
+
RSI_PERIOD = 14
|
46 |
+
MACD_FAST = 12
|
47 |
+
MACD_SLOW = 26
|
48 |
+
MACD_SIGNAL = 9
|
49 |
+
ATR_PERIOD = 14
|
50 |
+
CONFIDENCE_THRESHOLD = 0.65 # Min confidence for Rise signal
|
51 |
+
TP1_ATR_MULTIPLIER = 1.5
|
52 |
+
TP2_ATR_MULTIPLIER = 3.0
|
53 |
+
SL_ATR_MULTIPLIER = 1.0
|
54 |
+
|
55 |
+
# --- CCXT Initialization ---
|
56 |
+
try:
|
57 |
+
exchange = ccxt.bitget({
|
58 |
+
'enableRateLimit': True,
|
59 |
+
'rateLimit': 1100,
|
60 |
+
'timeout': 45000,
|
61 |
+
'options': {'adjustForTimeDifference': True}
|
62 |
+
})
|
63 |
+
logging.info(f"Initialized {exchange.id} exchange.")
|
64 |
+
except Exception as e:
|
65 |
+
logging.exception("FATAL: Could not initialize CCXT exchange.")
|
66 |
+
sys.exit()
|
67 |
+
|
68 |
+
# --- Global Caches and Variables ---
|
69 |
+
markets_cache = None
|
70 |
+
last_markets_update = None
|
71 |
+
data_cache = {}
|
72 |
+
trained_models = {}
|
73 |
+
last_update_time = datetime.now(timezone.utc)
|
74 |
+
|
75 |
+
# --- Functions ---
|
76 |
+
|
77 |
+
def format_datetime(dt, default="N/A"):
|
78 |
+
# (Keep this function as is)
|
79 |
+
if pd.isna(dt) or dt is None:
|
80 |
+
return default
|
81 |
+
try:
|
82 |
+
if isinstance(dt, (datetime, pd.Timestamp)):
|
83 |
+
if dt.tzinfo is None:
|
84 |
+
dt = dt.replace(tzinfo=timezone.utc)
|
85 |
+
return dt.strftime('%Y-%m-%d %H:%M:%S %Z')
|
86 |
+
else:
|
87 |
+
return str(dt)
|
88 |
+
except Exception:
|
89 |
+
return default
|
90 |
+
|
91 |
+
def get_all_usdt_pairs():
|
92 |
+
# (Keep this function as is - no changes needed)
|
93 |
+
global markets_cache, last_markets_update
|
94 |
+
current_time = time.time()
|
95 |
+
cache_duration = 3600 # 1 hour
|
96 |
+
|
97 |
+
if markets_cache is not None and last_markets_update is not None:
|
98 |
+
if current_time - last_markets_update < cache_duration:
|
99 |
+
logging.info("Using cached markets list.")
|
100 |
+
if isinstance(markets_cache, list) and markets_cache:
|
101 |
+
return markets_cache
|
102 |
+
else:
|
103 |
+
logging.warning("Cached market list was invalid, fetching fresh.")
|
104 |
+
|
105 |
+
|
106 |
+
logging.info("Fetching markets from Bitget...")
|
107 |
+
try:
|
108 |
+
exchange.load_markets(reload=True)
|
109 |
+
all_symbols = list(exchange.markets.keys())
|
110 |
+
usdt_pairs = [
|
111 |
+
symbol for symbol in all_symbols
|
112 |
+
if isinstance(symbol, str)
|
113 |
+
and symbol.endswith('/USDT')
|
114 |
+
and exchange.markets.get(symbol, {}).get('active', False)
|
115 |
+
and exchange.markets.get(symbol, {}).get('spot', False)
|
116 |
+
and 'LEVERAGED' not in exchange.markets.get(symbol, {}).get('type', 'spot').upper()
|
117 |
+
and not exchange.markets.get(symbol, {}).get('inverse', False)
|
118 |
+
]
|
119 |
+
logging.info(f"Found {len(usdt_pairs)} active USDT spot pairs initially.")
|
120 |
+
if not usdt_pairs:
|
121 |
+
logging.warning("No active USDT spot pairs found.")
|
122 |
+
return ['BTC/USDT', 'ETH/USDT', 'SOL/USDT']
|
123 |
+
|
124 |
+
logging.info(f"Fetching tickers for {len(usdt_pairs)} pairs for volume sorting...")
|
125 |
+
volumes = {}
|
126 |
+
symbols_to_fetch = usdt_pairs[:]
|
127 |
+
fetched_tickers = {}
|
128 |
+
try:
|
129 |
+
if exchange.has['fetchTickers']:
|
130 |
+
batch_size_tickers = 100
|
131 |
+
for i in range(0, len(symbols_to_fetch), batch_size_tickers):
|
132 |
+
batch_symbols = symbols_to_fetch[i:i+batch_size_tickers]
|
133 |
+
logging.info(f"Fetching ticker batch {i//batch_size_tickers + 1}/{ (len(symbols_to_fetch) + batch_size_tickers -1)//batch_size_tickers }...")
|
134 |
+
retries = 2
|
135 |
+
for attempt in range(retries):
|
136 |
+
try:
|
137 |
+
batch_tickers = exchange.fetch_tickers(symbols=batch_symbols)
|
138 |
+
fetched_tickers.update(batch_tickers)
|
139 |
+
time.sleep(exchange.rateLimit / 1000 * 1.5) # Add delay
|
140 |
+
break
|
141 |
+
except (ccxt.RequestTimeout, ccxt.NetworkError) as e_timeout:
|
142 |
+
logging.warning(f"Ticker fetch timeout/network error on attempt {attempt+1}/{retries}: {e_timeout}, retrying after delay...")
|
143 |
+
time.sleep(3 * (attempt + 1))
|
144 |
+
except ccxt.RateLimitExceeded:
|
145 |
+
logging.warning(f"Rate limit exceeded fetching tickers, sleeping...")
|
146 |
+
time.sleep(10 * (attempt+1)) # Longer sleep for rate limit
|
147 |
+
except Exception as e_ticker:
|
148 |
+
logging.error(f"Error fetching ticker batch (attempt {attempt+1}): {e_ticker}")
|
149 |
+
if attempt == retries - 1: raise # Rethrow last error
|
150 |
+
time.sleep(2 * (attempt + 1))
|
151 |
+
|
152 |
+
logging.info(f"Fetched {len(fetched_tickers)} tickers using fetchTickers.")
|
153 |
+
else:
|
154 |
+
raise ccxt.NotSupported("fetchTickers not supported/enabled. Volume sorting requires it.")
|
155 |
+
|
156 |
+
except Exception as e:
|
157 |
+
logging.exception(f"Could not fetch tickers for volume sorting: {e}. Volume sorting unavailable.")
|
158 |
+
markets_cache = usdt_pairs[:MAX_COINS_TO_DISPLAY]
|
159 |
+
last_markets_update = current_time
|
160 |
+
logging.warning(f"Returning top {len(markets_cache)} unsorted pairs due to ticker error.")
|
161 |
+
return markets_cache
|
162 |
+
|
163 |
+
for symbol, ticker in fetched_tickers.items():
|
164 |
+
try:
|
165 |
+
quote_volume = ticker.get('info', {}).get('quoteVolume') # Prefer quoteVolume if available
|
166 |
+
last_price = ticker.get('last')
|
167 |
+
base_volume = ticker.get('baseVolume')
|
168 |
+
|
169 |
+
# Ensure values are convertible to float before calculation
|
170 |
+
valid_last = last_price is not None
|
171 |
+
valid_base = base_volume is not None
|
172 |
+
valid_quote = quote_volume is not None
|
173 |
+
|
174 |
+
if valid_quote:
|
175 |
+
volumes[symbol] = float(quote_volume)
|
176 |
+
elif valid_base and valid_last:
|
177 |
+
volumes[symbol] = float(base_volume) * float(last_price)
|
178 |
+
else:
|
179 |
+
volumes[symbol] = 0
|
180 |
+
except (TypeError, ValueError, KeyError, AttributeError) as e:
|
181 |
+
logging.warning(f"Could not parse volume/price for {symbol} from ticker: {ticker}. Error: {e}")
|
182 |
+
volumes[symbol] = 0
|
183 |
+
|
184 |
+
valid_volume_pairs = {k: v for k, v in volumes.items() if v > 0}
|
185 |
+
logging.info(f"Found {len(valid_volume_pairs)} pairs with non-zero volume.")
|
186 |
+
|
187 |
+
if not valid_volume_pairs:
|
188 |
+
logging.warning("No pairs with valid volume found. Returning default list.")
|
189 |
+
return ['BTC/USDT', 'ETH/USDT', 'SOL/USDT']
|
190 |
+
|
191 |
+
sorted_pairs = sorted(valid_volume_pairs.items(), key=lambda item: item[1], reverse=True)
|
192 |
+
num_pairs_to_take = min(MAX_COINS_TO_DISPLAY, len(sorted_pairs))
|
193 |
+
top_pairs = [pair[0] for pair in sorted_pairs[:num_pairs_to_take]]
|
194 |
+
logging.info(f"Selected Top {len(top_pairs)} pairs by volume. Top 5: {[p[0] for p in sorted_pairs[:5]]}")
|
195 |
+
|
196 |
+
markets_cache = top_pairs
|
197 |
+
last_markets_update = current_time
|
198 |
+
return top_pairs
|
199 |
+
|
200 |
+
except ccxt.NetworkError as e:
|
201 |
+
logging.error(f"Network error getting USDT pairs: {e}")
|
202 |
+
except ccxt.ExchangeError as e:
|
203 |
+
logging.error(f"Exchange error getting USDT pairs: {e}")
|
204 |
+
except Exception as e:
|
205 |
+
logging.exception("General error getting USDT pairs.")
|
206 |
+
|
207 |
+
logging.warning("Error fetching markets, returning default fallback list.")
|
208 |
+
return ['BTC/USDT', 'ETH/USDT', 'SOL/USDT', 'BNB/USDT', 'XRP/USDT']
|
209 |
+
|
210 |
+
def clean_and_process_ohlcv(ohlcv_list, symbol, expected_candles):
|
211 |
+
# (Keep this function as is - no changes needed)
|
212 |
+
if not ohlcv_list:
|
213 |
+
return None
|
214 |
+
try:
|
215 |
+
df = pd.DataFrame(ohlcv_list, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
|
216 |
+
initial_len = len(df)
|
217 |
+
if initial_len == 0: return None
|
218 |
+
|
219 |
+
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
|
220 |
+
df = df.drop_duplicates(subset=['timestamp'])
|
221 |
+
df = df.sort_values('timestamp')
|
222 |
+
len_after_dupes = len(df)
|
223 |
+
|
224 |
+
numeric_cols = ['open', 'high', 'low', 'close', 'volume']
|
225 |
+
for col in numeric_cols:
|
226 |
+
df[col] = pd.to_numeric(df[col], errors='coerce')
|
227 |
+
|
228 |
+
# Drop rows with NaN in essential price/volume features needed for TA-Lib
|
229 |
+
df = df.dropna(subset=numeric_cols)
|
230 |
+
len_after_na = len(df)
|
231 |
+
|
232 |
+
df.reset_index(drop=True, inplace=True)
|
233 |
+
|
234 |
+
logging.debug(f"Data cleaning for {symbol}: Initial Fetched={initial_len}, AfterDupes={len_after_dupes}, AfterNA={len_after_na}")
|
235 |
+
|
236 |
+
if len(df) >= expected_candles:
|
237 |
+
final_df = df.iloc[-expected_candles:].copy() # Take the most recent ones
|
238 |
+
return final_df
|
239 |
+
else:
|
240 |
+
return None
|
241 |
+
|
242 |
+
except Exception as e:
|
243 |
+
logging.exception(f"Error processing DataFrame for {symbol}")
|
244 |
+
return None
|
245 |
+
|
246 |
+
def fetch_historical_data(symbol, timeframe='1m', total_candles=WINDOW):
|
247 |
+
# (Keep this function as is - no changes needed)
|
248 |
+
cache_key = f"{symbol}_{timeframe}_{total_candles}"
|
249 |
+
current_time = time.time()
|
250 |
+
cache_validity_seconds = 300 # 5 minutes
|
251 |
+
|
252 |
+
if cache_key in data_cache:
|
253 |
+
cache_time, cached_data = data_cache[cache_key]
|
254 |
+
if current_time - cache_time < cache_validity_seconds:
|
255 |
+
if isinstance(cached_data, pd.DataFrame) and len(cached_data) == total_candles:
|
256 |
+
logging.debug(f"Using valid cached data for {symbol} ({len(cached_data)} candles)")
|
257 |
+
return cached_data.copy()
|
258 |
+
else:
|
259 |
+
logging.warning(f"Cache for {symbol} invalid or wrong size ({len(cached_data) if isinstance(cached_data, pd.DataFrame) else 'N/A'} vs {total_candles}), fetching fresh.")
|
260 |
+
if cache_key in data_cache: del data_cache[cache_key]
|
261 |
+
|
262 |
+
if not exchange.has['fetchOHLCV']:
|
263 |
+
logging.error(f"Exchange {exchange.id} does not support fetchOHLCV.")
|
264 |
+
return None
|
265 |
+
|
266 |
+
logging.debug(f"Fetching {total_candles} candles for {symbol} (timeframe: {timeframe})")
|
267 |
+
final_df = None
|
268 |
+
fetch_start_time = time.time()
|
269 |
+
duration_ms = exchange.parse_timeframe(timeframe) * 1000
|
270 |
+
now_ms = exchange.milliseconds()
|
271 |
+
|
272 |
+
# --- Strategy 1: Try Single Large Fetch ---
|
273 |
+
single_fetch_limit = total_candles + 200 # Buffer
|
274 |
+
single_fetch_since = now_ms - single_fetch_limit * duration_ms
|
275 |
+
try:
|
276 |
+
ohlcv_list = exchange.fetch_ohlcv(symbol, timeframe, limit=single_fetch_limit, since=single_fetch_since)
|
277 |
+
if ohlcv_list:
|
278 |
+
processed_df = clean_and_process_ohlcv(ohlcv_list, symbol, total_candles)
|
279 |
+
if processed_df is not None and len(processed_df) == total_candles:
|
280 |
+
final_df = processed_df
|
281 |
+
except ccxt.RateLimitExceeded as e:
|
282 |
+
logging.warning(f"Rate limit hit during single fetch for {symbol}, falling back: {e}")
|
283 |
+
time.sleep(5)
|
284 |
+
except (ccxt.RequestTimeout, ccxt.NetworkError) as e:
|
285 |
+
logging.warning(f"Timeout/Network error during single fetch for {symbol}, falling back: {e}")
|
286 |
+
time.sleep(2)
|
287 |
+
except ccxt.ExchangeNotAvailable as e:
|
288 |
+
logging.error(f"Exchange not available during fetch for {symbol}: {e}")
|
289 |
+
return None
|
290 |
+
except ccxt.AuthenticationError as e:
|
291 |
+
logging.error(f"Authentication error fetching {symbol}: {e}")
|
292 |
+
return None
|
293 |
+
except ccxt.ExchangeError as e:
|
294 |
+
logging.warning(f"Exchange error during single fetch for {symbol}, falling back: {e}")
|
295 |
+
except Exception as e:
|
296 |
+
logging.exception(f"Unexpected error during single fetch for {symbol}, falling back.")
|
297 |
+
|
298 |
+
# --- Strategy 2: Fallback to Iterative Chunking ---
|
299 |
+
if final_df is None:
|
300 |
+
logging.debug(f"Falling back to iterative chunk fetching for {symbol}.")
|
301 |
+
limit_per_call = exchange.safe_integer(exchange.limits.get('fetchOHLCV', {}), 'max', 1000)
|
302 |
+
limit_per_call = min(limit_per_call, 1000)
|
303 |
+
all_ohlcv_chunks = []
|
304 |
+
required_start_time_ms = now_ms - (total_candles + 5) * duration_ms
|
305 |
+
current_chunk_end_time_ms = now_ms
|
306 |
+
max_chunk_attempts = 15
|
307 |
+
attempts = 0
|
308 |
+
|
309 |
+
while attempts < max_chunk_attempts:
|
310 |
+
attempts += 1
|
311 |
+
oldest_ts_in_hand = all_ohlcv_chunks[0][0] if all_ohlcv_chunks else current_chunk_end_time_ms
|
312 |
+
if oldest_ts_in_hand <= required_start_time_ms:
|
313 |
+
logging.debug(f"Chunking: Collected enough historical range for {symbol}.")
|
314 |
+
break
|
315 |
+
|
316 |
+
fetch_limit = limit_per_call
|
317 |
+
chunk_fetch_since = oldest_ts_in_hand - fetch_limit * duration_ms
|
318 |
+
params = {}
|
319 |
+
try:
|
320 |
+
ohlcv_chunk = exchange.fetch_ohlcv(symbol, timeframe, since=chunk_fetch_since, limit=fetch_limit, params=params)
|
321 |
+
if not ohlcv_chunk:
|
322 |
+
logging.debug(f"Chunking: No more data received for {symbol} from API.")
|
323 |
+
break
|
324 |
+
|
325 |
+
new_chunk = [c for c in ohlcv_chunk if c[0] < oldest_ts_in_hand]
|
326 |
+
if not new_chunk:
|
327 |
+
break
|
328 |
+
|
329 |
+
new_chunk.sort(key=lambda x: x[0])
|
330 |
+
all_ohlcv_chunks = new_chunk + all_ohlcv_chunks
|
331 |
+
|
332 |
+
if len(new_chunk) < limit_per_call // 20 and attempts > 5:
|
333 |
+
logging.warning(f"Chunking: Received very few new candles ({len(new_chunk)}) repeatedly for {symbol}.")
|
334 |
+
break
|
335 |
+
time.sleep(exchange.rateLimit / 1000 * 1.1)
|
336 |
+
|
337 |
+
except ccxt.RateLimitExceeded as e:
|
338 |
+
logging.warning(f"Rate limit hit during chunking for {symbol}, sleeping 10s: {e}")
|
339 |
+
time.sleep(10 * (attempts/3 + 1))
|
340 |
+
except (ccxt.NetworkError, ccxt.RequestTimeout) as e:
|
341 |
+
logging.error(f"Network/Timeout error during chunking for {symbol}: {e}. Stopping.")
|
342 |
+
break
|
343 |
+
except ccxt.ExchangeError as e:
|
344 |
+
logging.error(f"Exchange error during chunking for {symbol}: {e}. Stopping.")
|
345 |
+
break
|
346 |
+
except Exception as e:
|
347 |
+
logging.exception(f"Generic error during chunking fetch for {symbol}")
|
348 |
+
break
|
349 |
+
|
350 |
+
if attempts >= max_chunk_attempts:
|
351 |
+
logging.warning(f"Max chunk fetch attempts reached for {symbol}.")
|
352 |
+
|
353 |
+
if all_ohlcv_chunks:
|
354 |
+
processed_df = clean_and_process_ohlcv(all_ohlcv_chunks, symbol, total_candles)
|
355 |
+
if processed_df is not None and len(processed_df) == total_candles:
|
356 |
+
final_df = processed_df
|
357 |
+
else:
|
358 |
+
logging.error(f"No data obtained from chunk fetching for {symbol}.")
|
359 |
+
|
360 |
+
# --- Final Check and Caching ---
|
361 |
+
if final_df is not None and len(final_df) == total_candles:
|
362 |
+
expected_cols = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
363 |
+
if all(col in final_df.columns for col in expected_cols):
|
364 |
+
data_cache[cache_key] = (current_time, final_df.copy())
|
365 |
+
return final_df
|
366 |
+
else:
|
367 |
+
logging.error(f"Final DataFrame for {symbol} missing expected columns. Won't cache.")
|
368 |
+
return None
|
369 |
+
else:
|
370 |
+
logging.error(f"Failed to fetch exactly {total_candles} candles for {symbol}. Found: {len(final_df) if final_df is not None else 0}")
|
371 |
+
return None
|
372 |
+
|
373 |
+
# --- Embedding, LLT, Normalize, Training Prep (Largely unchanged) ---
|
374 |
+
# Keep create_embedding, llt_transform, normalize_data, prepare_training_data, train_model
|
375 |
+
# as they don't depend on the TA library choice.
|
376 |
+
|
377 |
+
def create_embedding(data, l=L, lag=LAG):
|
378 |
+
# (Keep this function as is)
|
379 |
+
n = len(data)
|
380 |
+
rows = n - (l - 1) * lag
|
381 |
+
if rows <= 0:
|
382 |
+
logging.debug(f"Cannot create embedding: data length {n} too short for L={l}, Lag={lag}")
|
383 |
+
return np.array([])
|
384 |
+
A = np.zeros((rows, l))
|
385 |
+
try:
|
386 |
+
for t in range(rows):
|
387 |
+
indices = t + np.arange(l) * lag
|
388 |
+
A[t] = data[indices]
|
389 |
+
return A
|
390 |
+
except IndexError as e:
|
391 |
+
logging.error(f"IndexError during embedding: n={n}, l={l}, lag={lag}. Error: {e}")
|
392 |
+
return np.array([])
|
393 |
+
except Exception as e:
|
394 |
+
logging.exception("Error in create_embedding")
|
395 |
+
return np.array([])
|
396 |
+
|
397 |
+
def llt_transform(X_train, y_train, X_test):
|
398 |
+
# (Keep this function as is)
|
399 |
+
if not isinstance(X_train, np.ndarray) or X_train.ndim != 3 or \
|
400 |
+
not isinstance(y_train, np.ndarray) or y_train.ndim != 1 or \
|
401 |
+
not isinstance(X_test, np.ndarray) or (X_test.size > 0 and X_test.ndim != 3):
|
402 |
+
logging.error(f"LLT input type/shape error.")
|
403 |
+
return np.array([]), np.array([])
|
404 |
+
if X_train.shape[0] != y_train.shape[0]:
|
405 |
+
logging.error(f"LLT input mismatch: len(X_train) != len(y_train)")
|
406 |
+
return np.array([]), np.array([])
|
407 |
+
if X_train.size == 0 or y_train.size == 0:
|
408 |
+
logging.error("LLT requires non-empty training data.")
|
409 |
+
return np.array([]), np.array([])
|
410 |
+
if X_test.size > 0 and X_test.shape[1:] != X_train.shape[1:]:
|
411 |
+
logging.error(f"LLT train/test shape mismatch")
|
412 |
+
return np.array([]), np.array([])
|
413 |
+
|
414 |
+
try:
|
415 |
+
num_features = X_train.shape[2]
|
416 |
+
if num_features != len(FEATURES):
|
417 |
+
logging.error(f"LLT: Feature count mismatch.")
|
418 |
+
return np.array([]), np.array([])
|
419 |
+
|
420 |
+
V = {j: {'0': [], '1': []} for j in range(num_features)}
|
421 |
+
laws_computed_count = {j: {'0': 0, '1': 0} for j in range(num_features)}
|
422 |
+
|
423 |
+
for i in range(len(X_train)):
|
424 |
+
label = str(int(y_train[i]))
|
425 |
+
if label not in ['0', '1']: continue
|
426 |
+
for j in range(num_features):
|
427 |
+
feature_data = X_train[i, :, j]
|
428 |
+
A = create_embedding(feature_data, l=L, lag=LAG)
|
429 |
+
if A.shape[0] < L: continue
|
430 |
+
if np.isnan(A).any() or np.isinf(A).any(): continue
|
431 |
+
try:
|
432 |
+
S = A.T @ A
|
433 |
+
if np.isnan(S).any() or np.isinf(S).any(): continue
|
434 |
+
U, s, Vt = svd(S, full_matrices=False)
|
435 |
+
if Vt.shape[0] < L or Vt.shape[1] != L: continue
|
436 |
+
if s[-1] < 1e-9: continue
|
437 |
+
v = Vt[-1]
|
438 |
+
norm = np.linalg.norm(v)
|
439 |
+
if norm < 1e-9: continue
|
440 |
+
V[j][label].append(v / norm)
|
441 |
+
laws_computed_count[j][label] += 1
|
442 |
+
except np.linalg.LinAlgError: pass
|
443 |
+
except Exception: pass
|
444 |
+
|
445 |
+
valid_laws_exist = False
|
446 |
+
for j in V:
|
447 |
+
for c in ['0', '1']:
|
448 |
+
if laws_computed_count[j][c] > 0:
|
449 |
+
valid_vecs = [vec for vec in V[j][c] if isinstance(vec, np.ndarray) and vec.shape == (L,)]
|
450 |
+
if not valid_vecs:
|
451 |
+
V[j][c] = np.zeros((L, 0))
|
452 |
+
continue
|
453 |
+
try:
|
454 |
+
V[j][c] = np.array(valid_vecs).T
|
455 |
+
if V[j][c].shape[0] != L:
|
456 |
+
V[j][c] = np.zeros((L, 0))
|
457 |
+
else:
|
458 |
+
valid_laws_exist = True
|
459 |
+
except Exception: V[j][c] = np.zeros((L, 0))
|
460 |
+
else: V[j][c] = np.zeros((L, 0))
|
461 |
+
|
462 |
+
if not valid_laws_exist:
|
463 |
+
logging.error("LLT ERROR: No valid laws computed.")
|
464 |
+
return np.array([]), np.array([])
|
465 |
+
|
466 |
+
def transform_instance(X_instance):
|
467 |
+
transformed_features = []
|
468 |
+
if X_instance.ndim != 2 or X_instance.shape[0] != K or X_instance.shape[1] != num_features:
|
469 |
+
return np.zeros(num_features * 2 * D)
|
470 |
+
for j in range(num_features):
|
471 |
+
feature_data = X_instance[:, j]
|
472 |
+
A = create_embedding(feature_data, l=L, lag=LAG)
|
473 |
+
if A.shape[0] < L:
|
474 |
+
transformed_features.extend([0.0] * (2 * D))
|
475 |
+
continue
|
476 |
+
if np.isnan(A).any() or np.isinf(A).any():
|
477 |
+
transformed_features.extend([0.0] * (2 * D))
|
478 |
+
continue
|
479 |
+
try:
|
480 |
+
S = A.T @ A
|
481 |
+
if np.isnan(S).any() or np.isinf(S).any():
|
482 |
+
transformed_features.extend([0.0] * (2 * D))
|
483 |
+
continue
|
484 |
+
for c in ['0', '1']:
|
485 |
+
if V[j][c].shape[1] == 0:
|
486 |
+
transformed_features.extend([0.0] * D)
|
487 |
+
continue
|
488 |
+
S_V = S @ V[j][c]
|
489 |
+
if S_V.size == 0 or np.isnan(S_V).any() or np.isinf(S_V).any():
|
490 |
+
transformed_features.extend([0.0] * D)
|
491 |
+
continue
|
492 |
+
variances = np.var(S_V, axis=0)
|
493 |
+
if variances.size == 0:
|
494 |
+
transformed_features.extend([0.0] * D)
|
495 |
+
continue
|
496 |
+
variances = np.nan_to_num(variances, nan=np.finfo(variances.dtype).max, posinf=np.finfo(variances.dtype).max, neginf=np.finfo(variances.dtype).max)
|
497 |
+
num_vars_available = variances.size
|
498 |
+
num_vars_to_select = min(D, num_vars_available)
|
499 |
+
smallest_indices = np.argpartition(variances, num_vars_to_select -1)[:num_vars_to_select]
|
500 |
+
smallest_vars = np.sort(variances[smallest_indices])
|
501 |
+
padded_vars = np.pad(smallest_vars, (0, D - num_vars_to_select), 'constant', constant_values=0.0)
|
502 |
+
if np.isnan(padded_vars).any() or np.isinf(padded_vars).any():
|
503 |
+
padded_vars = np.nan_to_num(padded_vars, nan=0.0, posinf=0.0, neginf=0.0)
|
504 |
+
transformed_features.extend(padded_vars)
|
505 |
+
except Exception:
|
506 |
+
current_len = len(transformed_features)
|
507 |
+
expected_len_after_feature = (j + 1) * 2 * D
|
508 |
+
num_missing = expected_len_after_feature - current_len
|
509 |
+
if num_missing > 0: transformed_features.extend([0.0] * num_missing)
|
510 |
+
transformed_features = transformed_features[:expected_len_after_feature]
|
511 |
+
|
512 |
+
correct_len = num_features * 2 * D
|
513 |
+
if len(transformed_features) != correct_len:
|
514 |
+
if len(transformed_features) < correct_len: transformed_features.extend([0.0] * (correct_len - len(transformed_features)))
|
515 |
+
else: transformed_features = transformed_features[:correct_len]
|
516 |
+
return np.array(transformed_features)
|
517 |
+
|
518 |
+
X_train_t = np.array([transform_instance(X) for X in X_train])
|
519 |
+
X_test_t = np.array([])
|
520 |
+
if X_test.size > 0: X_test_t = np.array([transform_instance(X) for X in X_test])
|
521 |
+
|
522 |
+
expected_dim = num_features * 2 * D
|
523 |
+
if X_train_t.shape[0] != len(X_train) or (X_train_t.size > 0 and X_train_t.shape[1] != expected_dim):
|
524 |
+
logging.error(f"LLT Train transform resulted in unexpected shape.")
|
525 |
+
return np.array([]), np.array([])
|
526 |
+
if X_test.size > 0 and (X_test_t.shape[0] != len(X_test) or (X_test_t.size > 0 and X_test_t.shape[1] != expected_dim)):
|
527 |
+
logging.error(f"LLT Test transform resulted in unexpected shape.")
|
528 |
+
return X_train_t, np.array([])
|
529 |
+
|
530 |
+
return X_train_t, X_test_t
|
531 |
+
except Exception as e:
|
532 |
+
logging.exception("Error in llt_transform function")
|
533 |
+
return np.array([]), np.array([])
|
534 |
+
|
535 |
+
def normalize_data(df):
|
536 |
+
# (Keep this function as is)
|
537 |
+
normalized_df = df.copy()
|
538 |
+
if not isinstance(df, pd.DataFrame):
|
539 |
+
logging.error("Normalize_data received non-DataFrame input.")
|
540 |
+
return None
|
541 |
+
for feature in FEATURES:
|
542 |
+
if feature == 'timestamp': continue
|
543 |
+
if feature not in df.columns:
|
544 |
+
normalized_df[feature] = 0.0
|
545 |
+
continue
|
546 |
+
if pd.api.types.is_numeric_dtype(df[feature]):
|
547 |
+
mean = df[feature].mean()
|
548 |
+
std = df[feature].std()
|
549 |
+
if std is not None and not pd.isna(std) and std > 1e-9:
|
550 |
+
normalized_df[feature] = (df[feature] - mean) / std
|
551 |
+
else:
|
552 |
+
normalized_df[feature] = 0.0
|
553 |
+
if normalized_df[feature].isnull().any():
|
554 |
+
normalized_df[feature] = normalized_df[feature].fillna(0.0)
|
555 |
+
else:
|
556 |
+
normalized_df[feature] = 0.0
|
557 |
+
return normalized_df
|
558 |
+
|
559 |
+
def generate_synthetic_data(symbol, total_candles=WINDOW):
|
560 |
+
# (Keep this function as is)
|
561 |
+
logging.info(f"Generating synthetic data for {symbol} ({total_candles} candles)")
|
562 |
+
np.random.seed(int(time.time() * 1000) % (2**32 - 1))
|
563 |
+
end_time = pd.Timestamp.now(tz='UTC')
|
564 |
+
timestamps = pd.date_range(end=end_time, periods=total_candles, freq='T')
|
565 |
+
volatility = np.random.uniform(0.005, 0.03)
|
566 |
+
base_price = np.random.uniform(1, 5000)
|
567 |
+
prices = [base_price]
|
568 |
+
for _ in range(1, total_candles):
|
569 |
+
change = np.random.normal(0, volatility / np.sqrt(1440))
|
570 |
+
prices.append(prices[-1] * (1 + change))
|
571 |
+
prices = np.maximum(0.01, prices)
|
572 |
+
close_prices = np.array(prices)
|
573 |
+
open_prices = close_prices * (1 + np.random.normal(0, volatility / np.sqrt(1440) / 2, total_candles))
|
574 |
+
high_prices = np.maximum(close_prices, open_prices) * (1 + np.random.uniform(0, volatility / np.sqrt(1440), total_candles))
|
575 |
+
low_prices = np.minimum(close_prices, open_prices) * (1 - np.random.uniform(0, volatility / np.sqrt(1440), total_candles))
|
576 |
+
high_prices = np.maximum.reduce([high_prices, open_prices, close_prices])
|
577 |
+
low_prices = np.minimum.reduce([low_prices, open_prices, close_prices])
|
578 |
+
volumes = np.random.poisson(base_price * np.random.uniform(1, 10)) * (1 + np.abs(np.diff(close_prices, prepend=close_prices[0])) / close_prices * 5)
|
579 |
+
volumes = np.maximum(1, volumes)
|
580 |
+
df = pd.DataFrame({
|
581 |
+
'timestamp': timestamps, 'open': open_prices, 'high': high_prices,
|
582 |
+
'low': low_prices, 'close': close_prices, 'volume': volumes
|
583 |
+
})
|
584 |
+
for col in FEATURES: df[col] = pd.to_numeric(df[col])
|
585 |
+
df.reset_index(drop=True, inplace=True)
|
586 |
+
return df
|
587 |
+
|
588 |
+
def prepare_training_data(symbol, total_candles_to_fetch=WINDOW + OVERLAP_STEP * 20):
|
589 |
+
# (Keep this function as is)
|
590 |
+
logging.info(f"Preparing training data for {symbol}...")
|
591 |
+
try:
|
592 |
+
required_base_candles = WINDOW
|
593 |
+
estimated_candles_needed = required_base_candles + (MIN_TRAINING_EXAMPLES * 2) * OVERLAP_STEP + 500
|
594 |
+
fetch_candle_count = max(WINDOW + 500, estimated_candles_needed)
|
595 |
+
|
596 |
+
logging.info(f"Fetching {fetch_candle_count} candles for {symbol} training prep...")
|
597 |
+
df = fetch_historical_data(symbol, timeframe='1m', total_candles=fetch_candle_count)
|
598 |
+
|
599 |
+
if df is None or len(df) < WINDOW:
|
600 |
+
logging.error(f"Insufficient data fetched for {symbol} ({len(df) if df is not None else 0} < {WINDOW}).")
|
601 |
+
if USE_SYNTHETIC_DATA_FOR_LOW_VOLUME:
|
602 |
+
logging.warning(f"Attempting synthetic data generation for {symbol}.")
|
603 |
+
df = generate_synthetic_data(symbol, total_candles=WINDOW + OVERLAP_STEP * 10)
|
604 |
+
if df is None or len(df) < WINDOW:
|
605 |
+
logging.error(f"Synthetic data generation failed or insufficient for {symbol}.")
|
606 |
+
return None, None
|
607 |
+
else: logging.info(f"Using synthetic data ({len(df)} points) for {symbol}.")
|
608 |
+
else: return None, None
|
609 |
+
|
610 |
+
df_normalized = normalize_data(df)
|
611 |
+
if df_normalized is None:
|
612 |
+
logging.error(f"Normalization failed for {symbol}.")
|
613 |
+
return None, None
|
614 |
+
if df_normalized[FEATURES].isnull().any().any():
|
615 |
+
logging.warning(f"NaN values found after normalization for {symbol}. Filling with 0.")
|
616 |
+
df_normalized = df_normalized.fillna(0.0)
|
617 |
+
|
618 |
+
X, y = [], []
|
619 |
+
end_index = len(df)
|
620 |
+
start_index = WINDOW
|
621 |
+
num_windows_created = 0
|
622 |
+
|
623 |
+
for i in range(end_index, start_index - 1, -OVERLAP_STEP):
|
624 |
+
window_end_idx = i
|
625 |
+
window_start_idx = i - WINDOW
|
626 |
+
if window_start_idx < 0: continue
|
627 |
+
|
628 |
+
window_orig = df.iloc[window_start_idx:window_end_idx]
|
629 |
+
window_norm = df_normalized.iloc[window_start_idx:window_end_idx]
|
630 |
+
|
631 |
+
if len(window_orig) != WINDOW or len(window_norm) != WINDOW: continue
|
632 |
+
|
633 |
+
input_data_norm = window_norm.iloc[:K][FEATURES].values
|
634 |
+
if input_data_norm.shape[0] != K or input_data_norm.shape[1] != len(FEATURES): continue
|
635 |
+
if np.isnan(input_data_norm).any(): continue
|
636 |
+
|
637 |
+
start_price_iloc_idx = K - 1
|
638 |
+
end_price_iloc_idx = WINDOW - 1
|
639 |
+
start_price = window_orig['close'].iloc[start_price_iloc_idx]
|
640 |
+
end_price = window_orig['close'].iloc[end_price_iloc_idx]
|
641 |
+
|
642 |
+
if pd.isna(start_price) or pd.isna(end_price) or start_price <= 0: continue
|
643 |
+
|
644 |
+
X.append(input_data_norm)
|
645 |
+
y.append(1 if end_price > start_price else 0)
|
646 |
+
num_windows_created += 1
|
647 |
+
|
648 |
+
if not X:
|
649 |
+
logging.error(f"No valid windows created for {symbol}.")
|
650 |
+
return None, None
|
651 |
+
|
652 |
+
X = np.array(X)
|
653 |
+
y = np.array(y)
|
654 |
+
unique_classes, class_counts = np.unique(y, return_counts=True)
|
655 |
+
class_dist_str = ", ".join([f"Class {cls}: {count}" for cls, count in zip(unique_classes, class_counts)])
|
656 |
+
logging.info(f"Class distribution BEFORE balancing for {symbol}: {class_dist_str}")
|
657 |
+
|
658 |
+
if len(unique_classes) < 2:
|
659 |
+
logging.error(f"ONLY ONE CLASS ({unique_classes[0]}) present for {symbol}.")
|
660 |
+
return None, None
|
661 |
+
|
662 |
+
min_class_count = min(class_counts)
|
663 |
+
if min_class_count * 2 < MIN_TRAINING_EXAMPLES:
|
664 |
+
logging.error(f"Minority class count ({min_class_count}) too low for {symbol}.")
|
665 |
+
return None, None
|
666 |
+
|
667 |
+
samples_per_class = min_class_count
|
668 |
+
balanced_indices = []
|
669 |
+
for class_val in unique_classes:
|
670 |
+
class_indices = np.where(y == class_val)[0]
|
671 |
+
num_to_choose = min(samples_per_class, len(class_indices))
|
672 |
+
chosen_indices = np.random.choice(class_indices, size=num_to_choose, replace=False)
|
673 |
+
balanced_indices.extend(chosen_indices)
|
674 |
+
|
675 |
+
np.random.shuffle(balanced_indices)
|
676 |
+
X_balanced = X[balanced_indices]
|
677 |
+
y_balanced = y[balanced_indices]
|
678 |
+
final_unique, final_counts = np.unique(y_balanced, return_counts=True)
|
679 |
+
logging.info(f"Balanced dataset for {symbol}: {len(X_balanced)} instances. Final counts: {dict(zip(final_unique, final_counts))}")
|
680 |
+
|
681 |
+
if len(X_balanced) < MIN_TRAINING_EXAMPLES:
|
682 |
+
logging.error(f"Insufficient data ({len(X_balanced)}) for {symbol} AFTER balancing.")
|
683 |
+
return None, None
|
684 |
+
if X_balanced.ndim != 3 or X_balanced.shape[0] == 0 or X_balanced.shape[1] != K or X_balanced.shape[2] != len(FEATURES):
|
685 |
+
logging.error(f"Final balanced data has unexpected shape {X_balanced.shape} for {symbol}.")
|
686 |
+
return None, None
|
687 |
+
|
688 |
+
return X_balanced, y_balanced
|
689 |
+
except Exception as e:
|
690 |
+
logging.exception(f"Error preparing training data for {symbol}")
|
691 |
+
return None, None
|
692 |
+
|
693 |
+
def train_model(symbol):
|
694 |
+
# (Keep this function as is)
|
695 |
+
logging.info(f"--- Attempting to train model for {symbol} ---")
|
696 |
+
np.random.seed(int(time.time()) % (2**32 - 1))
|
697 |
+
|
698 |
+
X, y = prepare_training_data(symbol)
|
699 |
+
if X is None or y is None:
|
700 |
+
logging.error(f"Failed to prepare training data for {symbol}. Training aborted.")
|
701 |
+
return None, None, None
|
702 |
+
|
703 |
+
try:
|
704 |
+
accuracy = -1.0
|
705 |
+
if len(X) < MIN_TRAINING_EXAMPLES + 2:
|
706 |
+
logging.warning(f"Dataset for {symbol} too small ({len(X)}). Training on all data.")
|
707 |
+
X_train, y_train = X, y
|
708 |
+
X_val, y_val = np.array([]), np.array([])
|
709 |
+
else:
|
710 |
+
indices = np.random.permutation(len(X))
|
711 |
+
val_size = max(1, int(len(X) * 0.2))
|
712 |
+
split_idx = len(X) - val_size
|
713 |
+
train_indices, val_indices = indices[:split_idx], indices[split_idx:]
|
714 |
+
if len(train_indices) == 0 or len(val_indices) == 0:
|
715 |
+
logging.error(f"Train/Val split resulted in zero samples. Training on all data.")
|
716 |
+
X_train, y_train = X, y
|
717 |
+
X_val, y_val = np.array([]), np.array([])
|
718 |
+
else:
|
719 |
+
X_train, X_val = X[train_indices], X[val_indices]
|
720 |
+
y_train, y_val = y[train_indices], y[val_indices]
|
721 |
+
if len(np.unique(y_train)) < 2:
|
722 |
+
logging.error(f"Only one class in TRAINING set after split for {symbol}. Aborting.")
|
723 |
+
return None, None, None
|
724 |
+
if len(np.unique(y_val)) < 2:
|
725 |
+
logging.warning(f"Only one class in VALIDATION set after split for {symbol}.")
|
726 |
+
|
727 |
+
if X_val.size == 0: X_val_shaped = np.empty((0, K, len(FEATURES)))
|
728 |
+
else: X_val_shaped = X_val
|
729 |
+
|
730 |
+
X_train_t, X_val_t = llt_transform(X_train, y_train, X_val_shaped)
|
731 |
+
|
732 |
+
if X_train_t.size == 0:
|
733 |
+
logging.error(f"LLT training transformation failed for {symbol}. Training aborted.")
|
734 |
+
return None, None, None
|
735 |
+
if X_val.size > 0 and X_val_t.size == 0:
|
736 |
+
logging.warning(f"LLT validation transformation failed for {symbol}.")
|
737 |
+
accuracy = -1.0
|
738 |
+
if np.isnan(X_train_t).any() or np.isinf(X_train_t).any():
|
739 |
+
logging.error(f"NaN/Inf in LLT transformed TRAINING data for {symbol}. Training aborted.")
|
740 |
+
return None, None, None
|
741 |
+
if X_val_t.size > 0 and (np.isnan(X_val_t).any() or np.isinf(X_val_t).any()):
|
742 |
+
logging.warning(f"NaN/Inf in LLT transformed VALIDATION data for {symbol}.")
|
743 |
+
accuracy = -1.0
|
744 |
+
|
745 |
+
n_neighbors = min(5, len(y_train) - 1) if len(y_train) > 1 else 1
|
746 |
+
n_neighbors = max(1, n_neighbors)
|
747 |
+
if n_neighbors > 1 and n_neighbors % 2 == 0: n_neighbors -= 1
|
748 |
+
|
749 |
+
model = KNeighborsClassifier(n_neighbors=n_neighbors, weights='distance')
|
750 |
+
model.fit(X_train_t, y_train)
|
751 |
+
|
752 |
+
if accuracy != -1.0 and X_val_t.size > 0:
|
753 |
+
try:
|
754 |
+
accuracy = model.score(X_val_t, y_val)
|
755 |
+
logging.info(f"Model for {symbol} trained. Validation Accuracy: {accuracy:.3f}")
|
756 |
+
except Exception as eval_e:
|
757 |
+
logging.exception(f"Error during KNN validation scoring for {symbol}: {eval_e}")
|
758 |
+
accuracy = -1.0
|
759 |
+
elif accuracy == -1.0:
|
760 |
+
logging.info(f"Model for {symbol} trained. Validation skipped or failed.")
|
761 |
+
else:
|
762 |
+
logging.info(f"Model for {symbol} trained. No validation data.")
|
763 |
+
accuracy = -1.0
|
764 |
+
|
765 |
+
return model, X_train, y_train
|
766 |
+
except Exception as e:
|
767 |
+
logging.exception(f"Error during model training pipeline for {symbol}")
|
768 |
+
return None, None, None
|
769 |
+
|
770 |
+
def predict_real_time(symbol, model_data):
|
771 |
+
# (Keep this function as is)
|
772 |
+
if model_data is None: return "Model N/A", 0.0
|
773 |
+
model, X_train_orig_for_llt, y_train_orig_for_llt = model_data
|
774 |
+
|
775 |
+
if model is None or X_train_orig_for_llt is None or y_train_orig_for_llt is None:
|
776 |
+
logging.error(f"Invalid model data tuple for prediction on {symbol}")
|
777 |
+
return "Model Error", 0.0
|
778 |
+
if X_train_orig_for_llt.size == 0 or y_train_orig_for_llt.size == 0:
|
779 |
+
logging.error(f"Training data for LLT laws is empty for {symbol}")
|
780 |
+
return "LLT Data Error", 0.0
|
781 |
+
|
782 |
+
try:
|
783 |
+
df = fetch_historical_data(symbol, timeframe='1m', total_candles=K + 60)
|
784 |
+
if df is None or len(df) < K:
|
785 |
+
return "Data Error", 0.0
|
786 |
+
|
787 |
+
df_recent = df.iloc[-K:]
|
788 |
+
if len(df_recent) != K:
|
789 |
+
return "Data Error", 0.0
|
790 |
+
|
791 |
+
df_recent_normalized = normalize_data(df_recent)
|
792 |
+
if df_recent_normalized is None: return "Norm Error", 0.0
|
793 |
+
if df_recent_normalized[FEATURES].isnull().any().any():
|
794 |
+
df_recent_normalized = df_recent_normalized.fillna(0.0)
|
795 |
+
|
796 |
+
X_predict_input = np.array([df_recent_normalized[FEATURES].values])
|
797 |
+
_, X_predict_transformed = llt_transform(X_train_orig_for_llt, y_train_orig_for_llt, X_predict_input)
|
798 |
+
|
799 |
+
if X_predict_transformed.size == 0 or X_predict_transformed.shape[0] != 1:
|
800 |
+
return "Transform Error", 0.0
|
801 |
+
if np.isnan(X_predict_transformed).any() or np.isinf(X_predict_transformed).any():
|
802 |
+
X_predict_transformed = np.nan_to_num(X_predict_transformed, nan=0.0, posinf=0.0, neginf=0.0)
|
803 |
+
|
804 |
+
try:
|
805 |
+
probabilities = model.predict_proba(X_predict_transformed)
|
806 |
+
if probabilities.shape[0] != 1 or probabilities.shape[1] != 2:
|
807 |
+
return "Predict Error", 0.0
|
808 |
+
|
809 |
+
prob_class_1 = probabilities[0, 1]
|
810 |
+
prediction_label = "Rise" if prob_class_1 >= 0.5 else "Fall"
|
811 |
+
confidence = prob_class_1 if prediction_label == "Rise" else probabilities[0, 0]
|
812 |
+
return prediction_label, confidence
|
813 |
+
|
814 |
+
except Exception as knn_e:
|
815 |
+
logging.exception(f"Error during KNN prediction probability for {symbol}")
|
816 |
+
return "Predict Error", 0.0
|
817 |
+
|
818 |
+
except Exception as e:
|
819 |
+
logging.exception(f"Error in predict_real_time for {symbol}")
|
820 |
+
return "Error", 0.0
|
821 |
+
|
822 |
+
# --- TA Calculation Function (Using TA-Lib) ---
|
823 |
+
def calculate_ta_indicators(df_ta):
|
824 |
+
"""
|
825 |
+
Calculates TA indicators (RSI, MACD, VWAP, ATR) using TA-Lib.
|
826 |
+
Requires df_ta to have 'open', 'high', 'low', 'close', 'volume' columns.
|
827 |
+
"""
|
828 |
+
indicators = {'RSI': np.nan, 'MACD': np.nan, 'MACD_Signal': np.nan, 'MACD_Hist': np.nan, 'VWAP': np.nan, 'ATR': np.nan}
|
829 |
+
required_cols = ['open', 'high', 'low', 'close', 'volume']
|
830 |
+
min_len_needed = max(RSI_PERIOD, MACD_SLOW, ATR_PERIOD) + 1 # TA-Lib often needs P+1
|
831 |
+
|
832 |
+
if df_ta is None or len(df_ta) < min_len_needed:
|
833 |
+
logging.warning(f"Insufficient data ({len(df_ta) if df_ta is not None else 0} < {min_len_needed}) for TA-Lib calculations.")
|
834 |
+
return indicators
|
835 |
+
|
836 |
+
# Ensure columns exist
|
837 |
+
if not all(col in df_ta.columns for col in required_cols):
|
838 |
+
logging.error(f"Missing required columns for TA-Lib: Have {df_ta.columns}, Need {required_cols}")
|
839 |
+
return indicators
|
840 |
+
|
841 |
+
# --- Prepare data for TA-Lib (NumPy arrays, handle NaNs) ---
|
842 |
+
df_ta = df_ta[required_cols].copy() # Work on a copy with only needed columns
|
843 |
+
|
844 |
+
# Check for NaNs BEFORE converting to numpy, TA-Lib generally dislikes them
|
845 |
+
if df_ta.isnull().values.any():
|
846 |
+
nan_count = df_ta.isnull().sum().sum()
|
847 |
+
logging.warning(f"Found {nan_count} NaN(s) in TA input data. Applying ffill()...")
|
848 |
+
df_ta.ffill(inplace=True) # Forward fill NaNs
|
849 |
+
# Check again after ffill - if NaNs remain (e.g., at the start), need more robust handling
|
850 |
+
if df_ta.isnull().values.any():
|
851 |
+
logging.error(f"NaNs still present after ffill. Cannot proceed with TA-Lib.")
|
852 |
+
return indicators # Return NaNs
|
853 |
+
|
854 |
+
try:
|
855 |
+
# Convert to NumPy arrays of type float
|
856 |
+
open_p = df_ta['open'].values.astype(float)
|
857 |
+
high_p = df_ta['high'].values.astype(float)
|
858 |
+
low_p = df_ta['low'].values.astype(float)
|
859 |
+
close_p = df_ta['close'].values.astype(float)
|
860 |
+
volume_p = df_ta['volume'].values.astype(float)
|
861 |
+
|
862 |
+
# --- Calculate Indicators using TA-Lib ---
|
863 |
+
# RSI
|
864 |
+
rsi_values = talib.RSI(close_p, timeperiod=RSI_PERIOD)
|
865 |
+
indicators['RSI'] = rsi_values[-1] if len(rsi_values) > 0 else np.nan
|
866 |
+
|
867 |
+
# MACD
|
868 |
+
macd_line, signal_line, hist = talib.MACD(close_p, fastperiod=MACD_FAST, slowperiod=MACD_SLOW, signalperiod=MACD_SIGNAL)
|
869 |
+
indicators['MACD'] = macd_line[-1] if len(macd_line) > 0 else np.nan
|
870 |
+
indicators['MACD_Signal'] = signal_line[-1] if len(signal_line) > 0 else np.nan
|
871 |
+
indicators['MACD_Hist'] = hist[-1] if len(hist) > 0 else np.nan
|
872 |
+
|
873 |
+
# ATR
|
874 |
+
atr_values = talib.ATR(high_p, low_p, close_p, timeperiod=ATR_PERIOD)
|
875 |
+
indicators['ATR'] = atr_values[-1] if len(atr_values) > 0 else np.nan
|
876 |
+
|
877 |
+
# VWAP (Manual Calculation - TA-Lib doesn't have it built-in)
|
878 |
+
typical_price = (high_p + low_p + close_p) / 3.0
|
879 |
+
tp_vol = typical_price * volume_p
|
880 |
+
cumulative_volume = np.cumsum(volume_p)
|
881 |
+
# Avoid division by zero if volume is zero for initial periods
|
882 |
+
if cumulative_volume[-1] > 1e-12: # Check if there's significant volume
|
883 |
+
vwap_values = np.cumsum(tp_vol) / np.maximum(cumulative_volume, 1e-12) # Avoid div by zero strictly
|
884 |
+
indicators['VWAP'] = vwap_values[-1]
|
885 |
+
else:
|
886 |
+
indicators['VWAP'] = np.nan # VWAP undefined if no volume
|
887 |
+
|
888 |
+
# Final check for NaNs in results (TA-Lib might return NaN for initial periods)
|
889 |
+
for key, value in indicators.items():
|
890 |
+
if pd.isna(value):
|
891 |
+
indicators[key] = np.nan # Ensure consistent NaN representation
|
892 |
+
|
893 |
+
# logging.debug(f"TA-Lib Indicators calculated: {indicators}")
|
894 |
+
return indicators
|
895 |
+
|
896 |
+
except Exception as ta_e:
|
897 |
+
logging.exception(f"Error calculating TA indicators using TA-Lib: {ta_e}")
|
898 |
+
return {k: np.nan for k in indicators} # Return NaNs on error
|
899 |
+
|
900 |
+
# --- Trade Level Calculation (Unchanged) ---
|
901 |
+
def calculate_trade_levels(prediction, confidence, current_price, atr):
|
902 |
+
# (Keep this function as is - no changes needed)
|
903 |
+
levels = {'Entry': np.nan, 'TP1': np.nan, 'TP2': np.nan, 'SL': np.nan}
|
904 |
+
if pd.isna(current_price) or current_price <= 0 or pd.isna(atr) or atr <= 0:
|
905 |
+
return levels
|
906 |
+
if prediction == "Rise" and confidence >= CONFIDENCE_THRESHOLD:
|
907 |
+
entry_price = current_price
|
908 |
+
levels['Entry'] = entry_price
|
909 |
+
levels['TP1'] = entry_price + TP1_ATR_MULTIPLIER * atr
|
910 |
+
levels['TP2'] = entry_price + TP2_ATR_MULTIPLIER * atr
|
911 |
+
levels['SL'] = entry_price - SL_ATR_MULTIPLIER * atr
|
912 |
+
levels['SL'] = max(0.01, levels['SL'])
|
913 |
+
# Add Fall logic here if needed
|
914 |
+
return levels
|
915 |
+
|
916 |
+
# --- Concurrency Wrappers (Unchanged) ---
|
917 |
+
def train_model_task(coin):
|
918 |
+
# (Keep this function as is)
|
919 |
+
try:
|
920 |
+
result = train_model(coin)
|
921 |
+
if result != (None, None, None):
|
922 |
+
model, X_train_orig, y_train_orig = result
|
923 |
+
return coin, (model, X_train_orig, y_train_orig)
|
924 |
+
else:
|
925 |
+
return coin, None
|
926 |
+
except Exception as e:
|
927 |
+
logging.exception(f"Unhandled exception in train_model_task for {coin}")
|
928 |
+
return coin, None
|
929 |
+
|
930 |
+
def train_all_models(coin_list=None, num_workers=NUM_WORKERS_TRAINING):
|
931 |
+
# (Keep this function as is)
|
932 |
+
global trained_models
|
933 |
+
start_time = time.time()
|
934 |
+
if coin_list is None or not coin_list:
|
935 |
+
logging.info("No coin list provided, fetching top coins by volume...")
|
936 |
+
try:
|
937 |
+
coin_list = get_all_usdt_pairs()
|
938 |
+
if not coin_list:
|
939 |
+
msg = "Failed to fetch coin list even with fallback. Training aborted."
|
940 |
+
logging.error(msg)
|
941 |
+
return msg
|
942 |
+
except Exception as e:
|
943 |
+
msg = f"Error fetching coin list: {e}. Training aborted."
|
944 |
+
logging.exception(msg)
|
945 |
+
return msg
|
946 |
+
|
947 |
+
logging.info(f"Starting training for {len(coin_list)} coins using {num_workers} workers...")
|
948 |
+
results_log = []
|
949 |
+
successful_trains = 0
|
950 |
+
failed_trains = 0
|
951 |
+
new_models = {}
|
952 |
+
|
953 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers, thread_name_prefix='TrainWorker') as executor:
|
954 |
+
future_to_coin = {executor.submit(train_model_task, coin): coin for coin in coin_list}
|
955 |
+
processed_count = 0
|
956 |
+
total_coins = len(coin_list)
|
957 |
+
for future in concurrent.futures.as_completed(future_to_coin):
|
958 |
+
processed_count += 1
|
959 |
+
coin = future_to_coin[future]
|
960 |
+
try:
|
961 |
+
returned_coin, model_data = future.result()
|
962 |
+
if returned_coin == coin and model_data is not None:
|
963 |
+
new_models[returned_coin] = model_data
|
964 |
+
results_log.append(f"✅ {returned_coin}: Model trained successfully.")
|
965 |
+
successful_trains += 1
|
966 |
+
else:
|
967 |
+
results_log.append(f"❌ {coin}: Model training failed (check logs).")
|
968 |
+
failed_trains += 1
|
969 |
+
except Exception as e:
|
970 |
+
results_log.append(f"❌ {coin}: Training task generated exception: {e}")
|
971 |
+
failed_trains += 1
|
972 |
+
logging.exception(f"Exception from training future for {coin}")
|
973 |
+
if processed_count % 10 == 0 or processed_count == total_coins:
|
974 |
+
logging.info(f"Training progress: {processed_count}/{total_coins} coins processed.")
|
975 |
+
logging.getLogger().handlers[0].flush()
|
976 |
+
|
977 |
+
trained_models.update(new_models)
|
978 |
+
logging.info(f"Updated global models dictionary. Total models now: {len(trained_models)}")
|
979 |
+
|
980 |
+
end_time = time.time()
|
981 |
+
duration = end_time - start_time
|
982 |
+
completion_message = (
|
983 |
+
f"Training run completed in {duration:.2f} seconds.\n"
|
984 |
+
f"Successfully trained: {successful_trains}\n"
|
985 |
+
f"Failed to train: {failed_trains}\n"
|
986 |
+
f"Total models available now: {len(trained_models)}"
|
987 |
+
)
|
988 |
+
logging.info(completion_message)
|
989 |
+
return completion_message + "\n\n" + "\n".join(results_log[-20:])
|
990 |
+
|
991 |
+
# --- Update Predictions Table (Mostly Unchanged, uses new TA function) ---
|
992 |
+
def update_predictions_table():
|
993 |
+
# (This function structure remains the same, it just calls the new calculate_ta_indicators)
|
994 |
+
global last_update_time
|
995 |
+
logging.info("--- Updating Predictions Table ---")
|
996 |
+
start_time = time.time()
|
997 |
+
predictions_data = {}
|
998 |
+
current_models = trained_models.copy()
|
999 |
+
|
1000 |
+
if not current_models:
|
1001 |
+
msg = "No models available. Please train first."
|
1002 |
+
logging.warning(msg)
|
1003 |
+
cols = ['Rank', 'Coin', 'Prediction', 'Confidence', 'Price', 'Volume (Quote)', 'Entry', 'Entry Time', 'Exit Time', 'TP1', 'TP2', 'SL', 'RSI', 'MACD Hist', 'VWAP', 'ATR']
|
1004 |
+
return pd.DataFrame([], columns=cols), msg
|
1005 |
+
|
1006 |
+
symbols_with_models = list(current_models.keys())
|
1007 |
+
logging.info(f"Step 1: Generating predictions for {len(symbols_with_models)} models...")
|
1008 |
+
# --- Stage 1: Get Predictions Concurrently ---
|
1009 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_WORKERS_PREDICTION, thread_name_prefix='PredictWorker') as executor:
|
1010 |
+
future_to_coin_pred = {executor.submit(predict_real_time, coin, model_data): coin for coin, model_data in current_models.items()}
|
1011 |
+
pred_success = 0
|
1012 |
+
pred_fail = 0
|
1013 |
+
for future in concurrent.futures.as_completed(future_to_coin_pred):
|
1014 |
+
coin = future_to_coin_pred[future]
|
1015 |
+
try:
|
1016 |
+
pred, conf = future.result()
|
1017 |
+
if pred not in ["Model N/A", "Model Error", "Data Error", "Norm Error", "LLT Data Error", "Transform Error", "Predict Error", "Error"]:
|
1018 |
+
predictions_data[coin] = {'prediction': pred, 'confidence': float(conf)}
|
1019 |
+
pred_success += 1
|
1020 |
+
else:
|
1021 |
+
predictions_data[coin] = {'prediction': pred, 'confidence': 0.0}
|
1022 |
+
pred_fail += 1
|
1023 |
+
except Exception as e:
|
1024 |
+
logging.exception(f"Error getting prediction result for {coin}")
|
1025 |
+
predictions_data[coin] = {'prediction': "Future Error", 'confidence': 0.0}
|
1026 |
+
pred_fail +=1
|
1027 |
+
logging.info(f"Step 1 Complete: Predictions generated ({pred_success} success, {pred_fail} fail).")
|
1028 |
+
|
1029 |
+
# --- Stage 2: Fetch Current Tickers & TA Data Concurrently ---
|
1030 |
+
symbols_to_fetch_data = list(predictions_data.keys())
|
1031 |
+
if not symbols_to_fetch_data:
|
1032 |
+
logging.warning("No symbols with predictions to fetch data for.")
|
1033 |
+
cols = ['Rank', 'Coin', 'Prediction', 'Confidence', 'Price', 'Volume (Quote)', 'Entry', 'Entry Time', 'Exit Time', 'TP1', 'TP2', 'SL', 'RSI', 'MACD Hist', 'VWAP', 'ATR']
|
1034 |
+
return pd.DataFrame([], columns=cols), "No symbols processed."
|
1035 |
+
|
1036 |
+
logging.info(f"Step 2: Fetching Tickers and {TA_DATA_POINTS} OHLCV candles for {len(symbols_to_fetch_data)} symbols...")
|
1037 |
+
tickers_data = {}
|
1038 |
+
ohlcv_data = {}
|
1039 |
+
try: # Fetch Tickers
|
1040 |
+
batch_size_tickers = 100
|
1041 |
+
fetched_tickers_batch = {}
|
1042 |
+
for i in range(0, len(symbols_to_fetch_data), batch_size_tickers):
|
1043 |
+
batch_symbols = symbols_to_fetch_data[i:i+batch_size_tickers]
|
1044 |
+
try:
|
1045 |
+
batch_tickers = exchange.fetch_tickers(symbols=batch_symbols)
|
1046 |
+
fetched_tickers_batch.update(batch_tickers)
|
1047 |
+
time.sleep(exchange.rateLimit / 1000 * 0.5)
|
1048 |
+
except Exception as e:
|
1049 |
+
logging.error(f"Failed to fetch ticker batch starting with {batch_symbols[0]}: {e}")
|
1050 |
+
tickers_data = fetched_tickers_batch
|
1051 |
+
logging.info(f"Fetched {len(tickers_data)} tickers.")
|
1052 |
+
except Exception as e:
|
1053 |
+
logging.exception(f"Error fetching tickers in prediction update: {e}")
|
1054 |
+
|
1055 |
+
# Fetch OHLCV for TA
|
1056 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_WORKERS_PREDICTION, thread_name_prefix='TADataWorker') as executor:
|
1057 |
+
future_to_coin_ohlcv = {executor.submit(fetch_historical_data, coin, '1m', TA_DATA_POINTS): coin for coin in symbols_to_fetch_data}
|
1058 |
+
for future in concurrent.futures.as_completed(future_to_coin_ohlcv):
|
1059 |
+
coin = future_to_coin_ohlcv[future]
|
1060 |
+
try:
|
1061 |
+
df_ta = future.result()
|
1062 |
+
if df_ta is not None and len(df_ta) == TA_DATA_POINTS:
|
1063 |
+
# Ensure standard column names expected by calculate_ta_indicators
|
1064 |
+
df_ta.columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
1065 |
+
ohlcv_data[coin] = df_ta
|
1066 |
+
except Exception as e:
|
1067 |
+
logging.exception(f"Error fetching TA OHLCV data for {coin}")
|
1068 |
+
logging.info(f"Step 2 Complete: Fetched TA data for {len(ohlcv_data)} symbols.")
|
1069 |
+
|
1070 |
+
# --- Stage 3: Calculate TA & Trade Levels ---
|
1071 |
+
logging.info(f"Step 3: Calculating TA (using TA-Lib) and Trade Levels...")
|
1072 |
+
final_results = []
|
1073 |
+
processing_time = datetime.now(timezone.utc)
|
1074 |
+
|
1075 |
+
for symbol in symbols_to_fetch_data:
|
1076 |
+
pred_info = predictions_data.get(symbol, {'prediction': 'Missing Pred', 'confidence': 0.0})
|
1077 |
+
ticker = tickers_data.get(symbol)
|
1078 |
+
df_ta = ohlcv_data.get(symbol) # This df should have standard columns now
|
1079 |
+
|
1080 |
+
current_price, quote_volume = np.nan, np.nan
|
1081 |
+
ta_indicators = {k: np.nan for k in ['RSI', 'MACD', 'MACD_Signal', 'MACD_Hist', 'VWAP', 'ATR']}
|
1082 |
+
trade_levels = {k: np.nan for k in ['Entry', 'TP1', 'TP2', 'SL']}
|
1083 |
+
entry_time, exit_time = pd.NaT, pd.NaT
|
1084 |
+
|
1085 |
+
if ticker and isinstance(ticker, dict):
|
1086 |
+
current_price = ticker.get('last', np.nan)
|
1087 |
+
quote_volume = ticker.get('info', {}).get('quoteVolume')
|
1088 |
+
if quote_volume is None:
|
1089 |
+
base_volume = ticker.get('baseVolume')
|
1090 |
+
if base_volume is not None and current_price is not None:
|
1091 |
+
try: quote_volume = float(base_volume) * float(current_price)
|
1092 |
+
except (ValueError, TypeError): quote_volume = np.nan
|
1093 |
+
try: current_price = float(current_price) if current_price is not None else np.nan
|
1094 |
+
except (ValueError, TypeError): current_price = np.nan
|
1095 |
+
try: quote_volume = float(quote_volume) if quote_volume is not None else np.nan
|
1096 |
+
except (ValueError, TypeError): quote_volume = np.nan
|
1097 |
+
|
1098 |
+
# Calculate TA using the new function
|
1099 |
+
if df_ta is not None:
|
1100 |
+
ta_indicators = calculate_ta_indicators(df_ta) # Calls the TA-Lib version
|
1101 |
+
|
1102 |
+
if pred_info['prediction'] in ["Rise", "Fall"] and not pd.isna(current_price) and not pd.isna(ta_indicators['ATR']):
|
1103 |
+
trade_levels = calculate_trade_levels(pred_info['prediction'], pred_info['confidence'], current_price, ta_indicators['ATR'])
|
1104 |
+
if not pd.isna(trade_levels['Entry']):
|
1105 |
+
entry_time = processing_time
|
1106 |
+
exit_time = processing_time + timedelta(hours=PREDICTION_WINDOW_HOURS)
|
1107 |
+
|
1108 |
+
final_results.append({
|
1109 |
+
'coin': symbol.split('/')[0], 'full_symbol': symbol,
|
1110 |
+
'prediction': pred_info['prediction'], 'confidence': pred_info['confidence'],
|
1111 |
+
'price': current_price, 'volume': quote_volume,
|
1112 |
+
'entry': trade_levels['Entry'], 'entry_time': entry_time, 'exit_time': exit_time,
|
1113 |
+
'tp1': trade_levels['TP1'], 'tp2': trade_levels['TP2'], 'sl': trade_levels['SL'],
|
1114 |
+
'rsi': ta_indicators['RSI'], 'macd_hist': ta_indicators['MACD_Hist'],
|
1115 |
+
'vwap': ta_indicators['VWAP'], 'atr': ta_indicators['ATR']
|
1116 |
+
})
|
1117 |
+
logging.info("Step 3 Complete: TA and Trade Levels calculated.")
|
1118 |
+
|
1119 |
+
# --- Stage 4: Sort and Format (Unchanged) ---
|
1120 |
+
def sort_key(item):
|
1121 |
+
pred, conf = item['prediction'], item['confidence']
|
1122 |
+
if pred == "Rise" and conf >= CONFIDENCE_THRESHOLD and not pd.isna(item['entry']): return (0, -conf)
|
1123 |
+
elif pred == "Rise": return (1, -conf)
|
1124 |
+
elif pred == "Fall": return (2, -conf)
|
1125 |
+
else: return (3, 0)
|
1126 |
+
final_results.sort(key=sort_key)
|
1127 |
+
|
1128 |
+
formatted_output = []
|
1129 |
+
for i, p in enumerate(final_results[:MAX_COINS_TO_DISPLAY]):
|
1130 |
+
formatted_output.append([
|
1131 |
+
i + 1, p['coin'], p['prediction'], f"{p['confidence']:.3f}",
|
1132 |
+
f"{p['price']:.4f}" if not pd.isna(p['price']) else "N/A",
|
1133 |
+
f"{p['volume']:,.0f}" if not pd.isna(p['volume']) else "N/A",
|
1134 |
+
f"{p['entry']:.4f}" if not pd.isna(p['entry']) else "N/A",
|
1135 |
+
format_datetime(p['entry_time'], "N/A"), format_datetime(p['exit_time'], "N/A"),
|
1136 |
+
f"{p['tp1']:.4f}" if not pd.isna(p['tp1']) else "N/A",
|
1137 |
+
f"{p['tp2']:.4f}" if not pd.isna(p['tp2']) else "N/A",
|
1138 |
+
f"{p['sl']:.4f}" if not pd.isna(p['sl']) else "N/A",
|
1139 |
+
f"{p['rsi']:.2f}" if not pd.isna(p['rsi']) else "N/A",
|
1140 |
+
f"{p['macd_hist']:.4f}" if not pd.isna(p['macd_hist']) else "N/A",
|
1141 |
+
f"{p['vwap']:.4f}" if not pd.isna(p['vwap']) else "N/A",
|
1142 |
+
f"{p['atr']:.4f}" if not pd.isna(p['atr']) else "N/A",
|
1143 |
+
])
|
1144 |
+
|
1145 |
+
output_columns = [
|
1146 |
+
'Rank', 'Coin', 'Prediction', 'Confidence', 'Price', 'Volume (Quote)',
|
1147 |
+
'Entry', 'Entry Time', 'Exit Time', 'TP1', 'TP2', 'SL',
|
1148 |
+
'RSI', 'MACD Hist', 'VWAP', 'ATR'
|
1149 |
+
]
|
1150 |
+
output_df = pd.DataFrame(formatted_output, columns=output_columns)
|
1151 |
+
|
1152 |
+
end_time = time.time()
|
1153 |
+
duration = end_time - start_time
|
1154 |
+
last_update_time = processing_time
|
1155 |
+
status_message = f"Predictions updated ({len(final_results)} symbols processed) in {duration:.2f}s. Last update: {format_datetime(last_update_time)}"
|
1156 |
+
logging.info(status_message)
|
1157 |
+
|
1158 |
+
return output_df, status_message
|
1159 |
+
|
1160 |
+
|
1161 |
+
# --- Gradio UI Handlers (Unchanged) ---
|
1162 |
+
def handle_train_click(coin_input, num_workers):
|
1163 |
+
# (Keep this function as is)
|
1164 |
+
logging.info(f"Train button clicked. Workers: {num_workers}")
|
1165 |
+
coins = None
|
1166 |
+
num_workers = int(num_workers)
|
1167 |
+
if coin_input and coin_input.strip():
|
1168 |
+
raw_coins = coin_input.replace(',', ' ').split()
|
1169 |
+
coins = []
|
1170 |
+
valid = True
|
1171 |
+
for c in raw_coins:
|
1172 |
+
coin_upper = c.strip().upper()
|
1173 |
+
if '/' not in coin_upper: coin_upper += '/USDT'
|
1174 |
+
if coin_upper.endswith('/USDT'): coins.append(coin_upper)
|
1175 |
+
else:
|
1176 |
+
valid = False
|
1177 |
+
logging.error(f"Invalid coin format: {c}. Must be SYMBOL or SYMBOL/USDT.")
|
1178 |
+
break
|
1179 |
+
if not valid: return "Error: Custom coins must be valid SYMBOL or SYMBOL/USDT pairs."
|
1180 |
+
logging.info(f"Training requested for custom coin list: {coins}")
|
1181 |
+
else:
|
1182 |
+
logging.info("Training requested for top coins by volume.")
|
1183 |
+
train_status = train_all_models(coin_list=coins, num_workers=num_workers)
|
1184 |
+
return f"--- Training Run ---:\n{train_status}\n\n---> Press 'Refresh Predictions' <---"
|
1185 |
+
|
1186 |
+
def handle_refresh_click():
|
1187 |
+
# (Keep this function as is)
|
1188 |
+
logging.info("Refresh button clicked.")
|
1189 |
+
try:
|
1190 |
+
df, status = update_predictions_table()
|
1191 |
+
return df, status
|
1192 |
+
except Exception as e:
|
1193 |
+
logging.exception("Error during handle_refresh_click")
|
1194 |
+
cols = ['Rank', 'Coin', 'Prediction', 'Confidence', 'Price', 'Volume (Quote)', 'Entry', 'Entry Time', 'Exit Time', 'TP1', 'TP2', 'SL', 'RSI', 'MACD Hist', 'VWAP', 'ATR']
|
1195 |
+
return pd.DataFrame([], columns=cols), f"Error updating predictions: {e}"
|
1196 |
+
|
1197 |
+
# --- Gradio Interface Definition (Unchanged) ---
|
1198 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
1199 |
+
gr.Markdown("# Cryptocurrency Prediction & TA Signal Explorer (LLT-KNN + TA-Lib)") # Updated title slightly
|
1200 |
+
gr.Markdown(f"""
|
1201 |
+
Predicts **{PREDICTION_WINDOW_HOURS}-hour** price direction (Rise/Fall) using LLT-KNN.
|
1202 |
+
Displays current price, volume, TA indicators (RSI, MACD, VWAP, ATR calculated using **TA-Lib**), and potential trade levels for **Rise** signals meeting confidence >= **{CONFIDENCE_THRESHOLD}**.
|
1203 |
+
TP/SL levels based on **{TP1_ATR_MULTIPLIER}x / {TP2_ATR_MULTIPLIER}x / {SL_ATR_MULTIPLIER}x ATR({ATR_PERIOD})**.
|
1204 |
+
**Warning:** Educational. High risk. Not financial advice. Ensure TA-Lib is correctly installed.
|
1205 |
+
""")
|
1206 |
+
|
1207 |
+
with gr.Row():
|
1208 |
+
with gr.Column(scale=4):
|
1209 |
+
prediction_df = gr.Dataframe(
|
1210 |
+
headers=[
|
1211 |
+
'Rank', 'Coin', 'Prediction', 'Confidence', 'Price', 'Volume (Quote)',
|
1212 |
+
'Entry', 'Entry Time', 'Exit Time', 'TP1', 'TP2', 'SL',
|
1213 |
+
'RSI', 'MACD Hist', 'VWAP', 'ATR'
|
1214 |
+
],
|
1215 |
+
datatype=[
|
1216 |
+
'number', 'str', 'str', 'str', 'str', 'str',
|
1217 |
+
'str', 'str', 'str', 'str', 'str', 'str',
|
1218 |
+
'str', 'str', 'str', 'str'
|
1219 |
+
],
|
1220 |
+
row_count=15, col_count=(16, "fixed"), label="Predictions & TA Signals", wrap=True,
|
1221 |
+
)
|
1222 |
+
with gr.Column(scale=1):
|
1223 |
+
with gr.Accordion("Train Models", open=True):
|
1224 |
+
coin_input = gr.Textbox(label="Train Specific Coins (e.g., BTC, ETH/USDT)", placeholder="Leave empty for top coins by volume")
|
1225 |
+
max_workers_slider = gr.Slider(minimum=1, maximum=10, value=NUM_WORKERS_TRAINING, step=1, label="Parallel Training Workers")
|
1226 |
+
train_button = gr.Button("Start Training", variant="primary")
|
1227 |
+
refresh_button = gr.Button("Refresh Predictions", variant="secondary")
|
1228 |
+
status_text = gr.Textbox(label="Status Log", lines=15, interactive=False, max_lines=30)
|
1229 |
+
|
1230 |
+
gr.Markdown(
|
1231 |
+
"""
|
1232 |
+
## Notes
|
1233 |
+
- **TA-Lib**: This version uses the TA-Lib library for indicators. Ensure it's installed correctly (can be tricky).
|
1234 |
+
- **Data**: Fetches OHLCV data (Bitget, 1-min). Uses cache. Handles rate limits.
|
1235 |
+
- **Training**: Uses past ~14h data (12h train, 2h predict). Normalizes, balances classes, applies LLT, trains KNN.
|
1236 |
+
- **Prediction**: Uses latest 12h data for KNN input.
|
1237 |
+
- **Trade Levels**: Only shown for 'Rise' predictions above confidence threshold. Based on current price and ATR volatility. **Highly speculative.**
|
1238 |
+
- **Sorting**: Table sorted by (Potential Rise Signals > Other Rise > Fall > Errors), then by confidence descending.
|
1239 |
+
- **Refresh**: Fetches latest prices/TA and re-evaluates signals.
|
1240 |
+
"""
|
1241 |
+
)
|
1242 |
+
train_button.click(fn=handle_train_click, inputs=[coin_input, max_workers_slider], outputs=status_text)
|
1243 |
+
refresh_button.click(fn=handle_refresh_click, inputs=None, outputs=[prediction_df, status_text])
|
1244 |
+
|
1245 |
+
# --- Startup Initialization (Unchanged) ---
|
1246 |
+
def initialize_models_on_startup():
|
1247 |
+
# (Keep this function as is)
|
1248 |
+
logging.info("----- Initializing Models (Startup Thread) -----")
|
1249 |
+
default_coins = ['BTC/USDT', 'ETH/USDT', 'SOL/USDT', 'XRP/USDT', 'DOGE/USDT']
|
1250 |
+
try:
|
1251 |
+
initial_status = train_all_models(default_coins, num_workers=2)
|
1252 |
+
logging.info("----- Initial Model Training Complete -----")
|
1253 |
+
logging.info(initial_status)
|
1254 |
+
except Exception as e:
|
1255 |
+
logging.exception("Error during startup initialization.")
|
1256 |
+
|
1257 |
+
# --- Main Execution (Unchanged) ---
|
1258 |
+
if __name__ == "__main__":
|
1259 |
+
logging.info("Starting application...")
|
1260 |
+
# Check if TA-Lib import worked (basic check)
|
1261 |
+
try:
|
1262 |
+
# Try accessing a TA-Lib function
|
1263 |
+
_ = talib.RSI(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]))
|
1264 |
+
logging.info("TA-Lib library seems accessible.")
|
1265 |
+
except NameError:
|
1266 |
+
logging.error("FATAL: TA-Lib library not found or import failed. Please install it correctly.")
|
1267 |
+
sys.exit(1)
|
1268 |
+
except Exception as ta_init_e:
|
1269 |
+
logging.error(f"FATAL: Error testing TA-Lib library: {ta_init_e}. Please check installation.")
|
1270 |
+
sys.exit(1)
|
1271 |
+
|
1272 |
+
init_thread = threading.Thread(target=initialize_models_on_startup, name="StartupTrainThread", daemon=True)
|
1273 |
+
init_thread.start()
|
1274 |
+
|
1275 |
+
logging.info("Launching Gradio Interface...")
|
1276 |
+
try:
|
1277 |
+
demo.launch(server_name="0.0.0.0")
|
1278 |
+
except Exception as e:
|
1279 |
+
logging.exception("Failed to launch Gradio interface.")
|
1280 |
+
finally:
|
1281 |
+
logging.info("Gradio Interface stopped.")
|