Optimization#
The optimization procedure in RegMMD follows a two-tier strategy: it uses exact analytical gradient methods when available, and falls back to a general-purpose stochastic gradient descent (SGD) method otherwise.
Exact methods vs. SGD#
Exact methods#
For certain combinations of statistical model and kernel, the MMD objective and its gradient can be computed in closed form, without resorting to Monte Carlo sampling. This has two key advantages:
No sampling variance: the gradient is deterministic, leading to more stable optimization.
Efficiency: direct computation avoids the cost of drawing and evaluating random samples at each step.
Each model class can optionally implement an _exact_fit() method. When
called, the optimizer first tries this method. If it returns a result, that
result is used directly. If it returns None (meaning the current
model/kernel combination has no exact implementation), the optimizer falls back
to SGD.
The decision logic in the estimator and regressor looks like this:
# 1. Try the exact method
res = model._exact_fit(X=X, ...)
# 2. Fall back to SGD if no exact method is available
if res is None:
res = _sgd_estimation(X=X, ...)
SGD fallback#
The general SGD solver works with any model and kernel combination. It
approximates the MMD gradient by sampling from the model at each iteration and
uses the model.score() methods to calculate the gradients.
For the regression setting, two SGD variants are available, as described in section 3.2 of Universal robust regression via maximum mean discrepancy:
Tilde estimator (
_sgd_tilde_regression): uses only a kernel on \(Y\). This is selected when no covariate kernel is specified (bandwidth_X = 0).Hat estimator (
_sgd_hat_regression): uses a product kernel on \((X, Y)\). This is selected when a covariate kernel is specified (bandwidth_X > 0).
Available exact methods#
The table below summarises which model/kernel combinations currently have exact methods implemented.
Estimation#
Model |
Kernel |
Method |
|---|---|---|
|
Gaussian |
Exact gradient descent |
All other estimation models (GaussianScale, Gaussian, Beta,
Poisson, Gamma, etc.) use the general SGD solver.
Regression#
Model |
Kernel |
Estimator |
Method |
|---|---|---|---|
|
Gaussian |
Tilde |
Exact GD with backtracking line search |
|
Gaussian |
Tilde |
Exact GD with backtracking line search |
|
Any |
Tilde |
Exact GD with backtracking line search |
|
Any |
Hat |
Exact gradients of the expectations are used, but the diagonal and off-diagonal elements are still sub-sampled for efficiency considerations. |
All other regression models (GammaRegressionLoc,
PoissonRegressionLoc, etc.) use the general SGD solver.
Implementing a custom exact method#
To add an exact method for a new model, override the _exact_fit() method in
your model class. The base class implementation returns None, which
triggers the SGD fallback. Your override should:
Check whether the kernel and other settings are supported by your exact implementation.
If supported, run the optimization and return the result dictionary.
If not supported, return
Noneto fall back to SGD.
class MyModel(BaseModel):
def _exact_fit(self, X, par_v, par_c, solver, kernel, bandwidth):
if kernel != "Gaussian":
return None # fall back to SGD
# ... compute exact gradients and optimize ...
return {"estimator": par_v_opt, "trajectory": trajectory}