from dataclasses import dataclass, field import numpy as np import plotly.graph_objects as go import matplotlib.pyplot as plt import seaborn as sns from scipy.stats import rv_continuous, goodness_of_fit @dataclass class DistributionSummary: """ Summary of a single distribution fit against a dataset. Attributes ---------- distribution_object : rv_continuous The scipy distribution object used for fitting. distribution_name : str Human-readable name of the distribution. args_fit_params : tuple Initial guess positional parameters passed to fit() (empty tuple if none). kwds_fit_params : dict Keyword arguments that were passed to fit() (fixed params, etc.). fit_result_params : tuple The actual fitted parameters returned by fit() (empty tuple until fit() is called). statistic_method : str GoF statistic identifier used in validate() (e.g. 'ad', 'ks'). test_result : object Result object from scipy.stats.goodness_of_fit; None until validate() is called. Exposes .statistic and .pvalue. Computed properties ------------------- pvalue : float | None – p-value from the GoF test gof_statistic : float | None – test statistic from the GoF test mean : float – mean of the fitted distribution std : float – standard deviation of the fitted distribution var : float – variance of the fitted distribution """ distribution_object: rv_continuous distribution_name: str args_fit_params: tuple = field(default_factory=tuple) kwds_fit_params: dict = field(default_factory=dict) fit_result_params: tuple = field(default_factory=tuple) statistic_method: str = 'ad' test_result: object = None # ── properties ──────────────────────────────────────────────── @property def pvalue(self) -> float | None: """p-value from the goodness-of-fit test. Returns ------- float or None The p-value produced by the GoF test, or None if validate() has not been called yet. """ return self.test_result.pvalue if self.test_result is not None else None @property def gof_statistic(self) -> float | None: """Test statistic from the goodness-of-fit test. Returns ------- float or None The GoF statistic value, or None if validate() has not been called yet. """ return self.test_result.statistic if self.test_result is not None else None @property def mean(self) -> float: """Mean of the fitted distribution. Returns ------- float Mean computed from the fitted distribution parameters. """ return self.distribution_object.mean(*self.fit_result_params) @property def std(self) -> float: """Standard deviation of the fitted distribution. Returns ------- float Standard deviation computed from the fitted distribution parameters. """ return self.distribution_object.std(*self.fit_result_params) @property def var(self) -> float: """Variance of the fitted distribution. Returns ------- float Variance computed from the fitted distribution parameters. """ return self.distribution_object.var(*self.fit_result_params) def __repr__(self) -> str: """Return a concise string representation of the distribution summary. Returns ------- str Single-line string showing the distribution name, fitted parameters, mean, standard deviation, GoF statistic, and p-value. """ pval_str = f"{self.pvalue:.4f}" if self.pvalue is not None else "N/A" stat_str = f"{self.gof_statistic:.4f}" if self.gof_statistic is not None else "N/A" return ( f"DistributionSummary({self.distribution_name})" f" params={self.fit_result_params}" f" mean={self.mean:.4f} std={self.std:.4f}" f" GoF[{self.statistic_method}]={stat_str} p={pval_str}" ) def __str__(self) -> str: """Return the same string representation as ``__repr__``. Returns ------- str Delegates to :meth:`__repr__`. """ return self.__repr__() class Fitter: """ Fits and evaluates multiple distributions against a dataset. Parameters ---------- dist_list : list[rv_continuous] Distributions to fit. statistic_method : str Goodness-of-fit statistic passed to goodness_of_fit (default 'ad'). **kwargs : Per-distribution initial guesses and fixed parameters, keyed as: - ``_args`` : tuple of positional initial guesses - ``_params`` : dict of keyword fixed parameters Example: Fitter( [gamma, weibull_min], gamma_args=(2.0,), gamma_params={'floc': 0}, weibull_min_args=(1.5, 0.0, 1.0), ) Access fitted distributions --------------------------- Distributions can be accessed by name (str) or by the rv_continuous object: fitter = Fitter([gamma, weibull_min]) fitter.fit(data) # Access by name summary = fitter['gamma'] # Access by rv_continuous object summary = fitter[gamma] # Check if a distribution is present gamma in fitter # True 'weibull_min' in fitter # True # Override a DistributionSummary fitter['gamma'] = new_summary """ def __init__(self, dist_list: list[rv_continuous], statistic_method: str = 'ad', **kwargs): """Initialise the Fitter and build per-distribution summary objects. Parameters ---------- dist_list : list[rv_continuous] Distributions to fit. statistic_method : str, optional Goodness-of-fit statistic passed to ``goodness_of_fit`` (default ``'ad'`` for Anderson-Darling). **kwargs : Per-distribution initial guesses and fixed parameters, keyed as ``_args`` (tuple) or ``_params`` (dict). """ self._dist: dict[str, DistributionSummary] = {} self.dist_list = list(dist_list) for dist in dist_list: self._dist[dist.name] = DistributionSummary( distribution_object=dist, distribution_name=dist.name, args_fit_params=kwargs.get(f'{dist.name}_args', ()), kwds_fit_params=kwargs.get(f'{dist.name}_params', {}), statistic_method=statistic_method, test_result=None, ) #── Getter and setter with flexible keys (str or rv_continuous) ──────────────────────────────── def _resolve_key(self, key: str | rv_continuous) -> str: """Resolve a distribution name or ``rv_continuous`` object to its string key. Parameters ---------- key : str or rv_continuous Either the distribution's string name or its ``rv_continuous`` object (whose ``.name`` attribute is used). Returns ------- str The string key used in the internal ``_dist`` dictionary. Raises ------ KeyError If the resolved name is not found among the registered distributions. """ name = key.name if isinstance(key, rv_continuous) else key if name not in self._dist: available = ', '.join(self._dist) raise KeyError(f"Distribution '{name}' not found. Available: {available}") return name def __contains__(self, key: str | rv_continuous) -> bool: """Check whether a distribution is registered with this Fitter. Parameters ---------- key : str or rv_continuous Distribution name or ``rv_continuous`` object to look up. Returns ------- bool True if the distribution is registered, False otherwise. """ name = key.name if isinstance(key, rv_continuous) else key return name in self._dist def __getitem__(self, key: str | rv_continuous) -> DistributionSummary: """Retrieve the :class:`DistributionSummary` for the given distribution. Parameters ---------- key : str or rv_continuous Distribution name or ``rv_continuous`` object to retrieve. Returns ------- DistributionSummary The summary object associated with the requested distribution. Raises ------ KeyError If the distribution is not registered. """ return self._dist[self._resolve_key(key)] def __setitem__(self, key: str | rv_continuous, summary: DistributionSummary) -> None: """Override the :class:`DistributionSummary` for an existing distribution. Parameters ---------- key : str or rv_continuous Distribution name or ``rv_continuous`` object to update. summary : DistributionSummary Replacement summary object. Raises ------ TypeError If ``summary`` is not a :class:`DistributionSummary` instance. KeyError If the distribution is not registered. """ if not isinstance(summary, DistributionSummary): raise TypeError(f"Expected DistributionSummary, got {type(summary).__name__}.") self._dist[self._resolve_key(key)] = summary def fit(self, data: np.ndarray) -> None: """Fit every distribution to *data* via MLE. Parameters ---------- data : array-like Input data. Only the absolute value is used; the array is flattened before fitting. Returns ------- None Fitted parameters are stored in-place inside each :class:`DistributionSummary` held by this Fitter. """ data_flat = np.abs(data).flatten() self._last_data_flat = data_flat for dist in self.dist_list: _summary = self._dist[dist.name] fit_params = dist.fit(data_flat, *_summary.args_fit_params, **_summary.kwds_fit_params) _summary.fit_result_params = fit_params self._dist[dist.name] = _summary def validate(self, **kwargs) -> None: """Run the goodness-of-fit test on every previously fitted distribution. Parameters ---------- **kwargs : Extra keyword arguments forwarded to ``scipy.stats.goodness_of_fit`` (e.g. ``n_mc_samples=100``). Returns ------- None Test results are stored in-place inside each :class:`DistributionSummary` held by this Fitter. Raises ------ RuntimeError If :meth:`fit` has not been called before this method. """ if not hasattr(self, '_last_data_flat'): raise RuntimeError("No data available. Call fit() first.") data_flat = self._last_data_flat for dist in self.dist_list: _summary = self._dist[dist.name] test_result = goodness_of_fit( dist, data_flat, statistic=_summary.statistic_method, **kwargs ) _summary.test_result = test_result self._dist[dist.name] = _summary def summary(self) -> None: """Print a one-line summary for each registered distribution. Returns ------- None Output is written to stdout via ``print``. """ for dist_name, summary in self._dist.items(): print(summary) def plot_qq_plots(self) -> None: """Generate QQ plots for each fitted distribution against the data. A separate interactive Plotly figure is displayed for every distribution that has been both fitted and validated. Distributions that have not yet been validated are skipped with a printed warning. Returns ------- None Figures are rendered inline / in a browser via ``fig.show()``. Raises ------ RuntimeError If :meth:`fit` has not been called before this method. """ if not hasattr(self, '_last_data_flat'): raise RuntimeError("No data available. Call fit() first.") data_flat = self._last_data_flat for dist_name, summary in self._dist.items(): if summary.test_result is None: print(f"Distribution '{dist_name}' has not been validated yet. Skipping QQ plot.") continue sorted_data = np.sort(data_flat) theoretical_quantiles = summary.distribution_object.ppf( (np.arange(1, len(sorted_data) + 1) - 0.5) / len(sorted_data), *summary.fit_result_params ) fig = go.Figure() fig.add_trace(go.Scatter(x=theoretical_quantiles, y=sorted_data, mode='markers', name='Data vs. Fit')) fig.add_trace(go.Scatter(x=theoretical_quantiles, y=theoretical_quantiles, mode='lines', name='Ideal Fit', line=dict(dash='dash'))) fig.update_layout( title=f'QQ Plot for {summary.distribution_name}', xaxis_title='Theoretical Quantiles', yaxis_title='Empirical Quantiles', autosize=True, ) fig.show() def histogram_with_fits(self) -> go.Figure: """Return an interactive histogram with overlaid fitted PDFs (Plotly). Builds a probability-density histogram of the data and overlays a line trace for the PDF of each fitted distribution. Hover text shows the p-value and GoF statistic for each curve. Distributions that have not yet been fitted are skipped with a printed warning. Returns ------- plotly.graph_objects.Figure Interactive Plotly figure ready to display with ``fig.show()``. Raises ------ RuntimeError If :meth:`fit` has not been called before this method. """ if not hasattr(self, '_last_data_flat'): raise RuntimeError("No data available. Call fit() first.") data_flat = self._last_data_flat x = np.linspace(0, data_flat.max(), 1000) fig = go.Figure(layout=go.Layout(hovermode='x unified')) fig.add_trace(go.Histogram( x=data_flat, name='Data', opacity=0.3, histnorm='probability density', hoverinfo='skip', marker_color='blue', )) for dist_name, summary in self._dist.items(): if not summary.fit_result_params: print(f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay.") continue pdf_values = summary.distribution_object.pdf(x, *summary.fit_result_params) fig.add_trace(go.Scatter(x=x, y=pdf_values, mode='lines', name=f'{summary.distribution_name} Fit')) hover_text = [ f"{summary.distribution_name} p-value: {summary.pvalue:.4f} GoF: {summary.gof_statistic:.4f}" for _ in x ] fig.data[-1].update(hovertext=hover_text, hoverinfo="text") fig.update_layout( xaxis=dict(showgrid=True), yaxis=dict(showgrid=True), title=dict( text='Histogram of Data with Fitted Distributions', x=0.02, y=0.95, xanchor='left', yanchor='top', font=dict(size=20, color='darkgray', family='sans-serif'), ), xaxis_title='Value', yaxis_title='Density', autosize=True, ) return fig def histogram_with_fits_seaborn(self) -> plt.Figure: """Return a static histogram with overlaid fitted PDFs (Matplotlib/Seaborn). Builds a probability-density histogram using Seaborn and overlays a line for the PDF of each fitted distribution. The legend entry for each distribution includes its p-value. Distributions that have not yet been fitted are skipped with a printed warning. Returns ------- matplotlib.figure.Figure Matplotlib figure object. Raises ------ RuntimeError If :meth:`fit` has not been called before this method. """ if not hasattr(self, '_last_data_flat'): raise RuntimeError("No data available. Call fit() first.") data_flat = self._last_data_flat x = np.linspace(0, data_flat.max(), 1000) fig, ax = plt.subplots(figsize=(10, 6)) sns.histplot(data_flat, bins=int(np.sqrt(len(data_flat))), kde=False, stat='density', color='blue', alpha=0.2, ax=ax) for dist_name, summary in self._dist.items(): if not summary.fit_result_params: print(f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay.") continue pdf_values = summary.distribution_object.pdf(x, *summary.fit_result_params) ax.plot(x, pdf_values, label=f'{summary.distribution_name} --- p={summary.pvalue:.4f}') ax.set_title('Histogram of Data with Fitted Distributions', fontsize=8, loc='left', color='darkgray', fontfamily='sans-serif') ax.set_xlabel('Value') ax.set_ylabel('Density') ax.legend() ax.grid(True) return fig