import os
import time
from typing import Optional, Tuple, List, Dict
import numpy as np
import pandas as pd
import gradio as gr
import torch
import plotly.graph_objects as go
from chronos import Chronos2Pipeline
# =========================
# Config
# =========================
MODEL_ID_DEFAULT = os.getenv("CHRONOS_MODEL_ID", "amazon/chronos-2")
DATA_DIR = "data"
OUT_DIR = "/tmp"
DEFAULT_FREQ = "D" # se il CSV non ha timestamp, generiamo daily
# =========================
# Utils: files + device
# =========================
def available_test_csv() -> List[str]:
if not os.path.isdir(DATA_DIR):
return []
return sorted([f for f in os.listdir(DATA_DIR) if f.lower().endswith(".csv")])
def pick_device(ui_choice: str) -> str:
return "cuda" if (ui_choice or "").startswith("cuda") and torch.cuda.is_available() else "cpu"
# =========================
# Sample series
# =========================
def make_sample_df(
n: int,
seed: int,
trend: float,
season_period: int,
season_amp: float,
noise: float,
freq: str = DEFAULT_FREQ,
start: str = "2020-01-01",
) -> pd.DataFrame:
rng = np.random.default_rng(int(seed))
t = np.arange(int(n), dtype=np.float32)
y = (
float(trend) * t
+ float(season_amp) * np.sin(2 * np.pi * t / max(1, int(season_period)))
+ rng.normal(0.0, float(noise), size=int(n))
).astype(np.float32)
if float(np.min(y)) < 0:
y -= float(np.min(y))
ts = pd.date_range(start=start, periods=int(n), freq=freq)
return pd.DataFrame({"id": 0, "timestamp": ts, "target": y})
# =========================
# CSV loader -> context_df format (id,timestamp,target)
# =========================
def _guess_timestamp_column(df: pd.DataFrame) -> Optional[str]:
# prova colonne con nome tipico
for c in df.columns:
lc = str(c).lower()
if lc in ["ds", "date", "datetime", "timestamp", "time"]:
return c
# prova parsing: se una colonna ha tanti valori parseabili a datetime
for c in df.columns:
if df[c].dtype == object:
parsed = pd.to_datetime(df[c], errors="coerce", utc=False)
if parsed.notna().sum() >= max(10, int(0.6 * len(df))):
return c
return None
def _guess_numeric_target_column(df: pd.DataFrame, user_col: Optional[str]) -> str:
if user_col and user_col.strip():
col = user_col.strip()
if col not in df.columns:
raise ValueError(f"Colonna '{col}' non trovata. Disponibili: {list(df.columns)}")
return col
# numeric dtype first
numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
if numeric_cols:
return numeric_cols[0]
# try coercion
best = None
best_count = 0
for c in df.columns:
coerced = pd.to_numeric(df[c], errors="coerce")
cnt = coerced.notna().sum()
if cnt > best_count:
best = c
best_count = cnt
if best is None or best_count < 10:
raise ValueError("Non trovo una colonna numerica valida (>=10 valori) nel CSV.")
return best
def load_context_df_from_csv(path: str, user_target_col: Optional[str], user_time_col: Optional[str], freq: str) -> Tuple[pd.DataFrame, str, Optional[str]]:
df = pd.read_csv(path)
if df.shape[0] < 10:
raise ValueError("Serie troppo corta (minimo consigliato: 10 righe).")
target_col = _guess_numeric_target_column(df, user_target_col)
time_col = user_time_col.strip() if (user_time_col and user_time_col.strip()) else _guess_timestamp_column(df)
# target
y = pd.to_numeric(df[target_col], errors="coerce").dropna().astype(np.float32).to_numpy()
if len(y) < 10:
raise ValueError("Troppi NaN: la colonna target ha meno di 10 valori numerici.")
# timestamp
if time_col and time_col in df.columns:
ts = pd.to_datetime(df[time_col], errors="coerce")
# allinea su target non-NaN (stesso mask del target coercito)
mask = pd.to_numeric(df[target_col], errors="coerce").notna()
ts = ts[mask]
ts = ts.dropna()
# se timestamp troppo sporchi, fallback a range
if len(ts) < 10:
time_col = None
if not time_col:
ts = pd.date_range(start="2020-01-01", periods=len(y), freq=freq)
context_df = pd.DataFrame({"id": 0, "timestamp": ts[: len(y)], "target": y[: len(ts)]})
context_df = context_df.sort_values("timestamp").reset_index(drop=True)
return context_df, target_col, (time_col if time_col else None)
# =========================
# Pipeline cache
# =========================
_PIPE = None
_META = {"model_id": None, "device": None}
def get_pipeline(model_id: str, device: str) -> Chronos2Pipeline:
global _PIPE, _META
model_id = (model_id or MODEL_ID_DEFAULT).strip()
device = "cuda" if (device == "cuda" and torch.cuda.is_available()) else "cpu"
if _PIPE is None or _META["model_id"] != model_id or _META["device"] != device:
_PIPE = Chronos2Pipeline.from_pretrained(model_id, device_map=device)
_META = {"model_id": model_id, "device": device}
return _PIPE
# =========================
# Metrics
# =========================
def mae(y_true: np.ndarray, y_pred: np.ndarray) -> float:
return float(np.mean(np.abs(y_true - y_pred)))
def rmse(y_true: np.ndarray, y_pred: np.ndarray) -> float:
return float(np.sqrt(np.mean((y_true - y_pred) ** 2)))
def mape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
denom = np.maximum(1e-8, np.abs(y_true))
return float(np.mean(np.abs((y_true - y_pred) / denom)) * 100.0)
def coverage(y_true: np.ndarray, low: np.ndarray, high: np.ndarray) -> float:
return float(np.mean((y_true >= low) & (y_true <= high)) * 100.0)
def avg_width(low: np.ndarray, high: np.ndarray) -> float:
return float(np.mean(high - low))
# =========================
# Plotly
# =========================
def plot_forecast(context_df: pd.DataFrame, pred_df: pd.DataFrame, q_low: float, q_high: float, title: str) -> go.Figure:
ctx = context_df.copy()
pred = pred_df.copy()
fig = go.Figure()
fig.add_trace(go.Scatter(x=ctx["timestamp"], y=ctx["target"], mode="lines", name="History"))
# pred_df from predict_df typically has:
# - timestamp
# - predictions (median or q=0.5)
# - columns for quantiles like "0.1", "0.9"
if "predictions" in pred.columns:
y_med = pred["predictions"].to_numpy()
else:
# fallback: if "0.5" exists
y_med = pred.get("0.5", pred.iloc[:, -1]).to_numpy()
fig.add_trace(go.Scatter(x=pred["timestamp"], y=y_med, mode="lines", name="Forecast (median)"))
low_col = f"{q_low:.1f}".rstrip("0").rstrip(".")
high_col = f"{q_high:.1f}".rstrip("0").rstrip(".")
# columns in pred_df are often exactly "0.1", "0.5", "0.9" as strings
if str(q_low) in pred.columns:
low_series = pred[str(q_low)].to_numpy()
elif low_col in pred.columns:
low_series = pred[low_col].to_numpy()
else:
low_series = None
if str(q_high) in pred.columns:
high_series = pred[str(q_high)].to_numpy()
elif high_col in pred.columns:
high_series = pred[high_col].to_numpy()
else:
high_series = None
if low_series is not None and high_series is not None:
fig.add_trace(go.Scatter(
x=pred["timestamp"], y=high_series,
mode="lines", line=dict(width=0), showlegend=False, hoverinfo="skip"
))
fig.add_trace(go.Scatter(
x=pred["timestamp"], y=low_series,
mode="lines", fill="tonexty", line=dict(width=0),
name=f"Band [{q_low:.2f}, {q_high:.2f}]"
))
fig.update_layout(
title=title,
hovermode="x unified",
margin=dict(l=10, r=10, t=55, b=10),
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0),
xaxis_title="timestamp",
yaxis_title="value",
)
return fig
def kpi_card(label: str, value: str, hint: str = "") -> str:
hint_html = f"
{hint}
" if hint else ""
return f"""
{label}
{value}
{hint_html}
"""
def kpi_grid(cards: List[str]) -> str:
return f"""
{''.join(cards)}
"""
def explain_natural(context_df: pd.DataFrame, pred_df: pd.DataFrame, q_low: float, q_high: float, backtest_metrics: Optional[Dict[str, float]]) -> str:
ctx_y = context_df["target"].to_numpy(dtype=float)
if "predictions" in pred_df.columns:
med = pred_df["predictions"].to_numpy(dtype=float)
elif "0.5" in pred_df.columns:
med = pred_df["0.5"].to_numpy(dtype=float)
else:
med = pred_df.iloc[:, -1].to_numpy(dtype=float)
base = float(np.mean(ctx_y))
delta = float(med[-1] - med[0])
pct = (delta / max(1e-6, base)) * 100.0
if abs(pct) < 2:
trend_txt = "sostanzialmente stabile"
elif pct > 0:
trend_txt = "in crescita"
else:
trend_txt = "in calo"
txt = f"""### 🧠 Spiegazione
Nei prossimi **{len(med)} step**, la previsione mediana è **{trend_txt}** (variazione complessiva ≈ **{pct:+.1f}%** rispetto al livello medio storico).
- **Ultimo valore mediano previsto:** **{med[-1]:.2f}**
"""
# band, if present
low_key = str(q_low)
high_key = str(q_high)
if low_key in pred_df.columns and high_key in pred_df.columns:
low = pred_df[low_key].to_numpy(dtype=float)
high = pred_df[high_key].to_numpy(dtype=float)
txt += f"- **Intervallo [{q_low:.0%}–{q_high:.0%}] ultimo step:** **[{low[-1]:.2f} – {high[-1]:.2f}]**\n"
txt += f"- **Larghezza media banda:** **{avg_width(low, high):.2f}**\n"
else:
txt += "- **Banda di incertezza:** non disponibile (manca nel pred_df).\n"
if backtest_metrics:
txt += f"""
### 🧪 Backtest (holdout)
- **MAE:** {backtest_metrics["mae"]:.3f}
- **RMSE:** {backtest_metrics["rmse"]:.3f}
- **MAPE:** {backtest_metrics["mape"]:.2f}%
- **Coverage banda:** {backtest_metrics["coverage"]:.1f}%
"""
return txt
# =========================
# Run core (predict_df)
# =========================
def run_dashboard(
input_mode: str,
test_csv_name: str,
upload_csv,
target_col: str,
time_col: str,
freq: str,
n: int,
seed: int,
trend: float,
season_period: int,
season_amp: float,
noise: float,
prediction_length: int,
q_low: float,
q_high: float,
do_backtest: bool,
holdout: int,
device_ui: str,
model_id: str,
):
if q_low >= q_high:
raise gr.Error("Quantile low deve essere < quantile high.")
device = pick_device(device_ui)
pipe = get_pipeline(model_id, device)
# ---- build context_df
if input_mode == "Test CSV":
if not test_csv_name:
raise gr.Error("Seleziona un Test CSV.")
csv_path = os.path.join(DATA_DIR, test_csv_name)
if not os.path.exists(csv_path):
raise gr.Error(f"Non trovo {csv_path}")
context_df, used_target, used_time = load_context_df_from_csv(csv_path, target_col, time_col, freq)
source = f"Test CSV: {test_csv_name} • target={used_target} • time={used_time or 'generated'}"
elif input_mode == "Upload CSV":
if upload_csv is None:
raise gr.Error("Carica un CSV.")
context_df, used_target, used_time = load_context_df_from_csv(upload_csv.name, target_col, time_col, freq)
source = f"Upload CSV • target={used_target} • time={used_time or 'generated'}"
else:
context_df = make_sample_df(n, seed, trend, season_period, season_amp, noise, freq=freq)
source = "Sample series"
if len(context_df) < 10:
raise gr.Error("Serie troppo corta.")
if do_backtest and holdout >= len(context_df):
raise gr.Error("Holdout deve essere più piccolo della lunghezza dello storico.")
quantiles = sorted(list(set([float(q_low), 0.5, float(q_high)])))
t0 = time.time()
# ---- forecast (future_df not needed if no covariates)
pred_df = pipe.predict_df(
context_df,
prediction_length=int(prediction_length),
quantile_levels=quantiles,
id_column="id",
timestamp_column="timestamp",
target="target",
)
latency = time.time() - t0
# ---- exports
forecast_path = os.path.join(OUT_DIR, "chronos2_forecast_df.csv")
pred_df.to_csv(forecast_path, index=False)
# ---- backtest
backtest_metrics = None
backtest_path = None
backtest_df_out = pd.DataFrame()
backtest_fig = go.Figure().update_layout(title="Backtest disabled", margin=dict(l=10, r=10, t=55, b=10))
if do_backtest:
train_df = context_df.iloc[:-int(holdout)].copy()
true_df = context_df.iloc[-int(holdout):].copy()
bt_pred_df = pipe.predict_df(
train_df,
prediction_length=int(holdout),
quantile_levels=quantiles,
id_column="id",
timestamp_column="timestamp",
target="target",
)
# extract arrays
y_true = true_df["target"].to_numpy(dtype=float)
if "predictions" in bt_pred_df.columns:
y_hat = bt_pred_df["predictions"].to_numpy(dtype=float)
elif "0.5" in bt_pred_df.columns:
y_hat = bt_pred_df["0.5"].to_numpy(dtype=float)
else:
y_hat = bt_pred_df.iloc[:, -1].to_numpy(dtype=float)
# band
if str(q_low) in bt_pred_df.columns and str(q_high) in bt_pred_df.columns:
low = bt_pred_df[str(q_low)].to_numpy(dtype=float)
high = bt_pred_df[str(q_high)].to_numpy(dtype=float)
cov = coverage(y_true, low, high)
else:
low = y_hat.copy()
high = y_hat.copy()
cov = float("nan")
backtest_metrics = {
"mae": mae(y_true, y_hat),
"rmse": rmse(y_true, y_hat),
"mape": mape(y_true, y_hat),
"coverage": cov,
}
# plot backtest
fig = go.Figure()
fig.add_trace(go.Scatter(x=train_df["timestamp"], y=train_df["target"], mode="lines", name="Train"))
fig.add_trace(go.Scatter(x=true_df["timestamp"], y=true_df["target"], mode="lines", name="True (holdout)"))
fig.add_trace(go.Scatter(x=bt_pred_df["timestamp"], y=y_hat, mode="lines", name="Pred (median)"))
if str(q_low) in bt_pred_df.columns and str(q_high) in bt_pred_df.columns:
fig.add_trace(go.Scatter(
x=bt_pred_df["timestamp"], y=high, mode="lines", line=dict(width=0),
showlegend=False, hoverinfo="skip"
))
fig.add_trace(go.Scatter(
x=bt_pred_df["timestamp"], y=low, mode="lines", fill="tonexty",
line=dict(width=0), name=f"Band [{q_low:.2f}, {q_high:.2f}]"
))
fig.update_layout(
title="Backtest (holdout) — interactive",
hovermode="x unified",
margin=dict(l=10, r=10, t=55, b=10),
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0),
xaxis_title="timestamp",
yaxis_title="value",
)
backtest_fig = fig
backtest_path = os.path.join(OUT_DIR, "chronos2_backtest_df.csv")
bt_pred_df.to_csv(backtest_path, index=False)
backtest_df_out = bt_pred_df
# ---- main plot
forecast_fig = plot_forecast(context_df, pred_df, q_low, q_high, f"Forecast — {source}")
# ---- KPIs
cards = [
kpi_card("Device", device.upper(), f"cuda_available={torch.cuda.is_available()}"),
kpi_card("Model", (model_id or MODEL_ID_DEFAULT), "Chronos-2"),
kpi_card("Latency", f"{latency:.2f}s", "predict_df()"),
kpi_card("History", str(len(context_df)), "points"),
kpi_card("Horizon", str(prediction_length), "steps"),
kpi_card("Quantiles", f"{q_low:.2f}, 0.50, {q_high:.2f}", "levels"),
]
kpis_html = kpi_grid(cards)
explanation_md = explain_natural(context_df, pred_df, q_low, q_high, backtest_metrics)
info = {
"source": source,
"history_points": int(len(context_df)),
"prediction_length": int(prediction_length),
"quantile_levels": quantiles,
"backtest": bool(do_backtest),
"holdout": int(holdout) if do_backtest else None,
}
return (
kpis_html,
explanation_md,
forecast_fig,
backtest_fig,
pred_df,
backtest_df_out,
forecast_path,
backtest_path,
info,
)
# =========================
# UI
# =========================
css = """
.gradio-container { max-width: 1200px !important; }
/* KPI grid */
.kpi-grid{
display: grid;
grid-template-columns: repeat(auto-fit, minmax(190px, 1fr));
gap: 14px;
padding: 10px 8px; /* <-- spazio “esterno” */
margin-top: 6px; /* <-- separa dal titolo / contenuto sopra */
}
/* opzionale: un filo di aria sotto ogni card */
.kpi-grid > div{
min-height: 84px;
}
"""
with gr.Blocks(title="Chronos-2 • Forecast Dashboard (predict_df)", css=css) as demo:
gr.Markdown("# Chronos-2 Dashboard")
with gr.Row():
with gr.Column(scale=1, min_width=360):
gr.Markdown("## Input")
input_mode = gr.Radio(["Sample", "Test CSV", "Upload CSV"], value="Sample", label="Sorgente")
test_csv_name = gr.Dropdown(choices=available_test_csv(), label="Test CSV (data/)")
upload_csv = gr.File(label="Upload CSV", file_types=[".csv"])
target_col = gr.Textbox(label="Colonna target (opzionale)", placeholder="es: value")
time_col = gr.Textbox(label="Colonna timestamp (opzionale)", placeholder="es: timestamp / date / ds")
freq = gr.Dropdown(["D", "H", "W", "M"], value=DEFAULT_FREQ, label="Freq (se timestamp mancante)")
gr.Markdown("## Sistema")
device_ui = gr.Dropdown(
["cpu", "cuda (se disponibile)"],
value="cuda (se disponibile)" if torch.cuda.is_available() else "cpu",
label="Device",
)
model_id = gr.Textbox(value=MODEL_ID_DEFAULT, label="Model ID")
with gr.Accordion("Sample generator", open=False):
n = gr.Slider(60, 2000, value=300, step=10, label="History length")
seed = gr.Number(value=42, precision=0, label="Seed")
trend = gr.Slider(0.0, 0.2, value=0.03, step=0.005, label="Trend")
season_period = gr.Slider(2, 240, value=14, step=1, label="Season period")
season_amp = gr.Slider(0.0, 12.0, value=3.0, step=0.1, label="Season amplitude")
noise = gr.Slider(0.0, 6.0, value=0.8, step=0.05, label="Noise")
gr.Markdown("## Forecast")
prediction_length = gr.Slider(1, 365, value=30, step=1, label="Prediction length")
q_low = gr.Slider(0.01, 0.49, value=0.10, step=0.01, label="Quantile low")
q_high = gr.Slider(0.51, 0.99, value=0.90, step=0.01, label="Quantile high")
gr.Markdown("## Backtest")
do_backtest = gr.Checkbox(value=True, label="Esegui backtest holdout")
holdout = gr.Slider(5, 365, value=30, step=1, label="Holdout points")
run_btn = gr.Button("Run", variant="primary")
with gr.Column(scale=2):
kpis = gr.HTML()
with gr.Tabs():
with gr.Tab("Forecast"):
forecast_plot = gr.Plot()
forecast_table = gr.Dataframe(interactive=False)
with gr.Tab("Backtest"):
backtest_plot = gr.Plot()
backtest_table = gr.Dataframe(interactive=False)
with gr.Tab("Spiegazione"):
explanation = gr.Markdown()
with gr.Tab("Export"):
forecast_download = gr.File(label="Forecast CSV")
backtest_download = gr.File(label="Backtest CSV")
with gr.Tab("Info"):
info = gr.JSON()
run_btn.click(
fn=run_dashboard,
inputs=[
input_mode, test_csv_name, upload_csv,
target_col, time_col, freq,
n, seed, trend, season_period, season_amp, noise,
prediction_length, q_low, q_high,
do_backtest, holdout,
device_ui, model_id,
],
outputs=[
kpis,
explanation,
forecast_plot,
backtest_plot,
forecast_table,
backtest_table,
forecast_download,
backtest_download,
info,
],
)
demo.queue()
demo.launch(ssr_mode=False)