Kernel Distributionally Robust OptimizationΒΆ
This is not MMD-DRO. Instead, we still define DRO ambiguity sets through standard DRO models, while the model class is set beyond the kernel feature.
[2]:
from dro.linear_model.wasserstein_dro import *
from dro.data.dataloader_regression import regression_basic
from dro.data.dataloader_classification import classification_basic
from dro.data.draw_utils import draw_classification
feature_dim = 2
X, y = classification_basic(d = feature_dim, num_samples = 500, radius = 3, visualize = False)
#draw_classification(X, y, title = 'Raw Data')
kernel_clf_model = WassersteinDRO(input_dim = feature_dim, model_type = 'svm', kernel = 'rbf')
kernel_clf_model.update_kernel({'metric': 'poly', 'kernel_gamma': 1})
kernel_clf_model.update({'eps': .1})
params = kernel_clf_model.fit(X, y)
[3]:
from dro.linear_model.chi2_dro import *
from dro.linear_model.kl_dro import *
from dro.linear_model.tv_dro import *
from dro.linear_model.cvar_dro import *
from dro.linear_model.marginal_dro import *
from dro.linear_model.conditional_dro import *
kernel_clf_model = Chi2DRO(input_dim = feature_dim, model_type = 'svm', kernel = 'poly')
kernel_clf_model.update({'eps': 10})
params = kernel_clf_model.fit(X, y)
kernel_clf_model = KLDRO(input_dim = feature_dim, model_type = 'svm', kernel = 'rbf')
kernel_clf_model.update({'eps': 2})
params = kernel_clf_model.fit(X, y)
kernel_clf_model = CVaRDRO(input_dim = feature_dim, model_type = 'logistic', kernel = 'sigmoid')
kernel_clf_model.update({'alpha': 0.9})
params = kernel_clf_model.fit(X, y)
kernel_clf_model = TVDRO(input_dim = feature_dim, model_type = 'logistic', kernel = 'poly')
kernel_clf_model.update({'eps': 0.01})
params = kernel_clf_model.fit(X, y)
[5]:
kernel_clf_model = MarginalCVaRDRO(input_dim = feature_dim, model_type = 'svm', kernel = 'poly')
kernel_clf_model.update({'alpha': 0.9})
params = kernel_clf_model.fit(X, y)
kernel_clf_model.score(X,y)
/Users/jiashuo/anaconda3/envs/llm-ot/lib/python3.10/site-packages/cvxpy/expressions/expression.py:674: UserWarning:
This use of ``*`` has resulted in matrix multiplication.
Using ``*`` for matrix multiplication has been deprecated since CVXPY 1.1.
Use ``*`` for matrix-scalar and vector-scalar multiplication.
Use ``@`` for matrix-matrix and matrix-vector multiplication.
Use ``multiply`` for elementwise multiplication.
This code path has been hit 2 times so far.
warnings.warn(msg, UserWarning)
[5]:
(0.986, 0.9859994959818553)
[6]:
from sklearn.svm import SVC
from dro.data.dataloader_regression import regression_basic
from dro.data.dataloader_classification import classification_basic
from dro.data.draw_utils import draw_classification
feature_dim = 5
X, y = classification_basic(d = feature_dim, num_samples = 25, radius = 3, visualize = False)
clf = SVC(kernel = 'rbf', C = 100000000)
clf.fit(X, y)
print('coef', clf.dual_coef_, clf.support_vectors_, clf.intercept_)
coef [[-0.25769266 -0.64158585 -3.57086163 -0.78030986 -0.43010956 -0.01490551
0.76823835 0.16369363 1.56160661 1.47227558 0.03362333 1.69602757]] [[ 0.38846926 -0.7028592 1.35277704 0.69878398 -2.12650133]
[ 2.31753572 -0.46290575 1.17834297 1.18731603 -0.94596622]
[ 0.25018034 1.61514873 1.09731754 1.55435329 0.42096142]
[-0.3689567 -0.02826586 -0.84885536 1.28387817 -0.20472226]
[ 0.13204275 -0.69776822 2.16793699 2.95568251 -2.16462365]
[ 1.78316707 -1.07634697 0.80160239 2.94332765 0.57396163]
[-0.00613887 4.01699229 1.16884802 0.82769028 -1.76806615]
[ 0.45436803 2.56600279 0.90566671 -0.6451926 -1.13588996]
[-2.28630569 2.45244185 1.26490427 1.72628874 0.65931799]
[-0.06598713 2.44424396 0.03599602 0.40586944 1.60361199]
[ 0.42349747 1.56956827 2.60746837 -2.13880444 1.43853605]
[ 1.82292115 1.4884194 0.63837633 -0.63730201 0.3482033 ]] [0.47099307]
[7]:
feature_dim = 5
X, y = regression_basic(num_samples = 20, d = feature_dim, noise = 1)
reg_model = WassersteinDRO(input_dim = feature_dim, model_type = 'lad', kernel = 'rbf')
reg_model.update_kernel({'metric': 'rbf'})
reg_model.update({'eps': 1})
reg_model.fit(X, y)
[7]:
{'theta': [28.183849164084634,
-40.587191522263076,
-25.771734799878164,
24.507949490371534,
-20.773407080679252,
-32.97201768430408,
3.720089369558843,
-71.2805334354158,
70.95668207216211,
-50.90241684443111,
41.142825176686436,
-57.152569382263984,
57.59235347448549,
48.59610303450425,
26.428015099762195,
-0.05635013918886693,
9.225598822381361,
-40.7058061185105,
-30.048264859949104,
19.95875895029569],
'b': array(-37.78937028)}