Spaces:
Running
Running
| import os | |
| import time | |
| from typing import Optional, Tuple, List, Dict | |
| import numpy as np | |
| import pandas as pd | |
| import gradio as gr | |
| import torch | |
| import plotly.graph_objects as go | |
| from chronos import Chronos2Pipeline | |
| # ========================= | |
| # Config | |
| # ========================= | |
| MODEL_ID_DEFAULT = os.getenv("CHRONOS_MODEL_ID", "amazon/chronos-2") | |
| DATA_DIR = "data" | |
| OUT_DIR = "/tmp" | |
| DEFAULT_FREQ = "D" # se il CSV non ha timestamp, generiamo daily | |
| # ========================= | |
| # Utils: files + device | |
| # ========================= | |
| def available_test_csv() -> List[str]: | |
| if not os.path.isdir(DATA_DIR): | |
| return [] | |
| return sorted([f for f in os.listdir(DATA_DIR) if f.lower().endswith(".csv")]) | |
| def pick_device(ui_choice: str) -> str: | |
| return "cuda" if (ui_choice or "").startswith("cuda") and torch.cuda.is_available() else "cpu" | |
| # ========================= | |
| # Sample series | |
| # ========================= | |
| def make_sample_df( | |
| n: int, | |
| seed: int, | |
| trend: float, | |
| season_period: int, | |
| season_amp: float, | |
| noise: float, | |
| freq: str = DEFAULT_FREQ, | |
| start: str = "2020-01-01", | |
| ) -> pd.DataFrame: | |
| rng = np.random.default_rng(int(seed)) | |
| t = np.arange(int(n), dtype=np.float32) | |
| y = ( | |
| float(trend) * t | |
| + float(season_amp) * np.sin(2 * np.pi * t / max(1, int(season_period))) | |
| + rng.normal(0.0, float(noise), size=int(n)) | |
| ).astype(np.float32) | |
| if float(np.min(y)) < 0: | |
| y -= float(np.min(y)) | |
| ts = pd.date_range(start=start, periods=int(n), freq=freq) | |
| return pd.DataFrame({"id": 0, "timestamp": ts, "target": y}) | |
| # ========================= | |
| # CSV loader -> context_df format (id,timestamp,target) | |
| # ========================= | |
| def _guess_timestamp_column(df: pd.DataFrame) -> Optional[str]: | |
| # prova colonne con nome tipico | |
| for c in df.columns: | |
| lc = str(c).lower() | |
| if lc in ["ds", "date", "datetime", "timestamp", "time"]: | |
| return c | |
| # prova parsing: se una colonna ha tanti valori parseabili a datetime | |
| for c in df.columns: | |
| if df[c].dtype == object: | |
| parsed = pd.to_datetime(df[c], errors="coerce", utc=False) | |
| if parsed.notna().sum() >= max(10, int(0.6 * len(df))): | |
| return c | |
| return None | |
| def _guess_numeric_target_column(df: pd.DataFrame, user_col: Optional[str]) -> str: | |
| if user_col and user_col.strip(): | |
| col = user_col.strip() | |
| if col not in df.columns: | |
| raise ValueError(f"Colonna '{col}' non trovata. Disponibili: {list(df.columns)}") | |
| return col | |
| # numeric dtype first | |
| numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])] | |
| if numeric_cols: | |
| return numeric_cols[0] | |
| # try coercion | |
| best = None | |
| best_count = 0 | |
| for c in df.columns: | |
| coerced = pd.to_numeric(df[c], errors="coerce") | |
| cnt = coerced.notna().sum() | |
| if cnt > best_count: | |
| best = c | |
| best_count = cnt | |
| if best is None or best_count < 10: | |
| raise ValueError("Non trovo una colonna numerica valida (>=10 valori) nel CSV.") | |
| return best | |
| def load_context_df_from_csv(path: str, user_target_col: Optional[str], user_time_col: Optional[str], freq: str) -> Tuple[pd.DataFrame, str, Optional[str]]: | |
| df = pd.read_csv(path) | |
| if df.shape[0] < 10: | |
| raise ValueError("Serie troppo corta (minimo consigliato: 10 righe).") | |
| target_col = _guess_numeric_target_column(df, user_target_col) | |
| time_col = user_time_col.strip() if (user_time_col and user_time_col.strip()) else _guess_timestamp_column(df) | |
| # target | |
| y = pd.to_numeric(df[target_col], errors="coerce").dropna().astype(np.float32).to_numpy() | |
| if len(y) < 10: | |
| raise ValueError("Troppi NaN: la colonna target ha meno di 10 valori numerici.") | |
| # timestamp | |
| if time_col and time_col in df.columns: | |
| ts = pd.to_datetime(df[time_col], errors="coerce") | |
| # allinea su target non-NaN (stesso mask del target coercito) | |
| mask = pd.to_numeric(df[target_col], errors="coerce").notna() | |
| ts = ts[mask] | |
| ts = ts.dropna() | |
| # se timestamp troppo sporchi, fallback a range | |
| if len(ts) < 10: | |
| time_col = None | |
| if not time_col: | |
| ts = pd.date_range(start="2020-01-01", periods=len(y), freq=freq) | |
| context_df = pd.DataFrame({"id": 0, "timestamp": ts[: len(y)], "target": y[: len(ts)]}) | |
| context_df = context_df.sort_values("timestamp").reset_index(drop=True) | |
| return context_df, target_col, (time_col if time_col else None) | |
| # ========================= | |
| # Pipeline cache | |
| # ========================= | |
| _PIPE = None | |
| _META = {"model_id": None, "device": None} | |
| def get_pipeline(model_id: str, device: str) -> Chronos2Pipeline: | |
| global _PIPE, _META | |
| model_id = (model_id or MODEL_ID_DEFAULT).strip() | |
| device = "cuda" if (device == "cuda" and torch.cuda.is_available()) else "cpu" | |
| if _PIPE is None or _META["model_id"] != model_id or _META["device"] != device: | |
| _PIPE = Chronos2Pipeline.from_pretrained(model_id, device_map=device) | |
| _META = {"model_id": model_id, "device": device} | |
| return _PIPE | |
| # ========================= | |
| # Metrics | |
| # ========================= | |
| def mae(y_true: np.ndarray, y_pred: np.ndarray) -> float: | |
| return float(np.mean(np.abs(y_true - y_pred))) | |
| def rmse(y_true: np.ndarray, y_pred: np.ndarray) -> float: | |
| return float(np.sqrt(np.mean((y_true - y_pred) ** 2))) | |
| def mape(y_true: np.ndarray, y_pred: np.ndarray) -> float: | |
| denom = np.maximum(1e-8, np.abs(y_true)) | |
| return float(np.mean(np.abs((y_true - y_pred) / denom)) * 100.0) | |
| def coverage(y_true: np.ndarray, low: np.ndarray, high: np.ndarray) -> float: | |
| return float(np.mean((y_true >= low) & (y_true <= high)) * 100.0) | |
| def avg_width(low: np.ndarray, high: np.ndarray) -> float: | |
| return float(np.mean(high - low)) | |
| # ========================= | |
| # Plotly | |
| # ========================= | |
| def plot_forecast(context_df: pd.DataFrame, pred_df: pd.DataFrame, q_low: float, q_high: float, title: str) -> go.Figure: | |
| ctx = context_df.copy() | |
| pred = pred_df.copy() | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter(x=ctx["timestamp"], y=ctx["target"], mode="lines", name="History")) | |
| # pred_df from predict_df typically has: | |
| # - timestamp | |
| # - predictions (median or q=0.5) | |
| # - columns for quantiles like "0.1", "0.9" | |
| if "predictions" in pred.columns: | |
| y_med = pred["predictions"].to_numpy() | |
| else: | |
| # fallback: if "0.5" exists | |
| y_med = pred.get("0.5", pred.iloc[:, -1]).to_numpy() | |
| fig.add_trace(go.Scatter(x=pred["timestamp"], y=y_med, mode="lines", name="Forecast (median)")) | |
| low_col = f"{q_low:.1f}".rstrip("0").rstrip(".") | |
| high_col = f"{q_high:.1f}".rstrip("0").rstrip(".") | |
| # columns in pred_df are often exactly "0.1", "0.5", "0.9" as strings | |
| if str(q_low) in pred.columns: | |
| low_series = pred[str(q_low)].to_numpy() | |
| elif low_col in pred.columns: | |
| low_series = pred[low_col].to_numpy() | |
| else: | |
| low_series = None | |
| if str(q_high) in pred.columns: | |
| high_series = pred[str(q_high)].to_numpy() | |
| elif high_col in pred.columns: | |
| high_series = pred[high_col].to_numpy() | |
| else: | |
| high_series = None | |
| if low_series is not None and high_series is not None: | |
| fig.add_trace(go.Scatter( | |
| x=pred["timestamp"], y=high_series, | |
| mode="lines", line=dict(width=0), showlegend=False, hoverinfo="skip" | |
| )) | |
| fig.add_trace(go.Scatter( | |
| x=pred["timestamp"], y=low_series, | |
| mode="lines", fill="tonexty", line=dict(width=0), | |
| name=f"Band [{q_low:.2f}, {q_high:.2f}]" | |
| )) | |
| fig.update_layout( | |
| title=title, | |
| hovermode="x unified", | |
| margin=dict(l=10, r=10, t=55, b=10), | |
| legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0), | |
| xaxis_title="timestamp", | |
| yaxis_title="value", | |
| ) | |
| return fig | |
| def kpi_card(label: str, value: str, hint: str = "") -> str: | |
| hint_html = f"<div style='opacity:.75;font-size:12px;margin-top:6px;'>{hint}</div>" if hint else "" | |
| return f""" | |
| <div style="border:1px solid rgba(255,255,255,.12); border-radius:16px; padding:14px 16px; | |
| background: rgba(255,255,255,.04);"> | |
| <div style="font-size:12px;opacity:.8;">{label}</div> | |
| <div style="font-size:22px;font-weight:700;margin-top:4px;">{value}</div> | |
| {hint_html} | |
| </div> | |
| """ | |
| def kpi_grid(cards: List[str]) -> str: | |
| return f""" | |
| <div class="kpi-grid"> | |
| {''.join(cards)} | |
| </div> | |
| """ | |
| def explain_natural(context_df: pd.DataFrame, pred_df: pd.DataFrame, q_low: float, q_high: float, backtest_metrics: Optional[Dict[str, float]]) -> str: | |
| ctx_y = context_df["target"].to_numpy(dtype=float) | |
| if "predictions" in pred_df.columns: | |
| med = pred_df["predictions"].to_numpy(dtype=float) | |
| elif "0.5" in pred_df.columns: | |
| med = pred_df["0.5"].to_numpy(dtype=float) | |
| else: | |
| med = pred_df.iloc[:, -1].to_numpy(dtype=float) | |
| base = float(np.mean(ctx_y)) | |
| delta = float(med[-1] - med[0]) | |
| pct = (delta / max(1e-6, base)) * 100.0 | |
| if abs(pct) < 2: | |
| trend_txt = "sostanzialmente stabile" | |
| elif pct > 0: | |
| trend_txt = "in crescita" | |
| else: | |
| trend_txt = "in calo" | |
| txt = f"""### 🧠 Spiegazione | |
| Nei prossimi **{len(med)} step**, la previsione mediana è **{trend_txt}** (variazione complessiva ≈ **{pct:+.1f}%** rispetto al livello medio storico). | |
| - **Ultimo valore mediano previsto:** **{med[-1]:.2f}** | |
| """ | |
| # band, if present | |
| low_key = str(q_low) | |
| high_key = str(q_high) | |
| if low_key in pred_df.columns and high_key in pred_df.columns: | |
| low = pred_df[low_key].to_numpy(dtype=float) | |
| high = pred_df[high_key].to_numpy(dtype=float) | |
| txt += f"- **Intervallo [{q_low:.0%}–{q_high:.0%}] ultimo step:** **[{low[-1]:.2f} – {high[-1]:.2f}]**\n" | |
| txt += f"- **Larghezza media banda:** **{avg_width(low, high):.2f}**\n" | |
| else: | |
| txt += "- **Banda di incertezza:** non disponibile (manca nel pred_df).\n" | |
| if backtest_metrics: | |
| txt += f""" | |
| ### 🧪 Backtest (holdout) | |
| - **MAE:** {backtest_metrics["mae"]:.3f} | |
| - **RMSE:** {backtest_metrics["rmse"]:.3f} | |
| - **MAPE:** {backtest_metrics["mape"]:.2f}% | |
| - **Coverage banda:** {backtest_metrics["coverage"]:.1f}% | |
| """ | |
| return txt | |
| # ========================= | |
| # Run core (predict_df) | |
| # ========================= | |
| def run_dashboard( | |
| input_mode: str, | |
| test_csv_name: str, | |
| upload_csv, | |
| target_col: str, | |
| time_col: str, | |
| freq: str, | |
| n: int, | |
| seed: int, | |
| trend: float, | |
| season_period: int, | |
| season_amp: float, | |
| noise: float, | |
| prediction_length: int, | |
| q_low: float, | |
| q_high: float, | |
| do_backtest: bool, | |
| holdout: int, | |
| device_ui: str, | |
| model_id: str, | |
| ): | |
| if q_low >= q_high: | |
| raise gr.Error("Quantile low deve essere < quantile high.") | |
| device = pick_device(device_ui) | |
| pipe = get_pipeline(model_id, device) | |
| # ---- build context_df | |
| if input_mode == "Test CSV": | |
| if not test_csv_name: | |
| raise gr.Error("Seleziona un Test CSV.") | |
| csv_path = os.path.join(DATA_DIR, test_csv_name) | |
| if not os.path.exists(csv_path): | |
| raise gr.Error(f"Non trovo {csv_path}") | |
| context_df, used_target, used_time = load_context_df_from_csv(csv_path, target_col, time_col, freq) | |
| source = f"Test CSV: {test_csv_name} • target={used_target} • time={used_time or 'generated'}" | |
| elif input_mode == "Upload CSV": | |
| if upload_csv is None: | |
| raise gr.Error("Carica un CSV.") | |
| context_df, used_target, used_time = load_context_df_from_csv(upload_csv.name, target_col, time_col, freq) | |
| source = f"Upload CSV • target={used_target} • time={used_time or 'generated'}" | |
| else: | |
| context_df = make_sample_df(n, seed, trend, season_period, season_amp, noise, freq=freq) | |
| source = "Sample series" | |
| if len(context_df) < 10: | |
| raise gr.Error("Serie troppo corta.") | |
| if do_backtest and holdout >= len(context_df): | |
| raise gr.Error("Holdout deve essere più piccolo della lunghezza dello storico.") | |
| quantiles = sorted(list(set([float(q_low), 0.5, float(q_high)]))) | |
| t0 = time.time() | |
| # ---- forecast (future_df not needed if no covariates) | |
| pred_df = pipe.predict_df( | |
| context_df, | |
| prediction_length=int(prediction_length), | |
| quantile_levels=quantiles, | |
| id_column="id", | |
| timestamp_column="timestamp", | |
| target="target", | |
| ) | |
| latency = time.time() - t0 | |
| # ---- exports | |
| forecast_path = os.path.join(OUT_DIR, "chronos2_forecast_df.csv") | |
| pred_df.to_csv(forecast_path, index=False) | |
| # ---- backtest | |
| backtest_metrics = None | |
| backtest_path = None | |
| backtest_df_out = pd.DataFrame() | |
| backtest_fig = go.Figure().update_layout(title="Backtest disabled", margin=dict(l=10, r=10, t=55, b=10)) | |
| if do_backtest: | |
| train_df = context_df.iloc[:-int(holdout)].copy() | |
| true_df = context_df.iloc[-int(holdout):].copy() | |
| bt_pred_df = pipe.predict_df( | |
| train_df, | |
| prediction_length=int(holdout), | |
| quantile_levels=quantiles, | |
| id_column="id", | |
| timestamp_column="timestamp", | |
| target="target", | |
| ) | |
| # extract arrays | |
| y_true = true_df["target"].to_numpy(dtype=float) | |
| if "predictions" in bt_pred_df.columns: | |
| y_hat = bt_pred_df["predictions"].to_numpy(dtype=float) | |
| elif "0.5" in bt_pred_df.columns: | |
| y_hat = bt_pred_df["0.5"].to_numpy(dtype=float) | |
| else: | |
| y_hat = bt_pred_df.iloc[:, -1].to_numpy(dtype=float) | |
| # band | |
| if str(q_low) in bt_pred_df.columns and str(q_high) in bt_pred_df.columns: | |
| low = bt_pred_df[str(q_low)].to_numpy(dtype=float) | |
| high = bt_pred_df[str(q_high)].to_numpy(dtype=float) | |
| cov = coverage(y_true, low, high) | |
| else: | |
| low = y_hat.copy() | |
| high = y_hat.copy() | |
| cov = float("nan") | |
| backtest_metrics = { | |
| "mae": mae(y_true, y_hat), | |
| "rmse": rmse(y_true, y_hat), | |
| "mape": mape(y_true, y_hat), | |
| "coverage": cov, | |
| } | |
| # plot backtest | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter(x=train_df["timestamp"], y=train_df["target"], mode="lines", name="Train")) | |
| fig.add_trace(go.Scatter(x=true_df["timestamp"], y=true_df["target"], mode="lines", name="True (holdout)")) | |
| fig.add_trace(go.Scatter(x=bt_pred_df["timestamp"], y=y_hat, mode="lines", name="Pred (median)")) | |
| if str(q_low) in bt_pred_df.columns and str(q_high) in bt_pred_df.columns: | |
| fig.add_trace(go.Scatter( | |
| x=bt_pred_df["timestamp"], y=high, mode="lines", line=dict(width=0), | |
| showlegend=False, hoverinfo="skip" | |
| )) | |
| fig.add_trace(go.Scatter( | |
| x=bt_pred_df["timestamp"], y=low, mode="lines", fill="tonexty", | |
| line=dict(width=0), name=f"Band [{q_low:.2f}, {q_high:.2f}]" | |
| )) | |
| fig.update_layout( | |
| title="Backtest (holdout) — interactive", | |
| hovermode="x unified", | |
| margin=dict(l=10, r=10, t=55, b=10), | |
| legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0), | |
| xaxis_title="timestamp", | |
| yaxis_title="value", | |
| ) | |
| backtest_fig = fig | |
| backtest_path = os.path.join(OUT_DIR, "chronos2_backtest_df.csv") | |
| bt_pred_df.to_csv(backtest_path, index=False) | |
| backtest_df_out = bt_pred_df | |
| # ---- main plot | |
| forecast_fig = plot_forecast(context_df, pred_df, q_low, q_high, f"Forecast — {source}") | |
| # ---- KPIs | |
| cards = [ | |
| kpi_card("Device", device.upper(), f"cuda_available={torch.cuda.is_available()}"), | |
| kpi_card("Model", (model_id or MODEL_ID_DEFAULT), "Chronos-2"), | |
| kpi_card("Latency", f"{latency:.2f}s", "predict_df()"), | |
| kpi_card("History", str(len(context_df)), "points"), | |
| kpi_card("Horizon", str(prediction_length), "steps"), | |
| kpi_card("Quantiles", f"{q_low:.2f}, 0.50, {q_high:.2f}", "levels"), | |
| ] | |
| kpis_html = kpi_grid(cards) | |
| explanation_md = explain_natural(context_df, pred_df, q_low, q_high, backtest_metrics) | |
| info = { | |
| "source": source, | |
| "history_points": int(len(context_df)), | |
| "prediction_length": int(prediction_length), | |
| "quantile_levels": quantiles, | |
| "backtest": bool(do_backtest), | |
| "holdout": int(holdout) if do_backtest else None, | |
| } | |
| return ( | |
| kpis_html, | |
| explanation_md, | |
| forecast_fig, | |
| backtest_fig, | |
| pred_df, | |
| backtest_df_out, | |
| forecast_path, | |
| backtest_path, | |
| info, | |
| ) | |
| # ========================= | |
| # UI | |
| # ========================= | |
| css = """ | |
| .gradio-container { max-width: 1200px !important; } | |
| /* KPI grid */ | |
| .kpi-grid{ | |
| display: grid; | |
| grid-template-columns: repeat(auto-fit, minmax(190px, 1fr)); | |
| gap: 14px; | |
| padding: 10px 8px; /* <-- spazio “esterno” */ | |
| margin-top: 6px; /* <-- separa dal titolo / contenuto sopra */ | |
| } | |
| /* opzionale: un filo di aria sotto ogni card */ | |
| .kpi-grid > div{ | |
| min-height: 84px; | |
| } | |
| """ | |
| with gr.Blocks(title="Chronos-2 • Forecast Dashboard (predict_df)", css=css) as demo: | |
| gr.Markdown("# ⏱️ Chronos-2 Dashboard — **predict_df** edition (stabile)") | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=360): | |
| gr.Markdown("## Input") | |
| input_mode = gr.Radio(["Sample", "Test CSV", "Upload CSV"], value="Sample", label="Sorgente") | |
| test_csv_name = gr.Dropdown(choices=available_test_csv(), label="Test CSV (data/)") | |
| upload_csv = gr.File(label="Upload CSV", file_types=[".csv"]) | |
| target_col = gr.Textbox(label="Colonna target (opzionale)", placeholder="es: value") | |
| time_col = gr.Textbox(label="Colonna timestamp (opzionale)", placeholder="es: timestamp / date / ds") | |
| freq = gr.Dropdown(["D", "H", "W", "M"], value=DEFAULT_FREQ, label="Freq (se timestamp mancante)") | |
| gr.Markdown("## Sistema") | |
| device_ui = gr.Dropdown( | |
| ["cpu", "cuda (se disponibile)"], | |
| value="cuda (se disponibile)" if torch.cuda.is_available() else "cpu", | |
| label="Device", | |
| ) | |
| model_id = gr.Textbox(value=MODEL_ID_DEFAULT, label="Model ID") | |
| with gr.Accordion("Sample generator", open=False): | |
| n = gr.Slider(60, 2000, value=300, step=10, label="History length") | |
| seed = gr.Number(value=42, precision=0, label="Seed") | |
| trend = gr.Slider(0.0, 0.2, value=0.03, step=0.005, label="Trend") | |
| season_period = gr.Slider(2, 240, value=14, step=1, label="Season period") | |
| season_amp = gr.Slider(0.0, 12.0, value=3.0, step=0.1, label="Season amplitude") | |
| noise = gr.Slider(0.0, 6.0, value=0.8, step=0.05, label="Noise") | |
| gr.Markdown("## Forecast") | |
| prediction_length = gr.Slider(1, 365, value=30, step=1, label="Prediction length") | |
| q_low = gr.Slider(0.01, 0.49, value=0.10, step=0.01, label="Quantile low") | |
| q_high = gr.Slider(0.51, 0.99, value=0.90, step=0.01, label="Quantile high") | |
| gr.Markdown("## Backtest") | |
| do_backtest = gr.Checkbox(value=True, label="Esegui backtest holdout") | |
| holdout = gr.Slider(5, 365, value=30, step=1, label="Holdout points") | |
| run_btn = gr.Button("Run", variant="primary") | |
| with gr.Column(scale=2): | |
| kpis = gr.HTML() | |
| with gr.Tabs(): | |
| with gr.Tab("Forecast"): | |
| forecast_plot = gr.Plot() | |
| forecast_table = gr.Dataframe(interactive=False) | |
| with gr.Tab("Backtest"): | |
| backtest_plot = gr.Plot() | |
| backtest_table = gr.Dataframe(interactive=False) | |
| with gr.Tab("Spiegazione"): | |
| explanation = gr.Markdown() | |
| with gr.Tab("Export"): | |
| forecast_download = gr.File(label="Forecast CSV") | |
| backtest_download = gr.File(label="Backtest CSV") | |
| with gr.Tab("Info"): | |
| info = gr.JSON() | |
| run_btn.click( | |
| fn=run_dashboard, | |
| inputs=[ | |
| input_mode, test_csv_name, upload_csv, | |
| target_col, time_col, freq, | |
| n, seed, trend, season_period, season_amp, noise, | |
| prediction_length, q_low, q_high, | |
| do_backtest, holdout, | |
| device_ui, model_id, | |
| ], | |
| outputs=[ | |
| kpis, | |
| explanation, | |
| forecast_plot, | |
| backtest_plot, | |
| forecast_table, | |
| backtest_table, | |
| forecast_download, | |
| backtest_download, | |
| info, | |
| ], | |
| ) | |
| demo.queue() | |
| demo.launch(ssr_mode=False) | |