[MAIN] Change workdir files, add docstring in functions

This commit is contained in:
2026-03-25 16:37:56 -03:00
parent be50b41b78
commit bcd8f25a62
8 changed files with 357 additions and 30 deletions

502
etc/fitting/fitter.py Normal file
View File

@@ -0,0 +1,502 @@
from dataclasses import dataclass, field
import numpy as np
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import rv_continuous, goodness_of_fit
@dataclass
class DistributionSummary:
"""
Summary of a single distribution fit against a dataset.
Attributes
----------
distribution_object : rv_continuous
The scipy distribution object used for fitting.
distribution_name : str
Human-readable name of the distribution.
args_fit_params : tuple
Initial guess positional parameters passed to fit() (empty tuple if none).
kwds_fit_params : dict
Keyword arguments that were passed to fit() (fixed params, etc.).
fit_result_params : tuple
The actual fitted parameters returned by fit() (empty tuple until fit() is called).
statistic_method : str
GoF statistic identifier used in validate() (e.g. 'ad', 'ks').
test_result : object
Result object from scipy.stats.goodness_of_fit; None until
validate() is called. Exposes .statistic and .pvalue.
Computed properties
-------------------
pvalue : float | None p-value from the GoF test
gof_statistic : float | None test statistic from the GoF test
mean : float mean of the fitted distribution
std : float standard deviation of the fitted distribution
var : float variance of the fitted distribution
"""
distribution_object: rv_continuous
distribution_name: str
args_fit_params: tuple = field(default_factory=tuple)
kwds_fit_params: dict = field(default_factory=dict)
fit_result_params: tuple = field(default_factory=tuple)
statistic_method: str = 'ad'
test_result: object = None
# ── properties ────────────────────────────────────────────────
@property
def pvalue(self) -> float | None:
"""p-value from the goodness-of-fit test.
Returns
-------
float or None
The p-value produced by the GoF test, or None if validate() has
not been called yet.
"""
return self.test_result.pvalue if self.test_result is not None else None
@property
def gof_statistic(self) -> float | None:
"""Test statistic from the goodness-of-fit test.
Returns
-------
float or None
The GoF statistic value, or None if validate() has not been
called yet.
"""
return self.test_result.statistic if self.test_result is not None else None
@property
def mean(self) -> float:
"""Mean of the fitted distribution.
Returns
-------
float
Mean computed from the fitted distribution parameters.
"""
return self.distribution_object.mean(*self.fit_result_params)
@property
def std(self) -> float:
"""Standard deviation of the fitted distribution.
Returns
-------
float
Standard deviation computed from the fitted distribution parameters.
"""
return self.distribution_object.std(*self.fit_result_params)
@property
def var(self) -> float:
"""Variance of the fitted distribution.
Returns
-------
float
Variance computed from the fitted distribution parameters.
"""
return self.distribution_object.var(*self.fit_result_params)
def __repr__(self) -> str:
"""Return a concise string representation of the distribution summary.
Returns
-------
str
Single-line string showing the distribution name, fitted
parameters, mean, standard deviation, GoF statistic, and p-value.
"""
pval_str = f"{self.pvalue:.4f}" if self.pvalue is not None else "N/A"
stat_str = f"{self.gof_statistic:.4f}" if self.gof_statistic is not None else "N/A"
return (
f"DistributionSummary({self.distribution_name})"
f" params={self.fit_result_params}"
f" mean={self.mean:.4f} std={self.std:.4f}"
f" GoF[{self.statistic_method}]={stat_str} p={pval_str}"
)
def __str__(self) -> str:
"""Return the same string representation as ``__repr__``.
Returns
-------
str
Delegates to :meth:`__repr__`.
"""
return self.__repr__()
class Fitter:
"""
Fits and evaluates multiple distributions against a dataset.
Parameters
----------
dist_list : list[rv_continuous]
Distributions to fit.
statistic_method : str
Goodness-of-fit statistic passed to goodness_of_fit (default 'ad').
**kwargs :
Per-distribution initial guesses and fixed parameters, keyed as:
- ``<dist.name>_args`` : tuple of positional initial guesses
- ``<dist.name>_params`` : dict of keyword fixed parameters
Example:
Fitter(
[gamma, weibull_min],
gamma_args=(2.0,),
gamma_params={'floc': 0},
weibull_min_args=(1.5, 0.0, 1.0),
)
Access fitted distributions
---------------------------
Distributions can be accessed by name (str) or by the rv_continuous object:
fitter = Fitter([gamma, weibull_min])
fitter.fit(data)
# Access by name
summary = fitter['gamma']
# Access by rv_continuous object
summary = fitter[gamma]
# Check if a distribution is present
gamma in fitter # True
'weibull_min' in fitter # True
# Override a DistributionSummary
fitter['gamma'] = new_summary
"""
def __init__(self, dist_list: list[rv_continuous], statistic_method: str = 'ad', **kwargs):
"""Initialise the Fitter and build per-distribution summary objects.
Parameters
----------
dist_list : list[rv_continuous]
Distributions to fit.
statistic_method : str, optional
Goodness-of-fit statistic passed to ``goodness_of_fit``
(default ``'ad'`` for Anderson-Darling).
**kwargs :
Per-distribution initial guesses and fixed parameters, keyed as
``<dist.name>_args`` (tuple) or ``<dist.name>_params`` (dict).
"""
self._dist: dict[str, DistributionSummary] = {}
self.dist_list = list(dist_list)
for dist in dist_list:
self._dist[dist.name] = DistributionSummary(
distribution_object=dist,
distribution_name=dist.name,
args_fit_params=kwargs.get(f'{dist.name}_args', ()),
kwds_fit_params=kwargs.get(f'{dist.name}_params', {}),
statistic_method=statistic_method,
test_result=None,
)
#── Getter and setter with flexible keys (str or rv_continuous) ────────────────────────────────
def _resolve_key(self, key: str | rv_continuous) -> str:
"""Resolve a distribution name or ``rv_continuous`` object to its string key.
Parameters
----------
key : str or rv_continuous
Either the distribution's string name or its ``rv_continuous``
object (whose ``.name`` attribute is used).
Returns
-------
str
The string key used in the internal ``_dist`` dictionary.
Raises
------
KeyError
If the resolved name is not found among the registered
distributions.
"""
name = key.name if isinstance(key, rv_continuous) else key
if name not in self._dist:
available = ', '.join(self._dist)
raise KeyError(f"Distribution '{name}' not found. Available: {available}")
return name
def __contains__(self, key: str | rv_continuous) -> bool:
"""Check whether a distribution is registered with this Fitter.
Parameters
----------
key : str or rv_continuous
Distribution name or ``rv_continuous`` object to look up.
Returns
-------
bool
True if the distribution is registered, False otherwise.
"""
name = key.name if isinstance(key, rv_continuous) else key
return name in self._dist
def __getitem__(self, key: str | rv_continuous) -> DistributionSummary:
"""Retrieve the :class:`DistributionSummary` for the given distribution.
Parameters
----------
key : str or rv_continuous
Distribution name or ``rv_continuous`` object to retrieve.
Returns
-------
DistributionSummary
The summary object associated with the requested distribution.
Raises
------
KeyError
If the distribution is not registered.
"""
return self._dist[self._resolve_key(key)]
def __setitem__(self, key: str | rv_continuous, summary: DistributionSummary) -> None:
"""Override the :class:`DistributionSummary` for an existing distribution.
Parameters
----------
key : str or rv_continuous
Distribution name or ``rv_continuous`` object to update.
summary : DistributionSummary
Replacement summary object.
Raises
------
TypeError
If ``summary`` is not a :class:`DistributionSummary` instance.
KeyError
If the distribution is not registered.
"""
if not isinstance(summary, DistributionSummary):
raise TypeError(f"Expected DistributionSummary, got {type(summary).__name__}.")
self._dist[self._resolve_key(key)] = summary
def fit(self, data: np.ndarray) -> None:
"""Fit every distribution to *data* via MLE.
Parameters
----------
data : array-like
Input data. Only the absolute value is used; the array is
flattened before fitting.
Returns
-------
None
Fitted parameters are stored in-place inside each
:class:`DistributionSummary` held by this Fitter.
"""
data_flat = np.abs(data).flatten()
self._last_data_flat = data_flat
for dist in self.dist_list:
_summary = self._dist[dist.name]
fit_params = dist.fit(data_flat, *_summary.args_fit_params, **_summary.kwds_fit_params)
_summary.fit_result_params = fit_params
self._dist[dist.name] = _summary
def validate(self, **kwargs) -> None:
"""Run the goodness-of-fit test on every previously fitted distribution.
Parameters
----------
**kwargs :
Extra keyword arguments forwarded to ``scipy.stats.goodness_of_fit``
(e.g. ``n_mc_samples=100``).
Returns
-------
None
Test results are stored in-place inside each
:class:`DistributionSummary` held by this Fitter.
Raises
------
RuntimeError
If :meth:`fit` has not been called before this method.
"""
if not hasattr(self, '_last_data_flat'):
raise RuntimeError("No data available. Call fit() first.")
data_flat = self._last_data_flat
for dist in self.dist_list:
_summary = self._dist[dist.name]
test_result = goodness_of_fit(
dist,
data_flat,
statistic=_summary.statistic_method,
**kwargs
)
_summary.test_result = test_result
self._dist[dist.name] = _summary
def summary(self) -> None:
"""Print a one-line summary for each registered distribution.
Returns
-------
None
Output is written to stdout via ``print``.
"""
for dist_name, summary in self._dist.items():
print(summary)
def plot_qq_plots(self) -> None:
"""Generate QQ plots for each fitted distribution against the data.
A separate interactive Plotly figure is displayed for every
distribution that has been both fitted and validated. Distributions
that have not yet been validated are skipped with a printed warning.
Returns
-------
None
Figures are rendered inline / in a browser via ``fig.show()``.
Raises
------
RuntimeError
If :meth:`fit` has not been called before this method.
"""
if not hasattr(self, '_last_data_flat'):
raise RuntimeError("No data available. Call fit() first.")
data_flat = self._last_data_flat
for dist_name, summary in self._dist.items():
if summary.test_result is None:
print(f"Distribution '{dist_name}' has not been validated yet. Skipping QQ plot.")
continue
sorted_data = np.sort(data_flat)
theoretical_quantiles = summary.distribution_object.ppf(
(np.arange(1, len(sorted_data) + 1) - 0.5) / len(sorted_data),
*summary.fit_result_params
)
fig = go.Figure()
fig.add_trace(go.Scatter(x=theoretical_quantiles, y=sorted_data, mode='markers', name='Data vs. Fit'))
fig.add_trace(go.Scatter(x=theoretical_quantiles, y=theoretical_quantiles, mode='lines', name='Ideal Fit', line=dict(dash='dash')))
fig.update_layout(
title=f'QQ Plot for {summary.distribution_name}',
xaxis_title='Theoretical Quantiles',
yaxis_title='Empirical Quantiles',
autosize=True,
)
fig.show()
def histogram_with_fits(self) -> go.Figure:
"""Return an interactive histogram with overlaid fitted PDFs (Plotly).
Builds a probability-density histogram of the data and overlays a
line trace for the PDF of each fitted distribution. Hover text shows
the p-value and GoF statistic for each curve. Distributions that
have not yet been fitted are skipped with a printed warning.
Returns
-------
plotly.graph_objects.Figure
Interactive Plotly figure ready to display with ``fig.show()``.
Raises
------
RuntimeError
If :meth:`fit` has not been called before this method.
"""
if not hasattr(self, '_last_data_flat'):
raise RuntimeError("No data available. Call fit() first.")
data_flat = self._last_data_flat
x = np.linspace(0, data_flat.max(), 1000)
fig = go.Figure(layout=go.Layout(hovermode='x unified'))
fig.add_trace(go.Histogram(
x=data_flat, name='Data', opacity=0.3,
histnorm='probability density', hoverinfo='skip', marker_color='blue',
))
for dist_name, summary in self._dist.items():
if not summary.fit_result_params:
print(f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay.")
continue
pdf_values = summary.distribution_object.pdf(x, *summary.fit_result_params)
fig.add_trace(go.Scatter(x=x, y=pdf_values, mode='lines', name=f'{summary.distribution_name} Fit'))
hover_text = [
f"<b>{summary.distribution_name}</b> p-value: {summary.pvalue:.4f} GoF: {summary.gof_statistic:.4f}"
for _ in x
]
fig.data[-1].update(hovertext=hover_text, hoverinfo="text")
fig.update_layout(
xaxis=dict(showgrid=True),
yaxis=dict(showgrid=True),
title=dict(
text='Histogram of Data with Fitted Distributions',
x=0.02, y=0.95, xanchor='left', yanchor='top',
font=dict(size=20, color='darkgray', family='sans-serif'),
),
xaxis_title='Value',
yaxis_title='Density',
autosize=True,
)
return fig
def histogram_with_fits_seaborn(self) -> plt.Figure:
"""Return a static histogram with overlaid fitted PDFs (Matplotlib/Seaborn).
Builds a probability-density histogram using Seaborn and overlays a
line for the PDF of each fitted distribution. The legend entry for
each distribution includes its p-value. Distributions that have not
yet been fitted are skipped with a printed warning.
Returns
-------
matplotlib.figure.Figure
Matplotlib figure object.
Raises
------
RuntimeError
If :meth:`fit` has not been called before this method.
"""
if not hasattr(self, '_last_data_flat'):
raise RuntimeError("No data available. Call fit() first.")
data_flat = self._last_data_flat
x = np.linspace(0, data_flat.max(), 1000)
fig, ax = plt.subplots(figsize=(10, 6))
sns.histplot(data_flat, bins=int(np.sqrt(len(data_flat))), kde=False, stat='density', color='blue', alpha=0.2, ax=ax)
for dist_name, summary in self._dist.items():
if not summary.fit_result_params:
print(f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay.")
continue
pdf_values = summary.distribution_object.pdf(x, *summary.fit_result_params)
ax.plot(x, pdf_values, label=f'{summary.distribution_name} --- p={summary.pvalue:.4f}')
ax.set_title('Histogram of Data with Fitted Distributions', fontsize=8, loc='left', color='darkgray', fontfamily='sans-serif')
ax.set_xlabel('Value')
ax.set_ylabel('Density')
ax.legend()
ax.grid(True)
return fig