Files
Clutter_chuva/clutter_chuva/fitting/fitter.py

320 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from dataclasses import dataclass, field
import numpy as np
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import rv_continuous, goodness_of_fit
@dataclass
class DistributionSummary:
"""
Summary of a single distribution fit against a dataset.
Attributes
----------
distribution_object : rv_continuous
The scipy distribution object used for fitting.
distribution_name : str
Human-readable name of the distribution.
args_fit_params : tuple
Initial guess positional parameters passed to fit() (empty tuple if none).
kwds_fit_params : dict
Keyword arguments that were passed to fit() (fixed params, etc.).
fit_result_params : tuple
The actual fitted parameters returned by fit() (empty tuple until fit() is called).
statistic_method : str
GoF statistic identifier used in validate() (e.g. 'ad', 'ks').
test_result : object
Result object from scipy.stats.goodness_of_fit; None until
validate() is called. Exposes .statistic and .pvalue.
Computed properties
-------------------
pvalue : float | None p-value from the GoF test
gof_statistic : float | None test statistic from the GoF test
mean : float mean of the fitted distribution
std : float standard deviation of the fitted distribution
var : float variance of the fitted distribution
"""
distribution_object: rv_continuous
distribution_name: str
args_fit_params: tuple = field(default_factory=tuple)
kwds_fit_params: dict = field(default_factory=dict)
fit_result_params: tuple = field(default_factory=tuple)
statistic_method: str = 'ad'
test_result: object = None
# ── properties ────────────────────────────────────────────────
@property
def pvalue(self) -> float | None:
"""p-value from the goodness-of-fit test, or None if not yet run."""
return self.test_result.pvalue if self.test_result is not None else None
@property
def gof_statistic(self) -> float | None:
"""Test statistic from the goodness-of-fit test, or None if not yet run."""
return self.test_result.statistic if self.test_result is not None else None
@property
def mean(self) -> float:
"""Mean of the fitted distribution."""
return self.distribution_object.mean(*self.fit_result_params)
@property
def std(self) -> float:
"""Standard deviation of the fitted distribution."""
return self.distribution_object.std(*self.fit_result_params)
@property
def var(self) -> float:
"""Variance of the fitted distribution."""
return self.distribution_object.var(*self.fit_result_params)
def __repr__(self) -> str:
pval_str = f"{self.pvalue:.4f}" if self.pvalue is not None else "N/A"
stat_str = f"{self.gof_statistic:.4f}" if self.gof_statistic is not None else "N/A"
return (
f"DistributionSummary({self.distribution_name})"
f" params={self.fit_result_params}"
f" mean={self.mean:.4f} std={self.std:.4f}"
f" GoF[{self.statistic_method}]={stat_str} p={pval_str}"
)
def __str__(self) -> str:
return self.__repr__()
class Fitter:
"""
Fits and evaluates multiple distributions against a dataset.
Parameters
----------
dist_list : list[rv_continuous]
Distributions to fit.
statistic_method : str
Goodness-of-fit statistic passed to goodness_of_fit (default 'ad').
**kwargs :
Per-distribution initial guesses and fixed parameters, keyed as:
- ``<dist.name>_args`` : tuple of positional initial guesses
- ``<dist.name>_params`` : dict of keyword fixed parameters
Example:
Fitter(
[gamma, weibull_min],
gamma_args=(2.0,),
gamma_params={'floc': 0},
weibull_min_args=(1.5, 0.0, 1.0),
)
Access fitted distributions
---------------------------
Distributions can be accessed by name (str) or by the rv_continuous object:
fitter = Fitter([gamma, weibull_min])
fitter.fit(data)
# Access by name
summary = fitter['gamma']
# Access by rv_continuous object
summary = fitter[gamma]
# Check if a distribution is present
gamma in fitter # True
'weibull_min' in fitter # True
# Override a DistributionSummary
fitter['gamma'] = new_summary
"""
def __init__(self, dist_list: list[rv_continuous], statistic_method: str = 'ad', **kwargs):
self._dist: dict[str, DistributionSummary] = {}
self.dist_list = list(dist_list)
for dist in dist_list:
self._dist[dist.name] = DistributionSummary(
distribution_object=dist,
distribution_name=dist.name,
args_fit_params=kwargs.get(f'{dist.name}_args', ()),
kwds_fit_params=kwargs.get(f'{dist.name}_params', {}),
statistic_method=statistic_method,
test_result=None,
)
#── Getter and setter with flexible keys (str or rv_continuous) ────────────────────────────────
def _resolve_key(self, key: str | rv_continuous) -> str:
"""Resolve a distribution name or rv_continuous object to its string key."""
name = key.name if isinstance(key, rv_continuous) else key
if name not in self._dist:
available = ', '.join(self._dist)
raise KeyError(f"Distribution '{name}' not found. Available: {available}")
return name
def __contains__(self, key: str | rv_continuous) -> bool:
name = key.name if isinstance(key, rv_continuous) else key
return name in self._dist
def __getitem__(self, key: str | rv_continuous) -> DistributionSummary:
return self._dist[self._resolve_key(key)]
def __setitem__(self, key: str | rv_continuous, summary: DistributionSummary) -> None:
if not isinstance(summary, DistributionSummary):
raise TypeError(f"Expected DistributionSummary, got {type(summary).__name__}.")
self._dist[self._resolve_key(key)] = summary
def fit(self, data: np.ndarray) -> None:
"""
Fit every distribution to *data* via MLE.
Parameters
----------
data : array-like
Input data. Only the absolute value is used.
"""
data_flat = np.abs(data).flatten()
self._last_data_flat = data_flat
for dist in self.dist_list:
_summary = self._dist[dist.name]
fit_params = dist.fit(data_flat, *_summary.args_fit_params, **_summary.kwds_fit_params)
_summary.fit_result_params = fit_params
self._dist[dist.name] = _summary
def validate(self, **kwargs) -> None:
"""
Run the goodness-of-fit test on every previously fitted distribution.
Parameters
----------
**kwargs :
Extra keyword arguments forwarded to goodness_of_fit()
(e.g. n_mc_samples=100).
"""
if not hasattr(self, '_last_data_flat'):
raise RuntimeError("No data available. Call fit() first.")
data_flat = self._last_data_flat
for dist in self.dist_list:
_summary = self._dist[dist.name]
test_result = goodness_of_fit(
dist,
data_flat,
statistic=_summary.statistic_method,
**kwargs
)
_summary.test_result = test_result
self._dist[dist.name] = _summary
def summary(self) -> None:
"""Print a summary of all fitted distributions."""
for dist_name, summary in self._dist.items():
print(summary)
def plot_qq_plots(self) -> None:
"""
Generate QQ plots for each fitted distribution against the data.
Requires fit() and validate() to have been called.
"""
if not hasattr(self, '_last_data_flat'):
raise RuntimeError("No data available. Call fit() first.")
data_flat = self._last_data_flat
for dist_name, summary in self._dist.items():
if summary.test_result is None:
print(f"Distribution '{dist_name}' has not been validated yet. Skipping QQ plot.")
continue
sorted_data = np.sort(data_flat)
theoretical_quantiles = summary.distribution_object.ppf(
(np.arange(1, len(sorted_data) + 1) - 0.5) / len(sorted_data),
*summary.fit_result_params
)
fig = go.Figure()
fig.add_trace(go.Scatter(x=theoretical_quantiles, y=sorted_data, mode='markers', name='Data vs. Fit'))
fig.add_trace(go.Scatter(x=theoretical_quantiles, y=theoretical_quantiles, mode='lines', name='Ideal Fit', line=dict(dash='dash')))
fig.update_layout(
title=f'QQ Plot for {summary.distribution_name}',
xaxis_title='Theoretical Quantiles',
yaxis_title='Empirical Quantiles',
autosize=True,
)
fig.show()
def histogram_with_fits(self) -> go.Figure:
"""
Histogram of the data with overlaid PDFs (Plotly).
Requires fit() to have been called.
"""
if not hasattr(self, '_last_data_flat'):
raise RuntimeError("No data available. Call fit() first.")
data_flat = self._last_data_flat
x = np.linspace(0, data_flat.max(), 1000)
fig = go.Figure(layout=go.Layout(hovermode='x unified'))
fig.add_trace(go.Histogram(
x=data_flat, name='Data', opacity=0.3,
histnorm='probability density', hoverinfo='skip', marker_color='blue',
))
for dist_name, summary in self._dist.items():
if not summary.fit_result_params:
print(f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay.")
continue
pdf_values = summary.distribution_object.pdf(x, *summary.fit_result_params)
fig.add_trace(go.Scatter(x=x, y=pdf_values, mode='lines', name=f'{summary.distribution_name} Fit'))
hover_text = [
f"<b>{summary.distribution_name}</b> p-value: {summary.pvalue:.4f} GoF: {summary.gof_statistic:.4f}"
for _ in x
]
fig.data[-1].update(hovertext=hover_text, hoverinfo="text")
fig.update_layout(
xaxis=dict(showgrid=True),
yaxis=dict(showgrid=True),
title=dict(
text='Histogram of Data with Fitted Distributions',
x=0.02, y=0.95, xanchor='left', yanchor='top',
font=dict(size=20, color='darkgray', family='sans-serif'),
),
xaxis_title='Value',
yaxis_title='Density',
autosize=True,
)
return fig
def histogram_with_fits_seaborn(self) -> plt.Figure:
"""
Histogram of the data with overlaid PDFs (Matplotlib/Seaborn).
Requires fit() to have been called.
"""
if not hasattr(self, '_last_data_flat'):
raise RuntimeError("No data available. Call fit() first.")
data_flat = self._last_data_flat
x = np.linspace(0, data_flat.max(), 1000)
fig, ax = plt.subplots(figsize=(10, 6))
sns.histplot(data_flat, bins=int(np.sqrt(len(data_flat))), kde=False, stat='density', color='blue', alpha=0.2, ax=ax)
for dist_name, summary in self._dist.items():
if not summary.fit_result_params:
print(f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay.")
continue
pdf_values = summary.distribution_object.pdf(x, *summary.fit_result_params)
ax.plot(x, pdf_values, label=f'{summary.distribution_name} --- p={summary.pvalue:.4f}')
ax.set_title('Histogram of Data with Fitted Distributions', fontsize=8, loc='left', color='darkgray', fontfamily='sans-serif')
ax.set_xlabel('Value')
ax.set_ylabel('Density')
ax.legend()
ax.grid(True)
return fig