from dataclasses import dataclass, field import numpy as np import plotly.graph_objects as go import matplotlib.pyplot as plt import seaborn as sns from scipy.stats import rv_continuous, goodness_of_fit from plotly.subplots import make_subplots @dataclass class DistributionSummary: """ Summary of a single distribution fit against a dataset. Attributes ---------- distribution_object : rv_continuous The scipy distribution object used for fitting. distribution_name : str Human-readable name of the distribution. args_fit_params : tuple Initial guess positional parameters passed to fit() (empty tuple if none). kwds_fit_params : dict Keyword arguments that were passed to fit() (fixed params, etc.). fit_result_params : tuple The actual fitted parameters returned by fit() (empty tuple until fit() is called). statistic_method : str GoF statistic identifier used in validate() (e.g. 'ad', 'ks'). test_result : object Result object from scipy.stats.goodness_of_fit; None until validate() is called. Exposes .statistic and .pvalue. Computed properties ------------------- pvalue : float | None – p-value from the GoF test gof_statistic : float | None – test statistic from the GoF test mean : float – mean of the fitted distribution std : float – standard deviation of the fitted distribution var : float – variance of the fitted distribution """ distribution_object: rv_continuous distribution_name: str args_fit_params: tuple = field(default_factory=tuple) kwds_fit_params: dict = field(default_factory=dict) fit_result_params: tuple = field(default_factory=tuple) statistic_method: str = "ad" test_result: object = None # ── convenience properties ──────────────────────────────────────────────── @property def pvalue(self) -> float | None: """p-value from the goodness-of-fit test, or None if not yet run.""" return self.test_result.pvalue if self.test_result is not None else None @property def gof_statistic(self) -> float | None: """Test statistic from the goodness-of-fit test, or None if not yet run.""" return self.test_result.statistic if self.test_result is not None else None @property def mean(self) -> float: """Mean of the fitted distribution.""" return float(self.distribution_object.mean(*self.fit_result_params)) @property def std(self) -> float: """Standard deviation of the fitted distribution.""" return float(self.distribution_object.std(*self.fit_result_params)) @property def var(self) -> float: """Variance of the fitted distribution.""" return float(self.distribution_object.var(*self.fit_result_params)) def __repr__(self) -> str: pval_str = f"{self.pvalue:.4f}" if self.pvalue is not None else "N/A" stat_str = ( f"{self.gof_statistic:.4f}" if self.gof_statistic is not None else "N/A" ) return ( f"DistributionSummary({self.distribution_name})" f" params={self.fit_result_params}" f" mean={self.mean:.4f} std={self.std:.4f}" f" GoF[{self.statistic_method}]={stat_str} p={pval_str}" ) def __str__(self) -> str: return self.__repr__() class Fitter: """ Fits and evaluates multiple distributions against a dataset. Parameters ---------- dist_list : list[rv_continuous] Distributions to fit. statistic_method : str Goodness-of-fit statistic passed to goodness_of_fit (default 'ad'). **kwargs : Per-distribution initial guesses and fixed parameters, keyed as: - ``_args`` : tuple of positional initial guesses - ``_params`` : dict of keyword fixed parameters Example: Fitter( [gamma, weibull_min], gamma_args=(2.0,), gamma_params={'floc': 0}, weibull_min_args=(1.5, 0.0, 1.0), ) """ def __init__( self, dist_list: list[rv_continuous], statistic_method: str = "ad", **kwargs ): self._dist: dict[str, DistributionSummary] = {} self.dist_list = list(dist_list) # Ensure it's a list for multiple iterations for dist in dist_list: self._dist[dist.name] = DistributionSummary( distribution_object=dist, distribution_name=dist.name, args_fit_params=kwargs.get(f"{dist.name}_args", ()), kwds_fit_params=kwargs.get(f"{dist.name}_params", {}), statistic_method=statistic_method, test_result=None, ) def _resolve_key(self, key: str | rv_continuous) -> str: """Resolve a distribution name or rv_continuous object to its string key.""" name = key.name if isinstance(key, rv_continuous) else key if name not in self._dist: available = ", ".join(self._dist) raise KeyError(f"Distribution '{name}' not found. Available: {available}") return name def __contains__(self, key: str | rv_continuous) -> bool: name = key.name if isinstance(key, rv_continuous) else key return name in self._dist def __getitem__(self, key: str | rv_continuous) -> DistributionSummary: return self._dist[self._resolve_key(key)] def __setitem__( self, key: str | rv_continuous, summary: DistributionSummary ) -> None: if not isinstance(summary, DistributionSummary): raise TypeError( f"Expected DistributionSummary, got {type(summary).__name__}." ) self._dist[self._resolve_key(key)] = summary def __iter__(self): return self def __next__(self): if not hasattr(self, "_iter_index"): self._iter_index = 0 if self._iter_index >= len(self._dist): del self._iter_index raise StopIteration key = list(self._dist.keys())[self._iter_index] summary = self._dist[key] self._iter_index += 1 return key, summary def fit(self, data: np.ndarray) -> dict[str, DistributionSummary]: """ Fit every distribution to *data* via MLE. Parameters ---------- data : array-like Input data. Only the absolute value is used. Returns ------- dict[str, DistributionSummary] Mapping of distribution name → summary (test_result is None until validate() is called). """ data_flat = np.abs(data).flatten() self._last_data_flat = data_flat for dist in self.dist_list: _summary = self._dist[dist.name] fit_params = dist.fit( data_flat, *_summary.args_fit_params, **_summary.kwds_fit_params ) _summary.fit_result_params = fit_params self._dist[dist.name] = _summary def validate(self, **kwargs) -> dict[str, DistributionSummary]: """ Run the goodness-of-fit test on every previously fitted distribution. Parameters ---------- **kwargs : Extra keyword arguments forwarded to goodness_of_fit() (e.g. n_mc_samples=100). Returns ------- dict[str, DistributionSummary] Same results dict, with test_result populated for each distribution. """ if not hasattr(self, "_last_data_flat"): raise RuntimeError("No data available. Call fit() first.") data_flat = self._last_data_flat for dist in self.dist_list: _summary = self._dist[dist.name] test_result = goodness_of_fit( dist, data_flat, statistic=_summary.statistic_method, **kwargs ) _summary.test_result = test_result self._dist[dist.name] = _summary def summary(self) -> dict[str, DistributionSummary]: """ Print a summary of all fitted distributions, including parameters and GoF results. """ for dist_name, summary in self._dist.items(): print(summary) def plot_qq_plots(self, method: str = "hazen"): """ Generate QQ plots for each fitted distribution against the data. Requires that fit() and validate() have been called to populate parameters and test results. Parameters ---------- method : str Plotting positions formula. Either 'hazen' (default) or 'filliben'. - 'hazen' : p_i = (i - 0.5) / n - 'filliben': p_1 = 1 - 0.5^(1/n), p_n = 0.5^(1/n), p_i = (i - 0.3175) / (n + 0.365) for 1 < i < n """ if not hasattr(self, "_last_data_flat"): raise RuntimeError("No data available. Call fit() first.") if method not in ("hazen", "filliben"): raise ValueError(f"method must be 'hazen' or 'filliben', got '{method}'.") data_flat = self._last_data_flat # generate subplots with 2 columns and as many rows as needed, but not more than 3 rows, if there are more than 6 distributions, create multiple figures num_dists = len(self._dist) num_cols = 2 num_rows = min(3, (num_dists + 1) // 2) fig = make_subplots( rows=num_rows, cols=num_cols, subplot_titles=[dist_name for dist_name in self._dist.keys()], ) for dist_name, summary in self: if summary.test_result is None: print( f"Distribution '{dist_name}' has not been validated yet. Skipping QQ plot." ) continue # Generate theoretical quantiles sorted_data = np.sort(data_flat) n = len(sorted_data) i = np.arange(1, n + 1) if method == "hazen": plotting_positions = (i - 0.5) / n else: # filliben plotting_positions = (i - 0.3175) / (n + 0.365) plotting_positions[0] = 1 - 0.5 ** (1 / n) plotting_positions[-1] = 0.5 ** (1 / n) theoretical_quantiles = summary.distribution_object.ppf( plotting_positions, *summary.fit_result_params ) # Create QQ plot in each subplot row = (list(self._dist.keys()).index(dist_name) // num_cols) + 1 col = (list(self._dist.keys()).index(dist_name) % num_cols) + 1 fig.add_trace( go.Scatter( x=theoretical_quantiles, y=sorted_data, mode="markers", name=dist_name, ), row=row, col=col, ) # Add a reference line y=x min_val = min(theoretical_quantiles.min(), sorted_data.min()) max_val = max(theoretical_quantiles.max(), sorted_data.max()) fig.add_trace( go.Scatter( x=[min_val, max_val], y=[min_val, max_val], mode="lines", name="y=x", line=dict(dash="dash"), ), row=row, col=col, ) fig.update_xaxes(title_text="Theoretical Quantiles", row=row, col=col) fig.update_yaxes(title_text="Empirical Quantiles", row=row, col=col) # add statistic value in bottom right of each subplot (summary.gof_statistic()) fig.add_annotation( x=0.95, y=0.05, xref="x domain", yref="y domain", text=f"{summary.statistic_method}={summary.gof_statistic:.4f}", showarrow=False, font=dict(size=10, color="green" if summary.pvalue > 0.05 else "red"), row=row, col=col, ) fig.update_layout( title=f"QQ Plots of Fitted Distributions ({method})", showlegend=False, autosize=True, ) return fig def histogram_with_fits(self): """ Generate a histogram of the data with overlaid PDFs of each fitted distribution. Requires that fit() has been called to populate parameters. """ if not hasattr(self, "_last_data_flat"): raise RuntimeError("No data available. Call fit() first.") data_flat = self._last_data_flat x = np.linspace(0, data_flat.max(), 1000) fig = go.Figure(layout=go.Layout(hovermode="x unified")) # do not show data in hoover, only show the distribution name, p-value and GoF statistic for each distribution fig.add_trace( go.Histogram( x=data_flat, name="Data", opacity=0.3, histnorm="probability density", hoverinfo="skip", marker_color="blue", ) ) for dist_name, summary in self._dist.items(): if not summary.fit_result_params: print( f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay." ) continue pdf_values = summary.distribution_object.pdf(x, *summary.fit_result_params) # add trace and stack hoover x like stock price, but make the y value shows the p-value and GoF statistic for each distribution fig.add_trace( go.Scatter( x=x, y=pdf_values, mode="lines", name=f"{summary.distribution_name} Fit", ) ) hover_text = [ f"{summary.distribution_name} p-value: {summary.pvalue:.4f} GoF: {summary.gof_statistic:.4f}" for _ in x ] fig.data[-1].update(hovertext=hover_text, hoverinfo="text") # add grid fig.update_layout(xaxis=dict(showgrid=True), yaxis=dict(showgrid=True)) # put title in top left, make it smaller, change it font to sans and put in light gray fig.update_layout( title=dict( text="Histogram of Data with Fitted Distributions", x=0.02, y=0.95, xanchor="left", yanchor="top", font=dict(size=20, color="darkgray", family="sans-serif"), ), xaxis_title="Value", yaxis_title="Density", autosize=True, ) return fig def histogram_with_fits_seaborn(self): """ Generate a histogram of the data with overlaid PDFs of each fitted distribution using seaborn. Requires that fit() has been called to populate parameters. """ if not hasattr(self, "_last_data_flat"): raise RuntimeError("No data available. Call fit() first.") data_flat = self._last_data_flat x = np.linspace(0, data_flat.max(), 1000) fig, ax = plt.subplots(figsize=(10, 6)) sns.histplot( data_flat, bins=int(np.sqrt(len(data_flat))), kde=False, stat="density", color="blue", alpha=0.2, ax=ax, ) for dist_name, summary in self._dist.items(): if not summary.fit_result_params: print( f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay." ) continue pdf_values = summary.distribution_object.pdf(x, *summary.fit_result_params) ax.plot( x, pdf_values, label=f"{summary.distribution_name} --- p={summary.pvalue:.4f}", ) # put title in top left, make it smaller, change it font to sans and put in light gray ax.set_title( "Histogram of Data with Fitted Distributions", fontsize=8, loc="left", color="darkgray", fontfamily="sans-serif", ) ax.set_xlabel("Value") ax.set_ylabel("Density") ax.legend() ax.grid(True) return fig