MMD DRO¶
- class dro.linear_model.mmd_dro.MMD_DRO(input_dim, model_type='svm', fit_intercept=True, solver='MOSEK', sampling_method='bound')¶
Bases:
BaseLinearDRO
MMD-DRO (Maximum Mean Discrepancy - Distributionally Robust Optimization) Implementation with flexible sampling methods and model types.
Reference: <https://arxiv.org/abs/2006.06981>
Initialize MMD-DRO with kernel-based ambiguity set.
- Parameters:
input_dim (int) – Dimension of input features. Must match training data.
model_type (str) –
Base model type. Supported:
'svm'
: Support Vector Machine (hinge loss)'logistic'
: Logistic Regression (log loss)'ols'
: Ordinary Least Squares (L2 loss)'lad'
: Least Absolute Deviation (L1 loss)
sampling_method (str) –
Supported:
'bound'
'hull'
fit_intercept (bool)
solver (str)
- Raises:
If model_type not in supported list
If input_dim ≤ 0
If sampling_method is invalid
- Example:
>>> model = MMD_DRO(input_dim=128, model_type='svm') >>> model.sampling_method = 'hull' >>> model.eta = 0.5
- update(config)¶
Update MMD-DRO model configuration.
- Parameters:
config (Dict[str, Any]) –
Configuration dictionary containing optional keys:
eta
(float):MMD radius controlling distributional robustness. Must satisfy \(\eta > 0\). Defaults to current value.
sampling_method
(str):Ambiguity set sampling strategy. Valid options:
'bound'
: Sample on MMD ball boundary'hull'
: Sample within convex hull
n_certify_ratio
(float):Ratio of certification samples to training data size. Must satisfy \(0 < ext{ratio} \leq 1\). Defaults to current ratio.
- Raises:
If
eta
is non-positiveIf
sampling_method
not in {‘bound’, ‘hull’}If
n_certify_ratio
∉ (0, 1]If config contains unrecognized keys
- Return type:
- Example:
>>> model = MMD_DRO(input_dim=10, model_type='svm') >>> model.update({ ... 'eta': 0.5, ... 'sampling_method': 'hull' ... })