# Install TA-Lib (see instructions above) then: pip install TA-Lib import ccxt import numpy as np import pandas as pd import time from sklearn.neighbors import KNeighborsClassifier from scipy.linalg import svd import gradio as gr import concurrent.futures import traceback from datetime import datetime, timezone, timedelta import logging import sys import talib # Import TA-Lib import threading # --- Setup Logging --- logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(threadName)s:%(funcName)s] - %(message)s', stream=sys.stdout ) logging.getLogger().handlers[0].flush = sys.stdout.flush # --- Parameters --- L = 10 LAG = 11 MINUTES_PER_HOUR = 60 PREDICTION_WINDOW_HOURS = 2 TRAINING_WINDOW_HOURS = 12 TOTAL_WINDOW_HOURS = TRAINING_WINDOW_HOURS + PREDICTION_WINDOW_HOURS K = TRAINING_WINDOW_HOURS * MINUTES_PER_HOUR # 720 WINDOW = TOTAL_WINDOW_HOURS * MINUTES_PER_HOUR # 840 FEATURES = ['open', 'high', 'low', 'close', 'volume'] D = 5 OVERLAP_STEP = 60 MIN_TRAINING_EXAMPLES = 20 MAX_COINS_TO_DISPLAY = 10 USE_SYNTHETIC_DATA_FOR_LOW_VOLUME = False NUM_WORKERS_TRAINING = 4 NUM_WORKERS_PREDICTION = 10 # --- TA & Risk Parameters --- TA_DATA_POINTS = 200 # Candles needed for TA calculation RSI_PERIOD = 14 MACD_FAST = 12 MACD_SLOW = 26 MACD_SIGNAL = 9 ATR_PERIOD = 14 CONFIDENCE_THRESHOLD = 0.65 # Min confidence for Rise signal TP1_ATR_MULTIPLIER = 1.5 TP2_ATR_MULTIPLIER = 3.0 SL_ATR_MULTIPLIER = 1.0 # --- CCXT Initialization --- try: exchange = ccxt.bitget({ 'enableRateLimit': True, 'rateLimit': 1100, 'timeout': 45000, 'options': {'adjustForTimeDifference': True} }) logging.info(f"Initialized {exchange.id} exchange.") except Exception as e: logging.exception("FATAL: Could not initialize CCXT exchange.") sys.exit() # --- Global Caches and Variables --- markets_cache = None last_markets_update = None data_cache = {} trained_models = {} last_update_time = datetime.now(timezone.utc) # --- Functions --- def format_datetime(dt, default="N/A"): # (Keep this function as is) if pd.isna(dt) or dt is None: return default try: if isinstance(dt, (datetime, pd.Timestamp)): if dt.tzinfo is None: dt = dt.replace(tzinfo=timezone.utc) return dt.strftime('%Y-%m-%d %H:%M:%S %Z') else: return str(dt) except Exception: return default def get_all_usdt_pairs(): # (Keep this function as is - no changes needed) global markets_cache, last_markets_update current_time = time.time() cache_duration = 3600 # 1 hour if markets_cache is not None and last_markets_update is not None: if current_time - last_markets_update < cache_duration: logging.info("Using cached markets list.") if isinstance(markets_cache, list) and markets_cache: return markets_cache else: logging.warning("Cached market list was invalid, fetching fresh.") logging.info("Fetching markets from Bitget...") try: exchange.load_markets(reload=True) all_symbols = list(exchange.markets.keys()) usdt_pairs = [ symbol for symbol in all_symbols if isinstance(symbol, str) and symbol.endswith('/USDT') and exchange.markets.get(symbol, {}).get('active', False) and exchange.markets.get(symbol, {}).get('spot', False) and 'LEVERAGED' not in exchange.markets.get(symbol, {}).get('type', 'spot').upper() and not exchange.markets.get(symbol, {}).get('inverse', False) ] logging.info(f"Found {len(usdt_pairs)} active USDT spot pairs initially.") if not usdt_pairs: logging.warning("No active USDT spot pairs found.") return ['BTC/USDT', 'ETH/USDT', 'SOL/USDT'] logging.info(f"Fetching tickers for {len(usdt_pairs)} pairs for volume sorting...") volumes = {} symbols_to_fetch = usdt_pairs[:] fetched_tickers = {} try: if exchange.has['fetchTickers']: batch_size_tickers = 100 for i in range(0, len(symbols_to_fetch), batch_size_tickers): batch_symbols = symbols_to_fetch[i:i+batch_size_tickers] logging.info(f"Fetching ticker batch {i//batch_size_tickers + 1}/{ (len(symbols_to_fetch) + batch_size_tickers -1)//batch_size_tickers }...") retries = 2 for attempt in range(retries): try: batch_tickers = exchange.fetch_tickers(symbols=batch_symbols) fetched_tickers.update(batch_tickers) time.sleep(exchange.rateLimit / 1000 * 1.5) # Add delay break except (ccxt.RequestTimeout, ccxt.NetworkError) as e_timeout: logging.warning(f"Ticker fetch timeout/network error on attempt {attempt+1}/{retries}: {e_timeout}, retrying after delay...") time.sleep(3 * (attempt + 1)) except ccxt.RateLimitExceeded: logging.warning(f"Rate limit exceeded fetching tickers, sleeping...") time.sleep(10 * (attempt+1)) # Longer sleep for rate limit except Exception as e_ticker: logging.error(f"Error fetching ticker batch (attempt {attempt+1}): {e_ticker}") if attempt == retries - 1: raise # Rethrow last error time.sleep(2 * (attempt + 1)) logging.info(f"Fetched {len(fetched_tickers)} tickers using fetchTickers.") else: raise ccxt.NotSupported("fetchTickers not supported/enabled. Volume sorting requires it.") except Exception as e: logging.exception(f"Could not fetch tickers for volume sorting: {e}. Volume sorting unavailable.") markets_cache = usdt_pairs[:MAX_COINS_TO_DISPLAY] last_markets_update = current_time logging.warning(f"Returning top {len(markets_cache)} unsorted pairs due to ticker error.") return markets_cache for symbol, ticker in fetched_tickers.items(): try: quote_volume = ticker.get('info', {}).get('quoteVolume') # Prefer quoteVolume if available last_price = ticker.get('last') base_volume = ticker.get('baseVolume') # Ensure values are convertible to float before calculation valid_last = last_price is not None valid_base = base_volume is not None valid_quote = quote_volume is not None if valid_quote: volumes[symbol] = float(quote_volume) elif valid_base and valid_last: volumes[symbol] = float(base_volume) * float(last_price) else: volumes[symbol] = 0 except (TypeError, ValueError, KeyError, AttributeError) as e: logging.warning(f"Could not parse volume/price for {symbol} from ticker: {ticker}. Error: {e}") volumes[symbol] = 0 valid_volume_pairs = {k: v for k, v in volumes.items() if v > 0} logging.info(f"Found {len(valid_volume_pairs)} pairs with non-zero volume.") if not valid_volume_pairs: logging.warning("No pairs with valid volume found. Returning default list.") return ['BTC/USDT', 'ETH/USDT', 'SOL/USDT'] sorted_pairs = sorted(valid_volume_pairs.items(), key=lambda item: item[1], reverse=True) num_pairs_to_take = min(MAX_COINS_TO_DISPLAY, len(sorted_pairs)) top_pairs = [pair[0] for pair in sorted_pairs[:num_pairs_to_take]] logging.info(f"Selected Top {len(top_pairs)} pairs by volume. Top 5: {[p[0] for p in sorted_pairs[:5]]}") markets_cache = top_pairs last_markets_update = current_time return top_pairs except ccxt.NetworkError as e: logging.error(f"Network error getting USDT pairs: {e}") except ccxt.ExchangeError as e: logging.error(f"Exchange error getting USDT pairs: {e}") except Exception as e: logging.exception("General error getting USDT pairs.") logging.warning("Error fetching markets, returning default fallback list.") return ['BTC/USDT', 'ETH/USDT', 'SOL/USDT', 'BNB/USDT', 'XRP/USDT'] def clean_and_process_ohlcv(ohlcv_list, symbol, expected_candles): # (Keep this function as is - no changes needed) if not ohlcv_list: return None try: df = pd.DataFrame(ohlcv_list, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']) initial_len = len(df) if initial_len == 0: return None df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True) df = df.drop_duplicates(subset=['timestamp']) df = df.sort_values('timestamp') len_after_dupes = len(df) numeric_cols = ['open', 'high', 'low', 'close', 'volume'] for col in numeric_cols: df[col] = pd.to_numeric(df[col], errors='coerce') # Drop rows with NaN in essential price/volume features needed for TA-Lib df = df.dropna(subset=numeric_cols) len_after_na = len(df) df.reset_index(drop=True, inplace=True) logging.debug(f"Data cleaning for {symbol}: Initial Fetched={initial_len}, AfterDupes={len_after_dupes}, AfterNA={len_after_na}") if len(df) >= expected_candles: final_df = df.iloc[-expected_candles:].copy() # Take the most recent ones return final_df else: return None except Exception as e: logging.exception(f"Error processing DataFrame for {symbol}") return None def fetch_historical_data(symbol, timeframe='1m', total_candles=WINDOW): # (Keep this function as is - no changes needed) cache_key = f"{symbol}_{timeframe}_{total_candles}" current_time = time.time() cache_validity_seconds = 300 # 5 minutes if cache_key in data_cache: cache_time, cached_data = data_cache[cache_key] if current_time - cache_time < cache_validity_seconds: if isinstance(cached_data, pd.DataFrame) and len(cached_data) == total_candles: logging.debug(f"Using valid cached data for {symbol} ({len(cached_data)} candles)") return cached_data.copy() else: logging.warning(f"Cache for {symbol} invalid or wrong size ({len(cached_data) if isinstance(cached_data, pd.DataFrame) else 'N/A'} vs {total_candles}), fetching fresh.") if cache_key in data_cache: del data_cache[cache_key] if not exchange.has['fetchOHLCV']: logging.error(f"Exchange {exchange.id} does not support fetchOHLCV.") return None logging.debug(f"Fetching {total_candles} candles for {symbol} (timeframe: {timeframe})") final_df = None fetch_start_time = time.time() duration_ms = exchange.parse_timeframe(timeframe) * 1000 now_ms = exchange.milliseconds() # --- Strategy 1: Try Single Large Fetch --- single_fetch_limit = total_candles + 200 # Buffer single_fetch_since = now_ms - single_fetch_limit * duration_ms try: ohlcv_list = exchange.fetch_ohlcv(symbol, timeframe, limit=single_fetch_limit, since=single_fetch_since) if ohlcv_list: processed_df = clean_and_process_ohlcv(ohlcv_list, symbol, total_candles) if processed_df is not None and len(processed_df) == total_candles: final_df = processed_df except ccxt.RateLimitExceeded as e: logging.warning(f"Rate limit hit during single fetch for {symbol}, falling back: {e}") time.sleep(5) except (ccxt.RequestTimeout, ccxt.NetworkError) as e: logging.warning(f"Timeout/Network error during single fetch for {symbol}, falling back: {e}") time.sleep(2) except ccxt.ExchangeNotAvailable as e: logging.error(f"Exchange not available during fetch for {symbol}: {e}") return None except ccxt.AuthenticationError as e: logging.error(f"Authentication error fetching {symbol}: {e}") return None except ccxt.ExchangeError as e: logging.warning(f"Exchange error during single fetch for {symbol}, falling back: {e}") except Exception as e: logging.exception(f"Unexpected error during single fetch for {symbol}, falling back.") # --- Strategy 2: Fallback to Iterative Chunking --- if final_df is None: logging.debug(f"Falling back to iterative chunk fetching for {symbol}.") limit_per_call = exchange.safe_integer(exchange.limits.get('fetchOHLCV', {}), 'max', 1000) limit_per_call = min(limit_per_call, 1000) all_ohlcv_chunks = [] required_start_time_ms = now_ms - (total_candles + 5) * duration_ms current_chunk_end_time_ms = now_ms max_chunk_attempts = 15 attempts = 0 while attempts < max_chunk_attempts: attempts += 1 oldest_ts_in_hand = all_ohlcv_chunks[0][0] if all_ohlcv_chunks else current_chunk_end_time_ms if oldest_ts_in_hand <= required_start_time_ms: logging.debug(f"Chunking: Collected enough historical range for {symbol}.") break fetch_limit = limit_per_call chunk_fetch_since = oldest_ts_in_hand - fetch_limit * duration_ms params = {} try: ohlcv_chunk = exchange.fetch_ohlcv(symbol, timeframe, since=chunk_fetch_since, limit=fetch_limit, params=params) if not ohlcv_chunk: logging.debug(f"Chunking: No more data received for {symbol} from API.") break new_chunk = [c for c in ohlcv_chunk if c[0] < oldest_ts_in_hand] if not new_chunk: break new_chunk.sort(key=lambda x: x[0]) all_ohlcv_chunks = new_chunk + all_ohlcv_chunks if len(new_chunk) < limit_per_call // 20 and attempts > 5: logging.warning(f"Chunking: Received very few new candles ({len(new_chunk)}) repeatedly for {symbol}.") break time.sleep(exchange.rateLimit / 1000 * 1.1) except ccxt.RateLimitExceeded as e: logging.warning(f"Rate limit hit during chunking for {symbol}, sleeping 10s: {e}") time.sleep(10 * (attempts/3 + 1)) except (ccxt.NetworkError, ccxt.RequestTimeout) as e: logging.error(f"Network/Timeout error during chunking for {symbol}: {e}. Stopping.") break except ccxt.ExchangeError as e: logging.error(f"Exchange error during chunking for {symbol}: {e}. Stopping.") break except Exception as e: logging.exception(f"Generic error during chunking fetch for {symbol}") break if attempts >= max_chunk_attempts: logging.warning(f"Max chunk fetch attempts reached for {symbol}.") if all_ohlcv_chunks: processed_df = clean_and_process_ohlcv(all_ohlcv_chunks, symbol, total_candles) if processed_df is not None and len(processed_df) == total_candles: final_df = processed_df else: logging.error(f"No data obtained from chunk fetching for {symbol}.") # --- Final Check and Caching --- if final_df is not None and len(final_df) == total_candles: expected_cols = ['timestamp', 'open', 'high', 'low', 'close', 'volume'] if all(col in final_df.columns for col in expected_cols): data_cache[cache_key] = (current_time, final_df.copy()) return final_df else: logging.error(f"Final DataFrame for {symbol} missing expected columns. Won't cache.") return None else: logging.error(f"Failed to fetch exactly {total_candles} candles for {symbol}. Found: {len(final_df) if final_df is not None else 0}") return None # --- Embedding, LLT, Normalize, Training Prep (Largely unchanged) --- # Keep create_embedding, llt_transform, normalize_data, prepare_training_data, train_model # as they don't depend on the TA library choice. def create_embedding(data, l=L, lag=LAG): # (Keep this function as is) n = len(data) rows = n - (l - 1) * lag if rows <= 0: logging.debug(f"Cannot create embedding: data length {n} too short for L={l}, Lag={lag}") return np.array([]) A = np.zeros((rows, l)) try: for t in range(rows): indices = t + np.arange(l) * lag A[t] = data[indices] return A except IndexError as e: logging.error(f"IndexError during embedding: n={n}, l={l}, lag={lag}. Error: {e}") return np.array([]) except Exception as e: logging.exception("Error in create_embedding") return np.array([]) def llt_transform(X_train, y_train, X_test): # (Keep this function as is) if not isinstance(X_train, np.ndarray) or X_train.ndim != 3 or \ not isinstance(y_train, np.ndarray) or y_train.ndim != 1 or \ not isinstance(X_test, np.ndarray) or (X_test.size > 0 and X_test.ndim != 3): logging.error(f"LLT input type/shape error.") return np.array([]), np.array([]) if X_train.shape[0] != y_train.shape[0]: logging.error(f"LLT input mismatch: len(X_train) != len(y_train)") return np.array([]), np.array([]) if X_train.size == 0 or y_train.size == 0: logging.error("LLT requires non-empty training data.") return np.array([]), np.array([]) if X_test.size > 0 and X_test.shape[1:] != X_train.shape[1:]: logging.error(f"LLT train/test shape mismatch") return np.array([]), np.array([]) try: num_features = X_train.shape[2] if num_features != len(FEATURES): logging.error(f"LLT: Feature count mismatch.") return np.array([]), np.array([]) V = {j: {'0': [], '1': []} for j in range(num_features)} laws_computed_count = {j: {'0': 0, '1': 0} for j in range(num_features)} for i in range(len(X_train)): label = str(int(y_train[i])) if label not in ['0', '1']: continue for j in range(num_features): feature_data = X_train[i, :, j] A = create_embedding(feature_data, l=L, lag=LAG) if A.shape[0] < L: continue if np.isnan(A).any() or np.isinf(A).any(): continue try: S = A.T @ A if np.isnan(S).any() or np.isinf(S).any(): continue U, s, Vt = svd(S, full_matrices=False) if Vt.shape[0] < L or Vt.shape[1] != L: continue if s[-1] < 1e-9: continue v = Vt[-1] norm = np.linalg.norm(v) if norm < 1e-9: continue V[j][label].append(v / norm) laws_computed_count[j][label] += 1 except np.linalg.LinAlgError: pass except Exception: pass valid_laws_exist = False for j in V: for c in ['0', '1']: if laws_computed_count[j][c] > 0: valid_vecs = [vec for vec in V[j][c] if isinstance(vec, np.ndarray) and vec.shape == (L,)] if not valid_vecs: V[j][c] = np.zeros((L, 0)) continue try: V[j][c] = np.array(valid_vecs).T if V[j][c].shape[0] != L: V[j][c] = np.zeros((L, 0)) else: valid_laws_exist = True except Exception: V[j][c] = np.zeros((L, 0)) else: V[j][c] = np.zeros((L, 0)) if not valid_laws_exist: logging.error("LLT ERROR: No valid laws computed.") return np.array([]), np.array([]) def transform_instance(X_instance): transformed_features = [] if X_instance.ndim != 2 or X_instance.shape[0] != K or X_instance.shape[1] != num_features: return np.zeros(num_features * 2 * D) for j in range(num_features): feature_data = X_instance[:, j] A = create_embedding(feature_data, l=L, lag=LAG) if A.shape[0] < L: transformed_features.extend([0.0] * (2 * D)) continue if np.isnan(A).any() or np.isinf(A).any(): transformed_features.extend([0.0] * (2 * D)) continue try: S = A.T @ A if np.isnan(S).any() or np.isinf(S).any(): transformed_features.extend([0.0] * (2 * D)) continue for c in ['0', '1']: if V[j][c].shape[1] == 0: transformed_features.extend([0.0] * D) continue S_V = S @ V[j][c] if S_V.size == 0 or np.isnan(S_V).any() or np.isinf(S_V).any(): transformed_features.extend([0.0] * D) continue variances = np.var(S_V, axis=0) if variances.size == 0: transformed_features.extend([0.0] * D) continue variances = np.nan_to_num(variances, nan=np.finfo(variances.dtype).max, posinf=np.finfo(variances.dtype).max, neginf=np.finfo(variances.dtype).max) num_vars_available = variances.size num_vars_to_select = min(D, num_vars_available) smallest_indices = np.argpartition(variances, num_vars_to_select -1)[:num_vars_to_select] smallest_vars = np.sort(variances[smallest_indices]) padded_vars = np.pad(smallest_vars, (0, D - num_vars_to_select), 'constant', constant_values=0.0) if np.isnan(padded_vars).any() or np.isinf(padded_vars).any(): padded_vars = np.nan_to_num(padded_vars, nan=0.0, posinf=0.0, neginf=0.0) transformed_features.extend(padded_vars) except Exception: current_len = len(transformed_features) expected_len_after_feature = (j + 1) * 2 * D num_missing = expected_len_after_feature - current_len if num_missing > 0: transformed_features.extend([0.0] * num_missing) transformed_features = transformed_features[:expected_len_after_feature] correct_len = num_features * 2 * D if len(transformed_features) != correct_len: if len(transformed_features) < correct_len: transformed_features.extend([0.0] * (correct_len - len(transformed_features))) else: transformed_features = transformed_features[:correct_len] return np.array(transformed_features) X_train_t = np.array([transform_instance(X) for X in X_train]) X_test_t = np.array([]) if X_test.size > 0: X_test_t = np.array([transform_instance(X) for X in X_test]) expected_dim = num_features * 2 * D if X_train_t.shape[0] != len(X_train) or (X_train_t.size > 0 and X_train_t.shape[1] != expected_dim): logging.error(f"LLT Train transform resulted in unexpected shape.") return np.array([]), np.array([]) if X_test.size > 0 and (X_test_t.shape[0] != len(X_test) or (X_test_t.size > 0 and X_test_t.shape[1] != expected_dim)): logging.error(f"LLT Test transform resulted in unexpected shape.") return X_train_t, np.array([]) return X_train_t, X_test_t except Exception as e: logging.exception("Error in llt_transform function") return np.array([]), np.array([]) def normalize_data(df): # (Keep this function as is) normalized_df = df.copy() if not isinstance(df, pd.DataFrame): logging.error("Normalize_data received non-DataFrame input.") return None for feature in FEATURES: if feature == 'timestamp': continue if feature not in df.columns: normalized_df[feature] = 0.0 continue if pd.api.types.is_numeric_dtype(df[feature]): mean = df[feature].mean() std = df[feature].std() if std is not None and not pd.isna(std) and std > 1e-9: normalized_df[feature] = (df[feature] - mean) / std else: normalized_df[feature] = 0.0 if normalized_df[feature].isnull().any(): normalized_df[feature] = normalized_df[feature].fillna(0.0) else: normalized_df[feature] = 0.0 return normalized_df def generate_synthetic_data(symbol, total_candles=WINDOW): # (Keep this function as is) logging.info(f"Generating synthetic data for {symbol} ({total_candles} candles)") np.random.seed(int(time.time() * 1000) % (2**32 - 1)) end_time = pd.Timestamp.now(tz='UTC') timestamps = pd.date_range(end=end_time, periods=total_candles, freq='T') volatility = np.random.uniform(0.005, 0.03) base_price = np.random.uniform(1, 5000) prices = [base_price] for _ in range(1, total_candles): change = np.random.normal(0, volatility / np.sqrt(1440)) prices.append(prices[-1] * (1 + change)) prices = np.maximum(0.01, prices) close_prices = np.array(prices) open_prices = close_prices * (1 + np.random.normal(0, volatility / np.sqrt(1440) / 2, total_candles)) high_prices = np.maximum(close_prices, open_prices) * (1 + np.random.uniform(0, volatility / np.sqrt(1440), total_candles)) low_prices = np.minimum(close_prices, open_prices) * (1 - np.random.uniform(0, volatility / np.sqrt(1440), total_candles)) high_prices = np.maximum.reduce([high_prices, open_prices, close_prices]) low_prices = np.minimum.reduce([low_prices, open_prices, close_prices]) volumes = np.random.poisson(base_price * np.random.uniform(1, 10)) * (1 + np.abs(np.diff(close_prices, prepend=close_prices[0])) / close_prices * 5) volumes = np.maximum(1, volumes) df = pd.DataFrame({ 'timestamp': timestamps, 'open': open_prices, 'high': high_prices, 'low': low_prices, 'close': close_prices, 'volume': volumes }) for col in FEATURES: df[col] = pd.to_numeric(df[col]) df.reset_index(drop=True, inplace=True) return df def prepare_training_data(symbol, total_candles_to_fetch=WINDOW + OVERLAP_STEP * 20): # (Keep this function as is) logging.info(f"Preparing training data for {symbol}...") try: required_base_candles = WINDOW estimated_candles_needed = required_base_candles + (MIN_TRAINING_EXAMPLES * 2) * OVERLAP_STEP + 500 fetch_candle_count = max(WINDOW + 500, estimated_candles_needed) logging.info(f"Fetching {fetch_candle_count} candles for {symbol} training prep...") df = fetch_historical_data(symbol, timeframe='1m', total_candles=fetch_candle_count) if df is None or len(df) < WINDOW: logging.error(f"Insufficient data fetched for {symbol} ({len(df) if df is not None else 0} < {WINDOW}).") if USE_SYNTHETIC_DATA_FOR_LOW_VOLUME: logging.warning(f"Attempting synthetic data generation for {symbol}.") df = generate_synthetic_data(symbol, total_candles=WINDOW + OVERLAP_STEP * 10) if df is None or len(df) < WINDOW: logging.error(f"Synthetic data generation failed or insufficient for {symbol}.") return None, None else: logging.info(f"Using synthetic data ({len(df)} points) for {symbol}.") else: return None, None df_normalized = normalize_data(df) if df_normalized is None: logging.error(f"Normalization failed for {symbol}.") return None, None if df_normalized[FEATURES].isnull().any().any(): logging.warning(f"NaN values found after normalization for {symbol}. Filling with 0.") df_normalized = df_normalized.fillna(0.0) X, y = [], [] end_index = len(df) start_index = WINDOW num_windows_created = 0 for i in range(end_index, start_index - 1, -OVERLAP_STEP): window_end_idx = i window_start_idx = i - WINDOW if window_start_idx < 0: continue window_orig = df.iloc[window_start_idx:window_end_idx] window_norm = df_normalized.iloc[window_start_idx:window_end_idx] if len(window_orig) != WINDOW or len(window_norm) != WINDOW: continue input_data_norm = window_norm.iloc[:K][FEATURES].values if input_data_norm.shape[0] != K or input_data_norm.shape[1] != len(FEATURES): continue if np.isnan(input_data_norm).any(): continue start_price_iloc_idx = K - 1 end_price_iloc_idx = WINDOW - 1 start_price = window_orig['close'].iloc[start_price_iloc_idx] end_price = window_orig['close'].iloc[end_price_iloc_idx] if pd.isna(start_price) or pd.isna(end_price) or start_price <= 0: continue X.append(input_data_norm) y.append(1 if end_price > start_price else 0) num_windows_created += 1 if not X: logging.error(f"No valid windows created for {symbol}.") return None, None X = np.array(X) y = np.array(y) unique_classes, class_counts = np.unique(y, return_counts=True) class_dist_str = ", ".join([f"Class {cls}: {count}" for cls, count in zip(unique_classes, class_counts)]) logging.info(f"Class distribution BEFORE balancing for {symbol}: {class_dist_str}") if len(unique_classes) < 2: logging.error(f"ONLY ONE CLASS ({unique_classes[0]}) present for {symbol}.") return None, None min_class_count = min(class_counts) if min_class_count * 2 < MIN_TRAINING_EXAMPLES: logging.error(f"Minority class count ({min_class_count}) too low for {symbol}.") return None, None samples_per_class = min_class_count balanced_indices = [] for class_val in unique_classes: class_indices = np.where(y == class_val)[0] num_to_choose = min(samples_per_class, len(class_indices)) chosen_indices = np.random.choice(class_indices, size=num_to_choose, replace=False) balanced_indices.extend(chosen_indices) np.random.shuffle(balanced_indices) X_balanced = X[balanced_indices] y_balanced = y[balanced_indices] final_unique, final_counts = np.unique(y_balanced, return_counts=True) logging.info(f"Balanced dataset for {symbol}: {len(X_balanced)} instances. Final counts: {dict(zip(final_unique, final_counts))}") if len(X_balanced) < MIN_TRAINING_EXAMPLES: logging.error(f"Insufficient data ({len(X_balanced)}) for {symbol} AFTER balancing.") return None, None if X_balanced.ndim != 3 or X_balanced.shape[0] == 0 or X_balanced.shape[1] != K or X_balanced.shape[2] != len(FEATURES): logging.error(f"Final balanced data has unexpected shape {X_balanced.shape} for {symbol}.") return None, None return X_balanced, y_balanced except Exception as e: logging.exception(f"Error preparing training data for {symbol}") return None, None def train_model(symbol): # (Keep this function as is) logging.info(f"--- Attempting to train model for {symbol} ---") np.random.seed(int(time.time()) % (2**32 - 1)) X, y = prepare_training_data(symbol) if X is None or y is None: logging.error(f"Failed to prepare training data for {symbol}. Training aborted.") return None, None, None try: accuracy = -1.0 if len(X) < MIN_TRAINING_EXAMPLES + 2: logging.warning(f"Dataset for {symbol} too small ({len(X)}). Training on all data.") X_train, y_train = X, y X_val, y_val = np.array([]), np.array([]) else: indices = np.random.permutation(len(X)) val_size = max(1, int(len(X) * 0.2)) split_idx = len(X) - val_size train_indices, val_indices = indices[:split_idx], indices[split_idx:] if len(train_indices) == 0 or len(val_indices) == 0: logging.error(f"Train/Val split resulted in zero samples. Training on all data.") X_train, y_train = X, y X_val, y_val = np.array([]), np.array([]) else: X_train, X_val = X[train_indices], X[val_indices] y_train, y_val = y[train_indices], y[val_indices] if len(np.unique(y_train)) < 2: logging.error(f"Only one class in TRAINING set after split for {symbol}. Aborting.") return None, None, None if len(np.unique(y_val)) < 2: logging.warning(f"Only one class in VALIDATION set after split for {symbol}.") if X_val.size == 0: X_val_shaped = np.empty((0, K, len(FEATURES))) else: X_val_shaped = X_val X_train_t, X_val_t = llt_transform(X_train, y_train, X_val_shaped) if X_train_t.size == 0: logging.error(f"LLT training transformation failed for {symbol}. Training aborted.") return None, None, None if X_val.size > 0 and X_val_t.size == 0: logging.warning(f"LLT validation transformation failed for {symbol}.") accuracy = -1.0 if np.isnan(X_train_t).any() or np.isinf(X_train_t).any(): logging.error(f"NaN/Inf in LLT transformed TRAINING data for {symbol}. Training aborted.") return None, None, None if X_val_t.size > 0 and (np.isnan(X_val_t).any() or np.isinf(X_val_t).any()): logging.warning(f"NaN/Inf in LLT transformed VALIDATION data for {symbol}.") accuracy = -1.0 n_neighbors = min(5, len(y_train) - 1) if len(y_train) > 1 else 1 n_neighbors = max(1, n_neighbors) if n_neighbors > 1 and n_neighbors % 2 == 0: n_neighbors -= 1 model = KNeighborsClassifier(n_neighbors=n_neighbors, weights='distance') model.fit(X_train_t, y_train) if accuracy != -1.0 and X_val_t.size > 0: try: accuracy = model.score(X_val_t, y_val) logging.info(f"Model for {symbol} trained. Validation Accuracy: {accuracy:.3f}") except Exception as eval_e: logging.exception(f"Error during KNN validation scoring for {symbol}: {eval_e}") accuracy = -1.0 elif accuracy == -1.0: logging.info(f"Model for {symbol} trained. Validation skipped or failed.") else: logging.info(f"Model for {symbol} trained. No validation data.") accuracy = -1.0 return model, X_train, y_train except Exception as e: logging.exception(f"Error during model training pipeline for {symbol}") return None, None, None def predict_real_time(symbol, model_data): # (Keep this function as is) if model_data is None: return "Model N/A", 0.0 model, X_train_orig_for_llt, y_train_orig_for_llt = model_data if model is None or X_train_orig_for_llt is None or y_train_orig_for_llt is None: logging.error(f"Invalid model data tuple for prediction on {symbol}") return "Model Error", 0.0 if X_train_orig_for_llt.size == 0 or y_train_orig_for_llt.size == 0: logging.error(f"Training data for LLT laws is empty for {symbol}") return "LLT Data Error", 0.0 try: df = fetch_historical_data(symbol, timeframe='1m', total_candles=K + 60) if df is None or len(df) < K: return "Data Error", 0.0 df_recent = df.iloc[-K:] if len(df_recent) != K: return "Data Error", 0.0 df_recent_normalized = normalize_data(df_recent) if df_recent_normalized is None: return "Norm Error", 0.0 if df_recent_normalized[FEATURES].isnull().any().any(): df_recent_normalized = df_recent_normalized.fillna(0.0) X_predict_input = np.array([df_recent_normalized[FEATURES].values]) _, X_predict_transformed = llt_transform(X_train_orig_for_llt, y_train_orig_for_llt, X_predict_input) if X_predict_transformed.size == 0 or X_predict_transformed.shape[0] != 1: return "Transform Error", 0.0 if np.isnan(X_predict_transformed).any() or np.isinf(X_predict_transformed).any(): X_predict_transformed = np.nan_to_num(X_predict_transformed, nan=0.0, posinf=0.0, neginf=0.0) try: probabilities = model.predict_proba(X_predict_transformed) if probabilities.shape[0] != 1 or probabilities.shape[1] != 2: return "Predict Error", 0.0 prob_class_1 = probabilities[0, 1] prediction_label = "Rise" if prob_class_1 >= 0.5 else "Fall" confidence = prob_class_1 if prediction_label == "Rise" else probabilities[0, 0] return prediction_label, confidence except Exception as knn_e: logging.exception(f"Error during KNN prediction probability for {symbol}") return "Predict Error", 0.0 except Exception as e: logging.exception(f"Error in predict_real_time for {symbol}") return "Error", 0.0 # --- TA Calculation Function (Using TA-Lib) --- def calculate_ta_indicators(df_ta): """ Calculates TA indicators (RSI, MACD, VWAP, ATR) using TA-Lib. Requires df_ta to have 'open', 'high', 'low', 'close', 'volume' columns. """ indicators = {'RSI': np.nan, 'MACD': np.nan, 'MACD_Signal': np.nan, 'MACD_Hist': np.nan, 'VWAP': np.nan, 'ATR': np.nan} required_cols = ['open', 'high', 'low', 'close', 'volume'] min_len_needed = max(RSI_PERIOD, MACD_SLOW, ATR_PERIOD) + 1 # TA-Lib often needs P+1 if df_ta is None or len(df_ta) < min_len_needed: logging.warning(f"Insufficient data ({len(df_ta) if df_ta is not None else 0} < {min_len_needed}) for TA-Lib calculations.") return indicators # Ensure columns exist if not all(col in df_ta.columns for col in required_cols): logging.error(f"Missing required columns for TA-Lib: Have {df_ta.columns}, Need {required_cols}") return indicators # --- Prepare data for TA-Lib (NumPy arrays, handle NaNs) --- df_ta = df_ta[required_cols].copy() # Work on a copy with only needed columns # Check for NaNs BEFORE converting to numpy, TA-Lib generally dislikes them if df_ta.isnull().values.any(): nan_count = df_ta.isnull().sum().sum() logging.warning(f"Found {nan_count} NaN(s) in TA input data. Applying ffill()...") df_ta.ffill(inplace=True) # Forward fill NaNs # Check again after ffill - if NaNs remain (e.g., at the start), need more robust handling if df_ta.isnull().values.any(): logging.error(f"NaNs still present after ffill. Cannot proceed with TA-Lib.") return indicators # Return NaNs try: # Convert to NumPy arrays of type float open_p = df_ta['open'].values.astype(float) high_p = df_ta['high'].values.astype(float) low_p = df_ta['low'].values.astype(float) close_p = df_ta['close'].values.astype(float) volume_p = df_ta['volume'].values.astype(float) # --- Calculate Indicators using TA-Lib --- # RSI rsi_values = talib.RSI(close_p, timeperiod=RSI_PERIOD) indicators['RSI'] = rsi_values[-1] if len(rsi_values) > 0 else np.nan # MACD macd_line, signal_line, hist = talib.MACD(close_p, fastperiod=MACD_FAST, slowperiod=MACD_SLOW, signalperiod=MACD_SIGNAL) indicators['MACD'] = macd_line[-1] if len(macd_line) > 0 else np.nan indicators['MACD_Signal'] = signal_line[-1] if len(signal_line) > 0 else np.nan indicators['MACD_Hist'] = hist[-1] if len(hist) > 0 else np.nan # ATR atr_values = talib.ATR(high_p, low_p, close_p, timeperiod=ATR_PERIOD) indicators['ATR'] = atr_values[-1] if len(atr_values) > 0 else np.nan # VWAP (Manual Calculation - TA-Lib doesn't have it built-in) typical_price = (high_p + low_p + close_p) / 3.0 tp_vol = typical_price * volume_p cumulative_volume = np.cumsum(volume_p) # Avoid division by zero if volume is zero for initial periods if cumulative_volume[-1] > 1e-12: # Check if there's significant volume vwap_values = np.cumsum(tp_vol) / np.maximum(cumulative_volume, 1e-12) # Avoid div by zero strictly indicators['VWAP'] = vwap_values[-1] else: indicators['VWAP'] = np.nan # VWAP undefined if no volume # Final check for NaNs in results (TA-Lib might return NaN for initial periods) for key, value in indicators.items(): if pd.isna(value): indicators[key] = np.nan # Ensure consistent NaN representation # logging.debug(f"TA-Lib Indicators calculated: {indicators}") return indicators except Exception as ta_e: logging.exception(f"Error calculating TA indicators using TA-Lib: {ta_e}") return {k: np.nan for k in indicators} # Return NaNs on error # --- Trade Level Calculation (Unchanged) --- def calculate_trade_levels(prediction, confidence, current_price, atr): # (Keep this function as is - no changes needed) levels = {'Entry': np.nan, 'TP1': np.nan, 'TP2': np.nan, 'SL': np.nan} if pd.isna(current_price) or current_price <= 0 or pd.isna(atr) or atr <= 0: return levels if prediction == "Rise" and confidence >= CONFIDENCE_THRESHOLD: entry_price = current_price levels['Entry'] = entry_price levels['TP1'] = entry_price + TP1_ATR_MULTIPLIER * atr levels['TP2'] = entry_price + TP2_ATR_MULTIPLIER * atr levels['SL'] = entry_price - SL_ATR_MULTIPLIER * atr levels['SL'] = max(0.01, levels['SL']) # Add Fall logic here if needed return levels # --- Concurrency Wrappers (Unchanged) --- def train_model_task(coin): # (Keep this function as is) try: result = train_model(coin) if result != (None, None, None): model, X_train_orig, y_train_orig = result return coin, (model, X_train_orig, y_train_orig) else: return coin, None except Exception as e: logging.exception(f"Unhandled exception in train_model_task for {coin}") return coin, None def train_all_models(coin_list=None, num_workers=NUM_WORKERS_TRAINING): # (Keep this function as is) global trained_models start_time = time.time() if coin_list is None or not coin_list: logging.info("No coin list provided, fetching top coins by volume...") try: coin_list = get_all_usdt_pairs() if not coin_list: msg = "Failed to fetch coin list even with fallback. Training aborted." logging.error(msg) return msg except Exception as e: msg = f"Error fetching coin list: {e}. Training aborted." logging.exception(msg) return msg logging.info(f"Starting training for {len(coin_list)} coins using {num_workers} workers...") results_log = [] successful_trains = 0 failed_trains = 0 new_models = {} with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers, thread_name_prefix='TrainWorker') as executor: future_to_coin = {executor.submit(train_model_task, coin): coin for coin in coin_list} processed_count = 0 total_coins = len(coin_list) for future in concurrent.futures.as_completed(future_to_coin): processed_count += 1 coin = future_to_coin[future] try: returned_coin, model_data = future.result() if returned_coin == coin and model_data is not None: new_models[returned_coin] = model_data results_log.append(f"✅ {returned_coin}: Model trained successfully.") successful_trains += 1 else: results_log.append(f"❌ {coin}: Model training failed (check logs).") failed_trains += 1 except Exception as e: results_log.append(f"❌ {coin}: Training task generated exception: {e}") failed_trains += 1 logging.exception(f"Exception from training future for {coin}") if processed_count % 10 == 0 or processed_count == total_coins: logging.info(f"Training progress: {processed_count}/{total_coins} coins processed.") logging.getLogger().handlers[0].flush() trained_models.update(new_models) logging.info(f"Updated global models dictionary. Total models now: {len(trained_models)}") end_time = time.time() duration = end_time - start_time completion_message = ( f"Training run completed in {duration:.2f} seconds.\n" f"Successfully trained: {successful_trains}\n" f"Failed to train: {failed_trains}\n" f"Total models available now: {len(trained_models)}" ) logging.info(completion_message) return completion_message + "\n\n" + "\n".join(results_log[-20:]) # --- Update Predictions Table (Mostly Unchanged, uses new TA function) --- def update_predictions_table(): # (This function structure remains the same, it just calls the new calculate_ta_indicators) global last_update_time logging.info("--- Updating Predictions Table ---") start_time = time.time() predictions_data = {} current_models = trained_models.copy() if not current_models: msg = "No models available. Please train first." logging.warning(msg) cols = ['Rank', 'Coin', 'Prediction', 'Confidence', 'Price', 'Volume (Quote)', 'Entry', 'Entry Time', 'Exit Time', 'TP1', 'TP2', 'SL', 'RSI', 'MACD Hist', 'VWAP', 'ATR'] return pd.DataFrame([], columns=cols), msg symbols_with_models = list(current_models.keys()) logging.info(f"Step 1: Generating predictions for {len(symbols_with_models)} models...") # --- Stage 1: Get Predictions Concurrently --- with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_WORKERS_PREDICTION, thread_name_prefix='PredictWorker') as executor: future_to_coin_pred = {executor.submit(predict_real_time, coin, model_data): coin for coin, model_data in current_models.items()} pred_success = 0 pred_fail = 0 for future in concurrent.futures.as_completed(future_to_coin_pred): coin = future_to_coin_pred[future] try: pred, conf = future.result() if pred not in ["Model N/A", "Model Error", "Data Error", "Norm Error", "LLT Data Error", "Transform Error", "Predict Error", "Error"]: predictions_data[coin] = {'prediction': pred, 'confidence': float(conf)} pred_success += 1 else: predictions_data[coin] = {'prediction': pred, 'confidence': 0.0} pred_fail += 1 except Exception as e: logging.exception(f"Error getting prediction result for {coin}") predictions_data[coin] = {'prediction': "Future Error", 'confidence': 0.0} pred_fail +=1 logging.info(f"Step 1 Complete: Predictions generated ({pred_success} success, {pred_fail} fail).") # --- Stage 2: Fetch Current Tickers & TA Data Concurrently --- symbols_to_fetch_data = list(predictions_data.keys()) if not symbols_to_fetch_data: logging.warning("No symbols with predictions to fetch data for.") cols = ['Rank', 'Coin', 'Prediction', 'Confidence', 'Price', 'Volume (Quote)', 'Entry', 'Entry Time', 'Exit Time', 'TP1', 'TP2', 'SL', 'RSI', 'MACD Hist', 'VWAP', 'ATR'] return pd.DataFrame([], columns=cols), "No symbols processed." logging.info(f"Step 2: Fetching Tickers and {TA_DATA_POINTS} OHLCV candles for {len(symbols_to_fetch_data)} symbols...") tickers_data = {} ohlcv_data = {} try: # Fetch Tickers batch_size_tickers = 100 fetched_tickers_batch = {} for i in range(0, len(symbols_to_fetch_data), batch_size_tickers): batch_symbols = symbols_to_fetch_data[i:i+batch_size_tickers] try: batch_tickers = exchange.fetch_tickers(symbols=batch_symbols) fetched_tickers_batch.update(batch_tickers) time.sleep(exchange.rateLimit / 1000 * 0.5) except Exception as e: logging.error(f"Failed to fetch ticker batch starting with {batch_symbols[0]}: {e}") tickers_data = fetched_tickers_batch logging.info(f"Fetched {len(tickers_data)} tickers.") except Exception as e: logging.exception(f"Error fetching tickers in prediction update: {e}") # Fetch OHLCV for TA with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_WORKERS_PREDICTION, thread_name_prefix='TADataWorker') as executor: future_to_coin_ohlcv = {executor.submit(fetch_historical_data, coin, '1m', TA_DATA_POINTS): coin for coin in symbols_to_fetch_data} for future in concurrent.futures.as_completed(future_to_coin_ohlcv): coin = future_to_coin_ohlcv[future] try: df_ta = future.result() if df_ta is not None and len(df_ta) == TA_DATA_POINTS: # Ensure standard column names expected by calculate_ta_indicators df_ta.columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume'] ohlcv_data[coin] = df_ta except Exception as e: logging.exception(f"Error fetching TA OHLCV data for {coin}") logging.info(f"Step 2 Complete: Fetched TA data for {len(ohlcv_data)} symbols.") # --- Stage 3: Calculate TA & Trade Levels --- logging.info(f"Step 3: Calculating TA (using TA-Lib) and Trade Levels...") final_results = [] processing_time = datetime.now(timezone.utc) for symbol in symbols_to_fetch_data: pred_info = predictions_data.get(symbol, {'prediction': 'Missing Pred', 'confidence': 0.0}) ticker = tickers_data.get(symbol) df_ta = ohlcv_data.get(symbol) # This df should have standard columns now current_price, quote_volume = np.nan, np.nan ta_indicators = {k: np.nan for k in ['RSI', 'MACD', 'MACD_Signal', 'MACD_Hist', 'VWAP', 'ATR']} trade_levels = {k: np.nan for k in ['Entry', 'TP1', 'TP2', 'SL']} entry_time, exit_time = pd.NaT, pd.NaT if ticker and isinstance(ticker, dict): current_price = ticker.get('last', np.nan) quote_volume = ticker.get('info', {}).get('quoteVolume') if quote_volume is None: base_volume = ticker.get('baseVolume') if base_volume is not None and current_price is not None: try: quote_volume = float(base_volume) * float(current_price) except (ValueError, TypeError): quote_volume = np.nan try: current_price = float(current_price) if current_price is not None else np.nan except (ValueError, TypeError): current_price = np.nan try: quote_volume = float(quote_volume) if quote_volume is not None else np.nan except (ValueError, TypeError): quote_volume = np.nan # Calculate TA using the new function if df_ta is not None: ta_indicators = calculate_ta_indicators(df_ta) # Calls the TA-Lib version if pred_info['prediction'] in ["Rise", "Fall"] and not pd.isna(current_price) and not pd.isna(ta_indicators['ATR']): trade_levels = calculate_trade_levels(pred_info['prediction'], pred_info['confidence'], current_price, ta_indicators['ATR']) if not pd.isna(trade_levels['Entry']): entry_time = processing_time exit_time = processing_time + timedelta(hours=PREDICTION_WINDOW_HOURS) final_results.append({ 'coin': symbol.split('/')[0], 'full_symbol': symbol, 'prediction': pred_info['prediction'], 'confidence': pred_info['confidence'], 'price': current_price, 'volume': quote_volume, 'entry': trade_levels['Entry'], 'entry_time': entry_time, 'exit_time': exit_time, 'tp1': trade_levels['TP1'], 'tp2': trade_levels['TP2'], 'sl': trade_levels['SL'], 'rsi': ta_indicators['RSI'], 'macd_hist': ta_indicators['MACD_Hist'], 'vwap': ta_indicators['VWAP'], 'atr': ta_indicators['ATR'] }) logging.info("Step 3 Complete: TA and Trade Levels calculated.") # --- Stage 4: Sort and Format (Unchanged) --- def sort_key(item): pred, conf = item['prediction'], item['confidence'] if pred == "Rise" and conf >= CONFIDENCE_THRESHOLD and not pd.isna(item['entry']): return (0, -conf) elif pred == "Rise": return (1, -conf) elif pred == "Fall": return (2, -conf) else: return (3, 0) final_results.sort(key=sort_key) formatted_output = [] for i, p in enumerate(final_results[:MAX_COINS_TO_DISPLAY]): formatted_output.append([ i + 1, p['coin'], p['prediction'], f"{p['confidence']:.3f}", f"{p['price']:.4f}" if not pd.isna(p['price']) else "N/A", f"{p['volume']:,.0f}" if not pd.isna(p['volume']) else "N/A", f"{p['entry']:.4f}" if not pd.isna(p['entry']) else "N/A", format_datetime(p['entry_time'], "N/A"), format_datetime(p['exit_time'], "N/A"), f"{p['tp1']:.4f}" if not pd.isna(p['tp1']) else "N/A", f"{p['tp2']:.4f}" if not pd.isna(p['tp2']) else "N/A", f"{p['sl']:.4f}" if not pd.isna(p['sl']) else "N/A", f"{p['rsi']:.2f}" if not pd.isna(p['rsi']) else "N/A", f"{p['macd_hist']:.4f}" if not pd.isna(p['macd_hist']) else "N/A", f"{p['vwap']:.4f}" if not pd.isna(p['vwap']) else "N/A", f"{p['atr']:.4f}" if not pd.isna(p['atr']) else "N/A", ]) output_columns = [ 'Rank', 'Coin', 'Prediction', 'Confidence', 'Price', 'Volume (Quote)', 'Entry', 'Entry Time', 'Exit Time', 'TP1', 'TP2', 'SL', 'RSI', 'MACD Hist', 'VWAP', 'ATR' ] output_df = pd.DataFrame(formatted_output, columns=output_columns) end_time = time.time() duration = end_time - start_time last_update_time = processing_time status_message = f"Predictions updated ({len(final_results)} symbols processed) in {duration:.2f}s. Last update: {format_datetime(last_update_time)}" logging.info(status_message) return output_df, status_message # --- Gradio UI Handlers (Unchanged) --- def handle_train_click(coin_input, num_workers): # (Keep this function as is) logging.info(f"Train button clicked. Workers: {num_workers}") coins = None num_workers = int(num_workers) if coin_input and coin_input.strip(): raw_coins = coin_input.replace(',', ' ').split() coins = [] valid = True for c in raw_coins: coin_upper = c.strip().upper() if '/' not in coin_upper: coin_upper += '/USDT' if coin_upper.endswith('/USDT'): coins.append(coin_upper) else: valid = False logging.error(f"Invalid coin format: {c}. Must be SYMBOL or SYMBOL/USDT.") break if not valid: return "Error: Custom coins must be valid SYMBOL or SYMBOL/USDT pairs." logging.info(f"Training requested for custom coin list: {coins}") else: logging.info("Training requested for top coins by volume.") train_status = train_all_models(coin_list=coins, num_workers=num_workers) return f"--- Training Run ---:\n{train_status}\n\n---> Press 'Refresh Predictions' <---" def handle_refresh_click(): # (Keep this function as is) logging.info("Refresh button clicked.") try: df, status = update_predictions_table() return df, status except Exception as e: logging.exception("Error during handle_refresh_click") cols = ['Rank', 'Coin', 'Prediction', 'Confidence', 'Price', 'Volume (Quote)', 'Entry', 'Entry Time', 'Exit Time', 'TP1', 'TP2', 'SL', 'RSI', 'MACD Hist', 'VWAP', 'ATR'] return pd.DataFrame([], columns=cols), f"Error updating predictions: {e}" # --- Gradio Interface Definition (Unchanged) --- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# Cryptocurrency Prediction & TA Signal Explorer (LLT-KNN + TA-Lib)") # Updated title slightly gr.Markdown(f""" Predicts **{PREDICTION_WINDOW_HOURS}-hour** price direction (Rise/Fall) using LLT-KNN. Displays current price, volume, TA indicators (RSI, MACD, VWAP, ATR calculated using **TA-Lib**), and potential trade levels for **Rise** signals meeting confidence >= **{CONFIDENCE_THRESHOLD}**. TP/SL levels based on **{TP1_ATR_MULTIPLIER}x / {TP2_ATR_MULTIPLIER}x / {SL_ATR_MULTIPLIER}x ATR({ATR_PERIOD})**. **Warning:** Educational. High risk. Not financial advice. Ensure TA-Lib is correctly installed. """) with gr.Row(): with gr.Column(scale=4): prediction_df = gr.Dataframe( headers=[ 'Rank', 'Coin', 'Prediction', 'Confidence', 'Price', 'Volume (Quote)', 'Entry', 'Entry Time', 'Exit Time', 'TP1', 'TP2', 'SL', 'RSI', 'MACD Hist', 'VWAP', 'ATR' ], datatype=[ 'number', 'str', 'str', 'str', 'str', 'str', 'str', 'str', 'str', 'str', 'str', 'str', 'str', 'str', 'str', 'str' ], row_count=15, col_count=(16, "fixed"), label="Predictions & TA Signals", wrap=True, ) with gr.Column(scale=1): with gr.Accordion("Train Models", open=True): coin_input = gr.Textbox(label="Train Specific Coins (e.g., BTC, ETH/USDT)", placeholder="Leave empty for top coins by volume") max_workers_slider = gr.Slider(minimum=1, maximum=10, value=NUM_WORKERS_TRAINING, step=1, label="Parallel Training Workers") train_button = gr.Button("Start Training", variant="primary") refresh_button = gr.Button("Refresh Predictions", variant="secondary") status_text = gr.Textbox(label="Status Log", lines=15, interactive=False, max_lines=30) gr.Markdown( """ ## Notes - **TA-Lib**: This version uses the TA-Lib library for indicators. Ensure it's installed correctly (can be tricky). - **Data**: Fetches OHLCV data (Bitget, 1-min). Uses cache. Handles rate limits. - **Training**: Uses past ~14h data (12h train, 2h predict). Normalizes, balances classes, applies LLT, trains KNN. - **Prediction**: Uses latest 12h data for KNN input. - **Trade Levels**: Only shown for 'Rise' predictions above confidence threshold. Based on current price and ATR volatility. **Highly speculative.** - **Sorting**: Table sorted by (Potential Rise Signals > Other Rise > Fall > Errors), then by confidence descending. - **Refresh**: Fetches latest prices/TA and re-evaluates signals. """ ) train_button.click(fn=handle_train_click, inputs=[coin_input, max_workers_slider], outputs=status_text) refresh_button.click(fn=handle_refresh_click, inputs=None, outputs=[prediction_df, status_text]) # --- Startup Initialization (Unchanged) --- def initialize_models_on_startup(): # (Keep this function as is) logging.info("----- Initializing Models (Startup Thread) -----") default_coins = ['BTC/USDT', 'ETH/USDT', 'SOL/USDT', 'XRP/USDT', 'DOGE/USDT'] try: initial_status = train_all_models(default_coins, num_workers=2) logging.info("----- Initial Model Training Complete -----") logging.info(initial_status) except Exception as e: logging.exception("Error during startup initialization.") # --- Main Execution (Unchanged) --- if __name__ == "__main__": logging.info("Starting application...") # Check if TA-Lib import worked (basic check) try: # Try accessing a TA-Lib function _ = talib.RSI(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) logging.info("TA-Lib library seems accessible.") except NameError: logging.error("FATAL: TA-Lib library not found or import failed. Please install it correctly.") sys.exit(1) except Exception as ta_init_e: logging.error(f"FATAL: Error testing TA-Lib library: {ta_init_e}. Please check installation.") sys.exit(1) init_thread = threading.Thread(target=initialize_models_on_startup, name="StartupTrainThread", daemon=True) init_thread.start() logging.info("Launching Gradio Interface...") try: demo.launch(server_name="0.0.0.0") except Exception as e: logging.exception("Failed to launch Gradio interface.") finally: logging.info("Gradio Interface stopped.")