301 lines
9.3 KiB
Python
301 lines
9.3 KiB
Python
"""Shared helpers for trading pattern scripts."""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
|
from typing import Dict, List, Optional, Sequence, Tuple
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import pyodbc
|
|
try:
|
|
import pywt
|
|
except ImportError: # pragma: no cover - optional dependency
|
|
pywt = None
|
|
|
|
DEFAULT_CONFIG_PATH = Path("config/pattern_knn_config.json")
|
|
|
|
|
|
def load_config(path: Optional[Path] = None) -> Dict:
|
|
"""Load the JSON configuration that holds operational parameters."""
|
|
cfg_path = Path(path or DEFAULT_CONFIG_PATH)
|
|
if not cfg_path.exists():
|
|
raise FileNotFoundError(f"Missing configuration file: {cfg_path}")
|
|
with cfg_path.open("r", encoding="utf-8") as fh:
|
|
return json.load(fh)
|
|
|
|
|
|
def require_section(config: Dict, section: str) -> Dict:
|
|
sect = config.get(section)
|
|
if not isinstance(sect, dict):
|
|
raise KeyError(f"Missing '{section}' section in configuration file")
|
|
return sect
|
|
|
|
|
|
def require_value(section: Dict, key: str, section_name: str) -> Any:
|
|
if key not in section:
|
|
raise KeyError(f"Missing key '{key}' inside '{section_name}' section of configuration file")
|
|
return section[key]
|
|
|
|
|
|
def detect_column(df: pd.DataFrame, candidates: Sequence[str]) -> Optional[str]:
|
|
"""Return the first column whose name matches one of the candidates (case insensitive)."""
|
|
low = {c.lower(): c for c in df.columns}
|
|
for cand in candidates:
|
|
cl = cand.lower()
|
|
if cl in low:
|
|
return low[cl]
|
|
for cand in candidates:
|
|
cl = cand.lower()
|
|
for col in df.columns:
|
|
if cl in col.lower():
|
|
return col
|
|
return None
|
|
|
|
|
|
def read_connection_txt(path: Path | str = "connection.txt") -> str:
|
|
params: Dict[str, str] = {}
|
|
path = Path(path)
|
|
if not path.exists():
|
|
raise FileNotFoundError(f"Missing connection.txt at {path}")
|
|
for line in path.read_text(encoding="utf-8").splitlines():
|
|
line = line.strip()
|
|
if not line or line.startswith("#") or "=" not in line:
|
|
continue
|
|
k, v = line.split("=", 1)
|
|
params[k.strip().lower()] = v.strip()
|
|
|
|
username = params.get("username")
|
|
password = params.get("password")
|
|
host = params.get("host")
|
|
port = params.get("port", "1433")
|
|
database = params.get("database")
|
|
|
|
if not all([username, password, host, database]):
|
|
raise ValueError("connection.txt incompleto: servono username/password/host/database.")
|
|
|
|
installed = [d for d in pyodbc.drivers()]
|
|
driver_q = "ODBC+Driver+18+for+SQL+Server" if "ODBC Driver 18 for SQL Server" in installed else "ODBC+Driver+17+for+SQL+Server"
|
|
return f"mssql+pyodbc://{username}:{password}@{host}:{port}/{database}?driver={driver_q}"
|
|
|
|
|
|
def z_norm(arr: np.ndarray) -> Optional[np.ndarray]:
|
|
arr = np.asarray(arr, dtype=float)
|
|
if arr.size == 0:
|
|
return None
|
|
mu = arr.mean()
|
|
sd = arr.std()
|
|
if sd < 1e-12:
|
|
return None
|
|
return (arr - mu) / (sd + 1e-12)
|
|
|
|
|
|
def wavelet_denoise(
|
|
series: pd.Series,
|
|
wavelet: str = "db3",
|
|
level: int = 3,
|
|
mode: str = "symmetric",
|
|
threshold_mode: str = "soft",
|
|
) -> Optional[pd.Series]:
|
|
"""Denoise/reshape the series with a wavelet decomposition.
|
|
|
|
Keeps the original index length; if PyWavelets is missing the function
|
|
returns None so callers can gracefully fall back to the raw signal.
|
|
"""
|
|
if pywt is None:
|
|
print("[WARN] pywt non installato: salto il filtraggio wavelet.")
|
|
return None
|
|
s = pd.to_numeric(series, errors="coerce")
|
|
if s.dropna().empty:
|
|
return None
|
|
|
|
w = pywt.Wavelet(wavelet)
|
|
max_level = pywt.dwt_max_level(len(s.dropna()), w.dec_len)
|
|
lvl = max(1, min(level, max_level)) if max_level > 0 else 1
|
|
|
|
valid = s.dropna()
|
|
coeffs = pywt.wavedec(valid.values, w, mode=mode, level=lvl)
|
|
# Universal threshold (Donoho-Johnstone)
|
|
sigma = np.median(np.abs(coeffs[-1])) / 0.6745 if len(coeffs[-1]) > 0 else 0.0
|
|
thresh = sigma * np.sqrt(2 * np.log(len(valid))) if sigma > 0 else 0.0
|
|
if thresh <= 0:
|
|
coeffs_f = coeffs
|
|
else:
|
|
def _safe_thresh(c: np.ndarray) -> np.ndarray:
|
|
if c is None or c.size == 0:
|
|
return c
|
|
if threshold_mode == "hard":
|
|
return pywt.threshold(c, value=thresh, mode="hard")
|
|
# soft threshold without divide-by-zero warnings
|
|
mag = np.abs(c)
|
|
mask = mag > thresh
|
|
out = np.zeros_like(c)
|
|
out[mask] = np.sign(c[mask]) * (mag[mask] - thresh)
|
|
return out
|
|
|
|
coeffs_f = [coeffs[0]] + [_safe_thresh(c) for c in coeffs[1:]]
|
|
|
|
rec = pywt.waverec(coeffs_f, w, mode=mode)
|
|
rec = rec[: len(valid)]
|
|
filt = pd.Series(rec, index=valid.index)
|
|
# Re-allineamento all'indice originale
|
|
return filt.reindex(s.index).interpolate(limit_direction="both")
|
|
|
|
|
|
def build_pattern_library(
|
|
ret_series: pd.Series,
|
|
wp: int,
|
|
ha: int,
|
|
embargo: Optional[int] = None,
|
|
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
|
|
"""Create the normalized pattern windows and their realized outcomes.
|
|
|
|
Args:
|
|
ret_series: Series of returns (ordered oldest→latest).
|
|
wp: Window length for the pattern.
|
|
ha: Holding horizon used to compute the outcome.
|
|
embargo: Optional number of most-recent observations to exclude when
|
|
building the library (useful to avoid leakage when reusing the
|
|
same series for inference).
|
|
"""
|
|
x = ret_series.dropna().values
|
|
n = len(x)
|
|
if n < wp + ha + 10:
|
|
return None, None
|
|
embargo = int(embargo or 0)
|
|
usable_n = n - max(0, embargo)
|
|
if usable_n <= wp + ha:
|
|
return None, None
|
|
wins: List[np.ndarray] = []
|
|
outs: List[float] = []
|
|
last_start = usable_n - wp - ha
|
|
if last_start <= 0:
|
|
return None, None
|
|
for t in range(0, last_start + 1):
|
|
win = x[t : t + wp]
|
|
winzn = z_norm(win)
|
|
if winzn is None:
|
|
continue
|
|
outcome = np.sum(x[t + wp : t + wp + ha])
|
|
wins.append(winzn)
|
|
outs.append(outcome)
|
|
if not wins:
|
|
return None, None
|
|
return np.array(wins), np.array(outs)
|
|
|
|
|
|
def predict_from_library(
|
|
curr_win: np.ndarray,
|
|
lib_wins: np.ndarray,
|
|
lib_out: np.ndarray,
|
|
k: int = 25,
|
|
) -> Tuple[float, float, np.ndarray]:
|
|
dists = np.linalg.norm(lib_wins - curr_win, axis=1)
|
|
idx = np.argsort(dists)[: min(k, len(dists))]
|
|
return float(np.median(lib_out[idx])), float(np.mean(dists[idx])), idx
|
|
|
|
|
|
def characterize_window(
|
|
ret_series: pd.Series,
|
|
wp: int,
|
|
z_rev: float = 2.0,
|
|
z_vol: float = 2.0,
|
|
std_comp_pct: float = 0.15,
|
|
) -> Tuple[Optional[str], float]:
|
|
x = ret_series.dropna().values
|
|
if len(x) < max(wp, 30):
|
|
return None, 0.0
|
|
win = x[-wp:]
|
|
mu, sd = win.mean(), win.std()
|
|
if sd < 1e-12:
|
|
return "compression", 0.5
|
|
|
|
last = win[-1]
|
|
z_last = (last - mu) / (sd + 1e-12)
|
|
abs_z_last = abs(z_last)
|
|
last3 = win[-3:] if len(win) >= 3 else win
|
|
sum3 = np.sum(last3)
|
|
|
|
if len(x) > 3 * wp:
|
|
roll_std = pd.Series(x).rolling(wp).std().dropna().values
|
|
if len(roll_std) > 20:
|
|
pct = (roll_std < np.std(win)).mean()
|
|
else:
|
|
pct = 0.5
|
|
else:
|
|
pct = 0.5
|
|
|
|
if pct < std_comp_pct:
|
|
return "compression", float(1.0 - pct)
|
|
|
|
if abs(sum3) > 2 * sd / np.sqrt(3) and np.sign(last3).sum() in (3, -3):
|
|
conf = min(1.0, abs(sum3) / (sd + 1e-12))
|
|
return "momentum_burst", float(conf)
|
|
|
|
mean_prev = np.mean(win[:-1]) if len(win) > 1 else 0.0
|
|
if abs_z_last >= z_rev and np.sign(last) != np.sign(mean_prev):
|
|
conf = min(1.0, abs_z_last / 3.0)
|
|
return "reversal_candidate", float(conf)
|
|
|
|
if abs_z_last >= z_vol:
|
|
conf = min(1.0, abs_z_last / 3.0)
|
|
return "vol_spike", float(conf)
|
|
|
|
return None, 0.0
|
|
|
|
|
|
def hurst_rs(series: pd.Series) -> Optional[float]:
|
|
x = pd.to_numeric(series.dropna(), errors="coerce").astype(float).values
|
|
n = len(x)
|
|
if n < 100:
|
|
return None
|
|
x = x - x.mean()
|
|
z = np.cumsum(x)
|
|
r = z.max() - z.min()
|
|
s = x.std(ddof=1)
|
|
if s <= 0 or r <= 0:
|
|
return None
|
|
h = np.log(r / s) / np.log(n)
|
|
if not np.isfinite(h):
|
|
return None
|
|
return float(h)
|
|
|
|
|
|
def build_hurst_map(
|
|
returns_long: pd.DataFrame,
|
|
lookback: Optional[int] = None,
|
|
min_length: int = 100,
|
|
) -> Dict[str, float]:
|
|
if returns_long.empty:
|
|
return {}
|
|
ret_wide = returns_long.pivot(index="Date", columns="ISIN", values="Ret").sort_index()
|
|
hurst_map: Dict[str, float] = {}
|
|
for isin in ret_wide.columns:
|
|
series = ret_wide[isin].dropna().astype(float)
|
|
if len(series) < max(1, int(min_length)):
|
|
continue
|
|
window = len(series) if lookback is None else min(len(series), int(lookback))
|
|
if window <= 0:
|
|
continue
|
|
h_val = hurst_rs(series.iloc[-window:])
|
|
if h_val is None or not np.isfinite(h_val):
|
|
continue
|
|
hurst_map[str(isin)] = float(h_val)
|
|
return hurst_map
|
|
|
|
|
|
__all__ = [
|
|
"build_hurst_map",
|
|
"build_pattern_library",
|
|
"characterize_window",
|
|
"detect_column",
|
|
"require_section",
|
|
"require_value",
|
|
"hurst_rs",
|
|
"load_config",
|
|
"predict_from_library",
|
|
"read_connection_txt",
|
|
"z_norm",
|
|
]
|