Files
Trading/shared_utils.py
2025-12-09 20:35:39 +01:00

301 lines
9.3 KiB
Python

"""Shared helpers for trading pattern scripts."""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple
from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
import pandas as pd
import pyodbc
try:
import pywt
except ImportError: # pragma: no cover - optional dependency
pywt = None
DEFAULT_CONFIG_PATH = Path("config/pattern_knn_config.json")
def load_config(path: Optional[Path] = None) -> Dict:
"""Load the JSON configuration that holds operational parameters."""
cfg_path = Path(path or DEFAULT_CONFIG_PATH)
if not cfg_path.exists():
raise FileNotFoundError(f"Missing configuration file: {cfg_path}")
with cfg_path.open("r", encoding="utf-8") as fh:
return json.load(fh)
def require_section(config: Dict, section: str) -> Dict:
sect = config.get(section)
if not isinstance(sect, dict):
raise KeyError(f"Missing '{section}' section in configuration file")
return sect
def require_value(section: Dict, key: str, section_name: str) -> Any:
if key not in section:
raise KeyError(f"Missing key '{key}' inside '{section_name}' section of configuration file")
return section[key]
def detect_column(df: pd.DataFrame, candidates: Sequence[str]) -> Optional[str]:
"""Return the first column whose name matches one of the candidates (case insensitive)."""
low = {c.lower(): c for c in df.columns}
for cand in candidates:
cl = cand.lower()
if cl in low:
return low[cl]
for cand in candidates:
cl = cand.lower()
for col in df.columns:
if cl in col.lower():
return col
return None
def read_connection_txt(path: Path | str = "connection.txt") -> str:
params: Dict[str, str] = {}
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"Missing connection.txt at {path}")
for line in path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
k, v = line.split("=", 1)
params[k.strip().lower()] = v.strip()
username = params.get("username")
password = params.get("password")
host = params.get("host")
port = params.get("port", "1433")
database = params.get("database")
if not all([username, password, host, database]):
raise ValueError("connection.txt incompleto: servono username/password/host/database.")
installed = [d for d in pyodbc.drivers()]
driver_q = "ODBC+Driver+18+for+SQL+Server" if "ODBC Driver 18 for SQL Server" in installed else "ODBC+Driver+17+for+SQL+Server"
return f"mssql+pyodbc://{username}:{password}@{host}:{port}/{database}?driver={driver_q}"
def z_norm(arr: np.ndarray) -> Optional[np.ndarray]:
arr = np.asarray(arr, dtype=float)
if arr.size == 0:
return None
mu = arr.mean()
sd = arr.std()
if sd < 1e-12:
return None
return (arr - mu) / (sd + 1e-12)
def wavelet_denoise(
series: pd.Series,
wavelet: str = "db3",
level: int = 3,
mode: str = "symmetric",
threshold_mode: str = "soft",
) -> Optional[pd.Series]:
"""Denoise/reshape the series with a wavelet decomposition.
Keeps the original index length; if PyWavelets is missing the function
returns None so callers can gracefully fall back to the raw signal.
"""
if pywt is None:
print("[WARN] pywt non installato: salto il filtraggio wavelet.")
return None
s = pd.to_numeric(series, errors="coerce")
if s.dropna().empty:
return None
w = pywt.Wavelet(wavelet)
max_level = pywt.dwt_max_level(len(s.dropna()), w.dec_len)
lvl = max(1, min(level, max_level)) if max_level > 0 else 1
valid = s.dropna()
coeffs = pywt.wavedec(valid.values, w, mode=mode, level=lvl)
# Universal threshold (Donoho-Johnstone)
sigma = np.median(np.abs(coeffs[-1])) / 0.6745 if len(coeffs[-1]) > 0 else 0.0
thresh = sigma * np.sqrt(2 * np.log(len(valid))) if sigma > 0 else 0.0
if thresh <= 0:
coeffs_f = coeffs
else:
def _safe_thresh(c: np.ndarray) -> np.ndarray:
if c is None or c.size == 0:
return c
if threshold_mode == "hard":
return pywt.threshold(c, value=thresh, mode="hard")
# soft threshold without divide-by-zero warnings
mag = np.abs(c)
mask = mag > thresh
out = np.zeros_like(c)
out[mask] = np.sign(c[mask]) * (mag[mask] - thresh)
return out
coeffs_f = [coeffs[0]] + [_safe_thresh(c) for c in coeffs[1:]]
rec = pywt.waverec(coeffs_f, w, mode=mode)
rec = rec[: len(valid)]
filt = pd.Series(rec, index=valid.index)
# Re-allineamento all'indice originale
return filt.reindex(s.index).interpolate(limit_direction="both")
def build_pattern_library(
ret_series: pd.Series,
wp: int,
ha: int,
embargo: Optional[int] = None,
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
"""Create the normalized pattern windows and their realized outcomes.
Args:
ret_series: Series of returns (ordered oldest→latest).
wp: Window length for the pattern.
ha: Holding horizon used to compute the outcome.
embargo: Optional number of most-recent observations to exclude when
building the library (useful to avoid leakage when reusing the
same series for inference).
"""
x = ret_series.dropna().values
n = len(x)
if n < wp + ha + 10:
return None, None
embargo = int(embargo or 0)
usable_n = n - max(0, embargo)
if usable_n <= wp + ha:
return None, None
wins: List[np.ndarray] = []
outs: List[float] = []
last_start = usable_n - wp - ha
if last_start <= 0:
return None, None
for t in range(0, last_start + 1):
win = x[t : t + wp]
winzn = z_norm(win)
if winzn is None:
continue
outcome = np.sum(x[t + wp : t + wp + ha])
wins.append(winzn)
outs.append(outcome)
if not wins:
return None, None
return np.array(wins), np.array(outs)
def predict_from_library(
curr_win: np.ndarray,
lib_wins: np.ndarray,
lib_out: np.ndarray,
k: int = 25,
) -> Tuple[float, float, np.ndarray]:
dists = np.linalg.norm(lib_wins - curr_win, axis=1)
idx = np.argsort(dists)[: min(k, len(dists))]
return float(np.median(lib_out[idx])), float(np.mean(dists[idx])), idx
def characterize_window(
ret_series: pd.Series,
wp: int,
z_rev: float = 2.0,
z_vol: float = 2.0,
std_comp_pct: float = 0.15,
) -> Tuple[Optional[str], float]:
x = ret_series.dropna().values
if len(x) < max(wp, 30):
return None, 0.0
win = x[-wp:]
mu, sd = win.mean(), win.std()
if sd < 1e-12:
return "compression", 0.5
last = win[-1]
z_last = (last - mu) / (sd + 1e-12)
abs_z_last = abs(z_last)
last3 = win[-3:] if len(win) >= 3 else win
sum3 = np.sum(last3)
if len(x) > 3 * wp:
roll_std = pd.Series(x).rolling(wp).std().dropna().values
if len(roll_std) > 20:
pct = (roll_std < np.std(win)).mean()
else:
pct = 0.5
else:
pct = 0.5
if pct < std_comp_pct:
return "compression", float(1.0 - pct)
if abs(sum3) > 2 * sd / np.sqrt(3) and np.sign(last3).sum() in (3, -3):
conf = min(1.0, abs(sum3) / (sd + 1e-12))
return "momentum_burst", float(conf)
mean_prev = np.mean(win[:-1]) if len(win) > 1 else 0.0
if abs_z_last >= z_rev and np.sign(last) != np.sign(mean_prev):
conf = min(1.0, abs_z_last / 3.0)
return "reversal_candidate", float(conf)
if abs_z_last >= z_vol:
conf = min(1.0, abs_z_last / 3.0)
return "vol_spike", float(conf)
return None, 0.0
def hurst_rs(series: pd.Series) -> Optional[float]:
x = pd.to_numeric(series.dropna(), errors="coerce").astype(float).values
n = len(x)
if n < 100:
return None
x = x - x.mean()
z = np.cumsum(x)
r = z.max() - z.min()
s = x.std(ddof=1)
if s <= 0 or r <= 0:
return None
h = np.log(r / s) / np.log(n)
if not np.isfinite(h):
return None
return float(h)
def build_hurst_map(
returns_long: pd.DataFrame,
lookback: Optional[int] = None,
min_length: int = 100,
) -> Dict[str, float]:
if returns_long.empty:
return {}
ret_wide = returns_long.pivot(index="Date", columns="ISIN", values="Ret").sort_index()
hurst_map: Dict[str, float] = {}
for isin in ret_wide.columns:
series = ret_wide[isin].dropna().astype(float)
if len(series) < max(1, int(min_length)):
continue
window = len(series) if lookback is None else min(len(series), int(lookback))
if window <= 0:
continue
h_val = hurst_rs(series.iloc[-window:])
if h_val is None or not np.isfinite(h_val):
continue
hurst_map[str(isin)] = float(h_val)
return hurst_map
__all__ = [
"build_hurst_map",
"build_pattern_library",
"characterize_window",
"detect_column",
"require_section",
"require_value",
"hurst_rs",
"load_config",
"predict_from_library",
"read_connection_txt",
"z_norm",
]