272 lines
9.2 KiB
Python
272 lines
9.2 KiB
Python
import numpy as np
|
|
import pytest
|
|
from scipy.stats import norm, gamma, expon, weibull_min
|
|
import sys
|
|
import os
|
|
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
|
|
|
from fitting.fitter import DistributionSummary, Fitter
|
|
|
|
|
|
# ── Data ──────────────────────────────────────────────────────────────────
|
|
|
|
RNG = np.random.default_rng(42)
|
|
GAMMA_DATA = RNG.gamma(shape=2.0, scale=1.5, size=200)
|
|
NORM_DATA = RNG.normal(loc=5.0, scale=1.0, size=200)
|
|
|
|
|
|
# ── DistributionSummary ───────────────────────────────────────────────────────
|
|
|
|
|
|
class TestDistributionSummary:
|
|
def _make_summary(self, dist=gamma, fit_params=(2.0, 0.0, 1.5)):
|
|
return DistributionSummary(
|
|
distribution_object=dist,
|
|
distribution_name=dist.name,
|
|
fit_result_params=fit_params,
|
|
)
|
|
|
|
def test_mean_std_var_are_floats(self):
|
|
s = self._make_summary()
|
|
assert isinstance(s.mean, float)
|
|
assert isinstance(s.std, float)
|
|
assert isinstance(s.var, float)
|
|
|
|
def test_var_equals_std_squared(self):
|
|
s = self._make_summary()
|
|
assert pytest.approx(s.var, rel=1e-6) == s.std**2
|
|
|
|
def test_pvalue_none_before_validate(self):
|
|
s = self._make_summary()
|
|
assert s.pvalue is None
|
|
|
|
def test_gof_statistic_none_before_validate(self):
|
|
s = self._make_summary()
|
|
assert s.gof_statistic is None
|
|
|
|
def test_repr_contains_name(self):
|
|
s = self._make_summary()
|
|
assert "gamma" in repr(s)
|
|
|
|
def test_repr_shows_na_when_no_test(self):
|
|
s = self._make_summary()
|
|
assert "N/A" in repr(s)
|
|
|
|
def test_default_statistic_method(self):
|
|
s = self._make_summary()
|
|
assert s.statistic_method == "ad"
|
|
|
|
|
|
# ── Fitter construction ───────────────────────────────────────────────────────
|
|
|
|
|
|
class TestFitterConstruction:
|
|
def test_distributions_registered(self):
|
|
f = Fitter([gamma, expon])
|
|
assert "gamma" in f
|
|
assert "expon" in f
|
|
|
|
def test_getitem_by_name(self):
|
|
f = Fitter([gamma])
|
|
s = f["gamma"]
|
|
assert isinstance(s, DistributionSummary)
|
|
|
|
def test_getitem_by_object(self):
|
|
f = Fitter([gamma])
|
|
s = f[gamma]
|
|
assert isinstance(s, DistributionSummary)
|
|
|
|
def test_getitem_missing_raises(self):
|
|
f = Fitter([gamma])
|
|
with pytest.raises(KeyError):
|
|
_ = f["norm"]
|
|
|
|
def test_contains_true(self):
|
|
f = Fitter([gamma])
|
|
assert gamma in f
|
|
assert "gamma" in f
|
|
|
|
def test_contains_false(self):
|
|
f = Fitter([gamma])
|
|
assert norm not in f
|
|
|
|
def test_kwargs_args_stored(self):
|
|
f = Fitter([gamma], gamma_args=(2.0,), gamma_params={"floc": 0})
|
|
assert f["gamma"].args_fit_params == (2.0,)
|
|
assert f["gamma"].kwds_fit_params == {"floc": 0}
|
|
|
|
def test_setitem_valid(self):
|
|
f = Fitter([gamma])
|
|
original = f["gamma"]
|
|
f["gamma"] = original
|
|
assert f["gamma"] is original
|
|
|
|
def test_setitem_invalid_type(self):
|
|
f = Fitter([gamma])
|
|
with pytest.raises(TypeError):
|
|
f["gamma"] = "not_a_summary"
|
|
|
|
def test_setitem_missing_key_raises(self):
|
|
f = Fitter([gamma])
|
|
dummy = DistributionSummary(distribution_object=norm, distribution_name="norm")
|
|
with pytest.raises(KeyError):
|
|
f["norm"] = dummy
|
|
|
|
|
|
# ── Fitter iteration ──────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestFitterIteration:
|
|
def test_iterates_all_distributions(self):
|
|
f = Fitter([gamma, expon, weibull_min])
|
|
names = [name for name, _ in f]
|
|
assert set(names) == {"gamma", "expon", "weibull_min"}
|
|
|
|
def test_iterate_twice(self):
|
|
f = Fitter([gamma, expon])
|
|
first = [name for name, _ in f]
|
|
second = [name for name, _ in f]
|
|
assert first == second
|
|
|
|
def test_iterator_yields_summary_instances(self):
|
|
f = Fitter([gamma])
|
|
for _, summary in f:
|
|
assert isinstance(summary, DistributionSummary)
|
|
|
|
|
|
# ── Fitter.fit ────────────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestFitterFit:
|
|
def test_fit_populates_params(self):
|
|
f = Fitter([gamma], gamma_params={"floc": 0})
|
|
f.fit(GAMMA_DATA)
|
|
assert len(f["gamma"].fit_result_params) > 0
|
|
|
|
def test_fit_uses_abs_value(self):
|
|
f = Fitter([gamma], gamma_params={"floc": 0})
|
|
neg_data = -np.abs(GAMMA_DATA)
|
|
f.fit(neg_data)
|
|
assert len(f["gamma"].fit_result_params) > 0
|
|
|
|
def test_fit_flattens_2d(self):
|
|
f = Fitter([gamma], gamma_params={"floc": 0})
|
|
data_2d = GAMMA_DATA.reshape(10, 20)
|
|
f.fit(data_2d)
|
|
assert hasattr(f, "_last_data_flat")
|
|
assert f._last_data_flat.ndim == 1
|
|
|
|
def test_fit_multiple_dists(self):
|
|
f = Fitter([gamma, expon], gamma_params={"floc": 0}, expon_params={"floc": 0})
|
|
f.fit(GAMMA_DATA)
|
|
assert len(f["gamma"].fit_result_params) > 0
|
|
assert len(f["expon"].fit_result_params) > 0
|
|
|
|
|
|
# ── Fitter.validate ───────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestFitterValidate:
|
|
def test_validate_without_fit_raises(self):
|
|
f = Fitter([gamma])
|
|
with pytest.raises(RuntimeError, match="fit\\(\\)"):
|
|
f.validate()
|
|
|
|
def test_validate_populates_test_result(self):
|
|
f = Fitter([gamma], gamma_params={"floc": 0})
|
|
f.fit(GAMMA_DATA)
|
|
f.validate(n_mc_samples=99)
|
|
assert f["gamma"].test_result is not None
|
|
|
|
def test_validate_pvalue_in_range(self):
|
|
f = Fitter([gamma], gamma_params={"floc": 0})
|
|
f.fit(GAMMA_DATA)
|
|
f.validate(n_mc_samples=99)
|
|
pval = f["gamma"].pvalue
|
|
assert 0.0 <= pval <= 1.0
|
|
|
|
def test_validate_gof_statistic_positive(self):
|
|
f = Fitter([gamma], gamma_params={"floc": 0})
|
|
f.fit(GAMMA_DATA)
|
|
f.validate(n_mc_samples=99)
|
|
assert f["gamma"].gof_statistic >= 0.0
|
|
|
|
def test_validate_custom_statistic(self):
|
|
f = Fitter([gamma], statistic_method="ks", gamma_params={"floc": 0})
|
|
f.fit(GAMMA_DATA)
|
|
f.validate(n_mc_samples=99)
|
|
assert f["gamma"].test_result is not None
|
|
|
|
|
|
# ── Fitter.summary ────────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestFitterSummary:
|
|
def test_summary_runs_without_error(self, capsys):
|
|
f = Fitter([gamma], gamma_params={"floc": 0})
|
|
f.fit(GAMMA_DATA)
|
|
f.validate(n_mc_samples=99)
|
|
f.summary()
|
|
captured = capsys.readouterr()
|
|
assert "gamma" in captured.out
|
|
|
|
|
|
# ── Fitter.plot_qq_plots ──────────────────────────────────────────────────────
|
|
|
|
|
|
class TestFitterPlotQQ:
|
|
def setup_method(self):
|
|
self.f = Fitter([gamma], gamma_params={"floc": 0})
|
|
self.f.fit(GAMMA_DATA)
|
|
self.f.validate(n_mc_samples=99)
|
|
|
|
def test_qq_without_fit_raises(self):
|
|
f = Fitter([gamma])
|
|
with pytest.raises(RuntimeError, match="fit\\(\\)"):
|
|
f.plot_qq_plots()
|
|
|
|
def test_qq_invalid_method_raises(self):
|
|
with pytest.raises(ValueError, match="method"):
|
|
self.f.plot_qq_plots(method="invalid")
|
|
|
|
def test_qq_hazen_returns_figure(self):
|
|
import plotly.graph_objects as go
|
|
|
|
fig = self.f.plot_qq_plots(method="hazen")
|
|
assert isinstance(fig, go.Figure)
|
|
|
|
def test_qq_filliben_returns_figure(self):
|
|
import plotly.graph_objects as go
|
|
|
|
fig = self.f.plot_qq_plots(method="filliben")
|
|
assert isinstance(fig, go.Figure)
|
|
|
|
|
|
# ── Fitter.histogram_with_fits ────────────────────────────────────────────────
|
|
|
|
|
|
class TestFitterHistogram:
|
|
def setup_method(self):
|
|
self.f = Fitter([gamma], gamma_params={"floc": 0})
|
|
self.f.fit(GAMMA_DATA)
|
|
self.f.validate(n_mc_samples=99)
|
|
|
|
def test_histogram_without_fit_raises(self):
|
|
f = Fitter([gamma])
|
|
with pytest.raises(RuntimeError, match="fit\\(\\)"):
|
|
f.histogram_with_fits()
|
|
|
|
def test_histogram_returns_figure(self):
|
|
import plotly.graph_objects as go
|
|
|
|
fig = self.f.histogram_with_fits()
|
|
assert isinstance(fig, go.Figure)
|
|
|
|
def test_histogram_seaborn_returns_figure(self):
|
|
import matplotlib.pyplot as plt
|
|
|
|
fig = self.f.histogram_with_fits_seaborn()
|
|
assert isinstance(fig, plt.Figure)
|
|
plt.close("all")
|