Files
Clutter_chuva/etc/fitting/fitter.py

503 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from dataclasses import dataclass, field
import numpy as np
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import rv_continuous, goodness_of_fit
@dataclass
class DistributionSummary:
"""
Summary of a single distribution fit against a dataset.
Attributes
----------
distribution_object : rv_continuous
The scipy distribution object used for fitting.
distribution_name : str
Human-readable name of the distribution.
args_fit_params : tuple
Initial guess positional parameters passed to fit() (empty tuple if none).
kwds_fit_params : dict
Keyword arguments that were passed to fit() (fixed params, etc.).
fit_result_params : tuple
The actual fitted parameters returned by fit() (empty tuple until fit() is called).
statistic_method : str
GoF statistic identifier used in validate() (e.g. 'ad', 'ks').
test_result : object
Result object from scipy.stats.goodness_of_fit; None until
validate() is called. Exposes .statistic and .pvalue.
Computed properties
-------------------
pvalue : float | None p-value from the GoF test
gof_statistic : float | None test statistic from the GoF test
mean : float mean of the fitted distribution
std : float standard deviation of the fitted distribution
var : float variance of the fitted distribution
"""
distribution_object: rv_continuous
distribution_name: str
args_fit_params: tuple = field(default_factory=tuple)
kwds_fit_params: dict = field(default_factory=dict)
fit_result_params: tuple = field(default_factory=tuple)
statistic_method: str = 'ad'
test_result: object = None
# ── properties ────────────────────────────────────────────────
@property
def pvalue(self) -> float | None:
"""p-value from the goodness-of-fit test.
Returns
-------
float or None
The p-value produced by the GoF test, or None if validate() has
not been called yet.
"""
return self.test_result.pvalue if self.test_result is not None else None
@property
def gof_statistic(self) -> float | None:
"""Test statistic from the goodness-of-fit test.
Returns
-------
float or None
The GoF statistic value, or None if validate() has not been
called yet.
"""
return self.test_result.statistic if self.test_result is not None else None
@property
def mean(self) -> float:
"""Mean of the fitted distribution.
Returns
-------
float
Mean computed from the fitted distribution parameters.
"""
return self.distribution_object.mean(*self.fit_result_params)
@property
def std(self) -> float:
"""Standard deviation of the fitted distribution.
Returns
-------
float
Standard deviation computed from the fitted distribution parameters.
"""
return self.distribution_object.std(*self.fit_result_params)
@property
def var(self) -> float:
"""Variance of the fitted distribution.
Returns
-------
float
Variance computed from the fitted distribution parameters.
"""
return self.distribution_object.var(*self.fit_result_params)
def __repr__(self) -> str:
"""Return a concise string representation of the distribution summary.
Returns
-------
str
Single-line string showing the distribution name, fitted
parameters, mean, standard deviation, GoF statistic, and p-value.
"""
pval_str = f"{self.pvalue:.4f}" if self.pvalue is not None else "N/A"
stat_str = f"{self.gof_statistic:.4f}" if self.gof_statistic is not None else "N/A"
return (
f"DistributionSummary({self.distribution_name})"
f" params={self.fit_result_params}"
f" mean={self.mean:.4f} std={self.std:.4f}"
f" GoF[{self.statistic_method}]={stat_str} p={pval_str}"
)
def __str__(self) -> str:
"""Return the same string representation as ``__repr__``.
Returns
-------
str
Delegates to :meth:`__repr__`.
"""
return self.__repr__()
class Fitter:
"""
Fits and evaluates multiple distributions against a dataset.
Parameters
----------
dist_list : list[rv_continuous]
Distributions to fit.
statistic_method : str
Goodness-of-fit statistic passed to goodness_of_fit (default 'ad').
**kwargs :
Per-distribution initial guesses and fixed parameters, keyed as:
- ``<dist.name>_args`` : tuple of positional initial guesses
- ``<dist.name>_params`` : dict of keyword fixed parameters
Example:
Fitter(
[gamma, weibull_min],
gamma_args=(2.0,),
gamma_params={'floc': 0},
weibull_min_args=(1.5, 0.0, 1.0),
)
Access fitted distributions
---------------------------
Distributions can be accessed by name (str) or by the rv_continuous object:
fitter = Fitter([gamma, weibull_min])
fitter.fit(data)
# Access by name
summary = fitter['gamma']
# Access by rv_continuous object
summary = fitter[gamma]
# Check if a distribution is present
gamma in fitter # True
'weibull_min' in fitter # True
# Override a DistributionSummary
fitter['gamma'] = new_summary
"""
def __init__(self, dist_list: list[rv_continuous], statistic_method: str = 'ad', **kwargs):
"""Initialise the Fitter and build per-distribution summary objects.
Parameters
----------
dist_list : list[rv_continuous]
Distributions to fit.
statistic_method : str, optional
Goodness-of-fit statistic passed to ``goodness_of_fit``
(default ``'ad'`` for Anderson-Darling).
**kwargs :
Per-distribution initial guesses and fixed parameters, keyed as
``<dist.name>_args`` (tuple) or ``<dist.name>_params`` (dict).
"""
self._dist: dict[str, DistributionSummary] = {}
self.dist_list = list(dist_list)
for dist in dist_list:
self._dist[dist.name] = DistributionSummary(
distribution_object=dist,
distribution_name=dist.name,
args_fit_params=kwargs.get(f'{dist.name}_args', ()),
kwds_fit_params=kwargs.get(f'{dist.name}_params', {}),
statistic_method=statistic_method,
test_result=None,
)
#── Getter and setter with flexible keys (str or rv_continuous) ────────────────────────────────
def _resolve_key(self, key: str | rv_continuous) -> str:
"""Resolve a distribution name or ``rv_continuous`` object to its string key.
Parameters
----------
key : str or rv_continuous
Either the distribution's string name or its ``rv_continuous``
object (whose ``.name`` attribute is used).
Returns
-------
str
The string key used in the internal ``_dist`` dictionary.
Raises
------
KeyError
If the resolved name is not found among the registered
distributions.
"""
name = key.name if isinstance(key, rv_continuous) else key
if name not in self._dist:
available = ', '.join(self._dist)
raise KeyError(f"Distribution '{name}' not found. Available: {available}")
return name
def __contains__(self, key: str | rv_continuous) -> bool:
"""Check whether a distribution is registered with this Fitter.
Parameters
----------
key : str or rv_continuous
Distribution name or ``rv_continuous`` object to look up.
Returns
-------
bool
True if the distribution is registered, False otherwise.
"""
name = key.name if isinstance(key, rv_continuous) else key
return name in self._dist
def __getitem__(self, key: str | rv_continuous) -> DistributionSummary:
"""Retrieve the :class:`DistributionSummary` for the given distribution.
Parameters
----------
key : str or rv_continuous
Distribution name or ``rv_continuous`` object to retrieve.
Returns
-------
DistributionSummary
The summary object associated with the requested distribution.
Raises
------
KeyError
If the distribution is not registered.
"""
return self._dist[self._resolve_key(key)]
def __setitem__(self, key: str | rv_continuous, summary: DistributionSummary) -> None:
"""Override the :class:`DistributionSummary` for an existing distribution.
Parameters
----------
key : str or rv_continuous
Distribution name or ``rv_continuous`` object to update.
summary : DistributionSummary
Replacement summary object.
Raises
------
TypeError
If ``summary`` is not a :class:`DistributionSummary` instance.
KeyError
If the distribution is not registered.
"""
if not isinstance(summary, DistributionSummary):
raise TypeError(f"Expected DistributionSummary, got {type(summary).__name__}.")
self._dist[self._resolve_key(key)] = summary
def fit(self, data: np.ndarray) -> None:
"""Fit every distribution to *data* via MLE.
Parameters
----------
data : array-like
Input data. Only the absolute value is used; the array is
flattened before fitting.
Returns
-------
None
Fitted parameters are stored in-place inside each
:class:`DistributionSummary` held by this Fitter.
"""
data_flat = np.abs(data).flatten()
self._last_data_flat = data_flat
for dist in self.dist_list:
_summary = self._dist[dist.name]
fit_params = dist.fit(data_flat, *_summary.args_fit_params, **_summary.kwds_fit_params)
_summary.fit_result_params = fit_params
self._dist[dist.name] = _summary
def validate(self, **kwargs) -> None:
"""Run the goodness-of-fit test on every previously fitted distribution.
Parameters
----------
**kwargs :
Extra keyword arguments forwarded to ``scipy.stats.goodness_of_fit``
(e.g. ``n_mc_samples=100``).
Returns
-------
None
Test results are stored in-place inside each
:class:`DistributionSummary` held by this Fitter.
Raises
------
RuntimeError
If :meth:`fit` has not been called before this method.
"""
if not hasattr(self, '_last_data_flat'):
raise RuntimeError("No data available. Call fit() first.")
data_flat = self._last_data_flat
for dist in self.dist_list:
_summary = self._dist[dist.name]
test_result = goodness_of_fit(
dist,
data_flat,
statistic=_summary.statistic_method,
**kwargs
)
_summary.test_result = test_result
self._dist[dist.name] = _summary
def summary(self) -> None:
"""Print a one-line summary for each registered distribution.
Returns
-------
None
Output is written to stdout via ``print``.
"""
for dist_name, summary in self._dist.items():
print(summary)
def plot_qq_plots(self) -> None:
"""Generate QQ plots for each fitted distribution against the data.
A separate interactive Plotly figure is displayed for every
distribution that has been both fitted and validated. Distributions
that have not yet been validated are skipped with a printed warning.
Returns
-------
None
Figures are rendered inline / in a browser via ``fig.show()``.
Raises
------
RuntimeError
If :meth:`fit` has not been called before this method.
"""
if not hasattr(self, '_last_data_flat'):
raise RuntimeError("No data available. Call fit() first.")
data_flat = self._last_data_flat
for dist_name, summary in self._dist.items():
if summary.test_result is None:
print(f"Distribution '{dist_name}' has not been validated yet. Skipping QQ plot.")
continue
sorted_data = np.sort(data_flat)
theoretical_quantiles = summary.distribution_object.ppf(
(np.arange(1, len(sorted_data) + 1) - 0.5) / len(sorted_data),
*summary.fit_result_params
)
fig = go.Figure()
fig.add_trace(go.Scatter(x=theoretical_quantiles, y=sorted_data, mode='markers', name='Data vs. Fit'))
fig.add_trace(go.Scatter(x=theoretical_quantiles, y=theoretical_quantiles, mode='lines', name='Ideal Fit', line=dict(dash='dash')))
fig.update_layout(
title=f'QQ Plot for {summary.distribution_name}',
xaxis_title='Theoretical Quantiles',
yaxis_title='Empirical Quantiles',
autosize=True,
)
fig.show()
def histogram_with_fits(self) -> go.Figure:
"""Return an interactive histogram with overlaid fitted PDFs (Plotly).
Builds a probability-density histogram of the data and overlays a
line trace for the PDF of each fitted distribution. Hover text shows
the p-value and GoF statistic for each curve. Distributions that
have not yet been fitted are skipped with a printed warning.
Returns
-------
plotly.graph_objects.Figure
Interactive Plotly figure ready to display with ``fig.show()``.
Raises
------
RuntimeError
If :meth:`fit` has not been called before this method.
"""
if not hasattr(self, '_last_data_flat'):
raise RuntimeError("No data available. Call fit() first.")
data_flat = self._last_data_flat
x = np.linspace(0, data_flat.max(), 1000)
fig = go.Figure(layout=go.Layout(hovermode='x unified'))
fig.add_trace(go.Histogram(
x=data_flat, name='Data', opacity=0.3,
histnorm='probability density', hoverinfo='skip', marker_color='blue',
))
for dist_name, summary in self._dist.items():
if not summary.fit_result_params:
print(f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay.")
continue
pdf_values = summary.distribution_object.pdf(x, *summary.fit_result_params)
fig.add_trace(go.Scatter(x=x, y=pdf_values, mode='lines', name=f'{summary.distribution_name} Fit'))
hover_text = [
f"<b>{summary.distribution_name}</b> p-value: {summary.pvalue:.4f} GoF: {summary.gof_statistic:.4f}"
for _ in x
]
fig.data[-1].update(hovertext=hover_text, hoverinfo="text")
fig.update_layout(
xaxis=dict(showgrid=True),
yaxis=dict(showgrid=True),
title=dict(
text='Histogram of Data with Fitted Distributions',
x=0.02, y=0.95, xanchor='left', yanchor='top',
font=dict(size=20, color='darkgray', family='sans-serif'),
),
xaxis_title='Value',
yaxis_title='Density',
autosize=True,
)
return fig
def histogram_with_fits_seaborn(self) -> plt.Figure:
"""Return a static histogram with overlaid fitted PDFs (Matplotlib/Seaborn).
Builds a probability-density histogram using Seaborn and overlays a
line for the PDF of each fitted distribution. The legend entry for
each distribution includes its p-value. Distributions that have not
yet been fitted are skipped with a printed warning.
Returns
-------
matplotlib.figure.Figure
Matplotlib figure object.
Raises
------
RuntimeError
If :meth:`fit` has not been called before this method.
"""
if not hasattr(self, '_last_data_flat'):
raise RuntimeError("No data available. Call fit() first.")
data_flat = self._last_data_flat
x = np.linspace(0, data_flat.max(), 1000)
fig, ax = plt.subplots(figsize=(10, 6))
sns.histplot(data_flat, bins=int(np.sqrt(len(data_flat))), kde=False, stat='density', color='blue', alpha=0.2, ax=ax)
for dist_name, summary in self._dist.items():
if not summary.fit_result_params:
print(f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay.")
continue
pdf_values = summary.distribution_object.pdf(x, *summary.fit_result_params)
ax.plot(x, pdf_values, label=f'{summary.distribution_name} --- p={summary.pvalue:.4f}')
ax.set_title('Histogram of Data with Fitted Distributions', fontsize=8, loc='left', color='darkgray', fontfamily='sans-serif')
ax.set_xlabel('Value')
ax.set_ylabel('Density')
ax.legend()
ax.grid(True)
return fig