REFACTOR:
Fitter class refactored. Include getter and setter ADD: test/ dir with code tests
This commit is contained in:
267
etc/tests/test_fitter.py
Normal file
267
etc/tests/test_fitter.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from scipy.stats import norm, gamma, expon, weibull_min
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
from fitting.fitter import DistributionSummary, Fitter
|
||||
|
||||
|
||||
# ── Data ──────────────────────────────────────────────────────────────────
|
||||
|
||||
RNG = np.random.default_rng(42)
|
||||
GAMMA_DATA = RNG.gamma(shape=2.0, scale=1.5, size=200)
|
||||
NORM_DATA = RNG.normal(loc=5.0, scale=1.0, size=200)
|
||||
|
||||
|
||||
# ── DistributionSummary ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestDistributionSummary:
|
||||
def _make_summary(self, dist=gamma, fit_params=(2.0, 0.0, 1.5)):
|
||||
return DistributionSummary(
|
||||
distribution_object=dist,
|
||||
distribution_name=dist.name,
|
||||
fit_result_params=fit_params,
|
||||
)
|
||||
|
||||
def test_mean_std_var_are_floats(self):
|
||||
s = self._make_summary()
|
||||
assert isinstance(s.mean, float)
|
||||
assert isinstance(s.std, float)
|
||||
assert isinstance(s.var, float)
|
||||
|
||||
def test_var_equals_std_squared(self):
|
||||
s = self._make_summary()
|
||||
assert pytest.approx(s.var, rel=1e-6) == s.std**2
|
||||
|
||||
def test_pvalue_none_before_validate(self):
|
||||
s = self._make_summary()
|
||||
assert s.pvalue is None
|
||||
|
||||
def test_gof_statistic_none_before_validate(self):
|
||||
s = self._make_summary()
|
||||
assert s.gof_statistic is None
|
||||
|
||||
def test_repr_contains_name(self):
|
||||
s = self._make_summary()
|
||||
assert "gamma" in repr(s)
|
||||
|
||||
def test_repr_shows_na_when_no_test(self):
|
||||
s = self._make_summary()
|
||||
assert "N/A" in repr(s)
|
||||
|
||||
def test_default_statistic_method(self):
|
||||
s = self._make_summary()
|
||||
assert s.statistic_method == "ad"
|
||||
|
||||
|
||||
# ── Fitter construction ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestFitterConstruction:
|
||||
def test_distributions_registered(self):
|
||||
f = Fitter([gamma, expon])
|
||||
assert "gamma" in f
|
||||
assert "expon" in f
|
||||
|
||||
def test_getitem_by_name(self):
|
||||
f = Fitter([gamma])
|
||||
s = f["gamma"]
|
||||
assert isinstance(s, DistributionSummary)
|
||||
|
||||
def test_getitem_by_object(self):
|
||||
f = Fitter([gamma])
|
||||
s = f[gamma]
|
||||
assert isinstance(s, DistributionSummary)
|
||||
|
||||
def test_getitem_missing_raises(self):
|
||||
f = Fitter([gamma])
|
||||
with pytest.raises(KeyError):
|
||||
_ = f["norm"]
|
||||
|
||||
def test_contains_true(self):
|
||||
f = Fitter([gamma])
|
||||
assert gamma in f
|
||||
assert "gamma" in f
|
||||
|
||||
def test_contains_false(self):
|
||||
f = Fitter([gamma])
|
||||
assert norm not in f
|
||||
|
||||
def test_kwargs_args_stored(self):
|
||||
f = Fitter([gamma], gamma_args=(2.0,), gamma_params={"floc": 0})
|
||||
assert f["gamma"].args_fit_params == (2.0,)
|
||||
assert f["gamma"].kwds_fit_params == {"floc": 0}
|
||||
|
||||
def test_setitem_valid(self):
|
||||
f = Fitter([gamma])
|
||||
original = f["gamma"]
|
||||
f["gamma"] = original
|
||||
assert f["gamma"] is original
|
||||
|
||||
def test_setitem_invalid_type(self):
|
||||
f = Fitter([gamma])
|
||||
with pytest.raises(TypeError):
|
||||
f["gamma"] = "not_a_summary"
|
||||
|
||||
def test_setitem_missing_key_raises(self):
|
||||
f = Fitter([gamma])
|
||||
dummy = DistributionSummary(distribution_object=norm, distribution_name="norm")
|
||||
with pytest.raises(KeyError):
|
||||
f["norm"] = dummy
|
||||
|
||||
|
||||
# ── Fitter iteration ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestFitterIteration:
|
||||
def test_iterates_all_distributions(self):
|
||||
f = Fitter([gamma, expon, weibull_min])
|
||||
names = [name for name, _ in f]
|
||||
assert set(names) == {"gamma", "expon", "weibull_min"}
|
||||
|
||||
def test_iterate_twice(self):
|
||||
f = Fitter([gamma, expon])
|
||||
first = [name for name, _ in f]
|
||||
second = [name for name, _ in f]
|
||||
assert first == second
|
||||
|
||||
def test_iterator_yields_summary_instances(self):
|
||||
f = Fitter([gamma])
|
||||
for _, summary in f:
|
||||
assert isinstance(summary, DistributionSummary)
|
||||
|
||||
|
||||
# ── Fitter.fit ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestFitterFit:
|
||||
def test_fit_populates_params(self):
|
||||
f = Fitter([gamma], gamma_params={"floc": 0})
|
||||
f.fit(GAMMA_DATA)
|
||||
assert len(f["gamma"].fit_result_params) > 0
|
||||
|
||||
def test_fit_uses_abs_value(self):
|
||||
f = Fitter([gamma], gamma_params={"floc": 0})
|
||||
neg_data = -np.abs(GAMMA_DATA)
|
||||
f.fit(neg_data)
|
||||
assert len(f["gamma"].fit_result_params) > 0
|
||||
|
||||
def test_fit_flattens_2d(self):
|
||||
f = Fitter([gamma], gamma_params={"floc": 0})
|
||||
data_2d = GAMMA_DATA.reshape(10, 20)
|
||||
f.fit(data_2d)
|
||||
assert hasattr(f, "_last_data_flat")
|
||||
assert f._last_data_flat.ndim == 1
|
||||
|
||||
def test_fit_multiple_dists(self):
|
||||
f = Fitter([gamma, expon], gamma_params={"floc": 0}, expon_params={"floc": 0})
|
||||
f.fit(GAMMA_DATA)
|
||||
assert len(f["gamma"].fit_result_params) > 0
|
||||
assert len(f["expon"].fit_result_params) > 0
|
||||
|
||||
|
||||
# ── Fitter.validate ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestFitterValidate:
|
||||
def test_validate_without_fit_raises(self):
|
||||
f = Fitter([gamma])
|
||||
with pytest.raises(RuntimeError, match="fit\\(\\)"):
|
||||
f.validate()
|
||||
|
||||
def test_validate_populates_test_result(self):
|
||||
f = Fitter([gamma], gamma_params={"floc": 0})
|
||||
f.fit(GAMMA_DATA)
|
||||
f.validate(n_mc_samples=99)
|
||||
assert f["gamma"].test_result is not None
|
||||
|
||||
def test_validate_pvalue_in_range(self):
|
||||
f = Fitter([gamma], gamma_params={"floc": 0})
|
||||
f.fit(GAMMA_DATA)
|
||||
f.validate(n_mc_samples=99)
|
||||
pval = f["gamma"].pvalue
|
||||
assert 0.0 <= pval <= 1.0
|
||||
|
||||
def test_validate_gof_statistic_positive(self):
|
||||
f = Fitter([gamma], gamma_params={"floc": 0})
|
||||
f.fit(GAMMA_DATA)
|
||||
f.validate(n_mc_samples=99)
|
||||
assert f["gamma"].gof_statistic >= 0.0
|
||||
|
||||
def test_validate_custom_statistic(self):
|
||||
f = Fitter([gamma], statistic_method="ks", gamma_params={"floc": 0})
|
||||
f.fit(GAMMA_DATA)
|
||||
f.validate(n_mc_samples=99)
|
||||
assert f["gamma"].test_result is not None
|
||||
|
||||
|
||||
# ── Fitter.summary ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestFitterSummary:
|
||||
def test_summary_runs_without_error(self, capsys):
|
||||
f = Fitter([gamma], gamma_params={"floc": 0})
|
||||
f.fit(GAMMA_DATA)
|
||||
f.validate(n_mc_samples=99)
|
||||
f.summary()
|
||||
captured = capsys.readouterr()
|
||||
assert "gamma" in captured.out
|
||||
|
||||
|
||||
# ── Fitter.plot_qq_plots ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestFitterPlotQQ:
|
||||
def setup_method(self):
|
||||
self.f = Fitter([gamma], gamma_params={"floc": 0})
|
||||
self.f.fit(GAMMA_DATA)
|
||||
self.f.validate(n_mc_samples=99)
|
||||
|
||||
def test_qq_without_fit_raises(self):
|
||||
f = Fitter([gamma])
|
||||
with pytest.raises(RuntimeError, match="fit\\(\\)"):
|
||||
f.plot_qq_plots()
|
||||
|
||||
def test_qq_invalid_method_raises(self):
|
||||
with pytest.raises(ValueError, match="method"):
|
||||
self.f.plot_qq_plots(method="invalid")
|
||||
|
||||
def test_qq_hazen_returns_figure(self):
|
||||
import plotly.graph_objects as go
|
||||
fig = self.f.plot_qq_plots(method="hazen")
|
||||
assert isinstance(fig, go.Figure)
|
||||
|
||||
def test_qq_filliben_returns_figure(self):
|
||||
import plotly.graph_objects as go
|
||||
fig = self.f.plot_qq_plots(method="filliben")
|
||||
assert isinstance(fig, go.Figure)
|
||||
|
||||
|
||||
# ── Fitter.histogram_with_fits ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestFitterHistogram:
|
||||
def setup_method(self):
|
||||
self.f = Fitter([gamma], gamma_params={"floc": 0})
|
||||
self.f.fit(GAMMA_DATA)
|
||||
self.f.validate(n_mc_samples=99)
|
||||
|
||||
def test_histogram_without_fit_raises(self):
|
||||
f = Fitter([gamma])
|
||||
with pytest.raises(RuntimeError, match="fit\\(\\)"):
|
||||
f.histogram_with_fits()
|
||||
|
||||
def test_histogram_returns_figure(self):
|
||||
import plotly.graph_objects as go
|
||||
fig = self.f.histogram_with_fits()
|
||||
assert isinstance(fig, go.Figure)
|
||||
|
||||
def test_histogram_seaborn_returns_figure(self):
|
||||
import matplotlib.pyplot as plt
|
||||
fig = self.f.histogram_with_fits_seaborn()
|
||||
assert isinstance(fig, plt.Figure)
|
||||
plt.close("all")
|
||||
Reference in New Issue
Block a user