MMD DRO¶
- class dro.src.linear_model.mmd_dro.MMD_DRO(input_dim, model_type, fit_intercept=True, solver='MOSEK', sampling_method='bound')¶
Bases:
BaseLinearDRO
MMD-DRO (Maximum Mean Discrepancy - Distributionally Robust Optimization) Implementation with flexible sampling methods and model types.
Reference: <https://arxiv.org/abs/2006.06981>
Initialize MMD-DRO with kernel-based ambiguity set.
- Parameters:
input_dim (int) – Dimension of input features. Must match training data.
model_type (str) –
Base model type. Supported:
'svm'
: Support Vector Machine (hinge loss)'logistic'
: Logistic Regression (log loss)'ols'
: Ordinary Least Squares (L2 loss)'lad'
: Least Absolute Deviation (L1 loss)
sampling_method (str) –
Supported:
'bound'
'hull'
fit_intercept (bool)
solver (str)
- Raises:
If model_type not in supported list
If input_dim ≤ 0
If sampling_method is invalid
If n_certify_ratio < 1
- Example:
>>> model = MMD_DRO(input_dim=128, model_type='svm') >>> model.sampling_method = 'hull' >>> model.eta = 0.5
- update(config)¶
Update MMD-DRO model configuration.
- Parameters:
config (Dict[str, Any]) –
Configuration dictionary containing optional keys:
eta
(float):MMD radius controlling distributional robustness. Must satisfy \(\eta > 0\). Defaults to current value.
sampling_method
(str):Ambiguity set sampling strategy. Valid options:
'bound'
: Sample on MMD ball boundary'hull'
: Sample within convex hull
n_certify_ratio
(float):Ratio of certification samples to training data size. Must satisfy \(0 < ext{ratio} \leq 1\). Defaults to current ratio.
- Raises:
If
eta
is non-positiveIf
sampling_method
not in {‘bound’, ‘hull’}If
n_certify_ratio
∉ (0, 1]If config contains unrecognized keys
- Return type:
- Example:
>>> model = MMD_DRO(input_dim=10, model_type='svm') >>> model.update({ ... 'eta': 0.5, ... 'sampling_method': 'hull' ... })
- fit(X, y)¶
Fit the MMD-DRO model to the data.
- Parameters:
X (numpy.ndarray) – Training feature matrix of shape (n_samples, n_features). Must satisfy n_features == self.input_dim.
y (numpy.ndarray) –
Target values of shape (n_samples,). Format requirements:
Classification: ±1 labels
Regression: Continuous values
- Returns:
Dictionary containing trained parameters:
theta
: Weight vector of shape (n_features,)
- Return type:
Dict[str, Any]