ahsanr commited on
Commit
bb0a1f7
·
verified ·
1 Parent(s): d28a77f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1281 -0
app.py ADDED
@@ -0,0 +1,1281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Install TA-Lib (see instructions above) then: pip install TA-Lib
2
+ import ccxt
3
+ import numpy as np
4
+ import pandas as pd
5
+ import time
6
+ from sklearn.neighbors import KNeighborsClassifier
7
+ from scipy.linalg import svd
8
+ import gradio as gr
9
+ import concurrent.futures
10
+ import traceback
11
+ from datetime import datetime, timezone, timedelta
12
+ import logging
13
+ import sys
14
+ import talib # Import TA-Lib
15
+ import threading
16
+
17
+ # --- Setup Logging ---
18
+ logging.basicConfig(
19
+ level=logging.INFO,
20
+ format='%(asctime)s - %(levelname)s - [%(threadName)s:%(funcName)s] - %(message)s',
21
+ stream=sys.stdout
22
+ )
23
+ logging.getLogger().handlers[0].flush = sys.stdout.flush
24
+
25
+ # --- Parameters ---
26
+ L = 10
27
+ LAG = 11
28
+ MINUTES_PER_HOUR = 60
29
+ PREDICTION_WINDOW_HOURS = 2
30
+ TRAINING_WINDOW_HOURS = 12
31
+ TOTAL_WINDOW_HOURS = TRAINING_WINDOW_HOURS + PREDICTION_WINDOW_HOURS
32
+ K = TRAINING_WINDOW_HOURS * MINUTES_PER_HOUR # 720
33
+ WINDOW = TOTAL_WINDOW_HOURS * MINUTES_PER_HOUR # 840
34
+ FEATURES = ['open', 'high', 'low', 'close', 'volume']
35
+ D = 5
36
+ OVERLAP_STEP = 60
37
+ MIN_TRAINING_EXAMPLES = 20
38
+ MAX_COINS_TO_DISPLAY = 10
39
+ USE_SYNTHETIC_DATA_FOR_LOW_VOLUME = False
40
+ NUM_WORKERS_TRAINING = 4
41
+ NUM_WORKERS_PREDICTION = 10
42
+
43
+ # --- TA & Risk Parameters ---
44
+ TA_DATA_POINTS = 200 # Candles needed for TA calculation
45
+ RSI_PERIOD = 14
46
+ MACD_FAST = 12
47
+ MACD_SLOW = 26
48
+ MACD_SIGNAL = 9
49
+ ATR_PERIOD = 14
50
+ CONFIDENCE_THRESHOLD = 0.65 # Min confidence for Rise signal
51
+ TP1_ATR_MULTIPLIER = 1.5
52
+ TP2_ATR_MULTIPLIER = 3.0
53
+ SL_ATR_MULTIPLIER = 1.0
54
+
55
+ # --- CCXT Initialization ---
56
+ try:
57
+ exchange = ccxt.bitget({
58
+ 'enableRateLimit': True,
59
+ 'rateLimit': 1100,
60
+ 'timeout': 45000,
61
+ 'options': {'adjustForTimeDifference': True}
62
+ })
63
+ logging.info(f"Initialized {exchange.id} exchange.")
64
+ except Exception as e:
65
+ logging.exception("FATAL: Could not initialize CCXT exchange.")
66
+ sys.exit()
67
+
68
+ # --- Global Caches and Variables ---
69
+ markets_cache = None
70
+ last_markets_update = None
71
+ data_cache = {}
72
+ trained_models = {}
73
+ last_update_time = datetime.now(timezone.utc)
74
+
75
+ # --- Functions ---
76
+
77
+ def format_datetime(dt, default="N/A"):
78
+ # (Keep this function as is)
79
+ if pd.isna(dt) or dt is None:
80
+ return default
81
+ try:
82
+ if isinstance(dt, (datetime, pd.Timestamp)):
83
+ if dt.tzinfo is None:
84
+ dt = dt.replace(tzinfo=timezone.utc)
85
+ return dt.strftime('%Y-%m-%d %H:%M:%S %Z')
86
+ else:
87
+ return str(dt)
88
+ except Exception:
89
+ return default
90
+
91
+ def get_all_usdt_pairs():
92
+ # (Keep this function as is - no changes needed)
93
+ global markets_cache, last_markets_update
94
+ current_time = time.time()
95
+ cache_duration = 3600 # 1 hour
96
+
97
+ if markets_cache is not None and last_markets_update is not None:
98
+ if current_time - last_markets_update < cache_duration:
99
+ logging.info("Using cached markets list.")
100
+ if isinstance(markets_cache, list) and markets_cache:
101
+ return markets_cache
102
+ else:
103
+ logging.warning("Cached market list was invalid, fetching fresh.")
104
+
105
+
106
+ logging.info("Fetching markets from Bitget...")
107
+ try:
108
+ exchange.load_markets(reload=True)
109
+ all_symbols = list(exchange.markets.keys())
110
+ usdt_pairs = [
111
+ symbol for symbol in all_symbols
112
+ if isinstance(symbol, str)
113
+ and symbol.endswith('/USDT')
114
+ and exchange.markets.get(symbol, {}).get('active', False)
115
+ and exchange.markets.get(symbol, {}).get('spot', False)
116
+ and 'LEVERAGED' not in exchange.markets.get(symbol, {}).get('type', 'spot').upper()
117
+ and not exchange.markets.get(symbol, {}).get('inverse', False)
118
+ ]
119
+ logging.info(f"Found {len(usdt_pairs)} active USDT spot pairs initially.")
120
+ if not usdt_pairs:
121
+ logging.warning("No active USDT spot pairs found.")
122
+ return ['BTC/USDT', 'ETH/USDT', 'SOL/USDT']
123
+
124
+ logging.info(f"Fetching tickers for {len(usdt_pairs)} pairs for volume sorting...")
125
+ volumes = {}
126
+ symbols_to_fetch = usdt_pairs[:]
127
+ fetched_tickers = {}
128
+ try:
129
+ if exchange.has['fetchTickers']:
130
+ batch_size_tickers = 100
131
+ for i in range(0, len(symbols_to_fetch), batch_size_tickers):
132
+ batch_symbols = symbols_to_fetch[i:i+batch_size_tickers]
133
+ logging.info(f"Fetching ticker batch {i//batch_size_tickers + 1}/{ (len(symbols_to_fetch) + batch_size_tickers -1)//batch_size_tickers }...")
134
+ retries = 2
135
+ for attempt in range(retries):
136
+ try:
137
+ batch_tickers = exchange.fetch_tickers(symbols=batch_symbols)
138
+ fetched_tickers.update(batch_tickers)
139
+ time.sleep(exchange.rateLimit / 1000 * 1.5) # Add delay
140
+ break
141
+ except (ccxt.RequestTimeout, ccxt.NetworkError) as e_timeout:
142
+ logging.warning(f"Ticker fetch timeout/network error on attempt {attempt+1}/{retries}: {e_timeout}, retrying after delay...")
143
+ time.sleep(3 * (attempt + 1))
144
+ except ccxt.RateLimitExceeded:
145
+ logging.warning(f"Rate limit exceeded fetching tickers, sleeping...")
146
+ time.sleep(10 * (attempt+1)) # Longer sleep for rate limit
147
+ except Exception as e_ticker:
148
+ logging.error(f"Error fetching ticker batch (attempt {attempt+1}): {e_ticker}")
149
+ if attempt == retries - 1: raise # Rethrow last error
150
+ time.sleep(2 * (attempt + 1))
151
+
152
+ logging.info(f"Fetched {len(fetched_tickers)} tickers using fetchTickers.")
153
+ else:
154
+ raise ccxt.NotSupported("fetchTickers not supported/enabled. Volume sorting requires it.")
155
+
156
+ except Exception as e:
157
+ logging.exception(f"Could not fetch tickers for volume sorting: {e}. Volume sorting unavailable.")
158
+ markets_cache = usdt_pairs[:MAX_COINS_TO_DISPLAY]
159
+ last_markets_update = current_time
160
+ logging.warning(f"Returning top {len(markets_cache)} unsorted pairs due to ticker error.")
161
+ return markets_cache
162
+
163
+ for symbol, ticker in fetched_tickers.items():
164
+ try:
165
+ quote_volume = ticker.get('info', {}).get('quoteVolume') # Prefer quoteVolume if available
166
+ last_price = ticker.get('last')
167
+ base_volume = ticker.get('baseVolume')
168
+
169
+ # Ensure values are convertible to float before calculation
170
+ valid_last = last_price is not None
171
+ valid_base = base_volume is not None
172
+ valid_quote = quote_volume is not None
173
+
174
+ if valid_quote:
175
+ volumes[symbol] = float(quote_volume)
176
+ elif valid_base and valid_last:
177
+ volumes[symbol] = float(base_volume) * float(last_price)
178
+ else:
179
+ volumes[symbol] = 0
180
+ except (TypeError, ValueError, KeyError, AttributeError) as e:
181
+ logging.warning(f"Could not parse volume/price for {symbol} from ticker: {ticker}. Error: {e}")
182
+ volumes[symbol] = 0
183
+
184
+ valid_volume_pairs = {k: v for k, v in volumes.items() if v > 0}
185
+ logging.info(f"Found {len(valid_volume_pairs)} pairs with non-zero volume.")
186
+
187
+ if not valid_volume_pairs:
188
+ logging.warning("No pairs with valid volume found. Returning default list.")
189
+ return ['BTC/USDT', 'ETH/USDT', 'SOL/USDT']
190
+
191
+ sorted_pairs = sorted(valid_volume_pairs.items(), key=lambda item: item[1], reverse=True)
192
+ num_pairs_to_take = min(MAX_COINS_TO_DISPLAY, len(sorted_pairs))
193
+ top_pairs = [pair[0] for pair in sorted_pairs[:num_pairs_to_take]]
194
+ logging.info(f"Selected Top {len(top_pairs)} pairs by volume. Top 5: {[p[0] for p in sorted_pairs[:5]]}")
195
+
196
+ markets_cache = top_pairs
197
+ last_markets_update = current_time
198
+ return top_pairs
199
+
200
+ except ccxt.NetworkError as e:
201
+ logging.error(f"Network error getting USDT pairs: {e}")
202
+ except ccxt.ExchangeError as e:
203
+ logging.error(f"Exchange error getting USDT pairs: {e}")
204
+ except Exception as e:
205
+ logging.exception("General error getting USDT pairs.")
206
+
207
+ logging.warning("Error fetching markets, returning default fallback list.")
208
+ return ['BTC/USDT', 'ETH/USDT', 'SOL/USDT', 'BNB/USDT', 'XRP/USDT']
209
+
210
+ def clean_and_process_ohlcv(ohlcv_list, symbol, expected_candles):
211
+ # (Keep this function as is - no changes needed)
212
+ if not ohlcv_list:
213
+ return None
214
+ try:
215
+ df = pd.DataFrame(ohlcv_list, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
216
+ initial_len = len(df)
217
+ if initial_len == 0: return None
218
+
219
+ df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
220
+ df = df.drop_duplicates(subset=['timestamp'])
221
+ df = df.sort_values('timestamp')
222
+ len_after_dupes = len(df)
223
+
224
+ numeric_cols = ['open', 'high', 'low', 'close', 'volume']
225
+ for col in numeric_cols:
226
+ df[col] = pd.to_numeric(df[col], errors='coerce')
227
+
228
+ # Drop rows with NaN in essential price/volume features needed for TA-Lib
229
+ df = df.dropna(subset=numeric_cols)
230
+ len_after_na = len(df)
231
+
232
+ df.reset_index(drop=True, inplace=True)
233
+
234
+ logging.debug(f"Data cleaning for {symbol}: Initial Fetched={initial_len}, AfterDupes={len_after_dupes}, AfterNA={len_after_na}")
235
+
236
+ if len(df) >= expected_candles:
237
+ final_df = df.iloc[-expected_candles:].copy() # Take the most recent ones
238
+ return final_df
239
+ else:
240
+ return None
241
+
242
+ except Exception as e:
243
+ logging.exception(f"Error processing DataFrame for {symbol}")
244
+ return None
245
+
246
+ def fetch_historical_data(symbol, timeframe='1m', total_candles=WINDOW):
247
+ # (Keep this function as is - no changes needed)
248
+ cache_key = f"{symbol}_{timeframe}_{total_candles}"
249
+ current_time = time.time()
250
+ cache_validity_seconds = 300 # 5 minutes
251
+
252
+ if cache_key in data_cache:
253
+ cache_time, cached_data = data_cache[cache_key]
254
+ if current_time - cache_time < cache_validity_seconds:
255
+ if isinstance(cached_data, pd.DataFrame) and len(cached_data) == total_candles:
256
+ logging.debug(f"Using valid cached data for {symbol} ({len(cached_data)} candles)")
257
+ return cached_data.copy()
258
+ else:
259
+ logging.warning(f"Cache for {symbol} invalid or wrong size ({len(cached_data) if isinstance(cached_data, pd.DataFrame) else 'N/A'} vs {total_candles}), fetching fresh.")
260
+ if cache_key in data_cache: del data_cache[cache_key]
261
+
262
+ if not exchange.has['fetchOHLCV']:
263
+ logging.error(f"Exchange {exchange.id} does not support fetchOHLCV.")
264
+ return None
265
+
266
+ logging.debug(f"Fetching {total_candles} candles for {symbol} (timeframe: {timeframe})")
267
+ final_df = None
268
+ fetch_start_time = time.time()
269
+ duration_ms = exchange.parse_timeframe(timeframe) * 1000
270
+ now_ms = exchange.milliseconds()
271
+
272
+ # --- Strategy 1: Try Single Large Fetch ---
273
+ single_fetch_limit = total_candles + 200 # Buffer
274
+ single_fetch_since = now_ms - single_fetch_limit * duration_ms
275
+ try:
276
+ ohlcv_list = exchange.fetch_ohlcv(symbol, timeframe, limit=single_fetch_limit, since=single_fetch_since)
277
+ if ohlcv_list:
278
+ processed_df = clean_and_process_ohlcv(ohlcv_list, symbol, total_candles)
279
+ if processed_df is not None and len(processed_df) == total_candles:
280
+ final_df = processed_df
281
+ except ccxt.RateLimitExceeded as e:
282
+ logging.warning(f"Rate limit hit during single fetch for {symbol}, falling back: {e}")
283
+ time.sleep(5)
284
+ except (ccxt.RequestTimeout, ccxt.NetworkError) as e:
285
+ logging.warning(f"Timeout/Network error during single fetch for {symbol}, falling back: {e}")
286
+ time.sleep(2)
287
+ except ccxt.ExchangeNotAvailable as e:
288
+ logging.error(f"Exchange not available during fetch for {symbol}: {e}")
289
+ return None
290
+ except ccxt.AuthenticationError as e:
291
+ logging.error(f"Authentication error fetching {symbol}: {e}")
292
+ return None
293
+ except ccxt.ExchangeError as e:
294
+ logging.warning(f"Exchange error during single fetch for {symbol}, falling back: {e}")
295
+ except Exception as e:
296
+ logging.exception(f"Unexpected error during single fetch for {symbol}, falling back.")
297
+
298
+ # --- Strategy 2: Fallback to Iterative Chunking ---
299
+ if final_df is None:
300
+ logging.debug(f"Falling back to iterative chunk fetching for {symbol}.")
301
+ limit_per_call = exchange.safe_integer(exchange.limits.get('fetchOHLCV', {}), 'max', 1000)
302
+ limit_per_call = min(limit_per_call, 1000)
303
+ all_ohlcv_chunks = []
304
+ required_start_time_ms = now_ms - (total_candles + 5) * duration_ms
305
+ current_chunk_end_time_ms = now_ms
306
+ max_chunk_attempts = 15
307
+ attempts = 0
308
+
309
+ while attempts < max_chunk_attempts:
310
+ attempts += 1
311
+ oldest_ts_in_hand = all_ohlcv_chunks[0][0] if all_ohlcv_chunks else current_chunk_end_time_ms
312
+ if oldest_ts_in_hand <= required_start_time_ms:
313
+ logging.debug(f"Chunking: Collected enough historical range for {symbol}.")
314
+ break
315
+
316
+ fetch_limit = limit_per_call
317
+ chunk_fetch_since = oldest_ts_in_hand - fetch_limit * duration_ms
318
+ params = {}
319
+ try:
320
+ ohlcv_chunk = exchange.fetch_ohlcv(symbol, timeframe, since=chunk_fetch_since, limit=fetch_limit, params=params)
321
+ if not ohlcv_chunk:
322
+ logging.debug(f"Chunking: No more data received for {symbol} from API.")
323
+ break
324
+
325
+ new_chunk = [c for c in ohlcv_chunk if c[0] < oldest_ts_in_hand]
326
+ if not new_chunk:
327
+ break
328
+
329
+ new_chunk.sort(key=lambda x: x[0])
330
+ all_ohlcv_chunks = new_chunk + all_ohlcv_chunks
331
+
332
+ if len(new_chunk) < limit_per_call // 20 and attempts > 5:
333
+ logging.warning(f"Chunking: Received very few new candles ({len(new_chunk)}) repeatedly for {symbol}.")
334
+ break
335
+ time.sleep(exchange.rateLimit / 1000 * 1.1)
336
+
337
+ except ccxt.RateLimitExceeded as e:
338
+ logging.warning(f"Rate limit hit during chunking for {symbol}, sleeping 10s: {e}")
339
+ time.sleep(10 * (attempts/3 + 1))
340
+ except (ccxt.NetworkError, ccxt.RequestTimeout) as e:
341
+ logging.error(f"Network/Timeout error during chunking for {symbol}: {e}. Stopping.")
342
+ break
343
+ except ccxt.ExchangeError as e:
344
+ logging.error(f"Exchange error during chunking for {symbol}: {e}. Stopping.")
345
+ break
346
+ except Exception as e:
347
+ logging.exception(f"Generic error during chunking fetch for {symbol}")
348
+ break
349
+
350
+ if attempts >= max_chunk_attempts:
351
+ logging.warning(f"Max chunk fetch attempts reached for {symbol}.")
352
+
353
+ if all_ohlcv_chunks:
354
+ processed_df = clean_and_process_ohlcv(all_ohlcv_chunks, symbol, total_candles)
355
+ if processed_df is not None and len(processed_df) == total_candles:
356
+ final_df = processed_df
357
+ else:
358
+ logging.error(f"No data obtained from chunk fetching for {symbol}.")
359
+
360
+ # --- Final Check and Caching ---
361
+ if final_df is not None and len(final_df) == total_candles:
362
+ expected_cols = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
363
+ if all(col in final_df.columns for col in expected_cols):
364
+ data_cache[cache_key] = (current_time, final_df.copy())
365
+ return final_df
366
+ else:
367
+ logging.error(f"Final DataFrame for {symbol} missing expected columns. Won't cache.")
368
+ return None
369
+ else:
370
+ logging.error(f"Failed to fetch exactly {total_candles} candles for {symbol}. Found: {len(final_df) if final_df is not None else 0}")
371
+ return None
372
+
373
+ # --- Embedding, LLT, Normalize, Training Prep (Largely unchanged) ---
374
+ # Keep create_embedding, llt_transform, normalize_data, prepare_training_data, train_model
375
+ # as they don't depend on the TA library choice.
376
+
377
+ def create_embedding(data, l=L, lag=LAG):
378
+ # (Keep this function as is)
379
+ n = len(data)
380
+ rows = n - (l - 1) * lag
381
+ if rows <= 0:
382
+ logging.debug(f"Cannot create embedding: data length {n} too short for L={l}, Lag={lag}")
383
+ return np.array([])
384
+ A = np.zeros((rows, l))
385
+ try:
386
+ for t in range(rows):
387
+ indices = t + np.arange(l) * lag
388
+ A[t] = data[indices]
389
+ return A
390
+ except IndexError as e:
391
+ logging.error(f"IndexError during embedding: n={n}, l={l}, lag={lag}. Error: {e}")
392
+ return np.array([])
393
+ except Exception as e:
394
+ logging.exception("Error in create_embedding")
395
+ return np.array([])
396
+
397
+ def llt_transform(X_train, y_train, X_test):
398
+ # (Keep this function as is)
399
+ if not isinstance(X_train, np.ndarray) or X_train.ndim != 3 or \
400
+ not isinstance(y_train, np.ndarray) or y_train.ndim != 1 or \
401
+ not isinstance(X_test, np.ndarray) or (X_test.size > 0 and X_test.ndim != 3):
402
+ logging.error(f"LLT input type/shape error.")
403
+ return np.array([]), np.array([])
404
+ if X_train.shape[0] != y_train.shape[0]:
405
+ logging.error(f"LLT input mismatch: len(X_train) != len(y_train)")
406
+ return np.array([]), np.array([])
407
+ if X_train.size == 0 or y_train.size == 0:
408
+ logging.error("LLT requires non-empty training data.")
409
+ return np.array([]), np.array([])
410
+ if X_test.size > 0 and X_test.shape[1:] != X_train.shape[1:]:
411
+ logging.error(f"LLT train/test shape mismatch")
412
+ return np.array([]), np.array([])
413
+
414
+ try:
415
+ num_features = X_train.shape[2]
416
+ if num_features != len(FEATURES):
417
+ logging.error(f"LLT: Feature count mismatch.")
418
+ return np.array([]), np.array([])
419
+
420
+ V = {j: {'0': [], '1': []} for j in range(num_features)}
421
+ laws_computed_count = {j: {'0': 0, '1': 0} for j in range(num_features)}
422
+
423
+ for i in range(len(X_train)):
424
+ label = str(int(y_train[i]))
425
+ if label not in ['0', '1']: continue
426
+ for j in range(num_features):
427
+ feature_data = X_train[i, :, j]
428
+ A = create_embedding(feature_data, l=L, lag=LAG)
429
+ if A.shape[0] < L: continue
430
+ if np.isnan(A).any() or np.isinf(A).any(): continue
431
+ try:
432
+ S = A.T @ A
433
+ if np.isnan(S).any() or np.isinf(S).any(): continue
434
+ U, s, Vt = svd(S, full_matrices=False)
435
+ if Vt.shape[0] < L or Vt.shape[1] != L: continue
436
+ if s[-1] < 1e-9: continue
437
+ v = Vt[-1]
438
+ norm = np.linalg.norm(v)
439
+ if norm < 1e-9: continue
440
+ V[j][label].append(v / norm)
441
+ laws_computed_count[j][label] += 1
442
+ except np.linalg.LinAlgError: pass
443
+ except Exception: pass
444
+
445
+ valid_laws_exist = False
446
+ for j in V:
447
+ for c in ['0', '1']:
448
+ if laws_computed_count[j][c] > 0:
449
+ valid_vecs = [vec for vec in V[j][c] if isinstance(vec, np.ndarray) and vec.shape == (L,)]
450
+ if not valid_vecs:
451
+ V[j][c] = np.zeros((L, 0))
452
+ continue
453
+ try:
454
+ V[j][c] = np.array(valid_vecs).T
455
+ if V[j][c].shape[0] != L:
456
+ V[j][c] = np.zeros((L, 0))
457
+ else:
458
+ valid_laws_exist = True
459
+ except Exception: V[j][c] = np.zeros((L, 0))
460
+ else: V[j][c] = np.zeros((L, 0))
461
+
462
+ if not valid_laws_exist:
463
+ logging.error("LLT ERROR: No valid laws computed.")
464
+ return np.array([]), np.array([])
465
+
466
+ def transform_instance(X_instance):
467
+ transformed_features = []
468
+ if X_instance.ndim != 2 or X_instance.shape[0] != K or X_instance.shape[1] != num_features:
469
+ return np.zeros(num_features * 2 * D)
470
+ for j in range(num_features):
471
+ feature_data = X_instance[:, j]
472
+ A = create_embedding(feature_data, l=L, lag=LAG)
473
+ if A.shape[0] < L:
474
+ transformed_features.extend([0.0] * (2 * D))
475
+ continue
476
+ if np.isnan(A).any() or np.isinf(A).any():
477
+ transformed_features.extend([0.0] * (2 * D))
478
+ continue
479
+ try:
480
+ S = A.T @ A
481
+ if np.isnan(S).any() or np.isinf(S).any():
482
+ transformed_features.extend([0.0] * (2 * D))
483
+ continue
484
+ for c in ['0', '1']:
485
+ if V[j][c].shape[1] == 0:
486
+ transformed_features.extend([0.0] * D)
487
+ continue
488
+ S_V = S @ V[j][c]
489
+ if S_V.size == 0 or np.isnan(S_V).any() or np.isinf(S_V).any():
490
+ transformed_features.extend([0.0] * D)
491
+ continue
492
+ variances = np.var(S_V, axis=0)
493
+ if variances.size == 0:
494
+ transformed_features.extend([0.0] * D)
495
+ continue
496
+ variances = np.nan_to_num(variances, nan=np.finfo(variances.dtype).max, posinf=np.finfo(variances.dtype).max, neginf=np.finfo(variances.dtype).max)
497
+ num_vars_available = variances.size
498
+ num_vars_to_select = min(D, num_vars_available)
499
+ smallest_indices = np.argpartition(variances, num_vars_to_select -1)[:num_vars_to_select]
500
+ smallest_vars = np.sort(variances[smallest_indices])
501
+ padded_vars = np.pad(smallest_vars, (0, D - num_vars_to_select), 'constant', constant_values=0.0)
502
+ if np.isnan(padded_vars).any() or np.isinf(padded_vars).any():
503
+ padded_vars = np.nan_to_num(padded_vars, nan=0.0, posinf=0.0, neginf=0.0)
504
+ transformed_features.extend(padded_vars)
505
+ except Exception:
506
+ current_len = len(transformed_features)
507
+ expected_len_after_feature = (j + 1) * 2 * D
508
+ num_missing = expected_len_after_feature - current_len
509
+ if num_missing > 0: transformed_features.extend([0.0] * num_missing)
510
+ transformed_features = transformed_features[:expected_len_after_feature]
511
+
512
+ correct_len = num_features * 2 * D
513
+ if len(transformed_features) != correct_len:
514
+ if len(transformed_features) < correct_len: transformed_features.extend([0.0] * (correct_len - len(transformed_features)))
515
+ else: transformed_features = transformed_features[:correct_len]
516
+ return np.array(transformed_features)
517
+
518
+ X_train_t = np.array([transform_instance(X) for X in X_train])
519
+ X_test_t = np.array([])
520
+ if X_test.size > 0: X_test_t = np.array([transform_instance(X) for X in X_test])
521
+
522
+ expected_dim = num_features * 2 * D
523
+ if X_train_t.shape[0] != len(X_train) or (X_train_t.size > 0 and X_train_t.shape[1] != expected_dim):
524
+ logging.error(f"LLT Train transform resulted in unexpected shape.")
525
+ return np.array([]), np.array([])
526
+ if X_test.size > 0 and (X_test_t.shape[0] != len(X_test) or (X_test_t.size > 0 and X_test_t.shape[1] != expected_dim)):
527
+ logging.error(f"LLT Test transform resulted in unexpected shape.")
528
+ return X_train_t, np.array([])
529
+
530
+ return X_train_t, X_test_t
531
+ except Exception as e:
532
+ logging.exception("Error in llt_transform function")
533
+ return np.array([]), np.array([])
534
+
535
+ def normalize_data(df):
536
+ # (Keep this function as is)
537
+ normalized_df = df.copy()
538
+ if not isinstance(df, pd.DataFrame):
539
+ logging.error("Normalize_data received non-DataFrame input.")
540
+ return None
541
+ for feature in FEATURES:
542
+ if feature == 'timestamp': continue
543
+ if feature not in df.columns:
544
+ normalized_df[feature] = 0.0
545
+ continue
546
+ if pd.api.types.is_numeric_dtype(df[feature]):
547
+ mean = df[feature].mean()
548
+ std = df[feature].std()
549
+ if std is not None and not pd.isna(std) and std > 1e-9:
550
+ normalized_df[feature] = (df[feature] - mean) / std
551
+ else:
552
+ normalized_df[feature] = 0.0
553
+ if normalized_df[feature].isnull().any():
554
+ normalized_df[feature] = normalized_df[feature].fillna(0.0)
555
+ else:
556
+ normalized_df[feature] = 0.0
557
+ return normalized_df
558
+
559
+ def generate_synthetic_data(symbol, total_candles=WINDOW):
560
+ # (Keep this function as is)
561
+ logging.info(f"Generating synthetic data for {symbol} ({total_candles} candles)")
562
+ np.random.seed(int(time.time() * 1000) % (2**32 - 1))
563
+ end_time = pd.Timestamp.now(tz='UTC')
564
+ timestamps = pd.date_range(end=end_time, periods=total_candles, freq='T')
565
+ volatility = np.random.uniform(0.005, 0.03)
566
+ base_price = np.random.uniform(1, 5000)
567
+ prices = [base_price]
568
+ for _ in range(1, total_candles):
569
+ change = np.random.normal(0, volatility / np.sqrt(1440))
570
+ prices.append(prices[-1] * (1 + change))
571
+ prices = np.maximum(0.01, prices)
572
+ close_prices = np.array(prices)
573
+ open_prices = close_prices * (1 + np.random.normal(0, volatility / np.sqrt(1440) / 2, total_candles))
574
+ high_prices = np.maximum(close_prices, open_prices) * (1 + np.random.uniform(0, volatility / np.sqrt(1440), total_candles))
575
+ low_prices = np.minimum(close_prices, open_prices) * (1 - np.random.uniform(0, volatility / np.sqrt(1440), total_candles))
576
+ high_prices = np.maximum.reduce([high_prices, open_prices, close_prices])
577
+ low_prices = np.minimum.reduce([low_prices, open_prices, close_prices])
578
+ volumes = np.random.poisson(base_price * np.random.uniform(1, 10)) * (1 + np.abs(np.diff(close_prices, prepend=close_prices[0])) / close_prices * 5)
579
+ volumes = np.maximum(1, volumes)
580
+ df = pd.DataFrame({
581
+ 'timestamp': timestamps, 'open': open_prices, 'high': high_prices,
582
+ 'low': low_prices, 'close': close_prices, 'volume': volumes
583
+ })
584
+ for col in FEATURES: df[col] = pd.to_numeric(df[col])
585
+ df.reset_index(drop=True, inplace=True)
586
+ return df
587
+
588
+ def prepare_training_data(symbol, total_candles_to_fetch=WINDOW + OVERLAP_STEP * 20):
589
+ # (Keep this function as is)
590
+ logging.info(f"Preparing training data for {symbol}...")
591
+ try:
592
+ required_base_candles = WINDOW
593
+ estimated_candles_needed = required_base_candles + (MIN_TRAINING_EXAMPLES * 2) * OVERLAP_STEP + 500
594
+ fetch_candle_count = max(WINDOW + 500, estimated_candles_needed)
595
+
596
+ logging.info(f"Fetching {fetch_candle_count} candles for {symbol} training prep...")
597
+ df = fetch_historical_data(symbol, timeframe='1m', total_candles=fetch_candle_count)
598
+
599
+ if df is None or len(df) < WINDOW:
600
+ logging.error(f"Insufficient data fetched for {symbol} ({len(df) if df is not None else 0} < {WINDOW}).")
601
+ if USE_SYNTHETIC_DATA_FOR_LOW_VOLUME:
602
+ logging.warning(f"Attempting synthetic data generation for {symbol}.")
603
+ df = generate_synthetic_data(symbol, total_candles=WINDOW + OVERLAP_STEP * 10)
604
+ if df is None or len(df) < WINDOW:
605
+ logging.error(f"Synthetic data generation failed or insufficient for {symbol}.")
606
+ return None, None
607
+ else: logging.info(f"Using synthetic data ({len(df)} points) for {symbol}.")
608
+ else: return None, None
609
+
610
+ df_normalized = normalize_data(df)
611
+ if df_normalized is None:
612
+ logging.error(f"Normalization failed for {symbol}.")
613
+ return None, None
614
+ if df_normalized[FEATURES].isnull().any().any():
615
+ logging.warning(f"NaN values found after normalization for {symbol}. Filling with 0.")
616
+ df_normalized = df_normalized.fillna(0.0)
617
+
618
+ X, y = [], []
619
+ end_index = len(df)
620
+ start_index = WINDOW
621
+ num_windows_created = 0
622
+
623
+ for i in range(end_index, start_index - 1, -OVERLAP_STEP):
624
+ window_end_idx = i
625
+ window_start_idx = i - WINDOW
626
+ if window_start_idx < 0: continue
627
+
628
+ window_orig = df.iloc[window_start_idx:window_end_idx]
629
+ window_norm = df_normalized.iloc[window_start_idx:window_end_idx]
630
+
631
+ if len(window_orig) != WINDOW or len(window_norm) != WINDOW: continue
632
+
633
+ input_data_norm = window_norm.iloc[:K][FEATURES].values
634
+ if input_data_norm.shape[0] != K or input_data_norm.shape[1] != len(FEATURES): continue
635
+ if np.isnan(input_data_norm).any(): continue
636
+
637
+ start_price_iloc_idx = K - 1
638
+ end_price_iloc_idx = WINDOW - 1
639
+ start_price = window_orig['close'].iloc[start_price_iloc_idx]
640
+ end_price = window_orig['close'].iloc[end_price_iloc_idx]
641
+
642
+ if pd.isna(start_price) or pd.isna(end_price) or start_price <= 0: continue
643
+
644
+ X.append(input_data_norm)
645
+ y.append(1 if end_price > start_price else 0)
646
+ num_windows_created += 1
647
+
648
+ if not X:
649
+ logging.error(f"No valid windows created for {symbol}.")
650
+ return None, None
651
+
652
+ X = np.array(X)
653
+ y = np.array(y)
654
+ unique_classes, class_counts = np.unique(y, return_counts=True)
655
+ class_dist_str = ", ".join([f"Class {cls}: {count}" for cls, count in zip(unique_classes, class_counts)])
656
+ logging.info(f"Class distribution BEFORE balancing for {symbol}: {class_dist_str}")
657
+
658
+ if len(unique_classes) < 2:
659
+ logging.error(f"ONLY ONE CLASS ({unique_classes[0]}) present for {symbol}.")
660
+ return None, None
661
+
662
+ min_class_count = min(class_counts)
663
+ if min_class_count * 2 < MIN_TRAINING_EXAMPLES:
664
+ logging.error(f"Minority class count ({min_class_count}) too low for {symbol}.")
665
+ return None, None
666
+
667
+ samples_per_class = min_class_count
668
+ balanced_indices = []
669
+ for class_val in unique_classes:
670
+ class_indices = np.where(y == class_val)[0]
671
+ num_to_choose = min(samples_per_class, len(class_indices))
672
+ chosen_indices = np.random.choice(class_indices, size=num_to_choose, replace=False)
673
+ balanced_indices.extend(chosen_indices)
674
+
675
+ np.random.shuffle(balanced_indices)
676
+ X_balanced = X[balanced_indices]
677
+ y_balanced = y[balanced_indices]
678
+ final_unique, final_counts = np.unique(y_balanced, return_counts=True)
679
+ logging.info(f"Balanced dataset for {symbol}: {len(X_balanced)} instances. Final counts: {dict(zip(final_unique, final_counts))}")
680
+
681
+ if len(X_balanced) < MIN_TRAINING_EXAMPLES:
682
+ logging.error(f"Insufficient data ({len(X_balanced)}) for {symbol} AFTER balancing.")
683
+ return None, None
684
+ if X_balanced.ndim != 3 or X_balanced.shape[0] == 0 or X_balanced.shape[1] != K or X_balanced.shape[2] != len(FEATURES):
685
+ logging.error(f"Final balanced data has unexpected shape {X_balanced.shape} for {symbol}.")
686
+ return None, None
687
+
688
+ return X_balanced, y_balanced
689
+ except Exception as e:
690
+ logging.exception(f"Error preparing training data for {symbol}")
691
+ return None, None
692
+
693
+ def train_model(symbol):
694
+ # (Keep this function as is)
695
+ logging.info(f"--- Attempting to train model for {symbol} ---")
696
+ np.random.seed(int(time.time()) % (2**32 - 1))
697
+
698
+ X, y = prepare_training_data(symbol)
699
+ if X is None or y is None:
700
+ logging.error(f"Failed to prepare training data for {symbol}. Training aborted.")
701
+ return None, None, None
702
+
703
+ try:
704
+ accuracy = -1.0
705
+ if len(X) < MIN_TRAINING_EXAMPLES + 2:
706
+ logging.warning(f"Dataset for {symbol} too small ({len(X)}). Training on all data.")
707
+ X_train, y_train = X, y
708
+ X_val, y_val = np.array([]), np.array([])
709
+ else:
710
+ indices = np.random.permutation(len(X))
711
+ val_size = max(1, int(len(X) * 0.2))
712
+ split_idx = len(X) - val_size
713
+ train_indices, val_indices = indices[:split_idx], indices[split_idx:]
714
+ if len(train_indices) == 0 or len(val_indices) == 0:
715
+ logging.error(f"Train/Val split resulted in zero samples. Training on all data.")
716
+ X_train, y_train = X, y
717
+ X_val, y_val = np.array([]), np.array([])
718
+ else:
719
+ X_train, X_val = X[train_indices], X[val_indices]
720
+ y_train, y_val = y[train_indices], y[val_indices]
721
+ if len(np.unique(y_train)) < 2:
722
+ logging.error(f"Only one class in TRAINING set after split for {symbol}. Aborting.")
723
+ return None, None, None
724
+ if len(np.unique(y_val)) < 2:
725
+ logging.warning(f"Only one class in VALIDATION set after split for {symbol}.")
726
+
727
+ if X_val.size == 0: X_val_shaped = np.empty((0, K, len(FEATURES)))
728
+ else: X_val_shaped = X_val
729
+
730
+ X_train_t, X_val_t = llt_transform(X_train, y_train, X_val_shaped)
731
+
732
+ if X_train_t.size == 0:
733
+ logging.error(f"LLT training transformation failed for {symbol}. Training aborted.")
734
+ return None, None, None
735
+ if X_val.size > 0 and X_val_t.size == 0:
736
+ logging.warning(f"LLT validation transformation failed for {symbol}.")
737
+ accuracy = -1.0
738
+ if np.isnan(X_train_t).any() or np.isinf(X_train_t).any():
739
+ logging.error(f"NaN/Inf in LLT transformed TRAINING data for {symbol}. Training aborted.")
740
+ return None, None, None
741
+ if X_val_t.size > 0 and (np.isnan(X_val_t).any() or np.isinf(X_val_t).any()):
742
+ logging.warning(f"NaN/Inf in LLT transformed VALIDATION data for {symbol}.")
743
+ accuracy = -1.0
744
+
745
+ n_neighbors = min(5, len(y_train) - 1) if len(y_train) > 1 else 1
746
+ n_neighbors = max(1, n_neighbors)
747
+ if n_neighbors > 1 and n_neighbors % 2 == 0: n_neighbors -= 1
748
+
749
+ model = KNeighborsClassifier(n_neighbors=n_neighbors, weights='distance')
750
+ model.fit(X_train_t, y_train)
751
+
752
+ if accuracy != -1.0 and X_val_t.size > 0:
753
+ try:
754
+ accuracy = model.score(X_val_t, y_val)
755
+ logging.info(f"Model for {symbol} trained. Validation Accuracy: {accuracy:.3f}")
756
+ except Exception as eval_e:
757
+ logging.exception(f"Error during KNN validation scoring for {symbol}: {eval_e}")
758
+ accuracy = -1.0
759
+ elif accuracy == -1.0:
760
+ logging.info(f"Model for {symbol} trained. Validation skipped or failed.")
761
+ else:
762
+ logging.info(f"Model for {symbol} trained. No validation data.")
763
+ accuracy = -1.0
764
+
765
+ return model, X_train, y_train
766
+ except Exception as e:
767
+ logging.exception(f"Error during model training pipeline for {symbol}")
768
+ return None, None, None
769
+
770
+ def predict_real_time(symbol, model_data):
771
+ # (Keep this function as is)
772
+ if model_data is None: return "Model N/A", 0.0
773
+ model, X_train_orig_for_llt, y_train_orig_for_llt = model_data
774
+
775
+ if model is None or X_train_orig_for_llt is None or y_train_orig_for_llt is None:
776
+ logging.error(f"Invalid model data tuple for prediction on {symbol}")
777
+ return "Model Error", 0.0
778
+ if X_train_orig_for_llt.size == 0 or y_train_orig_for_llt.size == 0:
779
+ logging.error(f"Training data for LLT laws is empty for {symbol}")
780
+ return "LLT Data Error", 0.0
781
+
782
+ try:
783
+ df = fetch_historical_data(symbol, timeframe='1m', total_candles=K + 60)
784
+ if df is None or len(df) < K:
785
+ return "Data Error", 0.0
786
+
787
+ df_recent = df.iloc[-K:]
788
+ if len(df_recent) != K:
789
+ return "Data Error", 0.0
790
+
791
+ df_recent_normalized = normalize_data(df_recent)
792
+ if df_recent_normalized is None: return "Norm Error", 0.0
793
+ if df_recent_normalized[FEATURES].isnull().any().any():
794
+ df_recent_normalized = df_recent_normalized.fillna(0.0)
795
+
796
+ X_predict_input = np.array([df_recent_normalized[FEATURES].values])
797
+ _, X_predict_transformed = llt_transform(X_train_orig_for_llt, y_train_orig_for_llt, X_predict_input)
798
+
799
+ if X_predict_transformed.size == 0 or X_predict_transformed.shape[0] != 1:
800
+ return "Transform Error", 0.0
801
+ if np.isnan(X_predict_transformed).any() or np.isinf(X_predict_transformed).any():
802
+ X_predict_transformed = np.nan_to_num(X_predict_transformed, nan=0.0, posinf=0.0, neginf=0.0)
803
+
804
+ try:
805
+ probabilities = model.predict_proba(X_predict_transformed)
806
+ if probabilities.shape[0] != 1 or probabilities.shape[1] != 2:
807
+ return "Predict Error", 0.0
808
+
809
+ prob_class_1 = probabilities[0, 1]
810
+ prediction_label = "Rise" if prob_class_1 >= 0.5 else "Fall"
811
+ confidence = prob_class_1 if prediction_label == "Rise" else probabilities[0, 0]
812
+ return prediction_label, confidence
813
+
814
+ except Exception as knn_e:
815
+ logging.exception(f"Error during KNN prediction probability for {symbol}")
816
+ return "Predict Error", 0.0
817
+
818
+ except Exception as e:
819
+ logging.exception(f"Error in predict_real_time for {symbol}")
820
+ return "Error", 0.0
821
+
822
+ # --- TA Calculation Function (Using TA-Lib) ---
823
+ def calculate_ta_indicators(df_ta):
824
+ """
825
+ Calculates TA indicators (RSI, MACD, VWAP, ATR) using TA-Lib.
826
+ Requires df_ta to have 'open', 'high', 'low', 'close', 'volume' columns.
827
+ """
828
+ indicators = {'RSI': np.nan, 'MACD': np.nan, 'MACD_Signal': np.nan, 'MACD_Hist': np.nan, 'VWAP': np.nan, 'ATR': np.nan}
829
+ required_cols = ['open', 'high', 'low', 'close', 'volume']
830
+ min_len_needed = max(RSI_PERIOD, MACD_SLOW, ATR_PERIOD) + 1 # TA-Lib often needs P+1
831
+
832
+ if df_ta is None or len(df_ta) < min_len_needed:
833
+ logging.warning(f"Insufficient data ({len(df_ta) if df_ta is not None else 0} < {min_len_needed}) for TA-Lib calculations.")
834
+ return indicators
835
+
836
+ # Ensure columns exist
837
+ if not all(col in df_ta.columns for col in required_cols):
838
+ logging.error(f"Missing required columns for TA-Lib: Have {df_ta.columns}, Need {required_cols}")
839
+ return indicators
840
+
841
+ # --- Prepare data for TA-Lib (NumPy arrays, handle NaNs) ---
842
+ df_ta = df_ta[required_cols].copy() # Work on a copy with only needed columns
843
+
844
+ # Check for NaNs BEFORE converting to numpy, TA-Lib generally dislikes them
845
+ if df_ta.isnull().values.any():
846
+ nan_count = df_ta.isnull().sum().sum()
847
+ logging.warning(f"Found {nan_count} NaN(s) in TA input data. Applying ffill()...")
848
+ df_ta.ffill(inplace=True) # Forward fill NaNs
849
+ # Check again after ffill - if NaNs remain (e.g., at the start), need more robust handling
850
+ if df_ta.isnull().values.any():
851
+ logging.error(f"NaNs still present after ffill. Cannot proceed with TA-Lib.")
852
+ return indicators # Return NaNs
853
+
854
+ try:
855
+ # Convert to NumPy arrays of type float
856
+ open_p = df_ta['open'].values.astype(float)
857
+ high_p = df_ta['high'].values.astype(float)
858
+ low_p = df_ta['low'].values.astype(float)
859
+ close_p = df_ta['close'].values.astype(float)
860
+ volume_p = df_ta['volume'].values.astype(float)
861
+
862
+ # --- Calculate Indicators using TA-Lib ---
863
+ # RSI
864
+ rsi_values = talib.RSI(close_p, timeperiod=RSI_PERIOD)
865
+ indicators['RSI'] = rsi_values[-1] if len(rsi_values) > 0 else np.nan
866
+
867
+ # MACD
868
+ macd_line, signal_line, hist = talib.MACD(close_p, fastperiod=MACD_FAST, slowperiod=MACD_SLOW, signalperiod=MACD_SIGNAL)
869
+ indicators['MACD'] = macd_line[-1] if len(macd_line) > 0 else np.nan
870
+ indicators['MACD_Signal'] = signal_line[-1] if len(signal_line) > 0 else np.nan
871
+ indicators['MACD_Hist'] = hist[-1] if len(hist) > 0 else np.nan
872
+
873
+ # ATR
874
+ atr_values = talib.ATR(high_p, low_p, close_p, timeperiod=ATR_PERIOD)
875
+ indicators['ATR'] = atr_values[-1] if len(atr_values) > 0 else np.nan
876
+
877
+ # VWAP (Manual Calculation - TA-Lib doesn't have it built-in)
878
+ typical_price = (high_p + low_p + close_p) / 3.0
879
+ tp_vol = typical_price * volume_p
880
+ cumulative_volume = np.cumsum(volume_p)
881
+ # Avoid division by zero if volume is zero for initial periods
882
+ if cumulative_volume[-1] > 1e-12: # Check if there's significant volume
883
+ vwap_values = np.cumsum(tp_vol) / np.maximum(cumulative_volume, 1e-12) # Avoid div by zero strictly
884
+ indicators['VWAP'] = vwap_values[-1]
885
+ else:
886
+ indicators['VWAP'] = np.nan # VWAP undefined if no volume
887
+
888
+ # Final check for NaNs in results (TA-Lib might return NaN for initial periods)
889
+ for key, value in indicators.items():
890
+ if pd.isna(value):
891
+ indicators[key] = np.nan # Ensure consistent NaN representation
892
+
893
+ # logging.debug(f"TA-Lib Indicators calculated: {indicators}")
894
+ return indicators
895
+
896
+ except Exception as ta_e:
897
+ logging.exception(f"Error calculating TA indicators using TA-Lib: {ta_e}")
898
+ return {k: np.nan for k in indicators} # Return NaNs on error
899
+
900
+ # --- Trade Level Calculation (Unchanged) ---
901
+ def calculate_trade_levels(prediction, confidence, current_price, atr):
902
+ # (Keep this function as is - no changes needed)
903
+ levels = {'Entry': np.nan, 'TP1': np.nan, 'TP2': np.nan, 'SL': np.nan}
904
+ if pd.isna(current_price) or current_price <= 0 or pd.isna(atr) or atr <= 0:
905
+ return levels
906
+ if prediction == "Rise" and confidence >= CONFIDENCE_THRESHOLD:
907
+ entry_price = current_price
908
+ levels['Entry'] = entry_price
909
+ levels['TP1'] = entry_price + TP1_ATR_MULTIPLIER * atr
910
+ levels['TP2'] = entry_price + TP2_ATR_MULTIPLIER * atr
911
+ levels['SL'] = entry_price - SL_ATR_MULTIPLIER * atr
912
+ levels['SL'] = max(0.01, levels['SL'])
913
+ # Add Fall logic here if needed
914
+ return levels
915
+
916
+ # --- Concurrency Wrappers (Unchanged) ---
917
+ def train_model_task(coin):
918
+ # (Keep this function as is)
919
+ try:
920
+ result = train_model(coin)
921
+ if result != (None, None, None):
922
+ model, X_train_orig, y_train_orig = result
923
+ return coin, (model, X_train_orig, y_train_orig)
924
+ else:
925
+ return coin, None
926
+ except Exception as e:
927
+ logging.exception(f"Unhandled exception in train_model_task for {coin}")
928
+ return coin, None
929
+
930
+ def train_all_models(coin_list=None, num_workers=NUM_WORKERS_TRAINING):
931
+ # (Keep this function as is)
932
+ global trained_models
933
+ start_time = time.time()
934
+ if coin_list is None or not coin_list:
935
+ logging.info("No coin list provided, fetching top coins by volume...")
936
+ try:
937
+ coin_list = get_all_usdt_pairs()
938
+ if not coin_list:
939
+ msg = "Failed to fetch coin list even with fallback. Training aborted."
940
+ logging.error(msg)
941
+ return msg
942
+ except Exception as e:
943
+ msg = f"Error fetching coin list: {e}. Training aborted."
944
+ logging.exception(msg)
945
+ return msg
946
+
947
+ logging.info(f"Starting training for {len(coin_list)} coins using {num_workers} workers...")
948
+ results_log = []
949
+ successful_trains = 0
950
+ failed_trains = 0
951
+ new_models = {}
952
+
953
+ with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers, thread_name_prefix='TrainWorker') as executor:
954
+ future_to_coin = {executor.submit(train_model_task, coin): coin for coin in coin_list}
955
+ processed_count = 0
956
+ total_coins = len(coin_list)
957
+ for future in concurrent.futures.as_completed(future_to_coin):
958
+ processed_count += 1
959
+ coin = future_to_coin[future]
960
+ try:
961
+ returned_coin, model_data = future.result()
962
+ if returned_coin == coin and model_data is not None:
963
+ new_models[returned_coin] = model_data
964
+ results_log.append(f"✅ {returned_coin}: Model trained successfully.")
965
+ successful_trains += 1
966
+ else:
967
+ results_log.append(f"❌ {coin}: Model training failed (check logs).")
968
+ failed_trains += 1
969
+ except Exception as e:
970
+ results_log.append(f"❌ {coin}: Training task generated exception: {e}")
971
+ failed_trains += 1
972
+ logging.exception(f"Exception from training future for {coin}")
973
+ if processed_count % 10 == 0 or processed_count == total_coins:
974
+ logging.info(f"Training progress: {processed_count}/{total_coins} coins processed.")
975
+ logging.getLogger().handlers[0].flush()
976
+
977
+ trained_models.update(new_models)
978
+ logging.info(f"Updated global models dictionary. Total models now: {len(trained_models)}")
979
+
980
+ end_time = time.time()
981
+ duration = end_time - start_time
982
+ completion_message = (
983
+ f"Training run completed in {duration:.2f} seconds.\n"
984
+ f"Successfully trained: {successful_trains}\n"
985
+ f"Failed to train: {failed_trains}\n"
986
+ f"Total models available now: {len(trained_models)}"
987
+ )
988
+ logging.info(completion_message)
989
+ return completion_message + "\n\n" + "\n".join(results_log[-20:])
990
+
991
+ # --- Update Predictions Table (Mostly Unchanged, uses new TA function) ---
992
+ def update_predictions_table():
993
+ # (This function structure remains the same, it just calls the new calculate_ta_indicators)
994
+ global last_update_time
995
+ logging.info("--- Updating Predictions Table ---")
996
+ start_time = time.time()
997
+ predictions_data = {}
998
+ current_models = trained_models.copy()
999
+
1000
+ if not current_models:
1001
+ msg = "No models available. Please train first."
1002
+ logging.warning(msg)
1003
+ cols = ['Rank', 'Coin', 'Prediction', 'Confidence', 'Price', 'Volume (Quote)', 'Entry', 'Entry Time', 'Exit Time', 'TP1', 'TP2', 'SL', 'RSI', 'MACD Hist', 'VWAP', 'ATR']
1004
+ return pd.DataFrame([], columns=cols), msg
1005
+
1006
+ symbols_with_models = list(current_models.keys())
1007
+ logging.info(f"Step 1: Generating predictions for {len(symbols_with_models)} models...")
1008
+ # --- Stage 1: Get Predictions Concurrently ---
1009
+ with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_WORKERS_PREDICTION, thread_name_prefix='PredictWorker') as executor:
1010
+ future_to_coin_pred = {executor.submit(predict_real_time, coin, model_data): coin for coin, model_data in current_models.items()}
1011
+ pred_success = 0
1012
+ pred_fail = 0
1013
+ for future in concurrent.futures.as_completed(future_to_coin_pred):
1014
+ coin = future_to_coin_pred[future]
1015
+ try:
1016
+ pred, conf = future.result()
1017
+ if pred not in ["Model N/A", "Model Error", "Data Error", "Norm Error", "LLT Data Error", "Transform Error", "Predict Error", "Error"]:
1018
+ predictions_data[coin] = {'prediction': pred, 'confidence': float(conf)}
1019
+ pred_success += 1
1020
+ else:
1021
+ predictions_data[coin] = {'prediction': pred, 'confidence': 0.0}
1022
+ pred_fail += 1
1023
+ except Exception as e:
1024
+ logging.exception(f"Error getting prediction result for {coin}")
1025
+ predictions_data[coin] = {'prediction': "Future Error", 'confidence': 0.0}
1026
+ pred_fail +=1
1027
+ logging.info(f"Step 1 Complete: Predictions generated ({pred_success} success, {pred_fail} fail).")
1028
+
1029
+ # --- Stage 2: Fetch Current Tickers & TA Data Concurrently ---
1030
+ symbols_to_fetch_data = list(predictions_data.keys())
1031
+ if not symbols_to_fetch_data:
1032
+ logging.warning("No symbols with predictions to fetch data for.")
1033
+ cols = ['Rank', 'Coin', 'Prediction', 'Confidence', 'Price', 'Volume (Quote)', 'Entry', 'Entry Time', 'Exit Time', 'TP1', 'TP2', 'SL', 'RSI', 'MACD Hist', 'VWAP', 'ATR']
1034
+ return pd.DataFrame([], columns=cols), "No symbols processed."
1035
+
1036
+ logging.info(f"Step 2: Fetching Tickers and {TA_DATA_POINTS} OHLCV candles for {len(symbols_to_fetch_data)} symbols...")
1037
+ tickers_data = {}
1038
+ ohlcv_data = {}
1039
+ try: # Fetch Tickers
1040
+ batch_size_tickers = 100
1041
+ fetched_tickers_batch = {}
1042
+ for i in range(0, len(symbols_to_fetch_data), batch_size_tickers):
1043
+ batch_symbols = symbols_to_fetch_data[i:i+batch_size_tickers]
1044
+ try:
1045
+ batch_tickers = exchange.fetch_tickers(symbols=batch_symbols)
1046
+ fetched_tickers_batch.update(batch_tickers)
1047
+ time.sleep(exchange.rateLimit / 1000 * 0.5)
1048
+ except Exception as e:
1049
+ logging.error(f"Failed to fetch ticker batch starting with {batch_symbols[0]}: {e}")
1050
+ tickers_data = fetched_tickers_batch
1051
+ logging.info(f"Fetched {len(tickers_data)} tickers.")
1052
+ except Exception as e:
1053
+ logging.exception(f"Error fetching tickers in prediction update: {e}")
1054
+
1055
+ # Fetch OHLCV for TA
1056
+ with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_WORKERS_PREDICTION, thread_name_prefix='TADataWorker') as executor:
1057
+ future_to_coin_ohlcv = {executor.submit(fetch_historical_data, coin, '1m', TA_DATA_POINTS): coin for coin in symbols_to_fetch_data}
1058
+ for future in concurrent.futures.as_completed(future_to_coin_ohlcv):
1059
+ coin = future_to_coin_ohlcv[future]
1060
+ try:
1061
+ df_ta = future.result()
1062
+ if df_ta is not None and len(df_ta) == TA_DATA_POINTS:
1063
+ # Ensure standard column names expected by calculate_ta_indicators
1064
+ df_ta.columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
1065
+ ohlcv_data[coin] = df_ta
1066
+ except Exception as e:
1067
+ logging.exception(f"Error fetching TA OHLCV data for {coin}")
1068
+ logging.info(f"Step 2 Complete: Fetched TA data for {len(ohlcv_data)} symbols.")
1069
+
1070
+ # --- Stage 3: Calculate TA & Trade Levels ---
1071
+ logging.info(f"Step 3: Calculating TA (using TA-Lib) and Trade Levels...")
1072
+ final_results = []
1073
+ processing_time = datetime.now(timezone.utc)
1074
+
1075
+ for symbol in symbols_to_fetch_data:
1076
+ pred_info = predictions_data.get(symbol, {'prediction': 'Missing Pred', 'confidence': 0.0})
1077
+ ticker = tickers_data.get(symbol)
1078
+ df_ta = ohlcv_data.get(symbol) # This df should have standard columns now
1079
+
1080
+ current_price, quote_volume = np.nan, np.nan
1081
+ ta_indicators = {k: np.nan for k in ['RSI', 'MACD', 'MACD_Signal', 'MACD_Hist', 'VWAP', 'ATR']}
1082
+ trade_levels = {k: np.nan for k in ['Entry', 'TP1', 'TP2', 'SL']}
1083
+ entry_time, exit_time = pd.NaT, pd.NaT
1084
+
1085
+ if ticker and isinstance(ticker, dict):
1086
+ current_price = ticker.get('last', np.nan)
1087
+ quote_volume = ticker.get('info', {}).get('quoteVolume')
1088
+ if quote_volume is None:
1089
+ base_volume = ticker.get('baseVolume')
1090
+ if base_volume is not None and current_price is not None:
1091
+ try: quote_volume = float(base_volume) * float(current_price)
1092
+ except (ValueError, TypeError): quote_volume = np.nan
1093
+ try: current_price = float(current_price) if current_price is not None else np.nan
1094
+ except (ValueError, TypeError): current_price = np.nan
1095
+ try: quote_volume = float(quote_volume) if quote_volume is not None else np.nan
1096
+ except (ValueError, TypeError): quote_volume = np.nan
1097
+
1098
+ # Calculate TA using the new function
1099
+ if df_ta is not None:
1100
+ ta_indicators = calculate_ta_indicators(df_ta) # Calls the TA-Lib version
1101
+
1102
+ if pred_info['prediction'] in ["Rise", "Fall"] and not pd.isna(current_price) and not pd.isna(ta_indicators['ATR']):
1103
+ trade_levels = calculate_trade_levels(pred_info['prediction'], pred_info['confidence'], current_price, ta_indicators['ATR'])
1104
+ if not pd.isna(trade_levels['Entry']):
1105
+ entry_time = processing_time
1106
+ exit_time = processing_time + timedelta(hours=PREDICTION_WINDOW_HOURS)
1107
+
1108
+ final_results.append({
1109
+ 'coin': symbol.split('/')[0], 'full_symbol': symbol,
1110
+ 'prediction': pred_info['prediction'], 'confidence': pred_info['confidence'],
1111
+ 'price': current_price, 'volume': quote_volume,
1112
+ 'entry': trade_levels['Entry'], 'entry_time': entry_time, 'exit_time': exit_time,
1113
+ 'tp1': trade_levels['TP1'], 'tp2': trade_levels['TP2'], 'sl': trade_levels['SL'],
1114
+ 'rsi': ta_indicators['RSI'], 'macd_hist': ta_indicators['MACD_Hist'],
1115
+ 'vwap': ta_indicators['VWAP'], 'atr': ta_indicators['ATR']
1116
+ })
1117
+ logging.info("Step 3 Complete: TA and Trade Levels calculated.")
1118
+
1119
+ # --- Stage 4: Sort and Format (Unchanged) ---
1120
+ def sort_key(item):
1121
+ pred, conf = item['prediction'], item['confidence']
1122
+ if pred == "Rise" and conf >= CONFIDENCE_THRESHOLD and not pd.isna(item['entry']): return (0, -conf)
1123
+ elif pred == "Rise": return (1, -conf)
1124
+ elif pred == "Fall": return (2, -conf)
1125
+ else: return (3, 0)
1126
+ final_results.sort(key=sort_key)
1127
+
1128
+ formatted_output = []
1129
+ for i, p in enumerate(final_results[:MAX_COINS_TO_DISPLAY]):
1130
+ formatted_output.append([
1131
+ i + 1, p['coin'], p['prediction'], f"{p['confidence']:.3f}",
1132
+ f"{p['price']:.4f}" if not pd.isna(p['price']) else "N/A",
1133
+ f"{p['volume']:,.0f}" if not pd.isna(p['volume']) else "N/A",
1134
+ f"{p['entry']:.4f}" if not pd.isna(p['entry']) else "N/A",
1135
+ format_datetime(p['entry_time'], "N/A"), format_datetime(p['exit_time'], "N/A"),
1136
+ f"{p['tp1']:.4f}" if not pd.isna(p['tp1']) else "N/A",
1137
+ f"{p['tp2']:.4f}" if not pd.isna(p['tp2']) else "N/A",
1138
+ f"{p['sl']:.4f}" if not pd.isna(p['sl']) else "N/A",
1139
+ f"{p['rsi']:.2f}" if not pd.isna(p['rsi']) else "N/A",
1140
+ f"{p['macd_hist']:.4f}" if not pd.isna(p['macd_hist']) else "N/A",
1141
+ f"{p['vwap']:.4f}" if not pd.isna(p['vwap']) else "N/A",
1142
+ f"{p['atr']:.4f}" if not pd.isna(p['atr']) else "N/A",
1143
+ ])
1144
+
1145
+ output_columns = [
1146
+ 'Rank', 'Coin', 'Prediction', 'Confidence', 'Price', 'Volume (Quote)',
1147
+ 'Entry', 'Entry Time', 'Exit Time', 'TP1', 'TP2', 'SL',
1148
+ 'RSI', 'MACD Hist', 'VWAP', 'ATR'
1149
+ ]
1150
+ output_df = pd.DataFrame(formatted_output, columns=output_columns)
1151
+
1152
+ end_time = time.time()
1153
+ duration = end_time - start_time
1154
+ last_update_time = processing_time
1155
+ status_message = f"Predictions updated ({len(final_results)} symbols processed) in {duration:.2f}s. Last update: {format_datetime(last_update_time)}"
1156
+ logging.info(status_message)
1157
+
1158
+ return output_df, status_message
1159
+
1160
+
1161
+ # --- Gradio UI Handlers (Unchanged) ---
1162
+ def handle_train_click(coin_input, num_workers):
1163
+ # (Keep this function as is)
1164
+ logging.info(f"Train button clicked. Workers: {num_workers}")
1165
+ coins = None
1166
+ num_workers = int(num_workers)
1167
+ if coin_input and coin_input.strip():
1168
+ raw_coins = coin_input.replace(',', ' ').split()
1169
+ coins = []
1170
+ valid = True
1171
+ for c in raw_coins:
1172
+ coin_upper = c.strip().upper()
1173
+ if '/' not in coin_upper: coin_upper += '/USDT'
1174
+ if coin_upper.endswith('/USDT'): coins.append(coin_upper)
1175
+ else:
1176
+ valid = False
1177
+ logging.error(f"Invalid coin format: {c}. Must be SYMBOL or SYMBOL/USDT.")
1178
+ break
1179
+ if not valid: return "Error: Custom coins must be valid SYMBOL or SYMBOL/USDT pairs."
1180
+ logging.info(f"Training requested for custom coin list: {coins}")
1181
+ else:
1182
+ logging.info("Training requested for top coins by volume.")
1183
+ train_status = train_all_models(coin_list=coins, num_workers=num_workers)
1184
+ return f"--- Training Run ---:\n{train_status}\n\n---> Press 'Refresh Predictions' <---"
1185
+
1186
+ def handle_refresh_click():
1187
+ # (Keep this function as is)
1188
+ logging.info("Refresh button clicked.")
1189
+ try:
1190
+ df, status = update_predictions_table()
1191
+ return df, status
1192
+ except Exception as e:
1193
+ logging.exception("Error during handle_refresh_click")
1194
+ cols = ['Rank', 'Coin', 'Prediction', 'Confidence', 'Price', 'Volume (Quote)', 'Entry', 'Entry Time', 'Exit Time', 'TP1', 'TP2', 'SL', 'RSI', 'MACD Hist', 'VWAP', 'ATR']
1195
+ return pd.DataFrame([], columns=cols), f"Error updating predictions: {e}"
1196
+
1197
+ # --- Gradio Interface Definition (Unchanged) ---
1198
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
1199
+ gr.Markdown("# Cryptocurrency Prediction & TA Signal Explorer (LLT-KNN + TA-Lib)") # Updated title slightly
1200
+ gr.Markdown(f"""
1201
+ Predicts **{PREDICTION_WINDOW_HOURS}-hour** price direction (Rise/Fall) using LLT-KNN.
1202
+ Displays current price, volume, TA indicators (RSI, MACD, VWAP, ATR calculated using **TA-Lib**), and potential trade levels for **Rise** signals meeting confidence >= **{CONFIDENCE_THRESHOLD}**.
1203
+ TP/SL levels based on **{TP1_ATR_MULTIPLIER}x / {TP2_ATR_MULTIPLIER}x / {SL_ATR_MULTIPLIER}x ATR({ATR_PERIOD})**.
1204
+ **Warning:** Educational. High risk. Not financial advice. Ensure TA-Lib is correctly installed.
1205
+ """)
1206
+
1207
+ with gr.Row():
1208
+ with gr.Column(scale=4):
1209
+ prediction_df = gr.Dataframe(
1210
+ headers=[
1211
+ 'Rank', 'Coin', 'Prediction', 'Confidence', 'Price', 'Volume (Quote)',
1212
+ 'Entry', 'Entry Time', 'Exit Time', 'TP1', 'TP2', 'SL',
1213
+ 'RSI', 'MACD Hist', 'VWAP', 'ATR'
1214
+ ],
1215
+ datatype=[
1216
+ 'number', 'str', 'str', 'str', 'str', 'str',
1217
+ 'str', 'str', 'str', 'str', 'str', 'str',
1218
+ 'str', 'str', 'str', 'str'
1219
+ ],
1220
+ row_count=15, col_count=(16, "fixed"), label="Predictions & TA Signals", wrap=True,
1221
+ )
1222
+ with gr.Column(scale=1):
1223
+ with gr.Accordion("Train Models", open=True):
1224
+ coin_input = gr.Textbox(label="Train Specific Coins (e.g., BTC, ETH/USDT)", placeholder="Leave empty for top coins by volume")
1225
+ max_workers_slider = gr.Slider(minimum=1, maximum=10, value=NUM_WORKERS_TRAINING, step=1, label="Parallel Training Workers")
1226
+ train_button = gr.Button("Start Training", variant="primary")
1227
+ refresh_button = gr.Button("Refresh Predictions", variant="secondary")
1228
+ status_text = gr.Textbox(label="Status Log", lines=15, interactive=False, max_lines=30)
1229
+
1230
+ gr.Markdown(
1231
+ """
1232
+ ## Notes
1233
+ - **TA-Lib**: This version uses the TA-Lib library for indicators. Ensure it's installed correctly (can be tricky).
1234
+ - **Data**: Fetches OHLCV data (Bitget, 1-min). Uses cache. Handles rate limits.
1235
+ - **Training**: Uses past ~14h data (12h train, 2h predict). Normalizes, balances classes, applies LLT, trains KNN.
1236
+ - **Prediction**: Uses latest 12h data for KNN input.
1237
+ - **Trade Levels**: Only shown for 'Rise' predictions above confidence threshold. Based on current price and ATR volatility. **Highly speculative.**
1238
+ - **Sorting**: Table sorted by (Potential Rise Signals > Other Rise > Fall > Errors), then by confidence descending.
1239
+ - **Refresh**: Fetches latest prices/TA and re-evaluates signals.
1240
+ """
1241
+ )
1242
+ train_button.click(fn=handle_train_click, inputs=[coin_input, max_workers_slider], outputs=status_text)
1243
+ refresh_button.click(fn=handle_refresh_click, inputs=None, outputs=[prediction_df, status_text])
1244
+
1245
+ # --- Startup Initialization (Unchanged) ---
1246
+ def initialize_models_on_startup():
1247
+ # (Keep this function as is)
1248
+ logging.info("----- Initializing Models (Startup Thread) -----")
1249
+ default_coins = ['BTC/USDT', 'ETH/USDT', 'SOL/USDT', 'XRP/USDT', 'DOGE/USDT']
1250
+ try:
1251
+ initial_status = train_all_models(default_coins, num_workers=2)
1252
+ logging.info("----- Initial Model Training Complete -----")
1253
+ logging.info(initial_status)
1254
+ except Exception as e:
1255
+ logging.exception("Error during startup initialization.")
1256
+
1257
+ # --- Main Execution (Unchanged) ---
1258
+ if __name__ == "__main__":
1259
+ logging.info("Starting application...")
1260
+ # Check if TA-Lib import worked (basic check)
1261
+ try:
1262
+ # Try accessing a TA-Lib function
1263
+ _ = talib.RSI(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]))
1264
+ logging.info("TA-Lib library seems accessible.")
1265
+ except NameError:
1266
+ logging.error("FATAL: TA-Lib library not found or import failed. Please install it correctly.")
1267
+ sys.exit(1)
1268
+ except Exception as ta_init_e:
1269
+ logging.error(f"FATAL: Error testing TA-Lib library: {ta_init_e}. Please check installation.")
1270
+ sys.exit(1)
1271
+
1272
+ init_thread = threading.Thread(target=initialize_models_on_startup, name="StartupTrainThread", daemon=True)
1273
+ init_thread.start()
1274
+
1275
+ logging.info("Launching Gradio Interface...")
1276
+ try:
1277
+ demo.launch(server_name="0.0.0.0")
1278
+ except Exception as e:
1279
+ logging.exception("Failed to launch Gradio interface.")
1280
+ finally:
1281
+ logging.info("Gradio Interface stopped.")