From d053ebf02cb5a3057e890224935b18587694091f Mon Sep 17 00:00:00 2001 From: neutonsevero Date: Wed, 8 Apr 2026 21:48:19 -0300 Subject: [PATCH] REFACTOR: Fitter class refactored. Include getter and setter ADD: test/ dir with code tests --- etc/fitting/fitter.py | 504 +++++++++++++++++---------------------- etc/tests/test_fitter.py | 267 +++++++++++++++++++++ etc/tools/plots.py | 19 +- pyproject.toml | 1 + uv.lock | 29 ++- 5 files changed, 529 insertions(+), 291 deletions(-) create mode 100644 etc/tests/test_fitter.py diff --git a/etc/fitting/fitter.py b/etc/fitting/fitter.py index 891ec7f..91916ac 100644 --- a/etc/fitting/fitter.py +++ b/etc/fitting/fitter.py @@ -5,6 +5,7 @@ import plotly.graph_objects as go import matplotlib.pyplot as plt import seaborn as sns from scipy.stats import rv_continuous, goodness_of_fit +from plotly.subplots import make_subplots @dataclass @@ -44,79 +45,41 @@ class DistributionSummary: args_fit_params: tuple = field(default_factory=tuple) kwds_fit_params: dict = field(default_factory=dict) fit_result_params: tuple = field(default_factory=tuple) - statistic_method: str = 'ad' + statistic_method: str = "ad" test_result: object = None - # ── properties ──────────────────────────────────────────────── + # ── convenience properties ──────────────────────────────────────────────── @property def pvalue(self) -> float | None: - """p-value from the goodness-of-fit test. - - Returns - ------- - float or None - The p-value produced by the GoF test, or None if validate() has - not been called yet. - """ + """p-value from the goodness-of-fit test, or None if not yet run.""" return self.test_result.pvalue if self.test_result is not None else None @property def gof_statistic(self) -> float | None: - """Test statistic from the goodness-of-fit test. - - Returns - ------- - float or None - The GoF statistic value, or None if validate() has not been - called yet. - """ + """Test statistic from the goodness-of-fit test, or None if not yet run.""" return self.test_result.statistic if self.test_result is not None else None @property def mean(self) -> float: - """Mean of the fitted distribution. - - Returns - ------- - float - Mean computed from the fitted distribution parameters. - """ - return self.distribution_object.mean(*self.fit_result_params) + """Mean of the fitted distribution.""" + return float(self.distribution_object.mean(*self.fit_result_params)) @property def std(self) -> float: - """Standard deviation of the fitted distribution. - - Returns - ------- - float - Standard deviation computed from the fitted distribution parameters. - """ - return self.distribution_object.std(*self.fit_result_params) + """Standard deviation of the fitted distribution.""" + return float(self.distribution_object.std(*self.fit_result_params)) @property def var(self) -> float: - """Variance of the fitted distribution. - - Returns - ------- - float - Variance computed from the fitted distribution parameters. - """ - return self.distribution_object.var(*self.fit_result_params) + """Variance of the fitted distribution.""" + return float(self.distribution_object.var(*self.fit_result_params)) def __repr__(self) -> str: - """Return a concise string representation of the distribution summary. - - Returns - ------- - str - Single-line string showing the distribution name, fitted - parameters, mean, standard deviation, GoF statistic, and p-value. - """ pval_str = f"{self.pvalue:.4f}" if self.pvalue is not None else "N/A" - stat_str = f"{self.gof_statistic:.4f}" if self.gof_statistic is not None else "N/A" + stat_str = ( + f"{self.gof_statistic:.4f}" if self.gof_statistic is not None else "N/A" + ) return ( f"DistributionSummary({self.distribution_name})" f" params={self.fit_result_params}" @@ -125,13 +88,6 @@ class DistributionSummary: ) def __str__(self) -> str: - """Return the same string representation as ``__repr__``. - - Returns - ------- - str - Delegates to :meth:`__repr__`. - """ return self.__repr__() @@ -156,347 +112,329 @@ class Fitter: gamma_params={'floc': 0}, weibull_min_args=(1.5, 0.0, 1.0), ) - - Access fitted distributions - --------------------------- - Distributions can be accessed by name (str) or by the rv_continuous object: - - fitter = Fitter([gamma, weibull_min]) - fitter.fit(data) - - # Access by name - summary = fitter['gamma'] - - # Access by rv_continuous object - summary = fitter[gamma] - - # Check if a distribution is present - gamma in fitter # True - 'weibull_min' in fitter # True - - # Override a DistributionSummary - fitter['gamma'] = new_summary """ - def __init__(self, dist_list: list[rv_continuous], statistic_method: str = 'ad', **kwargs): - """Initialise the Fitter and build per-distribution summary objects. - - Parameters - ---------- - dist_list : list[rv_continuous] - Distributions to fit. - statistic_method : str, optional - Goodness-of-fit statistic passed to ``goodness_of_fit`` - (default ``'ad'`` for Anderson-Darling). - **kwargs : - Per-distribution initial guesses and fixed parameters, keyed as - ``_args`` (tuple) or ``_params`` (dict). - """ + def __init__( + self, dist_list: list[rv_continuous], statistic_method: str = "ad", **kwargs + ): self._dist: dict[str, DistributionSummary] = {} - self.dist_list = list(dist_list) + self.dist_list = list(dist_list) # Ensure it's a list for multiple iterations for dist in dist_list: self._dist[dist.name] = DistributionSummary( distribution_object=dist, distribution_name=dist.name, - args_fit_params=kwargs.get(f'{dist.name}_args', ()), - kwds_fit_params=kwargs.get(f'{dist.name}_params', {}), + args_fit_params=kwargs.get(f"{dist.name}_args", ()), + kwds_fit_params=kwargs.get(f"{dist.name}_params", {}), statistic_method=statistic_method, test_result=None, ) - #── Getter and setter with flexible keys (str or rv_continuous) ──────────────────────────────── def _resolve_key(self, key: str | rv_continuous) -> str: - """Resolve a distribution name or ``rv_continuous`` object to its string key. - - Parameters - ---------- - key : str or rv_continuous - Either the distribution's string name or its ``rv_continuous`` - object (whose ``.name`` attribute is used). - - Returns - ------- - str - The string key used in the internal ``_dist`` dictionary. - - Raises - ------ - KeyError - If the resolved name is not found among the registered - distributions. - """ + """Resolve a distribution name or rv_continuous object to its string key.""" name = key.name if isinstance(key, rv_continuous) else key if name not in self._dist: - available = ', '.join(self._dist) + available = ", ".join(self._dist) raise KeyError(f"Distribution '{name}' not found. Available: {available}") return name def __contains__(self, key: str | rv_continuous) -> bool: - """Check whether a distribution is registered with this Fitter. - - Parameters - ---------- - key : str or rv_continuous - Distribution name or ``rv_continuous`` object to look up. - - Returns - ------- - bool - True if the distribution is registered, False otherwise. - """ name = key.name if isinstance(key, rv_continuous) else key return name in self._dist def __getitem__(self, key: str | rv_continuous) -> DistributionSummary: - """Retrieve the :class:`DistributionSummary` for the given distribution. - - Parameters - ---------- - key : str or rv_continuous - Distribution name or ``rv_continuous`` object to retrieve. - - Returns - ------- - DistributionSummary - The summary object associated with the requested distribution. - - Raises - ------ - KeyError - If the distribution is not registered. - """ return self._dist[self._resolve_key(key)] - def __setitem__(self, key: str | rv_continuous, summary: DistributionSummary) -> None: - """Override the :class:`DistributionSummary` for an existing distribution. - - Parameters - ---------- - key : str or rv_continuous - Distribution name or ``rv_continuous`` object to update. - summary : DistributionSummary - Replacement summary object. - - Raises - ------ - TypeError - If ``summary`` is not a :class:`DistributionSummary` instance. - KeyError - If the distribution is not registered. - """ + def __setitem__( + self, key: str | rv_continuous, summary: DistributionSummary + ) -> None: if not isinstance(summary, DistributionSummary): - raise TypeError(f"Expected DistributionSummary, got {type(summary).__name__}.") + raise TypeError( + f"Expected DistributionSummary, got {type(summary).__name__}." + ) self._dist[self._resolve_key(key)] = summary + def __iter__(self): + return self + def __next__(self): + if not hasattr(self, "_iter_index"): + self._iter_index = 0 + if self._iter_index >= len(self._dist): + del self._iter_index + raise StopIteration + key = list(self._dist.keys())[self._iter_index] + summary = self._dist[key] + self._iter_index += 1 + return key, summary - def fit(self, data: np.ndarray) -> None: - """Fit every distribution to *data* via MLE. + def fit(self, data: np.ndarray) -> dict[str, DistributionSummary]: + """ + Fit every distribution to *data* via MLE. Parameters ---------- data : array-like - Input data. Only the absolute value is used; the array is - flattened before fitting. - + Input data. Only the absolute value is used. Returns ------- - None - Fitted parameters are stored in-place inside each - :class:`DistributionSummary` held by this Fitter. + dict[str, DistributionSummary] + Mapping of distribution name → summary (test_result is None + until validate() is called). """ data_flat = np.abs(data).flatten() self._last_data_flat = data_flat for dist in self.dist_list: _summary = self._dist[dist.name] - fit_params = dist.fit(data_flat, *_summary.args_fit_params, **_summary.kwds_fit_params) + fit_params = dist.fit( + data_flat, *_summary.args_fit_params, **_summary.kwds_fit_params + ) _summary.fit_result_params = fit_params self._dist[dist.name] = _summary - def validate(self, **kwargs) -> None: - """Run the goodness-of-fit test on every previously fitted distribution. + def validate(self, **kwargs) -> dict[str, DistributionSummary]: + """ + Run the goodness-of-fit test on every previously fitted distribution. Parameters ---------- **kwargs : - Extra keyword arguments forwarded to ``scipy.stats.goodness_of_fit`` - (e.g. ``n_mc_samples=100``). + Extra keyword arguments forwarded to goodness_of_fit() + (e.g. n_mc_samples=100). Returns ------- - None - Test results are stored in-place inside each - :class:`DistributionSummary` held by this Fitter. - - Raises - ------ - RuntimeError - If :meth:`fit` has not been called before this method. + dict[str, DistributionSummary] + Same results dict, with test_result populated for each distribution. """ - if not hasattr(self, '_last_data_flat'): + if not hasattr(self, "_last_data_flat"): raise RuntimeError("No data available. Call fit() first.") data_flat = self._last_data_flat for dist in self.dist_list: _summary = self._dist[dist.name] test_result = goodness_of_fit( - dist, - data_flat, - statistic=_summary.statistic_method, - **kwargs + dist, data_flat, statistic=_summary.statistic_method, **kwargs ) _summary.test_result = test_result self._dist[dist.name] = _summary - def summary(self) -> None: - """Print a one-line summary for each registered distribution. - - Returns - ------- - None - Output is written to stdout via ``print``. + def summary(self) -> dict[str, DistributionSummary]: + """ + Print a summary of all fitted distributions, including parameters and GoF results. """ for dist_name, summary in self._dist.items(): print(summary) - def plot_qq_plots(self) -> None: - """Generate QQ plots for each fitted distribution against the data. - - A separate interactive Plotly figure is displayed for every - distribution that has been both fitted and validated. Distributions - that have not yet been validated are skipped with a printed warning. - - Returns - ------- - None - Figures are rendered inline / in a browser via ``fig.show()``. - - Raises - ------ - RuntimeError - If :meth:`fit` has not been called before this method. + def plot_qq_plots(self, method: str = "hazen"): """ - if not hasattr(self, '_last_data_flat'): + Generate QQ plots for each fitted distribution against the data. + Requires that fit() and validate() have been called to populate parameters and test results. + + Parameters + ---------- + method : str + Plotting positions formula. Either 'hazen' (default) or 'filliben'. + - 'hazen' : p_i = (i - 0.5) / n + - 'filliben': p_1 = 1 - 0.5^(1/n), p_n = 0.5^(1/n), + p_i = (i - 0.3175) / (n + 0.365) for 1 < i < n + """ + if not hasattr(self, "_last_data_flat"): raise RuntimeError("No data available. Call fit() first.") + if method not in ("hazen", "filliben"): + raise ValueError(f"method must be 'hazen' or 'filliben', got '{method}'.") data_flat = self._last_data_flat - for dist_name, summary in self._dist.items(): + # generate subplots with 2 columns and as many rows as needed, but not more than 3 rows, if there are more than 6 distributions, create multiple figures + num_dists = len(self._dist) + num_cols = 2 + num_rows = min(3, (num_dists + 1) // 2) + fig = make_subplots( + rows=num_rows, + cols=num_cols, + subplot_titles=[dist_name for dist_name in self._dist.keys()], + ) + + for dist_name, summary in self: if summary.test_result is None: - print(f"Distribution '{dist_name}' has not been validated yet. Skipping QQ plot.") + print( + f"Distribution '{dist_name}' has not been validated yet. Skipping QQ plot." + ) continue + # Generate theoretical quantiles sorted_data = np.sort(data_flat) + n = len(sorted_data) + i = np.arange(1, n + 1) + if method == "hazen": + plotting_positions = (i - 0.5) / n + else: # filliben + plotting_positions = (i - 0.3175) / (n + 0.365) + plotting_positions[0] = 1 - 0.5 ** (1 / n) + plotting_positions[-1] = 0.5 ** (1 / n) theoretical_quantiles = summary.distribution_object.ppf( - (np.arange(1, len(sorted_data) + 1) - 0.5) / len(sorted_data), - *summary.fit_result_params + plotting_positions, *summary.fit_result_params ) - fig = go.Figure() - fig.add_trace(go.Scatter(x=theoretical_quantiles, y=sorted_data, mode='markers', name='Data vs. Fit')) - fig.add_trace(go.Scatter(x=theoretical_quantiles, y=theoretical_quantiles, mode='lines', name='Ideal Fit', line=dict(dash='dash'))) - fig.update_layout( - title=f'QQ Plot for {summary.distribution_name}', - xaxis_title='Theoretical Quantiles', - yaxis_title='Empirical Quantiles', - autosize=True, + # Create QQ plot in each subplot + row = (list(self._dist.keys()).index(dist_name) // num_cols) + 1 + col = (list(self._dist.keys()).index(dist_name) % num_cols) + 1 + fig.add_trace( + go.Scatter( + x=theoretical_quantiles, + y=sorted_data, + mode="markers", + name=dist_name, + ), + row=row, + col=col, ) - fig.show() + # Add a reference line y=x + min_val = min(theoretical_quantiles.min(), sorted_data.min()) + max_val = max(theoretical_quantiles.max(), sorted_data.max()) + fig.add_trace( + go.Scatter( + x=[min_val, max_val], + y=[min_val, max_val], + mode="lines", + name="y=x", + line=dict(dash="dash"), + ), + row=row, + col=col, + ) + fig.update_xaxes(title_text="Theoretical Quantiles", row=row, col=col) + fig.update_yaxes(title_text="Empirical Quantiles", row=row, col=col) + # add statistic value in bottom right of each subplot (summary.gof_statistic()) + fig.add_annotation( + x=0.95, + y=0.05, + xref="x domain", + yref="y domain", + text=f"{summary.statistic_method}={summary.gof_statistic:.4f}", + showarrow=False, + font=dict(size=10, color="green" if summary.pvalue > 0.05 else "red"), + row=row, + col=col, + ) + fig.update_layout( + title=f"QQ Plots of Fitted Distributions ({method})", + showlegend=False, + autosize=True, + ) + return fig - def histogram_with_fits(self) -> go.Figure: - """Return an interactive histogram with overlaid fitted PDFs (Plotly). - - Builds a probability-density histogram of the data and overlays a - line trace for the PDF of each fitted distribution. Hover text shows - the p-value and GoF statistic for each curve. Distributions that - have not yet been fitted are skipped with a printed warning. - - Returns - ------- - plotly.graph_objects.Figure - Interactive Plotly figure ready to display with ``fig.show()``. - - Raises - ------ - RuntimeError - If :meth:`fit` has not been called before this method. + def histogram_with_fits(self): """ - if not hasattr(self, '_last_data_flat'): + Generate a histogram of the data with overlaid PDFs of each fitted distribution. + Requires that fit() has been called to populate parameters. + """ + if not hasattr(self, "_last_data_flat"): raise RuntimeError("No data available. Call fit() first.") data_flat = self._last_data_flat x = np.linspace(0, data_flat.max(), 1000) - fig = go.Figure(layout=go.Layout(hovermode='x unified')) - fig.add_trace(go.Histogram( - x=data_flat, name='Data', opacity=0.3, - histnorm='probability density', hoverinfo='skip', marker_color='blue', - )) + fig = go.Figure(layout=go.Layout(hovermode="x unified")) + # do not show data in hoover, only show the distribution name, p-value and GoF statistic for each distribution + fig.add_trace( + go.Histogram( + x=data_flat, + name="Data", + opacity=0.3, + histnorm="probability density", + hoverinfo="skip", + marker_color="blue", + ) + ) for dist_name, summary in self._dist.items(): if not summary.fit_result_params: - print(f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay.") + print( + f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay." + ) continue pdf_values = summary.distribution_object.pdf(x, *summary.fit_result_params) - fig.add_trace(go.Scatter(x=x, y=pdf_values, mode='lines', name=f'{summary.distribution_name} Fit')) + # add trace and stack hoover x like stock price, but make the y value shows the p-value and GoF statistic for each distribution + fig.add_trace( + go.Scatter( + x=x, + y=pdf_values, + mode="lines", + name=f"{summary.distribution_name} Fit", + ) + ) hover_text = [ f"{summary.distribution_name} p-value: {summary.pvalue:.4f} GoF: {summary.gof_statistic:.4f}" for _ in x ] fig.data[-1].update(hovertext=hover_text, hoverinfo="text") + # add grid + fig.update_layout(xaxis=dict(showgrid=True), yaxis=dict(showgrid=True)) + # put title in top left, make it smaller, change it font to sans and put in light gray fig.update_layout( - xaxis=dict(showgrid=True), - yaxis=dict(showgrid=True), title=dict( - text='Histogram of Data with Fitted Distributions', - x=0.02, y=0.95, xanchor='left', yanchor='top', - font=dict(size=20, color='darkgray', family='sans-serif'), + text="Histogram of Data with Fitted Distributions", + x=0.02, + y=0.95, + xanchor="left", + yanchor="top", + font=dict(size=20, color="darkgray", family="sans-serif"), ), - xaxis_title='Value', - yaxis_title='Density', + xaxis_title="Value", + yaxis_title="Density", autosize=True, ) + return fig - def histogram_with_fits_seaborn(self) -> plt.Figure: - """Return a static histogram with overlaid fitted PDFs (Matplotlib/Seaborn). - - Builds a probability-density histogram using Seaborn and overlays a - line for the PDF of each fitted distribution. The legend entry for - each distribution includes its p-value. Distributions that have not - yet been fitted are skipped with a printed warning. - - Returns - ------- - matplotlib.figure.Figure - Matplotlib figure object. - - Raises - ------ - RuntimeError - If :meth:`fit` has not been called before this method. + def histogram_with_fits_seaborn(self): """ - if not hasattr(self, '_last_data_flat'): + Generate a histogram of the data with overlaid PDFs of each fitted distribution using seaborn. + Requires that fit() has been called to populate parameters. + """ + if not hasattr(self, "_last_data_flat"): raise RuntimeError("No data available. Call fit() first.") data_flat = self._last_data_flat x = np.linspace(0, data_flat.max(), 1000) fig, ax = plt.subplots(figsize=(10, 6)) - sns.histplot(data_flat, bins=int(np.sqrt(len(data_flat))), kde=False, stat='density', color='blue', alpha=0.2, ax=ax) + sns.histplot( + data_flat, + bins=int(np.sqrt(len(data_flat))), + kde=False, + stat="density", + color="blue", + alpha=0.2, + ax=ax, + ) for dist_name, summary in self._dist.items(): if not summary.fit_result_params: - print(f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay.") + print( + f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay." + ) continue pdf_values = summary.distribution_object.pdf(x, *summary.fit_result_params) - ax.plot(x, pdf_values, label=f'{summary.distribution_name} --- p={summary.pvalue:.4f}') - - ax.set_title('Histogram of Data with Fitted Distributions', fontsize=8, loc='left', color='darkgray', fontfamily='sans-serif') - ax.set_xlabel('Value') - ax.set_ylabel('Density') + ax.plot( + x, + pdf_values, + label=f"{summary.distribution_name} --- p={summary.pvalue:.4f}", + ) + # put title in top left, make it smaller, change it font to sans and put in light gray + ax.set_title( + "Histogram of Data with Fitted Distributions", + fontsize=8, + loc="left", + color="darkgray", + fontfamily="sans-serif", + ) + ax.set_xlabel("Value") + ax.set_ylabel("Density") ax.legend() ax.grid(True) + return fig diff --git a/etc/tests/test_fitter.py b/etc/tests/test_fitter.py new file mode 100644 index 0000000..e7f6be1 --- /dev/null +++ b/etc/tests/test_fitter.py @@ -0,0 +1,267 @@ +import numpy as np +import pytest +from scipy.stats import norm, gamma, expon, weibull_min +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from fitting.fitter import DistributionSummary, Fitter + + +# ── Data ────────────────────────────────────────────────────────────────── + +RNG = np.random.default_rng(42) +GAMMA_DATA = RNG.gamma(shape=2.0, scale=1.5, size=200) +NORM_DATA = RNG.normal(loc=5.0, scale=1.0, size=200) + + +# ── DistributionSummary ─────────────────────────────────────────────────────── + + +class TestDistributionSummary: + def _make_summary(self, dist=gamma, fit_params=(2.0, 0.0, 1.5)): + return DistributionSummary( + distribution_object=dist, + distribution_name=dist.name, + fit_result_params=fit_params, + ) + + def test_mean_std_var_are_floats(self): + s = self._make_summary() + assert isinstance(s.mean, float) + assert isinstance(s.std, float) + assert isinstance(s.var, float) + + def test_var_equals_std_squared(self): + s = self._make_summary() + assert pytest.approx(s.var, rel=1e-6) == s.std**2 + + def test_pvalue_none_before_validate(self): + s = self._make_summary() + assert s.pvalue is None + + def test_gof_statistic_none_before_validate(self): + s = self._make_summary() + assert s.gof_statistic is None + + def test_repr_contains_name(self): + s = self._make_summary() + assert "gamma" in repr(s) + + def test_repr_shows_na_when_no_test(self): + s = self._make_summary() + assert "N/A" in repr(s) + + def test_default_statistic_method(self): + s = self._make_summary() + assert s.statistic_method == "ad" + + +# ── Fitter construction ─────────────────────────────────────────────────────── + + +class TestFitterConstruction: + def test_distributions_registered(self): + f = Fitter([gamma, expon]) + assert "gamma" in f + assert "expon" in f + + def test_getitem_by_name(self): + f = Fitter([gamma]) + s = f["gamma"] + assert isinstance(s, DistributionSummary) + + def test_getitem_by_object(self): + f = Fitter([gamma]) + s = f[gamma] + assert isinstance(s, DistributionSummary) + + def test_getitem_missing_raises(self): + f = Fitter([gamma]) + with pytest.raises(KeyError): + _ = f["norm"] + + def test_contains_true(self): + f = Fitter([gamma]) + assert gamma in f + assert "gamma" in f + + def test_contains_false(self): + f = Fitter([gamma]) + assert norm not in f + + def test_kwargs_args_stored(self): + f = Fitter([gamma], gamma_args=(2.0,), gamma_params={"floc": 0}) + assert f["gamma"].args_fit_params == (2.0,) + assert f["gamma"].kwds_fit_params == {"floc": 0} + + def test_setitem_valid(self): + f = Fitter([gamma]) + original = f["gamma"] + f["gamma"] = original + assert f["gamma"] is original + + def test_setitem_invalid_type(self): + f = Fitter([gamma]) + with pytest.raises(TypeError): + f["gamma"] = "not_a_summary" + + def test_setitem_missing_key_raises(self): + f = Fitter([gamma]) + dummy = DistributionSummary(distribution_object=norm, distribution_name="norm") + with pytest.raises(KeyError): + f["norm"] = dummy + + +# ── Fitter iteration ────────────────────────────────────────────────────────── + + +class TestFitterIteration: + def test_iterates_all_distributions(self): + f = Fitter([gamma, expon, weibull_min]) + names = [name for name, _ in f] + assert set(names) == {"gamma", "expon", "weibull_min"} + + def test_iterate_twice(self): + f = Fitter([gamma, expon]) + first = [name for name, _ in f] + second = [name for name, _ in f] + assert first == second + + def test_iterator_yields_summary_instances(self): + f = Fitter([gamma]) + for _, summary in f: + assert isinstance(summary, DistributionSummary) + + +# ── Fitter.fit ──────────────────────────────────────────────────────────────── + + +class TestFitterFit: + def test_fit_populates_params(self): + f = Fitter([gamma], gamma_params={"floc": 0}) + f.fit(GAMMA_DATA) + assert len(f["gamma"].fit_result_params) > 0 + + def test_fit_uses_abs_value(self): + f = Fitter([gamma], gamma_params={"floc": 0}) + neg_data = -np.abs(GAMMA_DATA) + f.fit(neg_data) + assert len(f["gamma"].fit_result_params) > 0 + + def test_fit_flattens_2d(self): + f = Fitter([gamma], gamma_params={"floc": 0}) + data_2d = GAMMA_DATA.reshape(10, 20) + f.fit(data_2d) + assert hasattr(f, "_last_data_flat") + assert f._last_data_flat.ndim == 1 + + def test_fit_multiple_dists(self): + f = Fitter([gamma, expon], gamma_params={"floc": 0}, expon_params={"floc": 0}) + f.fit(GAMMA_DATA) + assert len(f["gamma"].fit_result_params) > 0 + assert len(f["expon"].fit_result_params) > 0 + + +# ── Fitter.validate ─────────────────────────────────────────────────────────── + + +class TestFitterValidate: + def test_validate_without_fit_raises(self): + f = Fitter([gamma]) + with pytest.raises(RuntimeError, match="fit\\(\\)"): + f.validate() + + def test_validate_populates_test_result(self): + f = Fitter([gamma], gamma_params={"floc": 0}) + f.fit(GAMMA_DATA) + f.validate(n_mc_samples=99) + assert f["gamma"].test_result is not None + + def test_validate_pvalue_in_range(self): + f = Fitter([gamma], gamma_params={"floc": 0}) + f.fit(GAMMA_DATA) + f.validate(n_mc_samples=99) + pval = f["gamma"].pvalue + assert 0.0 <= pval <= 1.0 + + def test_validate_gof_statistic_positive(self): + f = Fitter([gamma], gamma_params={"floc": 0}) + f.fit(GAMMA_DATA) + f.validate(n_mc_samples=99) + assert f["gamma"].gof_statistic >= 0.0 + + def test_validate_custom_statistic(self): + f = Fitter([gamma], statistic_method="ks", gamma_params={"floc": 0}) + f.fit(GAMMA_DATA) + f.validate(n_mc_samples=99) + assert f["gamma"].test_result is not None + + +# ── Fitter.summary ──────────────────────────────────────────────────────────── + + +class TestFitterSummary: + def test_summary_runs_without_error(self, capsys): + f = Fitter([gamma], gamma_params={"floc": 0}) + f.fit(GAMMA_DATA) + f.validate(n_mc_samples=99) + f.summary() + captured = capsys.readouterr() + assert "gamma" in captured.out + + +# ── Fitter.plot_qq_plots ────────────────────────────────────────────────────── + + +class TestFitterPlotQQ: + def setup_method(self): + self.f = Fitter([gamma], gamma_params={"floc": 0}) + self.f.fit(GAMMA_DATA) + self.f.validate(n_mc_samples=99) + + def test_qq_without_fit_raises(self): + f = Fitter([gamma]) + with pytest.raises(RuntimeError, match="fit\\(\\)"): + f.plot_qq_plots() + + def test_qq_invalid_method_raises(self): + with pytest.raises(ValueError, match="method"): + self.f.plot_qq_plots(method="invalid") + + def test_qq_hazen_returns_figure(self): + import plotly.graph_objects as go + fig = self.f.plot_qq_plots(method="hazen") + assert isinstance(fig, go.Figure) + + def test_qq_filliben_returns_figure(self): + import plotly.graph_objects as go + fig = self.f.plot_qq_plots(method="filliben") + assert isinstance(fig, go.Figure) + + +# ── Fitter.histogram_with_fits ──────────────────────────────────────────────── + + +class TestFitterHistogram: + def setup_method(self): + self.f = Fitter([gamma], gamma_params={"floc": 0}) + self.f.fit(GAMMA_DATA) + self.f.validate(n_mc_samples=99) + + def test_histogram_without_fit_raises(self): + f = Fitter([gamma]) + with pytest.raises(RuntimeError, match="fit\\(\\)"): + f.histogram_with_fits() + + def test_histogram_returns_figure(self): + import plotly.graph_objects as go + fig = self.f.histogram_with_fits() + assert isinstance(fig, go.Figure) + + def test_histogram_seaborn_returns_figure(self): + import matplotlib.pyplot as plt + fig = self.f.histogram_with_fits_seaborn() + assert isinstance(fig, plt.Figure) + plt.close("all") diff --git a/etc/tools/plots.py b/etc/tools/plots.py index 460e974..298a595 100644 --- a/etc/tools/plots.py +++ b/etc/tools/plots.py @@ -21,13 +21,16 @@ def stacked_plot(data): data = np.squeeze(data) mean_dp = np.mean(np.abs(data), axis=1) - fig = make_subplots(rows=2, cols=1, row_heights=[0.3, 0.7], shared_xaxes=True, - vertical_spacing=0.01) + fig = make_subplots( + rows=2, cols=1, row_heights=[0.3, 0.7], shared_xaxes=True, vertical_spacing=0.01 + ) - fig.add_trace(go.Scatter(y=mean_dp, name='Mean Power'), row=1, col=1) - fig.add_trace(go.Heatmap(z=np.abs(data).T, showscale=False, name='Heat Map'), row=2, col=1) + fig.add_trace(go.Scatter(y=mean_dp, name="Mean Power"), row=1, col=1) + fig.add_trace( + go.Heatmap(z=np.abs(data).T, showscale=False, name="Heat Map"), row=2, col=1 + ) - fig.update_layout(title='Mean DP Power and 2D Map', autosize=True) + fig.update_layout(title="Mean DP Power and 2D Map", autosize=True) fig.update_xaxes(visible=False, row=2, col=1) fig.update_yaxes(visible=False, row=2, col=1) @@ -98,6 +101,8 @@ def plot_cdfs(data_list, labels): fig = go.Figure() for data, label in zip(data_list, labels): sorted_data, cdf = calculate_cdf(data) - fig.add_trace(go.Scatter(x=sorted_data, y=cdf, mode='lines', name=label)) - fig.update_layout(title='CDF of Data', xaxis_title='Value', yaxis_title='CDF', autosize=True) + fig.add_trace(go.Scatter(x=sorted_data, y=cdf, mode="lines", name=label)) + fig.update_layout( + title="CDF of Data", xaxis_title="Value", yaxis_title="CDF", autosize=True + ) return fig diff --git a/pyproject.toml b/pyproject.toml index c44b90a..e92d1ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "statsmodels", "matplotlib", "seaborn", + "ruff>=0.15.9", ] classifiers = ["Private :: Do Not Upload"] diff --git a/uv.lock b/uv.lock index bf4b40e..9b76a95 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.11" resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'win32'", @@ -136,6 +136,7 @@ dependencies = [ { name = "numpy" }, { name = "pathlib2" }, { name = "plotly" }, + { name = "ruff" }, { name = "scipy" }, { name = "seaborn" }, { name = "statsmodels" }, @@ -151,6 +152,7 @@ requires-dist = [ { name = "numpy" }, { name = "pathlib2" }, { name = "plotly" }, + { name = "ruff", specifier = ">=0.15.9" }, { name = "scipy" }, { name = "seaborn" }, { name = "statsmodels" }, @@ -1562,6 +1564,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/b7/b95708304cd49b7b6f82fdd039f1748b66ec2b21d6a45180910802f1abf1/rpds_py-0.30.0-pp311-pypy311_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:ac37f9f516c51e5753f27dfdef11a88330f04de2d564be3991384b2f3535d02e", size = 562191, upload-time = "2025-11-30T20:24:36.853Z" }, ] +[[package]] +name = "ruff" +version = "0.15.9" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e6/97/e9f1ca355108ef7194e38c812ef40ba98c7208f47b13ad78d023caa583da/ruff-0.15.9.tar.gz", hash = "sha256:29cbb1255a9797903f6dde5ba0188c707907ff44a9006eb273b5a17bfa0739a2", size = 4617361, upload-time = "2026-04-02T18:17:20.829Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0b/1f/9cdfd0ac4b9d1e5a6cf09bedabdf0b56306ab5e333c85c87281273e7b041/ruff-0.15.9-py3-none-linux_armv6l.whl", hash = "sha256:6efbe303983441c51975c243e26dff328aca11f94b70992f35b093c2e71801e1", size = 10511206, upload-time = "2026-04-02T18:16:41.574Z" }, + { url = "https://files.pythonhosted.org/packages/3d/f6/32bfe3e9c136b35f02e489778d94384118bb80fd92c6d92e7ccd97db12ce/ruff-0.15.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:4965bac6ac9ea86772f4e23587746f0b7a395eccabb823eb8bfacc3fa06069f7", size = 10923307, upload-time = "2026-04-02T18:17:08.645Z" }, + { url = "https://files.pythonhosted.org/packages/ca/25/de55f52ab5535d12e7aaba1de37a84be6179fb20bddcbe71ec091b4a3243/ruff-0.15.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:eaf05aad70ca5b5a0a4b0e080df3a6b699803916d88f006efd1f5b46302daab8", size = 10316722, upload-time = "2026-04-02T18:16:44.206Z" }, + { url = "https://files.pythonhosted.org/packages/48/11/690d75f3fd6278fe55fff7c9eb429c92d207e14b25d1cae4064a32677029/ruff-0.15.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9439a342adb8725f32f92732e2bafb6d5246bd7a5021101166b223d312e8fc59", size = 10623674, upload-time = "2026-04-02T18:16:50.951Z" }, + { url = "https://files.pythonhosted.org/packages/bd/ec/176f6987be248fc5404199255522f57af1b4a5a1b57727e942479fec98ad/ruff-0.15.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9c5e6faf9d97c8edc43877c3f406f47446fc48c40e1442d58cfcdaba2acea745", size = 10351516, upload-time = "2026-04-02T18:16:57.206Z" }, + { url = "https://files.pythonhosted.org/packages/b2/fc/51cffbd2b3f240accc380171d51446a32aa2ea43a40d4a45ada67368fbd2/ruff-0.15.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7b34a9766aeec27a222373d0b055722900fbc0582b24f39661aa96f3fe6ad901", size = 11150202, upload-time = "2026-04-02T18:17:06.452Z" }, + { url = "https://files.pythonhosted.org/packages/d6/d4/25292a6dfc125f6b6528fe6af31f5e996e19bf73ca8e3ce6eb7fa5b95885/ruff-0.15.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:89dd695bc72ae76ff484ae54b7e8b0f6b50f49046e198355e44ea656e521fef9", size = 11988891, upload-time = "2026-04-02T18:17:18.575Z" }, + { url = "https://files.pythonhosted.org/packages/13/e1/1eebcb885c10e19f969dcb93d8413dfee8172578709d7ee933640f5e7147/ruff-0.15.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ce187224ef1de1bd225bc9a152ac7102a6171107f026e81f317e4257052916d5", size = 11480576, upload-time = "2026-04-02T18:16:52.986Z" }, + { url = "https://files.pythonhosted.org/packages/ff/6b/a1548ac378a78332a4c3dcf4a134c2475a36d2a22ddfa272acd574140b50/ruff-0.15.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2b0c7c341f68adb01c488c3b7d4b49aa8ea97409eae6462d860a79cf55f431b6", size = 11254525, upload-time = "2026-04-02T18:17:02.041Z" }, + { url = "https://files.pythonhosted.org/packages/42/aa/4bb3af8e61acd9b1281db2ab77e8b2c3c5e5599bf2a29d4a942f1c62b8d6/ruff-0.15.9-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:55cc15eee27dc0eebdfcb0d185a6153420efbedc15eb1d38fe5e685657b0f840", size = 11204072, upload-time = "2026-04-02T18:17:13.581Z" }, + { url = "https://files.pythonhosted.org/packages/69/48/d550dc2aa6e423ea0bcc1d0ff0699325ffe8a811e2dba156bd80750b86dc/ruff-0.15.9-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:a6537f6eed5cda688c81073d46ffdfb962a5f29ecb6f7e770b2dc920598997ed", size = 10594998, upload-time = "2026-04-02T18:16:46.369Z" }, + { url = "https://files.pythonhosted.org/packages/63/47/321167e17f5344ed5ec6b0aa2cff64efef5f9e985af8f5622cfa6536043f/ruff-0.15.9-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:6d3fcbca7388b066139c523bda744c822258ebdcfbba7d24410c3f454cc9af71", size = 10359769, upload-time = "2026-04-02T18:17:10.994Z" }, + { url = "https://files.pythonhosted.org/packages/67/5e/074f00b9785d1d2c6f8c22a21e023d0c2c1817838cfca4c8243200a1fa87/ruff-0.15.9-py3-none-musllinux_1_2_i686.whl", hash = "sha256:058d8e99e1bfe79d8a0def0b481c56059ee6716214f7e425d8e737e412d69677", size = 10850236, upload-time = "2026-04-02T18:16:48.749Z" }, + { url = "https://files.pythonhosted.org/packages/76/37/804c4135a2a2caf042925d30d5f68181bdbd4461fd0d7739da28305df593/ruff-0.15.9-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:8e1ddb11dbd61d5983fa2d7d6370ef3eb210951e443cace19594c01c72abab4c", size = 11358343, upload-time = "2026-04-02T18:16:55.068Z" }, + { url = "https://files.pythonhosted.org/packages/88/3d/1364fcde8656962782aa9ea93c92d98682b1ecec2f184e625a965ad3b4a6/ruff-0.15.9-py3-none-win32.whl", hash = "sha256:bde6ff36eaf72b700f32b7196088970bf8fdb2b917b7accd8c371bfc0fd573ec", size = 10583382, upload-time = "2026-04-02T18:17:04.261Z" }, + { url = "https://files.pythonhosted.org/packages/4c/56/5c7084299bd2cacaa07ae63a91c6f4ba66edc08bf28f356b24f6b717c799/ruff-0.15.9-py3-none-win_amd64.whl", hash = "sha256:45a70921b80e1c10cf0b734ef09421f71b5aa11d27404edc89d7e8a69505e43d", size = 11744969, upload-time = "2026-04-02T18:16:59.611Z" }, + { url = "https://files.pythonhosted.org/packages/03/36/76704c4f312257d6dbaae3c959add2a622f63fcca9d864659ce6d8d97d3d/ruff-0.15.9-py3-none-win_arm64.whl", hash = "sha256:0694e601c028fd97dc5c6ee244675bc241aeefced7ef80cd9c6935a871078f53", size = 11005870, upload-time = "2026-04-02T18:17:15.773Z" }, +] + [[package]] name = "scipy" version = "1.17.1"