Personalization¶
In general, when rewritting the function for a new class, one way is to apply the following protocol code:
from types import MethodType
def _loss(self, X, y):
xxx
def _cvx_loss(self, X, y, theta, b):
xxx
model = XXDRO(...)
model._loss = MethodType(_loss, model)
model._cvx_loss = MethodType(_cvx_loss, model)
if we want to modify the self._cvx_loss and _loss functions in the model class.
1. Linear (Exact) DRO Methods¶
In DRO models that are solved exactly, for each particular DRO type, we change _loss and _cvx_loss in each class.
\(f\)-DRO¶
In KLDRO, Chi2DRO, CVaRDRO, TVDRO (and corresponding BayesianDRO), we only need to rewrite _loss and _cvx_loss.
Wasserstein DRO¶
To adjust Wasserstein DRO, besides modifying the _cvx_loss (and _loss) functions, we also need to modify the _penalization function to adjust the regularization component, where the regularization component denotes the additional part besides the empirical objective in the Wasserstein DRO objective after the problem reformulation.
Remark¶
We remark that for more complicated losses, e.g., losses with a mixture of distances, we have not implemented the personalize loss yet.
Note that we have not implemented the personanlized constraint module yet \((e.g., for \)\theta$). Stay tuned for that.
2. NN-Based DRO Methods¶
2.1 Personalized Loss¶
For f-DRO and WDRO methods, our package supports personalized loss functions.
\(f\)-DRO¶
If the user would like to integrate his/her own loss functions, please
Create a new
RobustLossinstance (fromfdro.utils.py), and re-write theself._compute_individual_loss()function to user-specified forms.Create a new
Chi2NNDROorCVaRNNDROinstance (fromfdro_nn.py), and re-write theself._criterion()function with the newly-modifiedRobustLossinstance above.
WDRO¶
When personalizing the loss function for WDRO, please:
Create a new
WNNDROinstance (fromwdro_nn.py).Re-write the
self._loss()function.
2.2 Personalized Model Architecture¶
Users could pass their own model via self.update() function. Note that the personalized model must be written via PyTorch and is a sub-class of torch.nn.Module.
3. Tree-Based DRO Methods¶
For tree-based DRO methods, users could simply rewrite the self.loss() function to change loss functions. To change the DRO type, one need to adjust self._kl_dro_loss() (or self._cvar_dro_loss()) if their base model is KLDRO_XX (or CVaRDRO_XX) respectively.