Estimation Class#

class regmmd.estimation.MMDEstimator(model, par_v=None, par_c=None, kernel='Gaussian', bandwidth='auto', solver=None, random_state=None)[source]#

Bases: BaseEstimator

Estimator using the Maximum Mean Discrepancy criterion.

Maximum Mean Discrepancy (MMD) is a kernel-based statistical test used to compare two probability distributions. This estimator fits a parametric model by minimizing the MMD between the empirical distribution of the observed data and the model distribution.

Depending on the type of model provided, the estimator will use either an exact gradient descent procedure (for GaussianLoc) or a stochastic gradient descent approach for general models.

Parameters:
  • model (EstimationModel) – The parametric estimation model to be fitted, provided as an instance of EstimationModel. This model defines the distributional form assumed for the data and exposes an _init_params method used to initialise parameters before optimisation.

  • par_v (float, optional) – Initial value for the variable parameters of the model. If None, it will be initialised automatically by the model’s _init_params method when fit() is called.

  • par_c (float, optional) – Initial value for the constant parameters of the model. If None, it will be initialised automatically by the model’s _init_params method when fit() is called.

  • kernel (str, default="Gaussian") – The kernel function used to compute the MMD. Currently supports "Gaussian", "Laplace" or "Cauchy".

  • bandwidth (str or float, default="auto") – The bandwidth of the kernel. If set to "auto", the bandwidth is selected automatically using a heuristic method such as the median heuristic.

  • solver (dict, optional) –

    A dictionary specifying the solver parameters for the optimisation procedure. Expected keys are:

    • "burnin" (int): Number of burn-in steps before recording results.

    • "n_step" (int): Total number of optimisation steps.

    • "stepsize" (float): Learning rate for gradient updates.

    • "epsilon" (float): Convergence tolerance or regularisation term.

    If None, solver settings must be provided before calling fit().

  • random_state (int, optional) – random seed to be passed to the model and any sampler used in the SGD optimizers.

par_v[source]#

The variable parameter, updated with the optimised value after fitting.

Type:

float

par_c[source]#

The constant parameter, not updated with the optimised value after fitting.

Type:

float

Notes

  • For GaussianLoc models, an exact gradient descent routine (_gd_gaussian_loc_exact_estimation) can be used, when the kernel is "Gaussian" as wellduring fitting.

  • For all other models, a stochastic gradient descent routine (_sgd_estimation) is applied instead.

fit(X, use_exact=True, use_fast=True)[source]#

Fit the MMD estimation model according to the given training data.

Parameters:
  • X (np.ndarray, shape (n_samples, n_features)) – Training input samples.

  • use_exact (bool, default=True) – Use the model._exact_fit() method, if it is available, will default to SGD if it is not. Mainly used for performance comparisons

  • use_fast (bool, default=True) – If True, will try to build the CyModel version through model._build_cy_model(). If successful, a Cython version of the SGD loop will be called, which often results in a 5-10x speed up.

Returns:

res – A dictionary containing the results of the optimization process, including the estimated parameters and the optimization trajectory.

Return type:

MMDResult

set_fit_request(*, use_exact='$UNCHANGED$', use_fast='$UNCHANGED$')[source]#

Configure whether metadata should be requested to be passed to the fit method.

Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with enable_metadata_routing=True (see sklearn.set_config()). Please check the User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Parameters:
  • use_exact (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for use_exact parameter in fit.

  • use_fast (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for use_fast parameter in fit.

  • self (MMDEstimator)

Returns:

self – The updated object.

Return type:

object