import os import time from typing import Optional, Tuple, List, Dict import numpy as np import pandas as pd import gradio as gr import torch import plotly.graph_objects as go from chronos import Chronos2Pipeline # ========================= # Config # ========================= MODEL_ID_DEFAULT = os.getenv("CHRONOS_MODEL_ID", "amazon/chronos-2") DATA_DIR = "data" OUT_DIR = "/tmp" DEFAULT_FREQ = "D" # se il CSV non ha timestamp, generiamo daily # ========================= # Utils: files + device # ========================= def available_test_csv() -> List[str]: if not os.path.isdir(DATA_DIR): return [] return sorted([f for f in os.listdir(DATA_DIR) if f.lower().endswith(".csv")]) def pick_device(ui_choice: str) -> str: return "cuda" if (ui_choice or "").startswith("cuda") and torch.cuda.is_available() else "cpu" # ========================= # Sample series # ========================= def make_sample_df( n: int, seed: int, trend: float, season_period: int, season_amp: float, noise: float, freq: str = DEFAULT_FREQ, start: str = "2020-01-01", ) -> pd.DataFrame: rng = np.random.default_rng(int(seed)) t = np.arange(int(n), dtype=np.float32) y = ( float(trend) * t + float(season_amp) * np.sin(2 * np.pi * t / max(1, int(season_period))) + rng.normal(0.0, float(noise), size=int(n)) ).astype(np.float32) if float(np.min(y)) < 0: y -= float(np.min(y)) ts = pd.date_range(start=start, periods=int(n), freq=freq) return pd.DataFrame({"id": 0, "timestamp": ts, "target": y}) # ========================= # CSV loader -> context_df format (id,timestamp,target) # ========================= def _guess_timestamp_column(df: pd.DataFrame) -> Optional[str]: # prova colonne con nome tipico for c in df.columns: lc = str(c).lower() if lc in ["ds", "date", "datetime", "timestamp", "time"]: return c # prova parsing: se una colonna ha tanti valori parseabili a datetime for c in df.columns: if df[c].dtype == object: parsed = pd.to_datetime(df[c], errors="coerce", utc=False) if parsed.notna().sum() >= max(10, int(0.6 * len(df))): return c return None def _guess_numeric_target_column(df: pd.DataFrame, user_col: Optional[str]) -> str: if user_col and user_col.strip(): col = user_col.strip() if col not in df.columns: raise ValueError(f"Colonna '{col}' non trovata. Disponibili: {list(df.columns)}") return col # numeric dtype first numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])] if numeric_cols: return numeric_cols[0] # try coercion best = None best_count = 0 for c in df.columns: coerced = pd.to_numeric(df[c], errors="coerce") cnt = coerced.notna().sum() if cnt > best_count: best = c best_count = cnt if best is None or best_count < 10: raise ValueError("Non trovo una colonna numerica valida (>=10 valori) nel CSV.") return best def load_context_df_from_csv(path: str, user_target_col: Optional[str], user_time_col: Optional[str], freq: str) -> Tuple[pd.DataFrame, str, Optional[str]]: df = pd.read_csv(path) if df.shape[0] < 10: raise ValueError("Serie troppo corta (minimo consigliato: 10 righe).") target_col = _guess_numeric_target_column(df, user_target_col) time_col = user_time_col.strip() if (user_time_col and user_time_col.strip()) else _guess_timestamp_column(df) # target y = pd.to_numeric(df[target_col], errors="coerce").dropna().astype(np.float32).to_numpy() if len(y) < 10: raise ValueError("Troppi NaN: la colonna target ha meno di 10 valori numerici.") # timestamp if time_col and time_col in df.columns: ts = pd.to_datetime(df[time_col], errors="coerce") # allinea su target non-NaN (stesso mask del target coercito) mask = pd.to_numeric(df[target_col], errors="coerce").notna() ts = ts[mask] ts = ts.dropna() # se timestamp troppo sporchi, fallback a range if len(ts) < 10: time_col = None if not time_col: ts = pd.date_range(start="2020-01-01", periods=len(y), freq=freq) context_df = pd.DataFrame({"id": 0, "timestamp": ts[: len(y)], "target": y[: len(ts)]}) context_df = context_df.sort_values("timestamp").reset_index(drop=True) return context_df, target_col, (time_col if time_col else None) # ========================= # Pipeline cache # ========================= _PIPE = None _META = {"model_id": None, "device": None} def get_pipeline(model_id: str, device: str) -> Chronos2Pipeline: global _PIPE, _META model_id = (model_id or MODEL_ID_DEFAULT).strip() device = "cuda" if (device == "cuda" and torch.cuda.is_available()) else "cpu" if _PIPE is None or _META["model_id"] != model_id or _META["device"] != device: _PIPE = Chronos2Pipeline.from_pretrained(model_id, device_map=device) _META = {"model_id": model_id, "device": device} return _PIPE # ========================= # Metrics # ========================= def mae(y_true: np.ndarray, y_pred: np.ndarray) -> float: return float(np.mean(np.abs(y_true - y_pred))) def rmse(y_true: np.ndarray, y_pred: np.ndarray) -> float: return float(np.sqrt(np.mean((y_true - y_pred) ** 2))) def mape(y_true: np.ndarray, y_pred: np.ndarray) -> float: denom = np.maximum(1e-8, np.abs(y_true)) return float(np.mean(np.abs((y_true - y_pred) / denom)) * 100.0) def coverage(y_true: np.ndarray, low: np.ndarray, high: np.ndarray) -> float: return float(np.mean((y_true >= low) & (y_true <= high)) * 100.0) def avg_width(low: np.ndarray, high: np.ndarray) -> float: return float(np.mean(high - low)) # ========================= # Plotly # ========================= def plot_forecast(context_df: pd.DataFrame, pred_df: pd.DataFrame, q_low: float, q_high: float, title: str) -> go.Figure: ctx = context_df.copy() pred = pred_df.copy() fig = go.Figure() fig.add_trace(go.Scatter(x=ctx["timestamp"], y=ctx["target"], mode="lines", name="History")) # pred_df from predict_df typically has: # - timestamp # - predictions (median or q=0.5) # - columns for quantiles like "0.1", "0.9" if "predictions" in pred.columns: y_med = pred["predictions"].to_numpy() else: # fallback: if "0.5" exists y_med = pred.get("0.5", pred.iloc[:, -1]).to_numpy() fig.add_trace(go.Scatter(x=pred["timestamp"], y=y_med, mode="lines", name="Forecast (median)")) low_col = f"{q_low:.1f}".rstrip("0").rstrip(".") high_col = f"{q_high:.1f}".rstrip("0").rstrip(".") # columns in pred_df are often exactly "0.1", "0.5", "0.9" as strings if str(q_low) in pred.columns: low_series = pred[str(q_low)].to_numpy() elif low_col in pred.columns: low_series = pred[low_col].to_numpy() else: low_series = None if str(q_high) in pred.columns: high_series = pred[str(q_high)].to_numpy() elif high_col in pred.columns: high_series = pred[high_col].to_numpy() else: high_series = None if low_series is not None and high_series is not None: fig.add_trace(go.Scatter( x=pred["timestamp"], y=high_series, mode="lines", line=dict(width=0), showlegend=False, hoverinfo="skip" )) fig.add_trace(go.Scatter( x=pred["timestamp"], y=low_series, mode="lines", fill="tonexty", line=dict(width=0), name=f"Band [{q_low:.2f}, {q_high:.2f}]" )) fig.update_layout( title=title, hovermode="x unified", margin=dict(l=10, r=10, t=55, b=10), legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0), xaxis_title="timestamp", yaxis_title="value", ) return fig def kpi_card(label: str, value: str, hint: str = "") -> str: hint_html = f"
{hint}
" if hint else "" return f"""
{label}
{value}
{hint_html}
""" def kpi_grid(cards: List[str]) -> str: return f"""
{''.join(cards)}
""" def explain_natural(context_df: pd.DataFrame, pred_df: pd.DataFrame, q_low: float, q_high: float, backtest_metrics: Optional[Dict[str, float]]) -> str: ctx_y = context_df["target"].to_numpy(dtype=float) if "predictions" in pred_df.columns: med = pred_df["predictions"].to_numpy(dtype=float) elif "0.5" in pred_df.columns: med = pred_df["0.5"].to_numpy(dtype=float) else: med = pred_df.iloc[:, -1].to_numpy(dtype=float) base = float(np.mean(ctx_y)) delta = float(med[-1] - med[0]) pct = (delta / max(1e-6, base)) * 100.0 if abs(pct) < 2: trend_txt = "sostanzialmente stabile" elif pct > 0: trend_txt = "in crescita" else: trend_txt = "in calo" txt = f"""### 🧠 Spiegazione Nei prossimi **{len(med)} step**, la previsione mediana è **{trend_txt}** (variazione complessiva ≈ **{pct:+.1f}%** rispetto al livello medio storico). - **Ultimo valore mediano previsto:** **{med[-1]:.2f}** """ # band, if present low_key = str(q_low) high_key = str(q_high) if low_key in pred_df.columns and high_key in pred_df.columns: low = pred_df[low_key].to_numpy(dtype=float) high = pred_df[high_key].to_numpy(dtype=float) txt += f"- **Intervallo [{q_low:.0%}–{q_high:.0%}] ultimo step:** **[{low[-1]:.2f} – {high[-1]:.2f}]**\n" txt += f"- **Larghezza media banda:** **{avg_width(low, high):.2f}**\n" else: txt += "- **Banda di incertezza:** non disponibile (manca nel pred_df).\n" if backtest_metrics: txt += f""" ### 🧪 Backtest (holdout) - **MAE:** {backtest_metrics["mae"]:.3f} - **RMSE:** {backtest_metrics["rmse"]:.3f} - **MAPE:** {backtest_metrics["mape"]:.2f}% - **Coverage banda:** {backtest_metrics["coverage"]:.1f}% """ return txt # ========================= # Run core (predict_df) # ========================= def run_dashboard( input_mode: str, test_csv_name: str, upload_csv, target_col: str, time_col: str, freq: str, n: int, seed: int, trend: float, season_period: int, season_amp: float, noise: float, prediction_length: int, q_low: float, q_high: float, do_backtest: bool, holdout: int, device_ui: str, model_id: str, ): if q_low >= q_high: raise gr.Error("Quantile low deve essere < quantile high.") device = pick_device(device_ui) pipe = get_pipeline(model_id, device) # ---- build context_df if input_mode == "Test CSV": if not test_csv_name: raise gr.Error("Seleziona un Test CSV.") csv_path = os.path.join(DATA_DIR, test_csv_name) if not os.path.exists(csv_path): raise gr.Error(f"Non trovo {csv_path}") context_df, used_target, used_time = load_context_df_from_csv(csv_path, target_col, time_col, freq) source = f"Test CSV: {test_csv_name} • target={used_target} • time={used_time or 'generated'}" elif input_mode == "Upload CSV": if upload_csv is None: raise gr.Error("Carica un CSV.") context_df, used_target, used_time = load_context_df_from_csv(upload_csv.name, target_col, time_col, freq) source = f"Upload CSV • target={used_target} • time={used_time or 'generated'}" else: context_df = make_sample_df(n, seed, trend, season_period, season_amp, noise, freq=freq) source = "Sample series" if len(context_df) < 10: raise gr.Error("Serie troppo corta.") if do_backtest and holdout >= len(context_df): raise gr.Error("Holdout deve essere più piccolo della lunghezza dello storico.") quantiles = sorted(list(set([float(q_low), 0.5, float(q_high)]))) t0 = time.time() # ---- forecast (future_df not needed if no covariates) pred_df = pipe.predict_df( context_df, prediction_length=int(prediction_length), quantile_levels=quantiles, id_column="id", timestamp_column="timestamp", target="target", ) latency = time.time() - t0 # ---- exports forecast_path = os.path.join(OUT_DIR, "chronos2_forecast_df.csv") pred_df.to_csv(forecast_path, index=False) # ---- backtest backtest_metrics = None backtest_path = None backtest_df_out = pd.DataFrame() backtest_fig = go.Figure().update_layout(title="Backtest disabled", margin=dict(l=10, r=10, t=55, b=10)) if do_backtest: train_df = context_df.iloc[:-int(holdout)].copy() true_df = context_df.iloc[-int(holdout):].copy() bt_pred_df = pipe.predict_df( train_df, prediction_length=int(holdout), quantile_levels=quantiles, id_column="id", timestamp_column="timestamp", target="target", ) # extract arrays y_true = true_df["target"].to_numpy(dtype=float) if "predictions" in bt_pred_df.columns: y_hat = bt_pred_df["predictions"].to_numpy(dtype=float) elif "0.5" in bt_pred_df.columns: y_hat = bt_pred_df["0.5"].to_numpy(dtype=float) else: y_hat = bt_pred_df.iloc[:, -1].to_numpy(dtype=float) # band if str(q_low) in bt_pred_df.columns and str(q_high) in bt_pred_df.columns: low = bt_pred_df[str(q_low)].to_numpy(dtype=float) high = bt_pred_df[str(q_high)].to_numpy(dtype=float) cov = coverage(y_true, low, high) else: low = y_hat.copy() high = y_hat.copy() cov = float("nan") backtest_metrics = { "mae": mae(y_true, y_hat), "rmse": rmse(y_true, y_hat), "mape": mape(y_true, y_hat), "coverage": cov, } # plot backtest fig = go.Figure() fig.add_trace(go.Scatter(x=train_df["timestamp"], y=train_df["target"], mode="lines", name="Train")) fig.add_trace(go.Scatter(x=true_df["timestamp"], y=true_df["target"], mode="lines", name="True (holdout)")) fig.add_trace(go.Scatter(x=bt_pred_df["timestamp"], y=y_hat, mode="lines", name="Pred (median)")) if str(q_low) in bt_pred_df.columns and str(q_high) in bt_pred_df.columns: fig.add_trace(go.Scatter( x=bt_pred_df["timestamp"], y=high, mode="lines", line=dict(width=0), showlegend=False, hoverinfo="skip" )) fig.add_trace(go.Scatter( x=bt_pred_df["timestamp"], y=low, mode="lines", fill="tonexty", line=dict(width=0), name=f"Band [{q_low:.2f}, {q_high:.2f}]" )) fig.update_layout( title="Backtest (holdout) — interactive", hovermode="x unified", margin=dict(l=10, r=10, t=55, b=10), legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0), xaxis_title="timestamp", yaxis_title="value", ) backtest_fig = fig backtest_path = os.path.join(OUT_DIR, "chronos2_backtest_df.csv") bt_pred_df.to_csv(backtest_path, index=False) backtest_df_out = bt_pred_df # ---- main plot forecast_fig = plot_forecast(context_df, pred_df, q_low, q_high, f"Forecast — {source}") # ---- KPIs cards = [ kpi_card("Device", device.upper(), f"cuda_available={torch.cuda.is_available()}"), kpi_card("Model", (model_id or MODEL_ID_DEFAULT), "Chronos-2"), kpi_card("Latency", f"{latency:.2f}s", "predict_df()"), kpi_card("History", str(len(context_df)), "points"), kpi_card("Horizon", str(prediction_length), "steps"), kpi_card("Quantiles", f"{q_low:.2f}, 0.50, {q_high:.2f}", "levels"), ] kpis_html = kpi_grid(cards) explanation_md = explain_natural(context_df, pred_df, q_low, q_high, backtest_metrics) info = { "source": source, "history_points": int(len(context_df)), "prediction_length": int(prediction_length), "quantile_levels": quantiles, "backtest": bool(do_backtest), "holdout": int(holdout) if do_backtest else None, } return ( kpis_html, explanation_md, forecast_fig, backtest_fig, pred_df, backtest_df_out, forecast_path, backtest_path, info, ) # ========================= # UI # ========================= css = """ .gradio-container { max-width: 1200px !important; } /* KPI grid */ .kpi-grid{ display: grid; grid-template-columns: repeat(auto-fit, minmax(190px, 1fr)); gap: 14px; padding: 10px 8px; /* <-- spazio “esterno” */ margin-top: 6px; /* <-- separa dal titolo / contenuto sopra */ } /* opzionale: un filo di aria sotto ogni card */ .kpi-grid > div{ min-height: 84px; } """ with gr.Blocks(title="Chronos-2 • Forecast Dashboard (predict_df)", css=css) as demo: gr.Markdown("# Chronos-2 Dashboard") with gr.Row(): with gr.Column(scale=1, min_width=360): gr.Markdown("## Input") input_mode = gr.Radio(["Sample", "Test CSV", "Upload CSV"], value="Sample", label="Sorgente") test_csv_name = gr.Dropdown(choices=available_test_csv(), label="Test CSV (data/)") upload_csv = gr.File(label="Upload CSV", file_types=[".csv"]) target_col = gr.Textbox(label="Colonna target (opzionale)", placeholder="es: value") time_col = gr.Textbox(label="Colonna timestamp (opzionale)", placeholder="es: timestamp / date / ds") freq = gr.Dropdown(["D", "H", "W", "M"], value=DEFAULT_FREQ, label="Freq (se timestamp mancante)") gr.Markdown("## Sistema") device_ui = gr.Dropdown( ["cpu", "cuda (se disponibile)"], value="cuda (se disponibile)" if torch.cuda.is_available() else "cpu", label="Device", ) model_id = gr.Textbox(value=MODEL_ID_DEFAULT, label="Model ID") with gr.Accordion("Sample generator", open=False): n = gr.Slider(60, 2000, value=300, step=10, label="History length") seed = gr.Number(value=42, precision=0, label="Seed") trend = gr.Slider(0.0, 0.2, value=0.03, step=0.005, label="Trend") season_period = gr.Slider(2, 240, value=14, step=1, label="Season period") season_amp = gr.Slider(0.0, 12.0, value=3.0, step=0.1, label="Season amplitude") noise = gr.Slider(0.0, 6.0, value=0.8, step=0.05, label="Noise") gr.Markdown("## Forecast") prediction_length = gr.Slider(1, 365, value=30, step=1, label="Prediction length") q_low = gr.Slider(0.01, 0.49, value=0.10, step=0.01, label="Quantile low") q_high = gr.Slider(0.51, 0.99, value=0.90, step=0.01, label="Quantile high") gr.Markdown("## Backtest") do_backtest = gr.Checkbox(value=True, label="Esegui backtest holdout") holdout = gr.Slider(5, 365, value=30, step=1, label="Holdout points") run_btn = gr.Button("Run", variant="primary") with gr.Column(scale=2): kpis = gr.HTML() with gr.Tabs(): with gr.Tab("Forecast"): forecast_plot = gr.Plot() forecast_table = gr.Dataframe(interactive=False) with gr.Tab("Backtest"): backtest_plot = gr.Plot() backtest_table = gr.Dataframe(interactive=False) with gr.Tab("Spiegazione"): explanation = gr.Markdown() with gr.Tab("Export"): forecast_download = gr.File(label="Forecast CSV") backtest_download = gr.File(label="Backtest CSV") with gr.Tab("Info"): info = gr.JSON() run_btn.click( fn=run_dashboard, inputs=[ input_mode, test_csv_name, upload_csv, target_col, time_col, freq, n, seed, trend, season_period, season_amp, noise, prediction_length, q_low, q_high, do_backtest, holdout, device_ui, model_id, ], outputs=[ kpis, explanation, forecast_plot, backtest_plot, forecast_table, backtest_table, forecast_download, backtest_download, info, ], ) demo.queue() demo.launch(ssr_mode=False)