REFACTOR:

Fitter class refactored. Include getter and setter
ADD:
test/ dir with code tests
This commit is contained in:
2026-04-08 21:48:19 -03:00
parent bcd8f25a62
commit d053ebf02c
5 changed files with 529 additions and 291 deletions

267
etc/tests/test_fitter.py Normal file
View File

@@ -0,0 +1,267 @@
import numpy as np
import pytest
from scipy.stats import norm, gamma, expon, weibull_min
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from fitting.fitter import DistributionSummary, Fitter
# ── Data ──────────────────────────────────────────────────────────────────
RNG = np.random.default_rng(42)
GAMMA_DATA = RNG.gamma(shape=2.0, scale=1.5, size=200)
NORM_DATA = RNG.normal(loc=5.0, scale=1.0, size=200)
# ── DistributionSummary ───────────────────────────────────────────────────────
class TestDistributionSummary:
def _make_summary(self, dist=gamma, fit_params=(2.0, 0.0, 1.5)):
return DistributionSummary(
distribution_object=dist,
distribution_name=dist.name,
fit_result_params=fit_params,
)
def test_mean_std_var_are_floats(self):
s = self._make_summary()
assert isinstance(s.mean, float)
assert isinstance(s.std, float)
assert isinstance(s.var, float)
def test_var_equals_std_squared(self):
s = self._make_summary()
assert pytest.approx(s.var, rel=1e-6) == s.std**2
def test_pvalue_none_before_validate(self):
s = self._make_summary()
assert s.pvalue is None
def test_gof_statistic_none_before_validate(self):
s = self._make_summary()
assert s.gof_statistic is None
def test_repr_contains_name(self):
s = self._make_summary()
assert "gamma" in repr(s)
def test_repr_shows_na_when_no_test(self):
s = self._make_summary()
assert "N/A" in repr(s)
def test_default_statistic_method(self):
s = self._make_summary()
assert s.statistic_method == "ad"
# ── Fitter construction ───────────────────────────────────────────────────────
class TestFitterConstruction:
def test_distributions_registered(self):
f = Fitter([gamma, expon])
assert "gamma" in f
assert "expon" in f
def test_getitem_by_name(self):
f = Fitter([gamma])
s = f["gamma"]
assert isinstance(s, DistributionSummary)
def test_getitem_by_object(self):
f = Fitter([gamma])
s = f[gamma]
assert isinstance(s, DistributionSummary)
def test_getitem_missing_raises(self):
f = Fitter([gamma])
with pytest.raises(KeyError):
_ = f["norm"]
def test_contains_true(self):
f = Fitter([gamma])
assert gamma in f
assert "gamma" in f
def test_contains_false(self):
f = Fitter([gamma])
assert norm not in f
def test_kwargs_args_stored(self):
f = Fitter([gamma], gamma_args=(2.0,), gamma_params={"floc": 0})
assert f["gamma"].args_fit_params == (2.0,)
assert f["gamma"].kwds_fit_params == {"floc": 0}
def test_setitem_valid(self):
f = Fitter([gamma])
original = f["gamma"]
f["gamma"] = original
assert f["gamma"] is original
def test_setitem_invalid_type(self):
f = Fitter([gamma])
with pytest.raises(TypeError):
f["gamma"] = "not_a_summary"
def test_setitem_missing_key_raises(self):
f = Fitter([gamma])
dummy = DistributionSummary(distribution_object=norm, distribution_name="norm")
with pytest.raises(KeyError):
f["norm"] = dummy
# ── Fitter iteration ──────────────────────────────────────────────────────────
class TestFitterIteration:
def test_iterates_all_distributions(self):
f = Fitter([gamma, expon, weibull_min])
names = [name for name, _ in f]
assert set(names) == {"gamma", "expon", "weibull_min"}
def test_iterate_twice(self):
f = Fitter([gamma, expon])
first = [name for name, _ in f]
second = [name for name, _ in f]
assert first == second
def test_iterator_yields_summary_instances(self):
f = Fitter([gamma])
for _, summary in f:
assert isinstance(summary, DistributionSummary)
# ── Fitter.fit ────────────────────────────────────────────────────────────────
class TestFitterFit:
def test_fit_populates_params(self):
f = Fitter([gamma], gamma_params={"floc": 0})
f.fit(GAMMA_DATA)
assert len(f["gamma"].fit_result_params) > 0
def test_fit_uses_abs_value(self):
f = Fitter([gamma], gamma_params={"floc": 0})
neg_data = -np.abs(GAMMA_DATA)
f.fit(neg_data)
assert len(f["gamma"].fit_result_params) > 0
def test_fit_flattens_2d(self):
f = Fitter([gamma], gamma_params={"floc": 0})
data_2d = GAMMA_DATA.reshape(10, 20)
f.fit(data_2d)
assert hasattr(f, "_last_data_flat")
assert f._last_data_flat.ndim == 1
def test_fit_multiple_dists(self):
f = Fitter([gamma, expon], gamma_params={"floc": 0}, expon_params={"floc": 0})
f.fit(GAMMA_DATA)
assert len(f["gamma"].fit_result_params) > 0
assert len(f["expon"].fit_result_params) > 0
# ── Fitter.validate ───────────────────────────────────────────────────────────
class TestFitterValidate:
def test_validate_without_fit_raises(self):
f = Fitter([gamma])
with pytest.raises(RuntimeError, match="fit\\(\\)"):
f.validate()
def test_validate_populates_test_result(self):
f = Fitter([gamma], gamma_params={"floc": 0})
f.fit(GAMMA_DATA)
f.validate(n_mc_samples=99)
assert f["gamma"].test_result is not None
def test_validate_pvalue_in_range(self):
f = Fitter([gamma], gamma_params={"floc": 0})
f.fit(GAMMA_DATA)
f.validate(n_mc_samples=99)
pval = f["gamma"].pvalue
assert 0.0 <= pval <= 1.0
def test_validate_gof_statistic_positive(self):
f = Fitter([gamma], gamma_params={"floc": 0})
f.fit(GAMMA_DATA)
f.validate(n_mc_samples=99)
assert f["gamma"].gof_statistic >= 0.0
def test_validate_custom_statistic(self):
f = Fitter([gamma], statistic_method="ks", gamma_params={"floc": 0})
f.fit(GAMMA_DATA)
f.validate(n_mc_samples=99)
assert f["gamma"].test_result is not None
# ── Fitter.summary ────────────────────────────────────────────────────────────
class TestFitterSummary:
def test_summary_runs_without_error(self, capsys):
f = Fitter([gamma], gamma_params={"floc": 0})
f.fit(GAMMA_DATA)
f.validate(n_mc_samples=99)
f.summary()
captured = capsys.readouterr()
assert "gamma" in captured.out
# ── Fitter.plot_qq_plots ──────────────────────────────────────────────────────
class TestFitterPlotQQ:
def setup_method(self):
self.f = Fitter([gamma], gamma_params={"floc": 0})
self.f.fit(GAMMA_DATA)
self.f.validate(n_mc_samples=99)
def test_qq_without_fit_raises(self):
f = Fitter([gamma])
with pytest.raises(RuntimeError, match="fit\\(\\)"):
f.plot_qq_plots()
def test_qq_invalid_method_raises(self):
with pytest.raises(ValueError, match="method"):
self.f.plot_qq_plots(method="invalid")
def test_qq_hazen_returns_figure(self):
import plotly.graph_objects as go
fig = self.f.plot_qq_plots(method="hazen")
assert isinstance(fig, go.Figure)
def test_qq_filliben_returns_figure(self):
import plotly.graph_objects as go
fig = self.f.plot_qq_plots(method="filliben")
assert isinstance(fig, go.Figure)
# ── Fitter.histogram_with_fits ────────────────────────────────────────────────
class TestFitterHistogram:
def setup_method(self):
self.f = Fitter([gamma], gamma_params={"floc": 0})
self.f.fit(GAMMA_DATA)
self.f.validate(n_mc_samples=99)
def test_histogram_without_fit_raises(self):
f = Fitter([gamma])
with pytest.raises(RuntimeError, match="fit\\(\\)"):
f.histogram_with_fits()
def test_histogram_returns_figure(self):
import plotly.graph_objects as go
fig = self.f.histogram_with_fits()
assert isinstance(fig, go.Figure)
def test_histogram_seaborn_returns_figure(self):
import matplotlib.pyplot as plt
fig = self.f.histogram_with_fits_seaborn()
assert isinstance(fig, plt.Figure)
plt.close("all")