REFACTOR:
Fitter class refactored. Include getter and setter ADD: test/ dir with code tests
This commit is contained in:
@@ -5,6 +5,7 @@ import plotly.graph_objects as go
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from scipy.stats import rv_continuous, goodness_of_fit
|
||||
from plotly.subplots import make_subplots
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -44,79 +45,41 @@ class DistributionSummary:
|
||||
args_fit_params: tuple = field(default_factory=tuple)
|
||||
kwds_fit_params: dict = field(default_factory=dict)
|
||||
fit_result_params: tuple = field(default_factory=tuple)
|
||||
statistic_method: str = 'ad'
|
||||
statistic_method: str = "ad"
|
||||
test_result: object = None
|
||||
|
||||
# ── properties ────────────────────────────────────────────────
|
||||
# ── convenience properties ────────────────────────────────────────────────
|
||||
|
||||
@property
|
||||
def pvalue(self) -> float | None:
|
||||
"""p-value from the goodness-of-fit test.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float or None
|
||||
The p-value produced by the GoF test, or None if validate() has
|
||||
not been called yet.
|
||||
"""
|
||||
"""p-value from the goodness-of-fit test, or None if not yet run."""
|
||||
return self.test_result.pvalue if self.test_result is not None else None
|
||||
|
||||
@property
|
||||
def gof_statistic(self) -> float | None:
|
||||
"""Test statistic from the goodness-of-fit test.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float or None
|
||||
The GoF statistic value, or None if validate() has not been
|
||||
called yet.
|
||||
"""
|
||||
"""Test statistic from the goodness-of-fit test, or None if not yet run."""
|
||||
return self.test_result.statistic if self.test_result is not None else None
|
||||
|
||||
@property
|
||||
def mean(self) -> float:
|
||||
"""Mean of the fitted distribution.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
Mean computed from the fitted distribution parameters.
|
||||
"""
|
||||
return self.distribution_object.mean(*self.fit_result_params)
|
||||
"""Mean of the fitted distribution."""
|
||||
return float(self.distribution_object.mean(*self.fit_result_params))
|
||||
|
||||
@property
|
||||
def std(self) -> float:
|
||||
"""Standard deviation of the fitted distribution.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
Standard deviation computed from the fitted distribution parameters.
|
||||
"""
|
||||
return self.distribution_object.std(*self.fit_result_params)
|
||||
"""Standard deviation of the fitted distribution."""
|
||||
return float(self.distribution_object.std(*self.fit_result_params))
|
||||
|
||||
@property
|
||||
def var(self) -> float:
|
||||
"""Variance of the fitted distribution.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
Variance computed from the fitted distribution parameters.
|
||||
"""
|
||||
return self.distribution_object.var(*self.fit_result_params)
|
||||
"""Variance of the fitted distribution."""
|
||||
return float(self.distribution_object.var(*self.fit_result_params))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return a concise string representation of the distribution summary.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Single-line string showing the distribution name, fitted
|
||||
parameters, mean, standard deviation, GoF statistic, and p-value.
|
||||
"""
|
||||
pval_str = f"{self.pvalue:.4f}" if self.pvalue is not None else "N/A"
|
||||
stat_str = f"{self.gof_statistic:.4f}" if self.gof_statistic is not None else "N/A"
|
||||
stat_str = (
|
||||
f"{self.gof_statistic:.4f}" if self.gof_statistic is not None else "N/A"
|
||||
)
|
||||
return (
|
||||
f"DistributionSummary({self.distribution_name})"
|
||||
f" params={self.fit_result_params}"
|
||||
@@ -125,13 +88,6 @@ class DistributionSummary:
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return the same string representation as ``__repr__``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Delegates to :meth:`__repr__`.
|
||||
"""
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
@@ -156,347 +112,329 @@ class Fitter:
|
||||
gamma_params={'floc': 0},
|
||||
weibull_min_args=(1.5, 0.0, 1.0),
|
||||
)
|
||||
|
||||
Access fitted distributions
|
||||
---------------------------
|
||||
Distributions can be accessed by name (str) or by the rv_continuous object:
|
||||
|
||||
fitter = Fitter([gamma, weibull_min])
|
||||
fitter.fit(data)
|
||||
|
||||
# Access by name
|
||||
summary = fitter['gamma']
|
||||
|
||||
# Access by rv_continuous object
|
||||
summary = fitter[gamma]
|
||||
|
||||
# Check if a distribution is present
|
||||
gamma in fitter # True
|
||||
'weibull_min' in fitter # True
|
||||
|
||||
# Override a DistributionSummary
|
||||
fitter['gamma'] = new_summary
|
||||
"""
|
||||
|
||||
def __init__(self, dist_list: list[rv_continuous], statistic_method: str = 'ad', **kwargs):
|
||||
"""Initialise the Fitter and build per-distribution summary objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dist_list : list[rv_continuous]
|
||||
Distributions to fit.
|
||||
statistic_method : str, optional
|
||||
Goodness-of-fit statistic passed to ``goodness_of_fit``
|
||||
(default ``'ad'`` for Anderson-Darling).
|
||||
**kwargs :
|
||||
Per-distribution initial guesses and fixed parameters, keyed as
|
||||
``<dist.name>_args`` (tuple) or ``<dist.name>_params`` (dict).
|
||||
"""
|
||||
def __init__(
|
||||
self, dist_list: list[rv_continuous], statistic_method: str = "ad", **kwargs
|
||||
):
|
||||
self._dist: dict[str, DistributionSummary] = {}
|
||||
self.dist_list = list(dist_list)
|
||||
self.dist_list = list(dist_list) # Ensure it's a list for multiple iterations
|
||||
for dist in dist_list:
|
||||
self._dist[dist.name] = DistributionSummary(
|
||||
distribution_object=dist,
|
||||
distribution_name=dist.name,
|
||||
args_fit_params=kwargs.get(f'{dist.name}_args', ()),
|
||||
kwds_fit_params=kwargs.get(f'{dist.name}_params', {}),
|
||||
args_fit_params=kwargs.get(f"{dist.name}_args", ()),
|
||||
kwds_fit_params=kwargs.get(f"{dist.name}_params", {}),
|
||||
statistic_method=statistic_method,
|
||||
test_result=None,
|
||||
)
|
||||
|
||||
#── Getter and setter with flexible keys (str or rv_continuous) ────────────────────────────────
|
||||
def _resolve_key(self, key: str | rv_continuous) -> str:
|
||||
"""Resolve a distribution name or ``rv_continuous`` object to its string key.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str or rv_continuous
|
||||
Either the distribution's string name or its ``rv_continuous``
|
||||
object (whose ``.name`` attribute is used).
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The string key used in the internal ``_dist`` dictionary.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If the resolved name is not found among the registered
|
||||
distributions.
|
||||
"""
|
||||
"""Resolve a distribution name or rv_continuous object to its string key."""
|
||||
name = key.name if isinstance(key, rv_continuous) else key
|
||||
if name not in self._dist:
|
||||
available = ', '.join(self._dist)
|
||||
available = ", ".join(self._dist)
|
||||
raise KeyError(f"Distribution '{name}' not found. Available: {available}")
|
||||
return name
|
||||
|
||||
def __contains__(self, key: str | rv_continuous) -> bool:
|
||||
"""Check whether a distribution is registered with this Fitter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str or rv_continuous
|
||||
Distribution name or ``rv_continuous`` object to look up.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the distribution is registered, False otherwise.
|
||||
"""
|
||||
name = key.name if isinstance(key, rv_continuous) else key
|
||||
return name in self._dist
|
||||
|
||||
def __getitem__(self, key: str | rv_continuous) -> DistributionSummary:
|
||||
"""Retrieve the :class:`DistributionSummary` for the given distribution.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str or rv_continuous
|
||||
Distribution name or ``rv_continuous`` object to retrieve.
|
||||
|
||||
Returns
|
||||
-------
|
||||
DistributionSummary
|
||||
The summary object associated with the requested distribution.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If the distribution is not registered.
|
||||
"""
|
||||
return self._dist[self._resolve_key(key)]
|
||||
|
||||
def __setitem__(self, key: str | rv_continuous, summary: DistributionSummary) -> None:
|
||||
"""Override the :class:`DistributionSummary` for an existing distribution.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str or rv_continuous
|
||||
Distribution name or ``rv_continuous`` object to update.
|
||||
summary : DistributionSummary
|
||||
Replacement summary object.
|
||||
|
||||
Raises
|
||||
------
|
||||
TypeError
|
||||
If ``summary`` is not a :class:`DistributionSummary` instance.
|
||||
KeyError
|
||||
If the distribution is not registered.
|
||||
"""
|
||||
def __setitem__(
|
||||
self, key: str | rv_continuous, summary: DistributionSummary
|
||||
) -> None:
|
||||
if not isinstance(summary, DistributionSummary):
|
||||
raise TypeError(f"Expected DistributionSummary, got {type(summary).__name__}.")
|
||||
raise TypeError(
|
||||
f"Expected DistributionSummary, got {type(summary).__name__}."
|
||||
)
|
||||
self._dist[self._resolve_key(key)] = summary
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if not hasattr(self, "_iter_index"):
|
||||
self._iter_index = 0
|
||||
if self._iter_index >= len(self._dist):
|
||||
del self._iter_index
|
||||
raise StopIteration
|
||||
key = list(self._dist.keys())[self._iter_index]
|
||||
summary = self._dist[key]
|
||||
self._iter_index += 1
|
||||
return key, summary
|
||||
|
||||
def fit(self, data: np.ndarray) -> None:
|
||||
"""Fit every distribution to *data* via MLE.
|
||||
def fit(self, data: np.ndarray) -> dict[str, DistributionSummary]:
|
||||
"""
|
||||
Fit every distribution to *data* via MLE.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : array-like
|
||||
Input data. Only the absolute value is used; the array is
|
||||
flattened before fitting.
|
||||
|
||||
Input data. Only the absolute value is used.
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
Fitted parameters are stored in-place inside each
|
||||
:class:`DistributionSummary` held by this Fitter.
|
||||
dict[str, DistributionSummary]
|
||||
Mapping of distribution name → summary (test_result is None
|
||||
until validate() is called).
|
||||
"""
|
||||
data_flat = np.abs(data).flatten()
|
||||
self._last_data_flat = data_flat
|
||||
|
||||
for dist in self.dist_list:
|
||||
_summary = self._dist[dist.name]
|
||||
fit_params = dist.fit(data_flat, *_summary.args_fit_params, **_summary.kwds_fit_params)
|
||||
fit_params = dist.fit(
|
||||
data_flat, *_summary.args_fit_params, **_summary.kwds_fit_params
|
||||
)
|
||||
_summary.fit_result_params = fit_params
|
||||
self._dist[dist.name] = _summary
|
||||
|
||||
def validate(self, **kwargs) -> None:
|
||||
"""Run the goodness-of-fit test on every previously fitted distribution.
|
||||
def validate(self, **kwargs) -> dict[str, DistributionSummary]:
|
||||
"""
|
||||
Run the goodness-of-fit test on every previously fitted distribution.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
**kwargs :
|
||||
Extra keyword arguments forwarded to ``scipy.stats.goodness_of_fit``
|
||||
(e.g. ``n_mc_samples=100``).
|
||||
Extra keyword arguments forwarded to goodness_of_fit()
|
||||
(e.g. n_mc_samples=100).
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
Test results are stored in-place inside each
|
||||
:class:`DistributionSummary` held by this Fitter.
|
||||
|
||||
Raises
|
||||
------
|
||||
RuntimeError
|
||||
If :meth:`fit` has not been called before this method.
|
||||
dict[str, DistributionSummary]
|
||||
Same results dict, with test_result populated for each distribution.
|
||||
"""
|
||||
if not hasattr(self, '_last_data_flat'):
|
||||
if not hasattr(self, "_last_data_flat"):
|
||||
raise RuntimeError("No data available. Call fit() first.")
|
||||
|
||||
data_flat = self._last_data_flat
|
||||
for dist in self.dist_list:
|
||||
_summary = self._dist[dist.name]
|
||||
test_result = goodness_of_fit(
|
||||
dist,
|
||||
data_flat,
|
||||
statistic=_summary.statistic_method,
|
||||
**kwargs
|
||||
dist, data_flat, statistic=_summary.statistic_method, **kwargs
|
||||
)
|
||||
_summary.test_result = test_result
|
||||
self._dist[dist.name] = _summary
|
||||
|
||||
def summary(self) -> None:
|
||||
"""Print a one-line summary for each registered distribution.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
Output is written to stdout via ``print``.
|
||||
def summary(self) -> dict[str, DistributionSummary]:
|
||||
"""
|
||||
Print a summary of all fitted distributions, including parameters and GoF results.
|
||||
"""
|
||||
for dist_name, summary in self._dist.items():
|
||||
print(summary)
|
||||
|
||||
def plot_qq_plots(self) -> None:
|
||||
"""Generate QQ plots for each fitted distribution against the data.
|
||||
|
||||
A separate interactive Plotly figure is displayed for every
|
||||
distribution that has been both fitted and validated. Distributions
|
||||
that have not yet been validated are skipped with a printed warning.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
Figures are rendered inline / in a browser via ``fig.show()``.
|
||||
|
||||
Raises
|
||||
------
|
||||
RuntimeError
|
||||
If :meth:`fit` has not been called before this method.
|
||||
def plot_qq_plots(self, method: str = "hazen"):
|
||||
"""
|
||||
if not hasattr(self, '_last_data_flat'):
|
||||
Generate QQ plots for each fitted distribution against the data.
|
||||
Requires that fit() and validate() have been called to populate parameters and test results.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
method : str
|
||||
Plotting positions formula. Either 'hazen' (default) or 'filliben'.
|
||||
- 'hazen' : p_i = (i - 0.5) / n
|
||||
- 'filliben': p_1 = 1 - 0.5^(1/n), p_n = 0.5^(1/n),
|
||||
p_i = (i - 0.3175) / (n + 0.365) for 1 < i < n
|
||||
"""
|
||||
if not hasattr(self, "_last_data_flat"):
|
||||
raise RuntimeError("No data available. Call fit() first.")
|
||||
if method not in ("hazen", "filliben"):
|
||||
raise ValueError(f"method must be 'hazen' or 'filliben', got '{method}'.")
|
||||
|
||||
data_flat = self._last_data_flat
|
||||
for dist_name, summary in self._dist.items():
|
||||
# generate subplots with 2 columns and as many rows as needed, but not more than 3 rows, if there are more than 6 distributions, create multiple figures
|
||||
num_dists = len(self._dist)
|
||||
num_cols = 2
|
||||
num_rows = min(3, (num_dists + 1) // 2)
|
||||
fig = make_subplots(
|
||||
rows=num_rows,
|
||||
cols=num_cols,
|
||||
subplot_titles=[dist_name for dist_name in self._dist.keys()],
|
||||
)
|
||||
|
||||
for dist_name, summary in self:
|
||||
if summary.test_result is None:
|
||||
print(f"Distribution '{dist_name}' has not been validated yet. Skipping QQ plot.")
|
||||
print(
|
||||
f"Distribution '{dist_name}' has not been validated yet. Skipping QQ plot."
|
||||
)
|
||||
continue
|
||||
|
||||
# Generate theoretical quantiles
|
||||
sorted_data = np.sort(data_flat)
|
||||
n = len(sorted_data)
|
||||
i = np.arange(1, n + 1)
|
||||
if method == "hazen":
|
||||
plotting_positions = (i - 0.5) / n
|
||||
else: # filliben
|
||||
plotting_positions = (i - 0.3175) / (n + 0.365)
|
||||
plotting_positions[0] = 1 - 0.5 ** (1 / n)
|
||||
plotting_positions[-1] = 0.5 ** (1 / n)
|
||||
theoretical_quantiles = summary.distribution_object.ppf(
|
||||
(np.arange(1, len(sorted_data) + 1) - 0.5) / len(sorted_data),
|
||||
*summary.fit_result_params
|
||||
plotting_positions, *summary.fit_result_params
|
||||
)
|
||||
|
||||
fig = go.Figure()
|
||||
fig.add_trace(go.Scatter(x=theoretical_quantiles, y=sorted_data, mode='markers', name='Data vs. Fit'))
|
||||
fig.add_trace(go.Scatter(x=theoretical_quantiles, y=theoretical_quantiles, mode='lines', name='Ideal Fit', line=dict(dash='dash')))
|
||||
fig.update_layout(
|
||||
title=f'QQ Plot for {summary.distribution_name}',
|
||||
xaxis_title='Theoretical Quantiles',
|
||||
yaxis_title='Empirical Quantiles',
|
||||
autosize=True,
|
||||
# Create QQ plot in each subplot
|
||||
row = (list(self._dist.keys()).index(dist_name) // num_cols) + 1
|
||||
col = (list(self._dist.keys()).index(dist_name) % num_cols) + 1
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=theoretical_quantiles,
|
||||
y=sorted_data,
|
||||
mode="markers",
|
||||
name=dist_name,
|
||||
),
|
||||
row=row,
|
||||
col=col,
|
||||
)
|
||||
fig.show()
|
||||
# Add a reference line y=x
|
||||
min_val = min(theoretical_quantiles.min(), sorted_data.min())
|
||||
max_val = max(theoretical_quantiles.max(), sorted_data.max())
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=[min_val, max_val],
|
||||
y=[min_val, max_val],
|
||||
mode="lines",
|
||||
name="y=x",
|
||||
line=dict(dash="dash"),
|
||||
),
|
||||
row=row,
|
||||
col=col,
|
||||
)
|
||||
fig.update_xaxes(title_text="Theoretical Quantiles", row=row, col=col)
|
||||
fig.update_yaxes(title_text="Empirical Quantiles", row=row, col=col)
|
||||
# add statistic value in bottom right of each subplot (summary.gof_statistic())
|
||||
fig.add_annotation(
|
||||
x=0.95,
|
||||
y=0.05,
|
||||
xref="x domain",
|
||||
yref="y domain",
|
||||
text=f"{summary.statistic_method}={summary.gof_statistic:.4f}",
|
||||
showarrow=False,
|
||||
font=dict(size=10, color="green" if summary.pvalue > 0.05 else "red"),
|
||||
row=row,
|
||||
col=col,
|
||||
)
|
||||
fig.update_layout(
|
||||
title=f"QQ Plots of Fitted Distributions ({method})",
|
||||
showlegend=False,
|
||||
autosize=True,
|
||||
)
|
||||
return fig
|
||||
|
||||
def histogram_with_fits(self) -> go.Figure:
|
||||
"""Return an interactive histogram with overlaid fitted PDFs (Plotly).
|
||||
|
||||
Builds a probability-density histogram of the data and overlays a
|
||||
line trace for the PDF of each fitted distribution. Hover text shows
|
||||
the p-value and GoF statistic for each curve. Distributions that
|
||||
have not yet been fitted are skipped with a printed warning.
|
||||
|
||||
Returns
|
||||
-------
|
||||
plotly.graph_objects.Figure
|
||||
Interactive Plotly figure ready to display with ``fig.show()``.
|
||||
|
||||
Raises
|
||||
------
|
||||
RuntimeError
|
||||
If :meth:`fit` has not been called before this method.
|
||||
def histogram_with_fits(self):
|
||||
"""
|
||||
if not hasattr(self, '_last_data_flat'):
|
||||
Generate a histogram of the data with overlaid PDFs of each fitted distribution.
|
||||
Requires that fit() has been called to populate parameters.
|
||||
"""
|
||||
if not hasattr(self, "_last_data_flat"):
|
||||
raise RuntimeError("No data available. Call fit() first.")
|
||||
|
||||
data_flat = self._last_data_flat
|
||||
x = np.linspace(0, data_flat.max(), 1000)
|
||||
|
||||
fig = go.Figure(layout=go.Layout(hovermode='x unified'))
|
||||
fig.add_trace(go.Histogram(
|
||||
x=data_flat, name='Data', opacity=0.3,
|
||||
histnorm='probability density', hoverinfo='skip', marker_color='blue',
|
||||
))
|
||||
fig = go.Figure(layout=go.Layout(hovermode="x unified"))
|
||||
# do not show data in hoover, only show the distribution name, p-value and GoF statistic for each distribution
|
||||
fig.add_trace(
|
||||
go.Histogram(
|
||||
x=data_flat,
|
||||
name="Data",
|
||||
opacity=0.3,
|
||||
histnorm="probability density",
|
||||
hoverinfo="skip",
|
||||
marker_color="blue",
|
||||
)
|
||||
)
|
||||
for dist_name, summary in self._dist.items():
|
||||
if not summary.fit_result_params:
|
||||
print(f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay.")
|
||||
print(
|
||||
f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay."
|
||||
)
|
||||
continue
|
||||
|
||||
pdf_values = summary.distribution_object.pdf(x, *summary.fit_result_params)
|
||||
fig.add_trace(go.Scatter(x=x, y=pdf_values, mode='lines', name=f'{summary.distribution_name} Fit'))
|
||||
# add trace and stack hoover x like stock price, but make the y value shows the p-value and GoF statistic for each distribution
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=x,
|
||||
y=pdf_values,
|
||||
mode="lines",
|
||||
name=f"{summary.distribution_name} Fit",
|
||||
)
|
||||
)
|
||||
hover_text = [
|
||||
f"<b>{summary.distribution_name}</b> p-value: {summary.pvalue:.4f} GoF: {summary.gof_statistic:.4f}"
|
||||
for _ in x
|
||||
]
|
||||
fig.data[-1].update(hovertext=hover_text, hoverinfo="text")
|
||||
|
||||
# add grid
|
||||
fig.update_layout(xaxis=dict(showgrid=True), yaxis=dict(showgrid=True))
|
||||
# put title in top left, make it smaller, change it font to sans and put in light gray
|
||||
fig.update_layout(
|
||||
xaxis=dict(showgrid=True),
|
||||
yaxis=dict(showgrid=True),
|
||||
title=dict(
|
||||
text='Histogram of Data with Fitted Distributions',
|
||||
x=0.02, y=0.95, xanchor='left', yanchor='top',
|
||||
font=dict(size=20, color='darkgray', family='sans-serif'),
|
||||
text="Histogram of Data with Fitted Distributions",
|
||||
x=0.02,
|
||||
y=0.95,
|
||||
xanchor="left",
|
||||
yanchor="top",
|
||||
font=dict(size=20, color="darkgray", family="sans-serif"),
|
||||
),
|
||||
xaxis_title='Value',
|
||||
yaxis_title='Density',
|
||||
xaxis_title="Value",
|
||||
yaxis_title="Density",
|
||||
autosize=True,
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
def histogram_with_fits_seaborn(self) -> plt.Figure:
|
||||
"""Return a static histogram with overlaid fitted PDFs (Matplotlib/Seaborn).
|
||||
|
||||
Builds a probability-density histogram using Seaborn and overlays a
|
||||
line for the PDF of each fitted distribution. The legend entry for
|
||||
each distribution includes its p-value. Distributions that have not
|
||||
yet been fitted are skipped with a printed warning.
|
||||
|
||||
Returns
|
||||
-------
|
||||
matplotlib.figure.Figure
|
||||
Matplotlib figure object.
|
||||
|
||||
Raises
|
||||
------
|
||||
RuntimeError
|
||||
If :meth:`fit` has not been called before this method.
|
||||
def histogram_with_fits_seaborn(self):
|
||||
"""
|
||||
if not hasattr(self, '_last_data_flat'):
|
||||
Generate a histogram of the data with overlaid PDFs of each fitted distribution using seaborn.
|
||||
Requires that fit() has been called to populate parameters.
|
||||
"""
|
||||
if not hasattr(self, "_last_data_flat"):
|
||||
raise RuntimeError("No data available. Call fit() first.")
|
||||
|
||||
data_flat = self._last_data_flat
|
||||
x = np.linspace(0, data_flat.max(), 1000)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
sns.histplot(data_flat, bins=int(np.sqrt(len(data_flat))), kde=False, stat='density', color='blue', alpha=0.2, ax=ax)
|
||||
sns.histplot(
|
||||
data_flat,
|
||||
bins=int(np.sqrt(len(data_flat))),
|
||||
kde=False,
|
||||
stat="density",
|
||||
color="blue",
|
||||
alpha=0.2,
|
||||
ax=ax,
|
||||
)
|
||||
|
||||
for dist_name, summary in self._dist.items():
|
||||
if not summary.fit_result_params:
|
||||
print(f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay.")
|
||||
print(
|
||||
f"Distribution '{dist_name}' has not been fitted yet. Skipping PDF overlay."
|
||||
)
|
||||
continue
|
||||
|
||||
pdf_values = summary.distribution_object.pdf(x, *summary.fit_result_params)
|
||||
ax.plot(x, pdf_values, label=f'{summary.distribution_name} --- p={summary.pvalue:.4f}')
|
||||
|
||||
ax.set_title('Histogram of Data with Fitted Distributions', fontsize=8, loc='left', color='darkgray', fontfamily='sans-serif')
|
||||
ax.set_xlabel('Value')
|
||||
ax.set_ylabel('Density')
|
||||
ax.plot(
|
||||
x,
|
||||
pdf_values,
|
||||
label=f"{summary.distribution_name} --- p={summary.pvalue:.4f}",
|
||||
)
|
||||
# put title in top left, make it smaller, change it font to sans and put in light gray
|
||||
ax.set_title(
|
||||
"Histogram of Data with Fitted Distributions",
|
||||
fontsize=8,
|
||||
loc="left",
|
||||
color="darkgray",
|
||||
fontfamily="sans-serif",
|
||||
)
|
||||
ax.set_xlabel("Value")
|
||||
ax.set_ylabel("Density")
|
||||
ax.legend()
|
||||
ax.grid(True)
|
||||
|
||||
return fig
|
||||
|
||||
267
etc/tests/test_fitter.py
Normal file
267
etc/tests/test_fitter.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from scipy.stats import norm, gamma, expon, weibull_min
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
from fitting.fitter import DistributionSummary, Fitter
|
||||
|
||||
|
||||
# ── Data ──────────────────────────────────────────────────────────────────
|
||||
|
||||
RNG = np.random.default_rng(42)
|
||||
GAMMA_DATA = RNG.gamma(shape=2.0, scale=1.5, size=200)
|
||||
NORM_DATA = RNG.normal(loc=5.0, scale=1.0, size=200)
|
||||
|
||||
|
||||
# ── DistributionSummary ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestDistributionSummary:
|
||||
def _make_summary(self, dist=gamma, fit_params=(2.0, 0.0, 1.5)):
|
||||
return DistributionSummary(
|
||||
distribution_object=dist,
|
||||
distribution_name=dist.name,
|
||||
fit_result_params=fit_params,
|
||||
)
|
||||
|
||||
def test_mean_std_var_are_floats(self):
|
||||
s = self._make_summary()
|
||||
assert isinstance(s.mean, float)
|
||||
assert isinstance(s.std, float)
|
||||
assert isinstance(s.var, float)
|
||||
|
||||
def test_var_equals_std_squared(self):
|
||||
s = self._make_summary()
|
||||
assert pytest.approx(s.var, rel=1e-6) == s.std**2
|
||||
|
||||
def test_pvalue_none_before_validate(self):
|
||||
s = self._make_summary()
|
||||
assert s.pvalue is None
|
||||
|
||||
def test_gof_statistic_none_before_validate(self):
|
||||
s = self._make_summary()
|
||||
assert s.gof_statistic is None
|
||||
|
||||
def test_repr_contains_name(self):
|
||||
s = self._make_summary()
|
||||
assert "gamma" in repr(s)
|
||||
|
||||
def test_repr_shows_na_when_no_test(self):
|
||||
s = self._make_summary()
|
||||
assert "N/A" in repr(s)
|
||||
|
||||
def test_default_statistic_method(self):
|
||||
s = self._make_summary()
|
||||
assert s.statistic_method == "ad"
|
||||
|
||||
|
||||
# ── Fitter construction ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestFitterConstruction:
|
||||
def test_distributions_registered(self):
|
||||
f = Fitter([gamma, expon])
|
||||
assert "gamma" in f
|
||||
assert "expon" in f
|
||||
|
||||
def test_getitem_by_name(self):
|
||||
f = Fitter([gamma])
|
||||
s = f["gamma"]
|
||||
assert isinstance(s, DistributionSummary)
|
||||
|
||||
def test_getitem_by_object(self):
|
||||
f = Fitter([gamma])
|
||||
s = f[gamma]
|
||||
assert isinstance(s, DistributionSummary)
|
||||
|
||||
def test_getitem_missing_raises(self):
|
||||
f = Fitter([gamma])
|
||||
with pytest.raises(KeyError):
|
||||
_ = f["norm"]
|
||||
|
||||
def test_contains_true(self):
|
||||
f = Fitter([gamma])
|
||||
assert gamma in f
|
||||
assert "gamma" in f
|
||||
|
||||
def test_contains_false(self):
|
||||
f = Fitter([gamma])
|
||||
assert norm not in f
|
||||
|
||||
def test_kwargs_args_stored(self):
|
||||
f = Fitter([gamma], gamma_args=(2.0,), gamma_params={"floc": 0})
|
||||
assert f["gamma"].args_fit_params == (2.0,)
|
||||
assert f["gamma"].kwds_fit_params == {"floc": 0}
|
||||
|
||||
def test_setitem_valid(self):
|
||||
f = Fitter([gamma])
|
||||
original = f["gamma"]
|
||||
f["gamma"] = original
|
||||
assert f["gamma"] is original
|
||||
|
||||
def test_setitem_invalid_type(self):
|
||||
f = Fitter([gamma])
|
||||
with pytest.raises(TypeError):
|
||||
f["gamma"] = "not_a_summary"
|
||||
|
||||
def test_setitem_missing_key_raises(self):
|
||||
f = Fitter([gamma])
|
||||
dummy = DistributionSummary(distribution_object=norm, distribution_name="norm")
|
||||
with pytest.raises(KeyError):
|
||||
f["norm"] = dummy
|
||||
|
||||
|
||||
# ── Fitter iteration ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestFitterIteration:
|
||||
def test_iterates_all_distributions(self):
|
||||
f = Fitter([gamma, expon, weibull_min])
|
||||
names = [name for name, _ in f]
|
||||
assert set(names) == {"gamma", "expon", "weibull_min"}
|
||||
|
||||
def test_iterate_twice(self):
|
||||
f = Fitter([gamma, expon])
|
||||
first = [name for name, _ in f]
|
||||
second = [name for name, _ in f]
|
||||
assert first == second
|
||||
|
||||
def test_iterator_yields_summary_instances(self):
|
||||
f = Fitter([gamma])
|
||||
for _, summary in f:
|
||||
assert isinstance(summary, DistributionSummary)
|
||||
|
||||
|
||||
# ── Fitter.fit ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestFitterFit:
|
||||
def test_fit_populates_params(self):
|
||||
f = Fitter([gamma], gamma_params={"floc": 0})
|
||||
f.fit(GAMMA_DATA)
|
||||
assert len(f["gamma"].fit_result_params) > 0
|
||||
|
||||
def test_fit_uses_abs_value(self):
|
||||
f = Fitter([gamma], gamma_params={"floc": 0})
|
||||
neg_data = -np.abs(GAMMA_DATA)
|
||||
f.fit(neg_data)
|
||||
assert len(f["gamma"].fit_result_params) > 0
|
||||
|
||||
def test_fit_flattens_2d(self):
|
||||
f = Fitter([gamma], gamma_params={"floc": 0})
|
||||
data_2d = GAMMA_DATA.reshape(10, 20)
|
||||
f.fit(data_2d)
|
||||
assert hasattr(f, "_last_data_flat")
|
||||
assert f._last_data_flat.ndim == 1
|
||||
|
||||
def test_fit_multiple_dists(self):
|
||||
f = Fitter([gamma, expon], gamma_params={"floc": 0}, expon_params={"floc": 0})
|
||||
f.fit(GAMMA_DATA)
|
||||
assert len(f["gamma"].fit_result_params) > 0
|
||||
assert len(f["expon"].fit_result_params) > 0
|
||||
|
||||
|
||||
# ── Fitter.validate ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestFitterValidate:
|
||||
def test_validate_without_fit_raises(self):
|
||||
f = Fitter([gamma])
|
||||
with pytest.raises(RuntimeError, match="fit\\(\\)"):
|
||||
f.validate()
|
||||
|
||||
def test_validate_populates_test_result(self):
|
||||
f = Fitter([gamma], gamma_params={"floc": 0})
|
||||
f.fit(GAMMA_DATA)
|
||||
f.validate(n_mc_samples=99)
|
||||
assert f["gamma"].test_result is not None
|
||||
|
||||
def test_validate_pvalue_in_range(self):
|
||||
f = Fitter([gamma], gamma_params={"floc": 0})
|
||||
f.fit(GAMMA_DATA)
|
||||
f.validate(n_mc_samples=99)
|
||||
pval = f["gamma"].pvalue
|
||||
assert 0.0 <= pval <= 1.0
|
||||
|
||||
def test_validate_gof_statistic_positive(self):
|
||||
f = Fitter([gamma], gamma_params={"floc": 0})
|
||||
f.fit(GAMMA_DATA)
|
||||
f.validate(n_mc_samples=99)
|
||||
assert f["gamma"].gof_statistic >= 0.0
|
||||
|
||||
def test_validate_custom_statistic(self):
|
||||
f = Fitter([gamma], statistic_method="ks", gamma_params={"floc": 0})
|
||||
f.fit(GAMMA_DATA)
|
||||
f.validate(n_mc_samples=99)
|
||||
assert f["gamma"].test_result is not None
|
||||
|
||||
|
||||
# ── Fitter.summary ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestFitterSummary:
|
||||
def test_summary_runs_without_error(self, capsys):
|
||||
f = Fitter([gamma], gamma_params={"floc": 0})
|
||||
f.fit(GAMMA_DATA)
|
||||
f.validate(n_mc_samples=99)
|
||||
f.summary()
|
||||
captured = capsys.readouterr()
|
||||
assert "gamma" in captured.out
|
||||
|
||||
|
||||
# ── Fitter.plot_qq_plots ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestFitterPlotQQ:
|
||||
def setup_method(self):
|
||||
self.f = Fitter([gamma], gamma_params={"floc": 0})
|
||||
self.f.fit(GAMMA_DATA)
|
||||
self.f.validate(n_mc_samples=99)
|
||||
|
||||
def test_qq_without_fit_raises(self):
|
||||
f = Fitter([gamma])
|
||||
with pytest.raises(RuntimeError, match="fit\\(\\)"):
|
||||
f.plot_qq_plots()
|
||||
|
||||
def test_qq_invalid_method_raises(self):
|
||||
with pytest.raises(ValueError, match="method"):
|
||||
self.f.plot_qq_plots(method="invalid")
|
||||
|
||||
def test_qq_hazen_returns_figure(self):
|
||||
import plotly.graph_objects as go
|
||||
fig = self.f.plot_qq_plots(method="hazen")
|
||||
assert isinstance(fig, go.Figure)
|
||||
|
||||
def test_qq_filliben_returns_figure(self):
|
||||
import plotly.graph_objects as go
|
||||
fig = self.f.plot_qq_plots(method="filliben")
|
||||
assert isinstance(fig, go.Figure)
|
||||
|
||||
|
||||
# ── Fitter.histogram_with_fits ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestFitterHistogram:
|
||||
def setup_method(self):
|
||||
self.f = Fitter([gamma], gamma_params={"floc": 0})
|
||||
self.f.fit(GAMMA_DATA)
|
||||
self.f.validate(n_mc_samples=99)
|
||||
|
||||
def test_histogram_without_fit_raises(self):
|
||||
f = Fitter([gamma])
|
||||
with pytest.raises(RuntimeError, match="fit\\(\\)"):
|
||||
f.histogram_with_fits()
|
||||
|
||||
def test_histogram_returns_figure(self):
|
||||
import plotly.graph_objects as go
|
||||
fig = self.f.histogram_with_fits()
|
||||
assert isinstance(fig, go.Figure)
|
||||
|
||||
def test_histogram_seaborn_returns_figure(self):
|
||||
import matplotlib.pyplot as plt
|
||||
fig = self.f.histogram_with_fits_seaborn()
|
||||
assert isinstance(fig, plt.Figure)
|
||||
plt.close("all")
|
||||
@@ -21,13 +21,16 @@ def stacked_plot(data):
|
||||
data = np.squeeze(data)
|
||||
mean_dp = np.mean(np.abs(data), axis=1)
|
||||
|
||||
fig = make_subplots(rows=2, cols=1, row_heights=[0.3, 0.7], shared_xaxes=True,
|
||||
vertical_spacing=0.01)
|
||||
fig = make_subplots(
|
||||
rows=2, cols=1, row_heights=[0.3, 0.7], shared_xaxes=True, vertical_spacing=0.01
|
||||
)
|
||||
|
||||
fig.add_trace(go.Scatter(y=mean_dp, name='Mean Power'), row=1, col=1)
|
||||
fig.add_trace(go.Heatmap(z=np.abs(data).T, showscale=False, name='Heat Map'), row=2, col=1)
|
||||
fig.add_trace(go.Scatter(y=mean_dp, name="Mean Power"), row=1, col=1)
|
||||
fig.add_trace(
|
||||
go.Heatmap(z=np.abs(data).T, showscale=False, name="Heat Map"), row=2, col=1
|
||||
)
|
||||
|
||||
fig.update_layout(title='Mean DP Power and 2D Map', autosize=True)
|
||||
fig.update_layout(title="Mean DP Power and 2D Map", autosize=True)
|
||||
fig.update_xaxes(visible=False, row=2, col=1)
|
||||
fig.update_yaxes(visible=False, row=2, col=1)
|
||||
|
||||
@@ -98,6 +101,8 @@ def plot_cdfs(data_list, labels):
|
||||
fig = go.Figure()
|
||||
for data, label in zip(data_list, labels):
|
||||
sorted_data, cdf = calculate_cdf(data)
|
||||
fig.add_trace(go.Scatter(x=sorted_data, y=cdf, mode='lines', name=label))
|
||||
fig.update_layout(title='CDF of Data', xaxis_title='Value', yaxis_title='CDF', autosize=True)
|
||||
fig.add_trace(go.Scatter(x=sorted_data, y=cdf, mode="lines", name=label))
|
||||
fig.update_layout(
|
||||
title="CDF of Data", xaxis_title="Value", yaxis_title="CDF", autosize=True
|
||||
)
|
||||
return fig
|
||||
|
||||
Reference in New Issue
Block a user