Files
Clutter_chuva/etc/fitting/fitter.py
neutonsevero aacfe3f977 ADD:
AIC statistic added
2026-04-08 22:53:33 -03:00

441 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from dataclasses import dataclass, field
import numpy as np
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import rv_continuous, goodness_of_fit
from plotly.subplots import make_subplots
@dataclass
class DistributionSummary:
"""
Summary of a single distribution fit against a dataset.
Attributes
----------
distribution_object : rv_continuous
The scipy distribution object used for fitting.
distribution_name : str
Human-readable name of the distribution.
args_fit_params : tuple
Initial guess positional parameters passed to fit() (empty tuple if none).
kwds_fit_params : dict
Keyword arguments that were passed to fit() (fixed params, etc.).
fit_result_params : tuple
The actual fitted parameters returned by fit() (empty tuple until fit() is called).
statistic_method : object
GoF statistic identifier used in validate() (e.g. 'ad', 'ks' or a custom callable).
test_result : object
Result object from scipy.stats.goodness_of_fit; None until
validate() is called. Exposes .statistic and .pvalue.
Computed properties
-------------------
pvalue : float | None p-value from the GoF test
gof_statistic : float | None test statistic from the GoF test
mean : float mean of the fitted distribution
std : float standard deviation of the fitted distribution
var : float variance of the fitted distribution
"""
distribution_object: rv_continuous
distribution_name: str
args_fit_params: tuple = field(default_factory=tuple)
kwds_fit_params: dict = field(default_factory=dict)
fit_result_params: tuple = field(default_factory=tuple)
statistic_method: object = "ad"
test_result: object = None
# ── convenience properties ────────────────────────────────────────────────
@property
def pvalue(self) -> float | None:
"""p-value from the goodness-of-fit test, or None if not yet run."""
return self.test_result.pvalue if self.test_result is not None else None
@property
def gof_statistic(self) -> float | None:
"""Test statistic from the goodness-of-fit test, or None if not yet run."""
return self.test_result.statistic if self.test_result is not None else None
@property
def mean(self) -> float:
"""Mean of the fitted distribution."""
return float(self.distribution_object.mean(*self.fit_result_params))
@property
def std(self) -> float:
"""Standard deviation of the fitted distribution."""
return float(self.distribution_object.std(*self.fit_result_params))
@property
def var(self) -> float:
"""Variance of the fitted distribution."""
return float(self.distribution_object.var(*self.fit_result_params))
def __repr__(self) -> str:
pval_str = f"{self.pvalue:.4f}" if self.pvalue is not None else "N/A"
stat_str = (
f"{self.gof_statistic:.4f}" if self.gof_statistic is not None else "N/A"
)
return (
f"DistributionSummary({self.distribution_name})"
f" params={self.fit_result_params}"
f" mean={self.mean:.4f} std={self.std:.4f}"
f" GoF[{self.statistic_method}]={stat_str} p={pval_str}"
)
def __str__(self) -> str:
return self.__repr__()
class Fitter:
"""
Fits and evaluates multiple distributions against a dataset.
Parameters
----------
dist_list : list[rv_continuous]
Distributions to fit.
statistic_method : str
Goodness-of-fit statistic passed to goodness_of_fit (default 'ad').
**kwargs :
Per-distribution initial guesses and fixed parameters, keyed as:
- ``<dist.name>_args`` : tuple of positional initial guesses
- ``<dist.name>_params`` : dict of keyword fixed parameters
Example:
Fitter(
[gamma, weibull_min],
gamma_args=(2.0,),
gamma_params={'floc': 0},
weibull_min_args=(1.5, 0.0, 1.0),
)
"""
def __init__(
self, dist_list: list[rv_continuous], statistic_method: str = "ad", **kwargs
):
self._dist: dict[str, DistributionSummary] = {}
self.dist_list = list(dist_list) # Ensure it's a list for multiple iterations
for dist in dist_list:
self._dist[dist.name] = DistributionSummary(
distribution_object=dist,
distribution_name=dist.name,
args_fit_params=kwargs.get(f"{dist.name}_args", ()),
kwds_fit_params=kwargs.get(f"{dist.name}_params", {}),
statistic_method=statistic_method,
test_result=None,
)
def _resolve_key(self, key: str | rv_continuous) -> str:
"""Resolve a distribution name or rv_continuous object to its string key."""
name = key.name if isinstance(key, rv_continuous) else key
if name not in self._dist:
available = ", ".join(self._dist)
raise KeyError(f"Distribution '{name}' not found. Available: {available}")
return name
def __contains__(self, key: str | rv_continuous) -> bool:
name = key.name if isinstance(key, rv_continuous) else key
return name in self._dist
def __getitem__(self, key: str | rv_continuous) -> DistributionSummary:
return self._dist[self._resolve_key(key)]
def __setitem__(
self, key: str | rv_continuous, summary: DistributionSummary
) -> None:
if not isinstance(summary, DistributionSummary):
raise TypeError(
f"Expected DistributionSummary, got {type(summary).__name__}."
)
self._dist[self._resolve_key(key)] = summary
def __iter__(self):
return self
def __next__(self):
if not hasattr(self, "_iter_index"):
self._iter_index = 0
if self._iter_index >= len(self._dist):
del self._iter_index
raise StopIteration
key = list(self._dist.keys())[self._iter_index]
summary = self._dist[key]
self._iter_index += 1
return key, summary
def fit(self, data: np.ndarray) -> dict[str, DistributionSummary]:
"""
Fit every distribution to *data* via MLE.
Parameters
----------
data : array-like
Input data. Only the absolute value is used.
Returns
-------
dict[str, DistributionSummary]
Mapping of distribution name → summary (test_result is None
until validate() is called).
"""
data_flat = np.abs(data).flatten()
self._last_data_flat = data_flat
for dist in self.dist_list:
_summary = self._dist[dist.name]
fit_params = dist.fit(
data_flat, *_summary.args_fit_params, **_summary.kwds_fit_params
)
_summary.fit_result_params = fit_params
self._dist[dist.name] = _summary
def validate(self, **kwargs) -> dict[str, DistributionSummary]:
"""
Run the goodness-of-fit test on every previously fitted distribution.
Parameters
----------
**kwargs :
Extra keyword arguments forwarded to goodness_of_fit()
(e.g. n_mc_samples=100).
Returns
-------
dict[str, DistributionSummary]
Same results dict, with test_result populated for each distribution.
"""
if not hasattr(self, "_last_data_flat"):
raise RuntimeError("No data available. Call fit() first.")
data_flat = self._last_data_flat
for dist in self.dist_list:
_summary = self._dist[dist.name]
test_result = goodness_of_fit(
dist, data_flat, statistic=_summary.statistic_method, **kwargs
)
_summary.test_result = test_result
self._dist[dist.name] = _summary
def summary(self) -> dict[str, DistributionSummary]:
"""
Print a summary of all fitted distributions, including parameters and GoF results.
"""
for dist_name, summary in self._dist.items():
print(summary)
def plot_qq_plots(self, method: str = "hazen"):
"""
Generate QQ plots for each fitted distribution against the data.
Requires that fit() and validate() have been called to populate parameters and test results.
Parameters
----------
method : str
Plotting positions formula. Either 'hazen' (default) or 'filliben'.
- 'hazen' : p_i = (i - 0.5) / n
- 'filliben': p_1 = 1 - 0.5^(1/n), p_n = 0.5^(1/n),
p_i = (i - 0.3175) / (n + 0.365) for 1 < i < n
"""
if not hasattr(self, "_last_data_flat"):
raise RuntimeError("No data available. Call fit() first.")
if method not in ("hazen", "filliben"):
raise ValueError(f"method must be 'hazen' or 'filliben', got '{method}'.")
data_flat = self._last_data_flat
# generate subplots with 2 columns and as many rows as needed, but not more than 3 rows, if there are more than 6 distributions, create multiple figures
num_dists = len(self._dist)
num_cols = 2
num_rows = min(3, (num_dists + 1) // 2)
fig = make_subplots(
rows=num_rows,
cols=num_cols,
subplot_titles=[dist_name for dist_name in self._dist.keys()],
)
for dist_name, summary in self:
if summary.test_result is None:
print(
f"Distribution '{dist_name}' has not been validated yet. Skipping QQ plot."
)
continue
# Generate theoretical quantiles
sorted_data = np.sort(data_flat)
n = len(sorted_data)
i = np.arange(1, n + 1)
if method == "hazen":
plotting_positions = (i - 0.5) / n
else: # filliben
plotting_positions = (i - 0.3175) / (n + 0.365)
plotting_positions[0] = 1 - 0.5 ** (1 / n)
plotting_positions[-1] = 0.5 ** (1 / n)
theoretical_quantiles = summary.distribution_object.ppf(
plotting_positions, *summary.fit_result_params
)
# Create QQ plot in each subplot
row = (list(self._dist.keys()).index(dist_name) // num_cols) + 1
col = (list(self._dist.keys()).index(dist_name) % num_cols) + 1
fig.add_trace(
go.Scatter(
x=theoretical_quantiles,
y=sorted_data,
mode="markers",
name=dist_name,
),
row=row,
col=col,
)
# Add a reference line y=x
min_val = min(theoretical_quantiles.min(), sorted_data.min())
max_val = max(theoretical_quantiles.max(), sorted_data.max())
fig.add_trace(
go.Scatter(
x=[min_val, max_val],
y=[min_val, max_val],
mode="lines",
name="y=x",
line=dict(dash="dash"),
),
row=row,
col=col,
)
fig.update_xaxes(title_text="Theoretical Quantiles", row=row, col=col)
fig.update_yaxes(title_text="Empirical Quantiles", row=row, col=col)
# add statistic value in bottom right of each subplot (summary.gof_statistic())
fig.add_annotation(
x=0.95,
y=0.05,
xref="x domain",
yref="y domain",
text=f"{summary.statistic_method}={summary.gof_statistic:.4f}",
showarrow=False,
font=dict(size=10, color="green" if summary.pvalue > 0.05 else "red"),
row=row,
col=col,
)
fig.update_layout(
title=f"QQ Plots of Fitted Distributions ({method})",
showlegend=False,
autosize=True,
)
return fig
def histogram_with_fits(self):
"""
Generate a histogram of the data with overlaid PDFs of each fitted distribution.
Requires that fit() has been called to populate parameters.
"""
if not hasattr(self, "_last_data_flat"):
raise RuntimeError("No data available. Call fit() first.")
data_flat = self._last_data_flat
x = np.linspace(0, data_flat.max(), 1000)
fig = go.Figure(layout=go.Layout(hovermode="x unified"))
# do not show data in hoover, only show the distribution name, p-value and GoF statistic for each distribution
fig.add_trace(
go.Histogram(
x=data_flat,
name="Data",
opacity=0.3,
histnorm="probability density",
hoverinfo="skip",
marker_color="blue",
)
)
for dist_name, summary in self._dist.items():
if not summary.fit_result_params:
print(
f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay."
)
continue
pdf_values = summary.distribution_object.pdf(x, *summary.fit_result_params)
# add trace and stack hoover x like stock price, but make the y value shows the p-value and GoF statistic for each distribution
fig.add_trace(
go.Scatter(
x=x,
y=pdf_values,
mode="lines",
name=f"{summary.distribution_name} Fit",
)
)
hover_text = [
f"<b>{summary.distribution_name}</b> p-value: {summary.pvalue:.4f} GoF: {summary.gof_statistic:.4f}"
for _ in x
]
fig.data[-1].update(hovertext=hover_text, hoverinfo="text")
# add grid
fig.update_layout(xaxis=dict(showgrid=True), yaxis=dict(showgrid=True))
# put title in top left, make it smaller, change it font to sans and put in light gray
fig.update_layout(
title=dict(
text="Histogram of Data with Fitted Distributions",
x=0.02,
y=0.95,
xanchor="left",
yanchor="top",
font=dict(size=20, color="darkgray", family="sans-serif"),
),
xaxis_title="Value",
yaxis_title="Density",
autosize=True,
)
return fig
def histogram_with_fits_seaborn(self):
"""
Generate a histogram of the data with overlaid PDFs of each fitted distribution using seaborn.
Requires that fit() has been called to populate parameters.
"""
if not hasattr(self, "_last_data_flat"):
raise RuntimeError("No data available. Call fit() first.")
data_flat = self._last_data_flat
x = np.linspace(0, data_flat.max(), 1000)
fig, ax = plt.subplots(figsize=(10, 6))
sns.histplot(
data_flat,
bins=int(np.sqrt(len(data_flat))),
kde=False,
stat="density",
color="blue",
alpha=0.2,
ax=ax,
)
for dist_name, summary in self._dist.items():
if not summary.fit_result_params:
print(
f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay."
)
continue
pdf_values = summary.distribution_object.pdf(x, *summary.fit_result_params)
ax.plot(
x,
pdf_values,
label=f"{summary.distribution_name} --- p={summary.pvalue:.4f}",
)
# put title in top left, make it smaller, change it font to sans and put in light gray
ax.set_title(
"Histogram of Data with Fitted Distributions",
fontsize=8,
loc="left",
color="darkgray",
fontfamily="sans-serif",
)
ax.set_xlabel("Value")
ax.set_ylabel("Density")
ax.legend()
ax.grid(True)
return fig