Estimation Class#
- class regmmd.estimation.MMDEstimator(model, par_v=None, par_c=None, kernel='Gaussian', bandwidth='auto', solver=None, random_state=None)[source]#
Bases:
BaseEstimatorEstimator using the Maximum Mean Discrepancy criterion.
Maximum Mean Discrepancy (MMD) is a kernel-based statistical test used to compare two probability distributions. This estimator fits a parametric model by minimizing the MMD between the empirical distribution of the observed data and the model distribution.
Depending on the type of model provided, the estimator will use either an exact gradient descent procedure (for
GaussianLoc) or a stochastic gradient descent approach for general models.- Parameters:
model (EstimationModel) – The parametric estimation model to be fitted, provided as an instance of
EstimationModel. This model defines the distributional form assumed for the data and exposes an_init_paramsmethod used to initialise parameters before optimisation.par_v (float, optional) – Initial value for the variable parameters of the model. If
None, it will be initialised automatically by the model’s_init_paramsmethod whenfit()is called.par_c (float, optional) – Initial value for the constant parameters of the model. If
None, it will be initialised automatically by the model’s_init_paramsmethod whenfit()is called.kernel (str, default="Gaussian") – The kernel function used to compute the MMD. Currently supports
"Gaussian","Laplace"or"Cauchy".bandwidth (str or float, default="auto") – The bandwidth of the kernel. If set to
"auto", the bandwidth is selected automatically using a heuristic method such as the median heuristic.solver (dict, optional) –
A dictionary specifying the solver parameters for the optimisation procedure. Expected keys are:
"burnin"(int): Number of burn-in steps before recording results."n_step"(int): Total number of optimisation steps."stepsize"(float): Learning rate for gradient updates."epsilon"(float): Convergence tolerance or regularisation term.
If
None, solver settings must be provided before callingfit().random_state (int, optional) – random seed to be passed to the model and any sampler used in the SGD optimizers.
- par_c[source]#
The constant parameter, not updated with the optimised value after fitting.
- Type:
float
Notes
For
GaussianLocmodels, an exact gradient descent routine (_gd_gaussian_loc_exact_estimation) can be used, when the kernel is"Gaussian"as wellduring fitting.For all other models, a stochastic gradient descent routine (
_sgd_estimation) is applied instead.
- fit(X, use_exact=True, use_fast=True)[source]#
Fit the MMD estimation model according to the given training data.
- Parameters:
X (np.ndarray, shape (n_samples, n_features)) – Training input samples.
use_exact (bool, default=True) – Use the
model._exact_fit()method, if it is available, will default to SGD if it is not. Mainly used for performance comparisonsuse_fast (bool, default=True) – If
True, will try to build theCyModelversion throughmodel._build_cy_model(). If successful, a Cython version of the SGD loop will be called, which often results in a5-10xspeed up.
- Returns:
res – A dictionary containing the results of the optimization process, including the estimated parameters and the optimization trajectory.
- Return type:
MMDResult
- set_fit_request(*, use_exact='$UNCHANGED$', use_fast='$UNCHANGED$')[source]#
Configure whether metadata should be requested to be passed to the
fitmethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed tofitif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it tofit.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- Parameters:
use_exact (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
use_exactparameter infit.use_fast (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
use_fastparameter infit.self (MMDEstimator)
- Returns:
self – The updated object.
- Return type:
object