REFACTOR:

Fitter class refactored. Include getter and setter
ADD:
test/ dir with code tests
This commit is contained in:
2026-04-08 21:48:19 -03:00
parent bcd8f25a62
commit d053ebf02c
5 changed files with 529 additions and 291 deletions

View File

@@ -5,6 +5,7 @@ import plotly.graph_objects as go
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import seaborn as sns import seaborn as sns
from scipy.stats import rv_continuous, goodness_of_fit from scipy.stats import rv_continuous, goodness_of_fit
from plotly.subplots import make_subplots
@dataclass @dataclass
@@ -44,79 +45,41 @@ class DistributionSummary:
args_fit_params: tuple = field(default_factory=tuple) args_fit_params: tuple = field(default_factory=tuple)
kwds_fit_params: dict = field(default_factory=dict) kwds_fit_params: dict = field(default_factory=dict)
fit_result_params: tuple = field(default_factory=tuple) fit_result_params: tuple = field(default_factory=tuple)
statistic_method: str = 'ad' statistic_method: str = "ad"
test_result: object = None test_result: object = None
# ── properties ──────────────────────────────────────────────── # ── convenience properties ────────────────────────────────────────────────
@property @property
def pvalue(self) -> float | None: def pvalue(self) -> float | None:
"""p-value from the goodness-of-fit test. """p-value from the goodness-of-fit test, or None if not yet run."""
Returns
-------
float or None
The p-value produced by the GoF test, or None if validate() has
not been called yet.
"""
return self.test_result.pvalue if self.test_result is not None else None return self.test_result.pvalue if self.test_result is not None else None
@property @property
def gof_statistic(self) -> float | None: def gof_statistic(self) -> float | None:
"""Test statistic from the goodness-of-fit test. """Test statistic from the goodness-of-fit test, or None if not yet run."""
Returns
-------
float or None
The GoF statistic value, or None if validate() has not been
called yet.
"""
return self.test_result.statistic if self.test_result is not None else None return self.test_result.statistic if self.test_result is not None else None
@property @property
def mean(self) -> float: def mean(self) -> float:
"""Mean of the fitted distribution. """Mean of the fitted distribution."""
return float(self.distribution_object.mean(*self.fit_result_params))
Returns
-------
float
Mean computed from the fitted distribution parameters.
"""
return self.distribution_object.mean(*self.fit_result_params)
@property @property
def std(self) -> float: def std(self) -> float:
"""Standard deviation of the fitted distribution. """Standard deviation of the fitted distribution."""
return float(self.distribution_object.std(*self.fit_result_params))
Returns
-------
float
Standard deviation computed from the fitted distribution parameters.
"""
return self.distribution_object.std(*self.fit_result_params)
@property @property
def var(self) -> float: def var(self) -> float:
"""Variance of the fitted distribution. """Variance of the fitted distribution."""
return float(self.distribution_object.var(*self.fit_result_params))
Returns
-------
float
Variance computed from the fitted distribution parameters.
"""
return self.distribution_object.var(*self.fit_result_params)
def __repr__(self) -> str: def __repr__(self) -> str:
"""Return a concise string representation of the distribution summary.
Returns
-------
str
Single-line string showing the distribution name, fitted
parameters, mean, standard deviation, GoF statistic, and p-value.
"""
pval_str = f"{self.pvalue:.4f}" if self.pvalue is not None else "N/A" pval_str = f"{self.pvalue:.4f}" if self.pvalue is not None else "N/A"
stat_str = f"{self.gof_statistic:.4f}" if self.gof_statistic is not None else "N/A" stat_str = (
f"{self.gof_statistic:.4f}" if self.gof_statistic is not None else "N/A"
)
return ( return (
f"DistributionSummary({self.distribution_name})" f"DistributionSummary({self.distribution_name})"
f" params={self.fit_result_params}" f" params={self.fit_result_params}"
@@ -125,13 +88,6 @@ class DistributionSummary:
) )
def __str__(self) -> str: def __str__(self) -> str:
"""Return the same string representation as ``__repr__``.
Returns
-------
str
Delegates to :meth:`__repr__`.
"""
return self.__repr__() return self.__repr__()
@@ -156,347 +112,329 @@ class Fitter:
gamma_params={'floc': 0}, gamma_params={'floc': 0},
weibull_min_args=(1.5, 0.0, 1.0), weibull_min_args=(1.5, 0.0, 1.0),
) )
Access fitted distributions
---------------------------
Distributions can be accessed by name (str) or by the rv_continuous object:
fitter = Fitter([gamma, weibull_min])
fitter.fit(data)
# Access by name
summary = fitter['gamma']
# Access by rv_continuous object
summary = fitter[gamma]
# Check if a distribution is present
gamma in fitter # True
'weibull_min' in fitter # True
# Override a DistributionSummary
fitter['gamma'] = new_summary
""" """
def __init__(self, dist_list: list[rv_continuous], statistic_method: str = 'ad', **kwargs): def __init__(
"""Initialise the Fitter and build per-distribution summary objects. self, dist_list: list[rv_continuous], statistic_method: str = "ad", **kwargs
):
Parameters
----------
dist_list : list[rv_continuous]
Distributions to fit.
statistic_method : str, optional
Goodness-of-fit statistic passed to ``goodness_of_fit``
(default ``'ad'`` for Anderson-Darling).
**kwargs :
Per-distribution initial guesses and fixed parameters, keyed as
``<dist.name>_args`` (tuple) or ``<dist.name>_params`` (dict).
"""
self._dist: dict[str, DistributionSummary] = {} self._dist: dict[str, DistributionSummary] = {}
self.dist_list = list(dist_list) self.dist_list = list(dist_list) # Ensure it's a list for multiple iterations
for dist in dist_list: for dist in dist_list:
self._dist[dist.name] = DistributionSummary( self._dist[dist.name] = DistributionSummary(
distribution_object=dist, distribution_object=dist,
distribution_name=dist.name, distribution_name=dist.name,
args_fit_params=kwargs.get(f'{dist.name}_args', ()), args_fit_params=kwargs.get(f"{dist.name}_args", ()),
kwds_fit_params=kwargs.get(f'{dist.name}_params', {}), kwds_fit_params=kwargs.get(f"{dist.name}_params", {}),
statistic_method=statistic_method, statistic_method=statistic_method,
test_result=None, test_result=None,
) )
#── Getter and setter with flexible keys (str or rv_continuous) ────────────────────────────────
def _resolve_key(self, key: str | rv_continuous) -> str: def _resolve_key(self, key: str | rv_continuous) -> str:
"""Resolve a distribution name or ``rv_continuous`` object to its string key. """Resolve a distribution name or rv_continuous object to its string key."""
Parameters
----------
key : str or rv_continuous
Either the distribution's string name or its ``rv_continuous``
object (whose ``.name`` attribute is used).
Returns
-------
str
The string key used in the internal ``_dist`` dictionary.
Raises
------
KeyError
If the resolved name is not found among the registered
distributions.
"""
name = key.name if isinstance(key, rv_continuous) else key name = key.name if isinstance(key, rv_continuous) else key
if name not in self._dist: if name not in self._dist:
available = ', '.join(self._dist) available = ", ".join(self._dist)
raise KeyError(f"Distribution '{name}' not found. Available: {available}") raise KeyError(f"Distribution '{name}' not found. Available: {available}")
return name return name
def __contains__(self, key: str | rv_continuous) -> bool: def __contains__(self, key: str | rv_continuous) -> bool:
"""Check whether a distribution is registered with this Fitter.
Parameters
----------
key : str or rv_continuous
Distribution name or ``rv_continuous`` object to look up.
Returns
-------
bool
True if the distribution is registered, False otherwise.
"""
name = key.name if isinstance(key, rv_continuous) else key name = key.name if isinstance(key, rv_continuous) else key
return name in self._dist return name in self._dist
def __getitem__(self, key: str | rv_continuous) -> DistributionSummary: def __getitem__(self, key: str | rv_continuous) -> DistributionSummary:
"""Retrieve the :class:`DistributionSummary` for the given distribution.
Parameters
----------
key : str or rv_continuous
Distribution name or ``rv_continuous`` object to retrieve.
Returns
-------
DistributionSummary
The summary object associated with the requested distribution.
Raises
------
KeyError
If the distribution is not registered.
"""
return self._dist[self._resolve_key(key)] return self._dist[self._resolve_key(key)]
def __setitem__(self, key: str | rv_continuous, summary: DistributionSummary) -> None: def __setitem__(
"""Override the :class:`DistributionSummary` for an existing distribution. self, key: str | rv_continuous, summary: DistributionSummary
) -> None:
Parameters
----------
key : str or rv_continuous
Distribution name or ``rv_continuous`` object to update.
summary : DistributionSummary
Replacement summary object.
Raises
------
TypeError
If ``summary`` is not a :class:`DistributionSummary` instance.
KeyError
If the distribution is not registered.
"""
if not isinstance(summary, DistributionSummary): if not isinstance(summary, DistributionSummary):
raise TypeError(f"Expected DistributionSummary, got {type(summary).__name__}.") raise TypeError(
f"Expected DistributionSummary, got {type(summary).__name__}."
)
self._dist[self._resolve_key(key)] = summary self._dist[self._resolve_key(key)] = summary
def __iter__(self):
return self
def __next__(self):
if not hasattr(self, "_iter_index"):
self._iter_index = 0
if self._iter_index >= len(self._dist):
del self._iter_index
raise StopIteration
key = list(self._dist.keys())[self._iter_index]
summary = self._dist[key]
self._iter_index += 1
return key, summary
def fit(self, data: np.ndarray) -> None: def fit(self, data: np.ndarray) -> dict[str, DistributionSummary]:
"""Fit every distribution to *data* via MLE. """
Fit every distribution to *data* via MLE.
Parameters Parameters
---------- ----------
data : array-like data : array-like
Input data. Only the absolute value is used; the array is Input data. Only the absolute value is used.
flattened before fitting.
Returns Returns
------- -------
None dict[str, DistributionSummary]
Fitted parameters are stored in-place inside each Mapping of distribution name → summary (test_result is None
:class:`DistributionSummary` held by this Fitter. until validate() is called).
""" """
data_flat = np.abs(data).flatten() data_flat = np.abs(data).flatten()
self._last_data_flat = data_flat self._last_data_flat = data_flat
for dist in self.dist_list: for dist in self.dist_list:
_summary = self._dist[dist.name] _summary = self._dist[dist.name]
fit_params = dist.fit(data_flat, *_summary.args_fit_params, **_summary.kwds_fit_params) fit_params = dist.fit(
data_flat, *_summary.args_fit_params, **_summary.kwds_fit_params
)
_summary.fit_result_params = fit_params _summary.fit_result_params = fit_params
self._dist[dist.name] = _summary self._dist[dist.name] = _summary
def validate(self, **kwargs) -> None: def validate(self, **kwargs) -> dict[str, DistributionSummary]:
"""Run the goodness-of-fit test on every previously fitted distribution. """
Run the goodness-of-fit test on every previously fitted distribution.
Parameters Parameters
---------- ----------
**kwargs : **kwargs :
Extra keyword arguments forwarded to ``scipy.stats.goodness_of_fit`` Extra keyword arguments forwarded to goodness_of_fit()
(e.g. ``n_mc_samples=100``). (e.g. n_mc_samples=100).
Returns Returns
------- -------
None dict[str, DistributionSummary]
Test results are stored in-place inside each Same results dict, with test_result populated for each distribution.
:class:`DistributionSummary` held by this Fitter.
Raises
------
RuntimeError
If :meth:`fit` has not been called before this method.
""" """
if not hasattr(self, '_last_data_flat'): if not hasattr(self, "_last_data_flat"):
raise RuntimeError("No data available. Call fit() first.") raise RuntimeError("No data available. Call fit() first.")
data_flat = self._last_data_flat data_flat = self._last_data_flat
for dist in self.dist_list: for dist in self.dist_list:
_summary = self._dist[dist.name] _summary = self._dist[dist.name]
test_result = goodness_of_fit( test_result = goodness_of_fit(
dist, dist, data_flat, statistic=_summary.statistic_method, **kwargs
data_flat,
statistic=_summary.statistic_method,
**kwargs
) )
_summary.test_result = test_result _summary.test_result = test_result
self._dist[dist.name] = _summary self._dist[dist.name] = _summary
def summary(self) -> None: def summary(self) -> dict[str, DistributionSummary]:
"""Print a one-line summary for each registered distribution. """
Print a summary of all fitted distributions, including parameters and GoF results.
Returns
-------
None
Output is written to stdout via ``print``.
""" """
for dist_name, summary in self._dist.items(): for dist_name, summary in self._dist.items():
print(summary) print(summary)
def plot_qq_plots(self) -> None: def plot_qq_plots(self, method: str = "hazen"):
"""Generate QQ plots for each fitted distribution against the data.
A separate interactive Plotly figure is displayed for every
distribution that has been both fitted and validated. Distributions
that have not yet been validated are skipped with a printed warning.
Returns
-------
None
Figures are rendered inline / in a browser via ``fig.show()``.
Raises
------
RuntimeError
If :meth:`fit` has not been called before this method.
""" """
if not hasattr(self, '_last_data_flat'): Generate QQ plots for each fitted distribution against the data.
Requires that fit() and validate() have been called to populate parameters and test results.
Parameters
----------
method : str
Plotting positions formula. Either 'hazen' (default) or 'filliben'.
- 'hazen' : p_i = (i - 0.5) / n
- 'filliben': p_1 = 1 - 0.5^(1/n), p_n = 0.5^(1/n),
p_i = (i - 0.3175) / (n + 0.365) for 1 < i < n
"""
if not hasattr(self, "_last_data_flat"):
raise RuntimeError("No data available. Call fit() first.") raise RuntimeError("No data available. Call fit() first.")
if method not in ("hazen", "filliben"):
raise ValueError(f"method must be 'hazen' or 'filliben', got '{method}'.")
data_flat = self._last_data_flat data_flat = self._last_data_flat
for dist_name, summary in self._dist.items(): # generate subplots with 2 columns and as many rows as needed, but not more than 3 rows, if there are more than 6 distributions, create multiple figures
num_dists = len(self._dist)
num_cols = 2
num_rows = min(3, (num_dists + 1) // 2)
fig = make_subplots(
rows=num_rows,
cols=num_cols,
subplot_titles=[dist_name for dist_name in self._dist.keys()],
)
for dist_name, summary in self:
if summary.test_result is None: if summary.test_result is None:
print(f"Distribution '{dist_name}' has not been validated yet. Skipping QQ plot.") print(
f"Distribution '{dist_name}' has not been validated yet. Skipping QQ plot."
)
continue continue
# Generate theoretical quantiles
sorted_data = np.sort(data_flat) sorted_data = np.sort(data_flat)
n = len(sorted_data)
i = np.arange(1, n + 1)
if method == "hazen":
plotting_positions = (i - 0.5) / n
else: # filliben
plotting_positions = (i - 0.3175) / (n + 0.365)
plotting_positions[0] = 1 - 0.5 ** (1 / n)
plotting_positions[-1] = 0.5 ** (1 / n)
theoretical_quantiles = summary.distribution_object.ppf( theoretical_quantiles = summary.distribution_object.ppf(
(np.arange(1, len(sorted_data) + 1) - 0.5) / len(sorted_data), plotting_positions, *summary.fit_result_params
*summary.fit_result_params
) )
fig = go.Figure() # Create QQ plot in each subplot
fig.add_trace(go.Scatter(x=theoretical_quantiles, y=sorted_data, mode='markers', name='Data vs. Fit')) row = (list(self._dist.keys()).index(dist_name) // num_cols) + 1
fig.add_trace(go.Scatter(x=theoretical_quantiles, y=theoretical_quantiles, mode='lines', name='Ideal Fit', line=dict(dash='dash'))) col = (list(self._dist.keys()).index(dist_name) % num_cols) + 1
fig.add_trace(
go.Scatter(
x=theoretical_quantiles,
y=sorted_data,
mode="markers",
name=dist_name,
),
row=row,
col=col,
)
# Add a reference line y=x
min_val = min(theoretical_quantiles.min(), sorted_data.min())
max_val = max(theoretical_quantiles.max(), sorted_data.max())
fig.add_trace(
go.Scatter(
x=[min_val, max_val],
y=[min_val, max_val],
mode="lines",
name="y=x",
line=dict(dash="dash"),
),
row=row,
col=col,
)
fig.update_xaxes(title_text="Theoretical Quantiles", row=row, col=col)
fig.update_yaxes(title_text="Empirical Quantiles", row=row, col=col)
# add statistic value in bottom right of each subplot (summary.gof_statistic())
fig.add_annotation(
x=0.95,
y=0.05,
xref="x domain",
yref="y domain",
text=f"{summary.statistic_method}={summary.gof_statistic:.4f}",
showarrow=False,
font=dict(size=10, color="green" if summary.pvalue > 0.05 else "red"),
row=row,
col=col,
)
fig.update_layout( fig.update_layout(
title=f'QQ Plot for {summary.distribution_name}', title=f"QQ Plots of Fitted Distributions ({method})",
xaxis_title='Theoretical Quantiles', showlegend=False,
yaxis_title='Empirical Quantiles',
autosize=True, autosize=True,
) )
fig.show() return fig
def histogram_with_fits(self) -> go.Figure: def histogram_with_fits(self):
"""Return an interactive histogram with overlaid fitted PDFs (Plotly).
Builds a probability-density histogram of the data and overlays a
line trace for the PDF of each fitted distribution. Hover text shows
the p-value and GoF statistic for each curve. Distributions that
have not yet been fitted are skipped with a printed warning.
Returns
-------
plotly.graph_objects.Figure
Interactive Plotly figure ready to display with ``fig.show()``.
Raises
------
RuntimeError
If :meth:`fit` has not been called before this method.
""" """
if not hasattr(self, '_last_data_flat'): Generate a histogram of the data with overlaid PDFs of each fitted distribution.
Requires that fit() has been called to populate parameters.
"""
if not hasattr(self, "_last_data_flat"):
raise RuntimeError("No data available. Call fit() first.") raise RuntimeError("No data available. Call fit() first.")
data_flat = self._last_data_flat data_flat = self._last_data_flat
x = np.linspace(0, data_flat.max(), 1000) x = np.linspace(0, data_flat.max(), 1000)
fig = go.Figure(layout=go.Layout(hovermode='x unified')) fig = go.Figure(layout=go.Layout(hovermode="x unified"))
fig.add_trace(go.Histogram( # do not show data in hoover, only show the distribution name, p-value and GoF statistic for each distribution
x=data_flat, name='Data', opacity=0.3, fig.add_trace(
histnorm='probability density', hoverinfo='skip', marker_color='blue', go.Histogram(
)) x=data_flat,
name="Data",
opacity=0.3,
histnorm="probability density",
hoverinfo="skip",
marker_color="blue",
)
)
for dist_name, summary in self._dist.items(): for dist_name, summary in self._dist.items():
if not summary.fit_result_params: if not summary.fit_result_params:
print(f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay.") print(
f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay."
)
continue continue
pdf_values = summary.distribution_object.pdf(x, *summary.fit_result_params) pdf_values = summary.distribution_object.pdf(x, *summary.fit_result_params)
fig.add_trace(go.Scatter(x=x, y=pdf_values, mode='lines', name=f'{summary.distribution_name} Fit')) # add trace and stack hoover x like stock price, but make the y value shows the p-value and GoF statistic for each distribution
fig.add_trace(
go.Scatter(
x=x,
y=pdf_values,
mode="lines",
name=f"{summary.distribution_name} Fit",
)
)
hover_text = [ hover_text = [
f"<b>{summary.distribution_name}</b> p-value: {summary.pvalue:.4f} GoF: {summary.gof_statistic:.4f}" f"<b>{summary.distribution_name}</b> p-value: {summary.pvalue:.4f} GoF: {summary.gof_statistic:.4f}"
for _ in x for _ in x
] ]
fig.data[-1].update(hovertext=hover_text, hoverinfo="text") fig.data[-1].update(hovertext=hover_text, hoverinfo="text")
# add grid
fig.update_layout(xaxis=dict(showgrid=True), yaxis=dict(showgrid=True))
# put title in top left, make it smaller, change it font to sans and put in light gray
fig.update_layout( fig.update_layout(
xaxis=dict(showgrid=True),
yaxis=dict(showgrid=True),
title=dict( title=dict(
text='Histogram of Data with Fitted Distributions', text="Histogram of Data with Fitted Distributions",
x=0.02, y=0.95, xanchor='left', yanchor='top', x=0.02,
font=dict(size=20, color='darkgray', family='sans-serif'), y=0.95,
xanchor="left",
yanchor="top",
font=dict(size=20, color="darkgray", family="sans-serif"),
), ),
xaxis_title='Value', xaxis_title="Value",
yaxis_title='Density', yaxis_title="Density",
autosize=True, autosize=True,
) )
return fig return fig
def histogram_with_fits_seaborn(self) -> plt.Figure: def histogram_with_fits_seaborn(self):
"""Return a static histogram with overlaid fitted PDFs (Matplotlib/Seaborn).
Builds a probability-density histogram using Seaborn and overlays a
line for the PDF of each fitted distribution. The legend entry for
each distribution includes its p-value. Distributions that have not
yet been fitted are skipped with a printed warning.
Returns
-------
matplotlib.figure.Figure
Matplotlib figure object.
Raises
------
RuntimeError
If :meth:`fit` has not been called before this method.
""" """
if not hasattr(self, '_last_data_flat'): Generate a histogram of the data with overlaid PDFs of each fitted distribution using seaborn.
Requires that fit() has been called to populate parameters.
"""
if not hasattr(self, "_last_data_flat"):
raise RuntimeError("No data available. Call fit() first.") raise RuntimeError("No data available. Call fit() first.")
data_flat = self._last_data_flat data_flat = self._last_data_flat
x = np.linspace(0, data_flat.max(), 1000) x = np.linspace(0, data_flat.max(), 1000)
fig, ax = plt.subplots(figsize=(10, 6)) fig, ax = plt.subplots(figsize=(10, 6))
sns.histplot(data_flat, bins=int(np.sqrt(len(data_flat))), kde=False, stat='density', color='blue', alpha=0.2, ax=ax) sns.histplot(
data_flat,
bins=int(np.sqrt(len(data_flat))),
kde=False,
stat="density",
color="blue",
alpha=0.2,
ax=ax,
)
for dist_name, summary in self._dist.items(): for dist_name, summary in self._dist.items():
if not summary.fit_result_params: if not summary.fit_result_params:
print(f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay.") print(
f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay."
)
continue continue
pdf_values = summary.distribution_object.pdf(x, *summary.fit_result_params) pdf_values = summary.distribution_object.pdf(x, *summary.fit_result_params)
ax.plot(x, pdf_values, label=f'{summary.distribution_name} --- p={summary.pvalue:.4f}') ax.plot(
x,
ax.set_title('Histogram of Data with Fitted Distributions', fontsize=8, loc='left', color='darkgray', fontfamily='sans-serif') pdf_values,
ax.set_xlabel('Value') label=f"{summary.distribution_name} --- p={summary.pvalue:.4f}",
ax.set_ylabel('Density') )
# put title in top left, make it smaller, change it font to sans and put in light gray
ax.set_title(
"Histogram of Data with Fitted Distributions",
fontsize=8,
loc="left",
color="darkgray",
fontfamily="sans-serif",
)
ax.set_xlabel("Value")
ax.set_ylabel("Density")
ax.legend() ax.legend()
ax.grid(True) ax.grid(True)
return fig return fig

267
etc/tests/test_fitter.py Normal file
View File

@@ -0,0 +1,267 @@
import numpy as np
import pytest
from scipy.stats import norm, gamma, expon, weibull_min
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from fitting.fitter import DistributionSummary, Fitter
# ── Data ──────────────────────────────────────────────────────────────────
RNG = np.random.default_rng(42)
GAMMA_DATA = RNG.gamma(shape=2.0, scale=1.5, size=200)
NORM_DATA = RNG.normal(loc=5.0, scale=1.0, size=200)
# ── DistributionSummary ───────────────────────────────────────────────────────
class TestDistributionSummary:
def _make_summary(self, dist=gamma, fit_params=(2.0, 0.0, 1.5)):
return DistributionSummary(
distribution_object=dist,
distribution_name=dist.name,
fit_result_params=fit_params,
)
def test_mean_std_var_are_floats(self):
s = self._make_summary()
assert isinstance(s.mean, float)
assert isinstance(s.std, float)
assert isinstance(s.var, float)
def test_var_equals_std_squared(self):
s = self._make_summary()
assert pytest.approx(s.var, rel=1e-6) == s.std**2
def test_pvalue_none_before_validate(self):
s = self._make_summary()
assert s.pvalue is None
def test_gof_statistic_none_before_validate(self):
s = self._make_summary()
assert s.gof_statistic is None
def test_repr_contains_name(self):
s = self._make_summary()
assert "gamma" in repr(s)
def test_repr_shows_na_when_no_test(self):
s = self._make_summary()
assert "N/A" in repr(s)
def test_default_statistic_method(self):
s = self._make_summary()
assert s.statistic_method == "ad"
# ── Fitter construction ───────────────────────────────────────────────────────
class TestFitterConstruction:
def test_distributions_registered(self):
f = Fitter([gamma, expon])
assert "gamma" in f
assert "expon" in f
def test_getitem_by_name(self):
f = Fitter([gamma])
s = f["gamma"]
assert isinstance(s, DistributionSummary)
def test_getitem_by_object(self):
f = Fitter([gamma])
s = f[gamma]
assert isinstance(s, DistributionSummary)
def test_getitem_missing_raises(self):
f = Fitter([gamma])
with pytest.raises(KeyError):
_ = f["norm"]
def test_contains_true(self):
f = Fitter([gamma])
assert gamma in f
assert "gamma" in f
def test_contains_false(self):
f = Fitter([gamma])
assert norm not in f
def test_kwargs_args_stored(self):
f = Fitter([gamma], gamma_args=(2.0,), gamma_params={"floc": 0})
assert f["gamma"].args_fit_params == (2.0,)
assert f["gamma"].kwds_fit_params == {"floc": 0}
def test_setitem_valid(self):
f = Fitter([gamma])
original = f["gamma"]
f["gamma"] = original
assert f["gamma"] is original
def test_setitem_invalid_type(self):
f = Fitter([gamma])
with pytest.raises(TypeError):
f["gamma"] = "not_a_summary"
def test_setitem_missing_key_raises(self):
f = Fitter([gamma])
dummy = DistributionSummary(distribution_object=norm, distribution_name="norm")
with pytest.raises(KeyError):
f["norm"] = dummy
# ── Fitter iteration ──────────────────────────────────────────────────────────
class TestFitterIteration:
def test_iterates_all_distributions(self):
f = Fitter([gamma, expon, weibull_min])
names = [name for name, _ in f]
assert set(names) == {"gamma", "expon", "weibull_min"}
def test_iterate_twice(self):
f = Fitter([gamma, expon])
first = [name for name, _ in f]
second = [name for name, _ in f]
assert first == second
def test_iterator_yields_summary_instances(self):
f = Fitter([gamma])
for _, summary in f:
assert isinstance(summary, DistributionSummary)
# ── Fitter.fit ────────────────────────────────────────────────────────────────
class TestFitterFit:
def test_fit_populates_params(self):
f = Fitter([gamma], gamma_params={"floc": 0})
f.fit(GAMMA_DATA)
assert len(f["gamma"].fit_result_params) > 0
def test_fit_uses_abs_value(self):
f = Fitter([gamma], gamma_params={"floc": 0})
neg_data = -np.abs(GAMMA_DATA)
f.fit(neg_data)
assert len(f["gamma"].fit_result_params) > 0
def test_fit_flattens_2d(self):
f = Fitter([gamma], gamma_params={"floc": 0})
data_2d = GAMMA_DATA.reshape(10, 20)
f.fit(data_2d)
assert hasattr(f, "_last_data_flat")
assert f._last_data_flat.ndim == 1
def test_fit_multiple_dists(self):
f = Fitter([gamma, expon], gamma_params={"floc": 0}, expon_params={"floc": 0})
f.fit(GAMMA_DATA)
assert len(f["gamma"].fit_result_params) > 0
assert len(f["expon"].fit_result_params) > 0
# ── Fitter.validate ───────────────────────────────────────────────────────────
class TestFitterValidate:
def test_validate_without_fit_raises(self):
f = Fitter([gamma])
with pytest.raises(RuntimeError, match="fit\\(\\)"):
f.validate()
def test_validate_populates_test_result(self):
f = Fitter([gamma], gamma_params={"floc": 0})
f.fit(GAMMA_DATA)
f.validate(n_mc_samples=99)
assert f["gamma"].test_result is not None
def test_validate_pvalue_in_range(self):
f = Fitter([gamma], gamma_params={"floc": 0})
f.fit(GAMMA_DATA)
f.validate(n_mc_samples=99)
pval = f["gamma"].pvalue
assert 0.0 <= pval <= 1.0
def test_validate_gof_statistic_positive(self):
f = Fitter([gamma], gamma_params={"floc": 0})
f.fit(GAMMA_DATA)
f.validate(n_mc_samples=99)
assert f["gamma"].gof_statistic >= 0.0
def test_validate_custom_statistic(self):
f = Fitter([gamma], statistic_method="ks", gamma_params={"floc": 0})
f.fit(GAMMA_DATA)
f.validate(n_mc_samples=99)
assert f["gamma"].test_result is not None
# ── Fitter.summary ────────────────────────────────────────────────────────────
class TestFitterSummary:
def test_summary_runs_without_error(self, capsys):
f = Fitter([gamma], gamma_params={"floc": 0})
f.fit(GAMMA_DATA)
f.validate(n_mc_samples=99)
f.summary()
captured = capsys.readouterr()
assert "gamma" in captured.out
# ── Fitter.plot_qq_plots ──────────────────────────────────────────────────────
class TestFitterPlotQQ:
def setup_method(self):
self.f = Fitter([gamma], gamma_params={"floc": 0})
self.f.fit(GAMMA_DATA)
self.f.validate(n_mc_samples=99)
def test_qq_without_fit_raises(self):
f = Fitter([gamma])
with pytest.raises(RuntimeError, match="fit\\(\\)"):
f.plot_qq_plots()
def test_qq_invalid_method_raises(self):
with pytest.raises(ValueError, match="method"):
self.f.plot_qq_plots(method="invalid")
def test_qq_hazen_returns_figure(self):
import plotly.graph_objects as go
fig = self.f.plot_qq_plots(method="hazen")
assert isinstance(fig, go.Figure)
def test_qq_filliben_returns_figure(self):
import plotly.graph_objects as go
fig = self.f.plot_qq_plots(method="filliben")
assert isinstance(fig, go.Figure)
# ── Fitter.histogram_with_fits ────────────────────────────────────────────────
class TestFitterHistogram:
def setup_method(self):
self.f = Fitter([gamma], gamma_params={"floc": 0})
self.f.fit(GAMMA_DATA)
self.f.validate(n_mc_samples=99)
def test_histogram_without_fit_raises(self):
f = Fitter([gamma])
with pytest.raises(RuntimeError, match="fit\\(\\)"):
f.histogram_with_fits()
def test_histogram_returns_figure(self):
import plotly.graph_objects as go
fig = self.f.histogram_with_fits()
assert isinstance(fig, go.Figure)
def test_histogram_seaborn_returns_figure(self):
import matplotlib.pyplot as plt
fig = self.f.histogram_with_fits_seaborn()
assert isinstance(fig, plt.Figure)
plt.close("all")

View File

@@ -21,13 +21,16 @@ def stacked_plot(data):
data = np.squeeze(data) data = np.squeeze(data)
mean_dp = np.mean(np.abs(data), axis=1) mean_dp = np.mean(np.abs(data), axis=1)
fig = make_subplots(rows=2, cols=1, row_heights=[0.3, 0.7], shared_xaxes=True, fig = make_subplots(
vertical_spacing=0.01) rows=2, cols=1, row_heights=[0.3, 0.7], shared_xaxes=True, vertical_spacing=0.01
)
fig.add_trace(go.Scatter(y=mean_dp, name='Mean Power'), row=1, col=1) fig.add_trace(go.Scatter(y=mean_dp, name="Mean Power"), row=1, col=1)
fig.add_trace(go.Heatmap(z=np.abs(data).T, showscale=False, name='Heat Map'), row=2, col=1) fig.add_trace(
go.Heatmap(z=np.abs(data).T, showscale=False, name="Heat Map"), row=2, col=1
)
fig.update_layout(title='Mean DP Power and 2D Map', autosize=True) fig.update_layout(title="Mean DP Power and 2D Map", autosize=True)
fig.update_xaxes(visible=False, row=2, col=1) fig.update_xaxes(visible=False, row=2, col=1)
fig.update_yaxes(visible=False, row=2, col=1) fig.update_yaxes(visible=False, row=2, col=1)
@@ -98,6 +101,8 @@ def plot_cdfs(data_list, labels):
fig = go.Figure() fig = go.Figure()
for data, label in zip(data_list, labels): for data, label in zip(data_list, labels):
sorted_data, cdf = calculate_cdf(data) sorted_data, cdf = calculate_cdf(data)
fig.add_trace(go.Scatter(x=sorted_data, y=cdf, mode='lines', name=label)) fig.add_trace(go.Scatter(x=sorted_data, y=cdf, mode="lines", name=label))
fig.update_layout(title='CDF of Data', xaxis_title='Value', yaxis_title='CDF', autosize=True) fig.update_layout(
title="CDF of Data", xaxis_title="Value", yaxis_title="CDF", autosize=True
)
return fig return fig

View File

@@ -16,6 +16,7 @@ dependencies = [
"statsmodels", "statsmodels",
"matplotlib", "matplotlib",
"seaborn", "seaborn",
"ruff>=0.15.9",
] ]
classifiers = ["Private :: Do Not Upload"] classifiers = ["Private :: Do Not Upload"]

29
uv.lock generated
View File

@@ -1,5 +1,5 @@
version = 1 version = 1
revision = 2 revision = 3
requires-python = ">=3.11" requires-python = ">=3.11"
resolution-markers = [ resolution-markers = [
"python_full_version >= '3.14' and sys_platform == 'win32'", "python_full_version >= '3.14' and sys_platform == 'win32'",
@@ -136,6 +136,7 @@ dependencies = [
{ name = "numpy" }, { name = "numpy" },
{ name = "pathlib2" }, { name = "pathlib2" },
{ name = "plotly" }, { name = "plotly" },
{ name = "ruff" },
{ name = "scipy" }, { name = "scipy" },
{ name = "seaborn" }, { name = "seaborn" },
{ name = "statsmodels" }, { name = "statsmodels" },
@@ -151,6 +152,7 @@ requires-dist = [
{ name = "numpy" }, { name = "numpy" },
{ name = "pathlib2" }, { name = "pathlib2" },
{ name = "plotly" }, { name = "plotly" },
{ name = "ruff", specifier = ">=0.15.9" },
{ name = "scipy" }, { name = "scipy" },
{ name = "seaborn" }, { name = "seaborn" },
{ name = "statsmodels" }, { name = "statsmodels" },
@@ -1562,6 +1564,31 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/d1/b7/b95708304cd49b7b6f82fdd039f1748b66ec2b21d6a45180910802f1abf1/rpds_py-0.30.0-pp311-pypy311_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:ac37f9f516c51e5753f27dfdef11a88330f04de2d564be3991384b2f3535d02e", size = 562191, upload-time = "2025-11-30T20:24:36.853Z" }, { url = "https://files.pythonhosted.org/packages/d1/b7/b95708304cd49b7b6f82fdd039f1748b66ec2b21d6a45180910802f1abf1/rpds_py-0.30.0-pp311-pypy311_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:ac37f9f516c51e5753f27dfdef11a88330f04de2d564be3991384b2f3535d02e", size = 562191, upload-time = "2025-11-30T20:24:36.853Z" },
] ]
[[package]]
name = "ruff"
version = "0.15.9"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/e6/97/e9f1ca355108ef7194e38c812ef40ba98c7208f47b13ad78d023caa583da/ruff-0.15.9.tar.gz", hash = "sha256:29cbb1255a9797903f6dde5ba0188c707907ff44a9006eb273b5a17bfa0739a2", size = 4617361, upload-time = "2026-04-02T18:17:20.829Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/0b/1f/9cdfd0ac4b9d1e5a6cf09bedabdf0b56306ab5e333c85c87281273e7b041/ruff-0.15.9-py3-none-linux_armv6l.whl", hash = "sha256:6efbe303983441c51975c243e26dff328aca11f94b70992f35b093c2e71801e1", size = 10511206, upload-time = "2026-04-02T18:16:41.574Z" },
{ url = "https://files.pythonhosted.org/packages/3d/f6/32bfe3e9c136b35f02e489778d94384118bb80fd92c6d92e7ccd97db12ce/ruff-0.15.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:4965bac6ac9ea86772f4e23587746f0b7a395eccabb823eb8bfacc3fa06069f7", size = 10923307, upload-time = "2026-04-02T18:17:08.645Z" },
{ url = "https://files.pythonhosted.org/packages/ca/25/de55f52ab5535d12e7aaba1de37a84be6179fb20bddcbe71ec091b4a3243/ruff-0.15.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:eaf05aad70ca5b5a0a4b0e080df3a6b699803916d88f006efd1f5b46302daab8", size = 10316722, upload-time = "2026-04-02T18:16:44.206Z" },
{ url = "https://files.pythonhosted.org/packages/48/11/690d75f3fd6278fe55fff7c9eb429c92d207e14b25d1cae4064a32677029/ruff-0.15.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9439a342adb8725f32f92732e2bafb6d5246bd7a5021101166b223d312e8fc59", size = 10623674, upload-time = "2026-04-02T18:16:50.951Z" },
{ url = "https://files.pythonhosted.org/packages/bd/ec/176f6987be248fc5404199255522f57af1b4a5a1b57727e942479fec98ad/ruff-0.15.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9c5e6faf9d97c8edc43877c3f406f47446fc48c40e1442d58cfcdaba2acea745", size = 10351516, upload-time = "2026-04-02T18:16:57.206Z" },
{ url = "https://files.pythonhosted.org/packages/b2/fc/51cffbd2b3f240accc380171d51446a32aa2ea43a40d4a45ada67368fbd2/ruff-0.15.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7b34a9766aeec27a222373d0b055722900fbc0582b24f39661aa96f3fe6ad901", size = 11150202, upload-time = "2026-04-02T18:17:06.452Z" },
{ url = "https://files.pythonhosted.org/packages/d6/d4/25292a6dfc125f6b6528fe6af31f5e996e19bf73ca8e3ce6eb7fa5b95885/ruff-0.15.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:89dd695bc72ae76ff484ae54b7e8b0f6b50f49046e198355e44ea656e521fef9", size = 11988891, upload-time = "2026-04-02T18:17:18.575Z" },
{ url = "https://files.pythonhosted.org/packages/13/e1/1eebcb885c10e19f969dcb93d8413dfee8172578709d7ee933640f5e7147/ruff-0.15.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ce187224ef1de1bd225bc9a152ac7102a6171107f026e81f317e4257052916d5", size = 11480576, upload-time = "2026-04-02T18:16:52.986Z" },
{ url = "https://files.pythonhosted.org/packages/ff/6b/a1548ac378a78332a4c3dcf4a134c2475a36d2a22ddfa272acd574140b50/ruff-0.15.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2b0c7c341f68adb01c488c3b7d4b49aa8ea97409eae6462d860a79cf55f431b6", size = 11254525, upload-time = "2026-04-02T18:17:02.041Z" },
{ url = "https://files.pythonhosted.org/packages/42/aa/4bb3af8e61acd9b1281db2ab77e8b2c3c5e5599bf2a29d4a942f1c62b8d6/ruff-0.15.9-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:55cc15eee27dc0eebdfcb0d185a6153420efbedc15eb1d38fe5e685657b0f840", size = 11204072, upload-time = "2026-04-02T18:17:13.581Z" },
{ url = "https://files.pythonhosted.org/packages/69/48/d550dc2aa6e423ea0bcc1d0ff0699325ffe8a811e2dba156bd80750b86dc/ruff-0.15.9-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:a6537f6eed5cda688c81073d46ffdfb962a5f29ecb6f7e770b2dc920598997ed", size = 10594998, upload-time = "2026-04-02T18:16:46.369Z" },
{ url = "https://files.pythonhosted.org/packages/63/47/321167e17f5344ed5ec6b0aa2cff64efef5f9e985af8f5622cfa6536043f/ruff-0.15.9-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:6d3fcbca7388b066139c523bda744c822258ebdcfbba7d24410c3f454cc9af71", size = 10359769, upload-time = "2026-04-02T18:17:10.994Z" },
{ url = "https://files.pythonhosted.org/packages/67/5e/074f00b9785d1d2c6f8c22a21e023d0c2c1817838cfca4c8243200a1fa87/ruff-0.15.9-py3-none-musllinux_1_2_i686.whl", hash = "sha256:058d8e99e1bfe79d8a0def0b481c56059ee6716214f7e425d8e737e412d69677", size = 10850236, upload-time = "2026-04-02T18:16:48.749Z" },
{ url = "https://files.pythonhosted.org/packages/76/37/804c4135a2a2caf042925d30d5f68181bdbd4461fd0d7739da28305df593/ruff-0.15.9-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:8e1ddb11dbd61d5983fa2d7d6370ef3eb210951e443cace19594c01c72abab4c", size = 11358343, upload-time = "2026-04-02T18:16:55.068Z" },
{ url = "https://files.pythonhosted.org/packages/88/3d/1364fcde8656962782aa9ea93c92d98682b1ecec2f184e625a965ad3b4a6/ruff-0.15.9-py3-none-win32.whl", hash = "sha256:bde6ff36eaf72b700f32b7196088970bf8fdb2b917b7accd8c371bfc0fd573ec", size = 10583382, upload-time = "2026-04-02T18:17:04.261Z" },
{ url = "https://files.pythonhosted.org/packages/4c/56/5c7084299bd2cacaa07ae63a91c6f4ba66edc08bf28f356b24f6b717c799/ruff-0.15.9-py3-none-win_amd64.whl", hash = "sha256:45a70921b80e1c10cf0b734ef09421f71b5aa11d27404edc89d7e8a69505e43d", size = 11744969, upload-time = "2026-04-02T18:16:59.611Z" },
{ url = "https://files.pythonhosted.org/packages/03/36/76704c4f312257d6dbaae3c959add2a622f63fcca9d864659ce6d8d97d3d/ruff-0.15.9-py3-none-win_arm64.whl", hash = "sha256:0694e601c028fd97dc5c6ee244675bc241aeefced7ef80cd9c6935a871078f53", size = 11005870, upload-time = "2026-04-02T18:17:15.773Z" },
]
[[package]] [[package]]
name = "scipy" name = "scipy"
version = "1.17.1" version = "1.17.1"