Add three new continuous random variables for log-domain and linear-domain clutter modeling with compound Gamma-Rice structure. Fix numerical stability of k_dist._logpdf and logk._log_kve via a three-regime log(kve) asymptotic (direct / large-z Hankel / large-order Gamma); replace quad-based k_dist._cdf with Gauss-Laguerre quadrature. Fix fitter: use np.asarray instead of np.abs in fit(), pass fit_params to goodness_of_fit so the observed-data statistic reuses fitted params. Skip non-finite quantiles in QQ plots. Add plot_qq_plots_sns(); rename histogram_with_fits_seaborn() to histogram_with_fits_sns(). Add unit tests for logweibull and logricegamma.
545 lines
20 KiB
Python
545 lines
20 KiB
Python
from dataclasses import dataclass, field
|
||
|
||
import numpy as np
|
||
import plotly.graph_objects as go
|
||
import matplotlib.pyplot as plt
|
||
import seaborn as sns
|
||
from scipy.stats import rv_continuous, goodness_of_fit
|
||
from plotly.subplots import make_subplots
|
||
|
||
|
||
def set_plot_style():
|
||
sns.set_style("whitegrid")
|
||
sns.set_context("paper", font_scale=1.25)
|
||
|
||
|
||
|
||
@dataclass
|
||
class DistributionSummary:
|
||
"""
|
||
Summary of a single distribution fit against a dataset.
|
||
|
||
Attributes
|
||
----------
|
||
distribution_object : rv_continuous
|
||
The scipy distribution object used for fitting.
|
||
distribution_name : str
|
||
Human-readable name of the distribution.
|
||
args_fit_params : tuple
|
||
Initial guess positional parameters passed to fit() (empty tuple if none).
|
||
kwds_fit_params : dict
|
||
Keyword arguments that were passed to fit() (fixed params, etc.).
|
||
fit_result_params : tuple
|
||
The actual fitted parameters returned by fit() (empty tuple until fit() is called).
|
||
statistic_method : object
|
||
GoF statistic identifier used in validate() (e.g. 'ad', 'ks' or a custom callable).
|
||
test_result : object
|
||
Result object from scipy.stats.goodness_of_fit; None until
|
||
validate() is called. Exposes .statistic and .pvalue.
|
||
|
||
Computed properties
|
||
-------------------
|
||
pvalue : float | None – p-value from the GoF test
|
||
gof_statistic : float | None – test statistic from the GoF test
|
||
mean : float – mean of the fitted distribution
|
||
std : float – standard deviation of the fitted distribution
|
||
var : float – variance of the fitted distribution
|
||
"""
|
||
|
||
distribution_object: rv_continuous
|
||
distribution_name: str
|
||
args_fit_params: tuple = field(default_factory=tuple)
|
||
kwds_fit_params: dict = field(default_factory=dict)
|
||
fit_result_params: tuple = field(default_factory=tuple)
|
||
statistic_method: object = "ad"
|
||
test_result: object = None
|
||
|
||
# ── convenience properties ────────────────────────────────────────────────
|
||
|
||
@property
|
||
def pvalue(self) -> float | None:
|
||
"""p-value from the goodness-of-fit test, or None if not yet run."""
|
||
return self.test_result.pvalue if self.test_result is not None else None
|
||
|
||
@property
|
||
def gof_statistic(self) -> float | None:
|
||
"""Test statistic from the goodness-of-fit test, or None if not yet run."""
|
||
return self.test_result.statistic if self.test_result is not None else None
|
||
|
||
@property
|
||
def mean(self) -> float:
|
||
"""Mean of the fitted distribution."""
|
||
return float(self.distribution_object.mean(*self.fit_result_params))
|
||
|
||
@property
|
||
def std(self) -> float:
|
||
"""Standard deviation of the fitted distribution."""
|
||
return float(self.distribution_object.std(*self.fit_result_params))
|
||
|
||
@property
|
||
def var(self) -> float:
|
||
"""Variance of the fitted distribution."""
|
||
return float(self.distribution_object.var(*self.fit_result_params))
|
||
|
||
def __repr__(self) -> str:
|
||
pval_str = f"{self.pvalue:.4f}" if self.pvalue is not None else "N/A"
|
||
stat_str = (
|
||
f"{self.gof_statistic:.4f}" if self.gof_statistic is not None else "N/A"
|
||
)
|
||
return (
|
||
f"DistributionSummary({self.distribution_name})"
|
||
f" params={self.fit_result_params}"
|
||
f" mean={self.mean:.4f} std={self.std:.4f}"
|
||
f" GoF[{self.statistic_method}]={stat_str} p={pval_str}"
|
||
)
|
||
|
||
def __str__(self) -> str:
|
||
return self.__repr__()
|
||
|
||
|
||
class Fitter:
|
||
"""
|
||
Fits and evaluates multiple distributions against a dataset.
|
||
|
||
Parameters
|
||
----------
|
||
dist_list : list[rv_continuous]
|
||
Distributions to fit.
|
||
statistic_method : str
|
||
Goodness-of-fit statistic passed to goodness_of_fit (default 'ad').
|
||
**kwargs :
|
||
Per-distribution initial guesses and fixed parameters, keyed as:
|
||
- ``<dist.name>_args`` : tuple of positional initial guesses
|
||
- ``<dist.name>_params`` : dict of keyword fixed parameters
|
||
Example:
|
||
Fitter(
|
||
[gamma, weibull_min],
|
||
gamma_args=(2.0,),
|
||
gamma_params={'floc': 0},
|
||
weibull_min_args=(1.5, 0.0, 1.0),
|
||
)
|
||
"""
|
||
|
||
def __init__(
|
||
self, dist_list: list[rv_continuous], statistic_method: str = "ad", **kwargs
|
||
):
|
||
self._dist: dict[str, DistributionSummary] = {}
|
||
self.dist_list = list(dist_list) # Ensure it's a list for multiple iterations
|
||
for dist in dist_list:
|
||
self._dist[dist.name] = DistributionSummary(
|
||
distribution_object=dist,
|
||
distribution_name=dist.name,
|
||
args_fit_params=kwargs.get(f"{dist.name}_args", ()),
|
||
kwds_fit_params=kwargs.get(f"{dist.name}_params", {}),
|
||
statistic_method=statistic_method,
|
||
test_result=None,
|
||
)
|
||
|
||
def _resolve_key(self, key: str | rv_continuous) -> str:
|
||
"""Resolve a distribution name or rv_continuous object to its string key."""
|
||
name = key.name if isinstance(key, rv_continuous) else key
|
||
if name not in self._dist:
|
||
available = ", ".join(self._dist)
|
||
raise KeyError(f"Distribution '{name}' not found. Available: {available}")
|
||
return name
|
||
|
||
def __contains__(self, key: str | rv_continuous) -> bool:
|
||
name = key.name if isinstance(key, rv_continuous) else key
|
||
return name in self._dist
|
||
|
||
def __getitem__(self, key: str | rv_continuous) -> DistributionSummary:
|
||
return self._dist[self._resolve_key(key)]
|
||
|
||
def __setitem__(
|
||
self, key: str | rv_continuous, summary: DistributionSummary
|
||
) -> None:
|
||
if not isinstance(summary, DistributionSummary):
|
||
raise TypeError(
|
||
f"Expected DistributionSummary, got {type(summary).__name__}."
|
||
)
|
||
self._dist[self._resolve_key(key)] = summary
|
||
|
||
def __iter__(self):
|
||
return self
|
||
|
||
def __next__(self):
|
||
if not hasattr(self, "_iter_index"):
|
||
self._iter_index = 0
|
||
if self._iter_index >= len(self._dist):
|
||
del self._iter_index
|
||
raise StopIteration
|
||
key = list(self._dist.keys())[self._iter_index]
|
||
summary = self._dist[key]
|
||
self._iter_index += 1
|
||
return key, summary
|
||
|
||
def fit(self, data: np.ndarray) -> dict[str, DistributionSummary]:
|
||
"""
|
||
Fit every distribution to *data* via MLE.
|
||
|
||
Parameters
|
||
----------
|
||
data : array-like
|
||
Input data. Only the absolute value is used.
|
||
Returns
|
||
-------
|
||
dict[str, DistributionSummary]
|
||
Mapping of distribution name → summary (test_result is None
|
||
until validate() is called).
|
||
"""
|
||
data_flat = np.asarray(data).flatten()
|
||
self._last_data_flat = data_flat
|
||
|
||
for dist in self.dist_list:
|
||
_summary = self._dist[dist.name]
|
||
fit_params = dist.fit(
|
||
data_flat, *_summary.args_fit_params, **_summary.kwds_fit_params
|
||
)
|
||
_summary.fit_result_params = fit_params
|
||
self._dist[dist.name] = _summary
|
||
|
||
def validate(self, **kwargs) -> dict[str, DistributionSummary]:
|
||
"""
|
||
Run the goodness-of-fit test on every previously fitted distribution.
|
||
|
||
Parameters
|
||
----------
|
||
**kwargs :
|
||
Extra keyword arguments forwarded to goodness_of_fit()
|
||
(e.g. n_mc_samples=100).
|
||
|
||
Returns
|
||
-------
|
||
dict[str, DistributionSummary]
|
||
Same results dict, with test_result populated for each distribution.
|
||
"""
|
||
if not hasattr(self, "_last_data_flat"):
|
||
raise RuntimeError("No data available. Call fit() first.")
|
||
|
||
data_flat = self._last_data_flat
|
||
for dist in self.dist_list:
|
||
_summary = self._dist[dist.name]
|
||
# Build fit_params dict so goodness_of_fit reuses the already-fitted
|
||
# parameters for the observed-data statistic instead of re-fitting.
|
||
shape_names = [s.strip() for s in dist.shapes.split(',')] if dist.shapes else []
|
||
all_param_names = shape_names + ['loc', 'scale']
|
||
fit_params_dict = dict(zip(all_param_names, _summary.fit_result_params))
|
||
test_result = goodness_of_fit(
|
||
dist, data_flat,
|
||
statistic=_summary.statistic_method,
|
||
fit_params=fit_params_dict,
|
||
**kwargs,
|
||
)
|
||
_summary.test_result = test_result
|
||
self._dist[dist.name] = _summary
|
||
|
||
def summary(self) -> dict[str, DistributionSummary]:
|
||
"""
|
||
Print a summary of all fitted distributions, including parameters and GoF results.
|
||
"""
|
||
for dist_name, summary in self._dist.items():
|
||
print(summary)
|
||
|
||
def plot_qq_plots(self, method: str = "hazen"):
|
||
"""
|
||
Generate QQ plots for each fitted distribution against the data.
|
||
Requires that fit() and validate() have been called to populate parameters and test results.
|
||
|
||
Parameters
|
||
----------
|
||
method : str
|
||
Plotting positions formula. Either 'hazen' (default) or 'filliben'.
|
||
- 'hazen' : p_i = (i - 0.5) / n
|
||
- 'filliben': p_1 = 1 - 0.5^(1/n), p_n = 0.5^(1/n),
|
||
p_i = (i - 0.3175) / (n + 0.365) for 1 < i < n
|
||
"""
|
||
if not hasattr(self, "_last_data_flat"):
|
||
raise RuntimeError("No data available. Call fit() first.")
|
||
if method not in ("hazen", "filliben"):
|
||
raise ValueError(f"method must be 'hazen' or 'filliben', got '{method}'.")
|
||
|
||
data_flat = self._last_data_flat
|
||
# generate subplots with 2 columns and as many rows as needed, but not more than 3 rows, if there are more than 6 distributions, create multiple figures
|
||
num_dists = len(self._dist)
|
||
num_cols = 2
|
||
num_rows = min(3, (num_dists + 1) // 2)
|
||
fig = make_subplots(
|
||
rows=num_rows,
|
||
cols=num_cols,
|
||
subplot_titles=[dist_name for dist_name in self._dist.keys()],
|
||
)
|
||
|
||
for dist_name, summary in self:
|
||
if summary.test_result is None:
|
||
print(
|
||
f"Distribution '{dist_name}' has not been validated yet. Skipping QQ plot."
|
||
)
|
||
continue
|
||
|
||
# Generate theoretical quantiles
|
||
sorted_data = np.sort(data_flat)
|
||
n = len(sorted_data)
|
||
i = np.arange(1, n + 1)
|
||
if method == "hazen":
|
||
plotting_positions = (i - 0.5) / n
|
||
else: # filliben
|
||
plotting_positions = (i - 0.3175) / (n + 0.365)
|
||
plotting_positions[0] = 1 - 0.5 ** (1 / n)
|
||
plotting_positions[-1] = 0.5 ** (1 / n)
|
||
theoretical_quantiles = summary.distribution_object.ppf(
|
||
plotting_positions, *summary.fit_result_params
|
||
)
|
||
|
||
# Drop NaN/inf quantiles that arise when ppf fails to converge
|
||
valid = np.isfinite(theoretical_quantiles)
|
||
if not valid.any():
|
||
print(f"Distribution '{dist_name}': all theoretical quantiles are non-finite. Skipping QQ plot.")
|
||
continue
|
||
theoretical_quantiles = theoretical_quantiles[valid]
|
||
sorted_data_plot = sorted_data[valid]
|
||
|
||
# Create QQ plot in each subplot
|
||
row = (list(self._dist.keys()).index(dist_name) // num_cols) + 1
|
||
col = (list(self._dist.keys()).index(dist_name) % num_cols) + 1
|
||
fig.add_trace(
|
||
go.Scatter(
|
||
x=theoretical_quantiles,
|
||
y=sorted_data_plot,
|
||
mode="markers",
|
||
name=dist_name,
|
||
),
|
||
row=row,
|
||
col=col,
|
||
)
|
||
# Add a reference line y=x
|
||
min_val = min(theoretical_quantiles.min(), sorted_data_plot.min())
|
||
max_val = max(theoretical_quantiles.max(), sorted_data_plot.max())
|
||
fig.add_trace(
|
||
go.Scatter(
|
||
x=[min_val, max_val],
|
||
y=[min_val, max_val],
|
||
mode="lines",
|
||
name="y=x",
|
||
line=dict(dash="dash"),
|
||
),
|
||
row=row,
|
||
col=col,
|
||
)
|
||
fig.update_xaxes(title_text="Theoretical Quantiles", row=row, col=col)
|
||
fig.update_yaxes(title_text="Empirical Quantiles", row=row, col=col)
|
||
# add statistic value in bottom right of each subplot (summary.gof_statistic())
|
||
fig.add_annotation(
|
||
x=0.95,
|
||
y=0.05,
|
||
xref="x domain",
|
||
yref="y domain",
|
||
text=f"{summary.statistic_method}={summary.gof_statistic:.4f}",
|
||
showarrow=False,
|
||
font=dict(size=10, color="green" if summary.pvalue > 0.05 else "red"),
|
||
row=row,
|
||
col=col,
|
||
)
|
||
fig.update_layout(
|
||
title=f"QQ Plots of Fitted Distributions ({method})",
|
||
showlegend=False,
|
||
autosize=True,
|
||
)
|
||
|
||
return fig
|
||
|
||
def plot_qq_plots_sns(self, method: str = "hazen"):
|
||
"""
|
||
Generate QQ plots for each fitted distribution using seaborn/matplotlib.
|
||
|
||
Parameters
|
||
----------
|
||
method : str
|
||
Plotting positions formula. Either 'hazen' (default) or 'filliben'.
|
||
"""
|
||
if not hasattr(self, "_last_data_flat"):
|
||
raise RuntimeError("No data available. Call fit() first.")
|
||
if method not in ("hazen", "filliben"):
|
||
raise ValueError(f"method must be 'hazen' or 'filliben', got '{method}'.")
|
||
|
||
set_plot_style()
|
||
|
||
dist_names = list(self._dist.keys())
|
||
num_dists = len(dist_names)
|
||
num_cols = 2
|
||
num_rows = (num_dists + 1) // 2
|
||
|
||
dot_color = sns.color_palette()[0]
|
||
|
||
fig, axes = plt.subplots(num_rows, num_cols, figsize=(6 * num_cols, 5 * num_rows))
|
||
axes = np.array(axes).flatten()
|
||
|
||
sorted_data = np.sort(self._last_data_flat)
|
||
n = len(sorted_data)
|
||
i = np.arange(1, n + 1)
|
||
if method == "hazen":
|
||
plotting_positions = (i - 0.5) / n
|
||
else:
|
||
plotting_positions = (i - 0.3175) / (n + 0.365)
|
||
plotting_positions[0] = 1 - 0.5 ** (1 / n)
|
||
plotting_positions[-1] = 0.5 ** (1 / n)
|
||
|
||
for idx, (dist_name, summary) in enumerate(self):
|
||
ax = axes[idx]
|
||
if summary.test_result is None:
|
||
ax.set_title(dist_name)
|
||
ax.text(0.5, 0.5, "Not validated", ha="center", va="center",
|
||
transform=ax.transAxes)
|
||
continue
|
||
|
||
theoretical_quantiles = summary.distribution_object.ppf(
|
||
plotting_positions, *summary.fit_result_params
|
||
)
|
||
valid = np.isfinite(theoretical_quantiles)
|
||
if not valid.any():
|
||
ax.set_title(dist_name)
|
||
ax.text(0.5, 0.5, "All quantiles non-finite", ha="center", va="center",
|
||
transform=ax.transAxes)
|
||
continue
|
||
|
||
tq = theoretical_quantiles[valid]
|
||
sd = sorted_data[valid]
|
||
|
||
ax.scatter(tq, sd, color=dot_color, s=8, alpha=0.6, linewidths=0)
|
||
ref_min = min(tq.min(), sd.min())
|
||
ref_max = max(tq.max(), sd.max())
|
||
ax.plot([ref_min, ref_max], [ref_min, ref_max], "k--", linewidth=1)
|
||
|
||
pvalue = summary.pvalue
|
||
stat_color = "green" if pvalue > 0.05 else "red"
|
||
ax.text(0.95, 0.05,
|
||
f"{summary.statistic_method}={summary.gof_statistic:.4f}",
|
||
transform=ax.transAxes, ha="right", va="bottom",
|
||
fontsize=9, color=stat_color)
|
||
|
||
ax.set_title(dist_name)
|
||
ax.set_xlabel("Theoretical Quantiles")
|
||
ax.set_ylabel("Empirical Quantiles")
|
||
|
||
# Hide unused axes
|
||
for idx in range(num_dists, len(axes)):
|
||
axes[idx].set_visible(False)
|
||
|
||
fig.suptitle(f"QQ Plots of Fitted Distributions ({method})", y=1.01)
|
||
plt.tight_layout()
|
||
return fig
|
||
|
||
def histogram_with_fits(self):
|
||
"""
|
||
Generate a histogram of the data with overlaid PDFs of each fitted distribution.
|
||
Requires that fit() has been called to populate parameters.
|
||
"""
|
||
if not hasattr(self, "_last_data_flat"):
|
||
raise RuntimeError("No data available. Call fit() first.")
|
||
|
||
data_flat = self._last_data_flat
|
||
x = np.linspace(0, data_flat.max(), 1000)
|
||
|
||
fig = go.Figure(layout=go.Layout(hovermode="x unified"))
|
||
# do not show data in hoover, only show the distribution name, p-value and GoF statistic for each distribution
|
||
fig.add_trace(
|
||
go.Histogram(
|
||
x=data_flat,
|
||
name="Data",
|
||
opacity=0.3,
|
||
histnorm="probability density",
|
||
hoverinfo="skip",
|
||
marker_color="blue",
|
||
)
|
||
)
|
||
for dist_name, summary in self._dist.items():
|
||
if not summary.fit_result_params:
|
||
print(
|
||
f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay."
|
||
)
|
||
continue
|
||
|
||
pdf_values = summary.distribution_object.pdf(x, *summary.fit_result_params)
|
||
# add trace and stack hoover x like stock price, but make the y value shows the p-value and GoF statistic for each distribution
|
||
fig.add_trace(
|
||
go.Scatter(
|
||
x=x,
|
||
y=pdf_values,
|
||
mode="lines",
|
||
name=f"{summary.distribution_name} Fit",
|
||
)
|
||
)
|
||
hover_text = [
|
||
f"<b>{summary.distribution_name}</b> p-value: {summary.pvalue:.4f} GoF: {summary.gof_statistic:.4f}"
|
||
for _ in x
|
||
]
|
||
fig.data[-1].update(hovertext=hover_text, hoverinfo="text")
|
||
|
||
# add grid
|
||
fig.update_layout(xaxis=dict(showgrid=True), yaxis=dict(showgrid=True))
|
||
# put title in top left, make it smaller, change it font to sans and put in light gray
|
||
fig.update_layout(
|
||
title=dict(
|
||
text="Histogram of Data with Fitted Distributions",
|
||
x=0.02,
|
||
y=0.95,
|
||
xanchor="left",
|
||
yanchor="top",
|
||
font=dict(size=20, color="darkgray", family="sans-serif"),
|
||
),
|
||
xaxis_title="Value",
|
||
yaxis_title="Density",
|
||
autosize=True,
|
||
)
|
||
|
||
return fig
|
||
|
||
def histogram_with_fits_sns(self):
|
||
"""
|
||
Generate a histogram of the data with overlaid PDFs of each fitted distribution using seaborn.
|
||
Requires that fit() has been called to populate parameters.
|
||
"""
|
||
if not hasattr(self, "_last_data_flat"):
|
||
raise RuntimeError("No data available. Call fit() first.")
|
||
|
||
data_flat = self._last_data_flat
|
||
x = np.linspace(0, data_flat.max(), 1000)
|
||
|
||
fig, ax = plt.subplots(figsize=(10, 6))
|
||
sns.histplot(
|
||
data_flat,
|
||
bins=int(np.sqrt(len(data_flat))),
|
||
kde=False,
|
||
stat="density",
|
||
color="blue",
|
||
alpha=0.2,
|
||
ax=ax,
|
||
)
|
||
|
||
for dist_name, summary in self._dist.items():
|
||
if not summary.fit_result_params:
|
||
print(
|
||
f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay."
|
||
)
|
||
continue
|
||
|
||
pdf_values = summary.distribution_object.pdf(x, *summary.fit_result_params)
|
||
ax.plot(
|
||
x,
|
||
pdf_values,
|
||
label=f"{summary.distribution_name}",
|
||
)
|
||
# put title in top left, make it smaller, change it font to sans and put in light gray
|
||
ax.set_title(
|
||
"Histogram of Data with Fitted Distributions",
|
||
fontsize=8,
|
||
loc="left",
|
||
color="darkgray",
|
||
fontfamily="sans-serif",
|
||
)
|
||
ax.set_xlabel("Value")
|
||
ax.set_ylabel("Density")
|
||
ax.legend()
|
||
ax.grid(True)
|
||
|
||
return fig
|