feat(distributions): add logweibull, ricegamma, and logricegamma
Add three new continuous random variables for log-domain and linear-domain clutter modeling with compound Gamma-Rice structure. Fix numerical stability of k_dist._logpdf and logk._log_kve via a three-regime log(kve) asymptotic (direct / large-z Hankel / large-order Gamma); replace quad-based k_dist._cdf with Gauss-Laguerre quadrature. Fix fitter: use np.asarray instead of np.abs in fit(), pass fit_params to goodness_of_fit so the observed-data statistic reuses fitted params. Skip non-finite quantiles in QQ plots. Add plot_qq_plots_sns(); rename histogram_with_fits_seaborn() to histogram_with_fits_sns(). Add unit tests for logweibull and logricegamma.
This commit is contained in:
@@ -8,6 +8,12 @@ from scipy.stats import rv_continuous, goodness_of_fit
|
||||
from plotly.subplots import make_subplots
|
||||
|
||||
|
||||
def set_plot_style():
|
||||
sns.set_style("whitegrid")
|
||||
sns.set_context("paper", font_scale=1.25)
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class DistributionSummary:
|
||||
"""
|
||||
@@ -181,7 +187,7 @@ class Fitter:
|
||||
Mapping of distribution name → summary (test_result is None
|
||||
until validate() is called).
|
||||
"""
|
||||
data_flat = np.abs(data).flatten()
|
||||
data_flat = np.asarray(data).flatten()
|
||||
self._last_data_flat = data_flat
|
||||
|
||||
for dist in self.dist_list:
|
||||
@@ -213,8 +219,16 @@ class Fitter:
|
||||
data_flat = self._last_data_flat
|
||||
for dist in self.dist_list:
|
||||
_summary = self._dist[dist.name]
|
||||
# Build fit_params dict so goodness_of_fit reuses the already-fitted
|
||||
# parameters for the observed-data statistic instead of re-fitting.
|
||||
shape_names = [s.strip() for s in dist.shapes.split(',')] if dist.shapes else []
|
||||
all_param_names = shape_names + ['loc', 'scale']
|
||||
fit_params_dict = dict(zip(all_param_names, _summary.fit_result_params))
|
||||
test_result = goodness_of_fit(
|
||||
dist, data_flat, statistic=_summary.statistic_method, **kwargs
|
||||
dist, data_flat,
|
||||
statistic=_summary.statistic_method,
|
||||
fit_params=fit_params_dict,
|
||||
**kwargs,
|
||||
)
|
||||
_summary.test_result = test_result
|
||||
self._dist[dist.name] = _summary
|
||||
@@ -276,13 +290,21 @@ class Fitter:
|
||||
plotting_positions, *summary.fit_result_params
|
||||
)
|
||||
|
||||
# Drop NaN/inf quantiles that arise when ppf fails to converge
|
||||
valid = np.isfinite(theoretical_quantiles)
|
||||
if not valid.any():
|
||||
print(f"Distribution '{dist_name}': all theoretical quantiles are non-finite. Skipping QQ plot.")
|
||||
continue
|
||||
theoretical_quantiles = theoretical_quantiles[valid]
|
||||
sorted_data_plot = sorted_data[valid]
|
||||
|
||||
# Create QQ plot in each subplot
|
||||
row = (list(self._dist.keys()).index(dist_name) // num_cols) + 1
|
||||
col = (list(self._dist.keys()).index(dist_name) % num_cols) + 1
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=theoretical_quantiles,
|
||||
y=sorted_data,
|
||||
y=sorted_data_plot,
|
||||
mode="markers",
|
||||
name=dist_name,
|
||||
),
|
||||
@@ -290,8 +312,8 @@ class Fitter:
|
||||
col=col,
|
||||
)
|
||||
# Add a reference line y=x
|
||||
min_val = min(theoretical_quantiles.min(), sorted_data.min())
|
||||
max_val = max(theoretical_quantiles.max(), sorted_data.max())
|
||||
min_val = min(theoretical_quantiles.min(), sorted_data_plot.min())
|
||||
max_val = max(theoretical_quantiles.max(), sorted_data_plot.max())
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=[min_val, max_val],
|
||||
@@ -322,6 +344,88 @@ class Fitter:
|
||||
showlegend=False,
|
||||
autosize=True,
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
def plot_qq_plots_sns(self, method: str = "hazen"):
|
||||
"""
|
||||
Generate QQ plots for each fitted distribution using seaborn/matplotlib.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
method : str
|
||||
Plotting positions formula. Either 'hazen' (default) or 'filliben'.
|
||||
"""
|
||||
if not hasattr(self, "_last_data_flat"):
|
||||
raise RuntimeError("No data available. Call fit() first.")
|
||||
if method not in ("hazen", "filliben"):
|
||||
raise ValueError(f"method must be 'hazen' or 'filliben', got '{method}'.")
|
||||
|
||||
set_plot_style()
|
||||
|
||||
dist_names = list(self._dist.keys())
|
||||
num_dists = len(dist_names)
|
||||
num_cols = 2
|
||||
num_rows = (num_dists + 1) // 2
|
||||
|
||||
dot_color = sns.color_palette()[0]
|
||||
|
||||
fig, axes = plt.subplots(num_rows, num_cols, figsize=(6 * num_cols, 5 * num_rows))
|
||||
axes = np.array(axes).flatten()
|
||||
|
||||
sorted_data = np.sort(self._last_data_flat)
|
||||
n = len(sorted_data)
|
||||
i = np.arange(1, n + 1)
|
||||
if method == "hazen":
|
||||
plotting_positions = (i - 0.5) / n
|
||||
else:
|
||||
plotting_positions = (i - 0.3175) / (n + 0.365)
|
||||
plotting_positions[0] = 1 - 0.5 ** (1 / n)
|
||||
plotting_positions[-1] = 0.5 ** (1 / n)
|
||||
|
||||
for idx, (dist_name, summary) in enumerate(self):
|
||||
ax = axes[idx]
|
||||
if summary.test_result is None:
|
||||
ax.set_title(dist_name)
|
||||
ax.text(0.5, 0.5, "Not validated", ha="center", va="center",
|
||||
transform=ax.transAxes)
|
||||
continue
|
||||
|
||||
theoretical_quantiles = summary.distribution_object.ppf(
|
||||
plotting_positions, *summary.fit_result_params
|
||||
)
|
||||
valid = np.isfinite(theoretical_quantiles)
|
||||
if not valid.any():
|
||||
ax.set_title(dist_name)
|
||||
ax.text(0.5, 0.5, "All quantiles non-finite", ha="center", va="center",
|
||||
transform=ax.transAxes)
|
||||
continue
|
||||
|
||||
tq = theoretical_quantiles[valid]
|
||||
sd = sorted_data[valid]
|
||||
|
||||
ax.scatter(tq, sd, color=dot_color, s=8, alpha=0.6, linewidths=0)
|
||||
ref_min = min(tq.min(), sd.min())
|
||||
ref_max = max(tq.max(), sd.max())
|
||||
ax.plot([ref_min, ref_max], [ref_min, ref_max], "k--", linewidth=1)
|
||||
|
||||
pvalue = summary.pvalue
|
||||
stat_color = "green" if pvalue > 0.05 else "red"
|
||||
ax.text(0.95, 0.05,
|
||||
f"{summary.statistic_method}={summary.gof_statistic:.4f}",
|
||||
transform=ax.transAxes, ha="right", va="bottom",
|
||||
fontsize=9, color=stat_color)
|
||||
|
||||
ax.set_title(dist_name)
|
||||
ax.set_xlabel("Theoretical Quantiles")
|
||||
ax.set_ylabel("Empirical Quantiles")
|
||||
|
||||
# Hide unused axes
|
||||
for idx in range(num_dists, len(axes)):
|
||||
axes[idx].set_visible(False)
|
||||
|
||||
fig.suptitle(f"QQ Plots of Fitted Distributions ({method})", y=1.01)
|
||||
plt.tight_layout()
|
||||
return fig
|
||||
|
||||
def histogram_with_fits(self):
|
||||
@@ -389,7 +493,7 @@ class Fitter:
|
||||
|
||||
return fig
|
||||
|
||||
def histogram_with_fits_seaborn(self):
|
||||
def histogram_with_fits_sns(self):
|
||||
"""
|
||||
Generate a histogram of the data with overlaid PDFs of each fitted distribution using seaborn.
|
||||
Requires that fit() has been called to populate parameters.
|
||||
@@ -422,7 +526,7 @@ class Fitter:
|
||||
ax.plot(
|
||||
x,
|
||||
pdf_values,
|
||||
label=f"{summary.distribution_name} --- p={summary.pvalue:.4f}",
|
||||
label=f"{summary.distribution_name}",
|
||||
)
|
||||
# put title in top left, make it smaller, change it font to sans and put in light gray
|
||||
ax.set_title(
|
||||
|
||||
Reference in New Issue
Block a user