Kernel Distributionally Robust OptimizationΒΆ

This is not MMD-DRO. Instead, we still define DRO ambiguity sets through standard DRO models, while the model class is set beyond the kernel feature.

[2]:
# import os
# os.chdir('../../../')
from dro.src.linear_model.wasserstein_dro import *
from dro.src.data.dataloader_regression import regression_basic
from dro.src.data.dataloader_classification import classification_basic
from dro.src.data.draw_utils import draw_classification

feature_dim = 5
X, y = classification_basic(d = feature_dim, num_samples = 25, radius = 3, visualize = False)


#draw_classification(X, y, title = 'Raw Data')
kernel_clf_model = WassersteinDRO(input_dim = feature_dim, model_type = 'svm', kernel = 'rbf')
kernel_clf_model.update_kernel({'metric': 'poly', 'kernel_gamma': 1})
kernel_clf_model.update({'eps': 1})
kernel_clf_model.fit(X, y)

[2]:
{'theta': [-0.16104024032382355,
  0.048395359857878774,
  0.12293374271263681,
  -0.059218694026617606,
  -0.26396127900233757,
  -0.12747839301562838,
  -0.06610574612531378,
  0.01936335061582429,
  0.08109022066372668,
  -0.08014570552519816,
  -0.12661786535569905,
  -0.01620292119149252,
  0.0345609438234761,
  0.1146307847005102,
  0.036123565154558585,
  -0.01630122310620072,
  -0.055200009041263855,
  -0.013088731696068508,
  0.06569980417065936,
  -0.015705425237797853,
  0.07864622060163941,
  0.188215369431673,
  0.024233518060246394,
  0.08189350692147811],
 'b': array(-0.19451036)}
[3]:
from dro.src.linear_model.chi2_dro import *
from dro.src.linear_model.kl_dro import *
from dro.src.linear_model.tv_dro import *
from dro.src.linear_model.cvar_dro import *
from dro.src.linear_model.marginal_dro import *
from dro.src.linear_model.conditional_dro import *

kernel_clf_model = Chi2DRO(input_dim = feature_dim, model_type = 'svm', kernel = 'poly')
kernel_clf_model.update({'eps': 10})
kernel_clf_model.fit(X, y)

kernel_clf_model = KLDRO(input_dim = feature_dim, model_type = 'svm', kernel = 'rbf')
kernel_clf_model.update({'eps': 2})
kernel_clf_model.fit(X, y)

kernel_clf_model = CVaRDRO(input_dim = feature_dim, model_type = 'logistic', kernel = 'sigmoid')
kernel_clf_model.update({'alpha': 0.9})
kernel_clf_model.fit(X, y)

kernel_clf_model = TVDRO(input_dim = feature_dim, model_type = 'logistic', kernel = 'poly')
kernel_clf_model.update({'eps': 0.01})
kernel_clf_model.fit(X, y)



[3]:
{'theta': [12.962059776832325,
  0.0,
  -41.23333171520414,
  17.95164599844062,
  -51.54123399285766,
  -5.622340899150795,
  8.567083574439527,
  -14.25831621758392,
  -6.091080186152102,
  -8.659262706738705,
  61.50388502711999,
  18.705868594821755,
  2.3565374641324044,
  14.182027509645296,
  8.58043910589452,
  -5.375518824121545,
  -26.23442327858078,
  -17.329416394331346,
  28.335130689663124,
  1.0646708883222875,
  13.265331059031373,
  30.40784507557173,
  -4.905773590347232,
  37.604591748881894],
 'threshold': array(-1.8858299e-08),
 'b': array(-95.25475747)}
[4]:
kernel_clf_model = MarginalCVaRDRO(input_dim = feature_dim, model_type = 'svm', kernel = 'poly')
kernel_clf_model.update({'alpha': 0.9})
kernel_clf_model.fit(X, y)
[4]:
{'theta': [0.7152348833517868,
  -0.0,
  -2.279697334359294,
  0.9927846306961221,
  -2.849966228187915,
  -0.3119639511005327,
  0.47473699096434524,
  -0.7880069523395982,
  -0.33579964629235887,
  -0.47734041718469666,
  3.3910629260703984,
  1.033152512067522,
  0.1309952968329014,
  0.781890641607764,
  0.47589726587677433,
  -0.29493206782188996,
  -1.449101893657605,
  -0.9597117426648142,
  1.5644978373954748,
  0.057118748055863824,
  0.7339216228056991,
  1.6814093166469604,
  -0.2721559696436323,
  2.079190695288901],
 'B': [[0.0,
   1.0532347945085568e-12,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0],
  [4.427104274080785e-13,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0],
  [0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   3.717095554805546e-12,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0],
  [0.0,
   2.0152511928671398e-12,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0],
  [0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   4.863519291439401e-12,
   0.0,
   0.0],
  [0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   1.7197460946213827e-12,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0],
  [0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   3.5965511110139845e-12,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0],
  [0.0,
   3.40267261755231e-12,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0],
  [0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   9.742171130373734e-13,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0],
  [0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   9.742171130013734e-13,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0],
  [0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   1.719746094778384e-12,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0],
  [0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   3.634196181409275e-12,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0],
  [0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   4.970215592650567e-12,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0],
  [0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   4.735402904633421e-12,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0],
  [0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   4.29858790824743e-12,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0],
  [0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   2.5978316422507703e-12,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0],
  [0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   2.4156148069530403e-12,
   0.0,
   0.0,
   0.0,
   0.0],
  [0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   7.824307744474655e-13,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0],
  [0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   1.2445735330430606e-12,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0],
  [0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   2.46207013018606e-12,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0],
  [0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   4.93908614296095e-12,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0],
  [0.0,
   0.0,
   0.0,
   0.0,
   4.863519291426125e-12,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0],
  [0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   4.961563963234218e-12,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0],
  [0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   4.992252642845671e-12,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0]],
 'b': array(-5.2539282),
 'threshold': array(-1.71600117e-10)}
[5]:
from sklearn.svm import SVC
from dro.src.data.dataloader_regression import regression_basic
from dro.src.data.dataloader_classification import classification_basic
from dro.src.data.draw_utils import draw_classification

feature_dim = 5
X, y = classification_basic(d = feature_dim, num_samples = 25, radius = 3, visualize = False)

clf = SVC(kernel = 'rbf', C = 100000000)
clf.fit(X, y)
print('coef', clf.dual_coef_, clf.support_vectors_, clf.intercept_)
coef [[-0.25769266 -0.64158585 -3.57086163 -0.78030986 -0.43010956 -0.01490551
   0.76823835  0.16369363  1.56160661  1.47227558  0.03362333  1.69602757]] [[ 0.38846926 -0.7028592   1.35277704  0.69878398 -2.12650133]
 [ 2.31753572 -0.46290575  1.17834297  1.18731603 -0.94596622]
 [ 0.25018034  1.61514873  1.09731754  1.55435329  0.42096142]
 [-0.3689567  -0.02826586 -0.84885536  1.28387817 -0.20472226]
 [ 0.13204275 -0.69776822  2.16793699  2.95568251 -2.16462365]
 [ 1.78316707 -1.07634697  0.80160239  2.94332765  0.57396163]
 [-0.00613887  4.01699229  1.16884802  0.82769028 -1.76806615]
 [ 0.45436803  2.56600279  0.90566671 -0.6451926  -1.13588996]
 [-2.28630569  2.45244185  1.26490427  1.72628874  0.65931799]
 [-0.06598713  2.44424396  0.03599602  0.40586944  1.60361199]
 [ 0.42349747  1.56956827  2.60746837 -2.13880444  1.43853605]
 [ 1.82292115  1.4884194   0.63837633 -0.63730201  0.3482033 ]] [0.47099307]
[6]:
feature_dim = 5
X, y = regression_basic(num_samples = 20, d = feature_dim, noise = 1)

reg_model = WassersteinDRO(input_dim = feature_dim, model_type = 'lad', kernel = 'rbf')

reg_model.update_kernel({'metric': 'rbf'})
reg_model.update({'eps': 1})
reg_model.fit(X, y)
[6]:
{'theta': [28.183849164084634,
  -40.587191522263076,
  -25.771734799878164,
  24.507949490371534,
  -20.773407080679252,
  -32.97201768430408,
  3.720089369558843,
  -71.2805334354158,
  70.95668207216211,
  -50.90241684443111,
  41.142825176686436,
  -57.152569382263984,
  57.59235347448549,
  48.59610303450425,
  26.428015099762195,
  -0.05635013918886693,
  9.225598822381361,
  -40.7058061185105,
  -30.048264859949104,
  19.95875895029569],
 'b': array(-37.78937028)}