00001
00036
00037
00038
00039 #ifndef _Svm_H_
00040 #define _Svm_H_
00041
00042
00043 #include <SharkDefs.h>
00044 #include <ReClaM/Model.h>
00045 #include <ReClaM/Optimizer.h>
00046 #include <ReClaM/KernelFunction.h>
00047 #include <ReClaM/QuadraticProgram.h>
00048 #include <Rng/GlobalRng.h>
00049
00050 class SvmApproximation;
00051
00052
00088 class SVM : public Model
00089 {
00090 public:
00095 SVM(KernelFunction* pKernel, bool bSignOutput = false);
00096
00102 SVM(KernelFunction* pKernel, const Array<double>& input, bool bSignOutput = false);
00103
00105 ~SVM();
00106
00127 void SetTrainingData(const Array<double>& input, bool copy = false);
00128
00133 void model(const Array<double>& input, Array<double>& output);
00134
00139 double model(const Array<double>& input);
00140
00155 void modelDerivative(const Array<double>& input, Array<double>& derivative);
00156
00171 void modelDerivative(const Array<double>& input, Array<double>& output, Array<double>& derivative);
00172
00180 inline double getAlpha(int index)
00181 {
00182 return parameter(index);
00183 }
00184
00192 inline double getOffset()
00193 {
00194 return parameter(examples);
00195 }
00196
00198 inline KernelFunction* getKernel()
00199 {
00200 return kernel;
00201 }
00202
00204 inline const Array<double>& getPoints()
00205 {
00206 return *x;
00207 }
00208
00210 inline unsigned int getExamples()
00211 {
00212 return examples;
00213 }
00214
00216 inline unsigned int getDimension()
00217 {
00218 return inputDimension;
00219 }
00220
00228 void MakeSparse();
00229
00238 bool LoadSVMModel(std::istream& is);
00239
00248 bool SaveSVMModel(std::ostream& os);
00249
00259 static SVM* ImportLibsvmModel(std::istream& is);
00260
00270 static SVM* ImportSvmlightModel(std::istream& is);
00271
00272 friend class SvmApproximation;
00273
00274 protected:
00295 static int ReadToken(std::istream& is, char* buffer, int maxlength, const char* separators);
00296
00312 static int DiscardUntil(std::istream& is, const char* separators);
00313
00315 KernelFunction* kernel;
00316
00318 bool bOwnMemory;
00319
00321 const Array<double>* x;
00322
00325 bool signOutput;
00326
00328 unsigned int examples;
00329 };
00330
00331
00374 class MultiClassSVM : public Model
00375 {
00376 public:
00382 MultiClassSVM(KernelFunction* pKernel, unsigned int numberOfClasses, bool bOrthogonalVectors, bool bNumberOutput = true);
00383
00388 MultiClassSVM(KernelFunction* pKernel, Array<double> prototypes, bool bNumberOutput = true);
00389
00391 ~MultiClassSVM();
00392
00393
00403 void SetTrainingData(const Array<double>& input, bool copy = false);
00404
00406 void model(const Array<double>& input, Array<double>& output);
00407
00409 unsigned int model(const Array<double>& input);
00410
00412 void Normalize();
00413
00415 inline KernelFunction* getKernel()
00416 {
00417 return kernel;
00418 }
00419
00421 inline const Array<double>& getPoints()
00422 {
00423 return *x;
00424 }
00425
00429 inline double getAlpha(unsigned int index, unsigned int c) const
00430 {
00431 return parameter(classes * index + c);
00432 }
00433
00436 inline double getOffset(unsigned int c) const
00437 {
00438 return parameter(classes * examples + c);
00439 }
00440
00442 inline unsigned int getClasses() const
00443 {
00444 return classes;
00445 }
00446
00449 inline const ArrayReference<double> getClassPrototype(unsigned int c) const
00450 {
00451 return prototypes[c];
00452 }
00453
00455 unsigned int VectorToClass(const Array<double>& v);
00456
00457 protected:
00458 void Predict(const Array<double>& input, Array<double>& output);
00459 void Predict(const Array<double>& input, ArrayReference<double> output);
00460
00462 KernelFunction* kernel;
00463
00465 bool bOwnMemory;
00466
00468 const Array<double>* x;
00469
00472 bool numberOutput;
00473
00475 unsigned int examples;
00476
00478 unsigned int classes;
00479
00481 Array<double> prototypes;
00482 };
00483
00484
00497 class MetaSVM : public Model
00498 {
00499 public:
00504 MetaSVM(SVM* pSVM, unsigned int numberOfHyperParameters);
00505
00510 MetaSVM(MultiClassSVM* pSVM, unsigned int numberOfHyperParameters);
00511
00513 ~MetaSVM();
00514
00515
00517 inline SVM* getSVM()
00518 {
00519 return dynamic_cast<SVM*>(svm);
00520 }
00521
00523 inline MultiClassSVM* getMultiClassSVM()
00524 {
00525 return dynamic_cast<MultiClassSVM*>(svm);
00526 }
00527
00529 void model(const Array<double>& input, Array<double>& output);
00530
00532 void setParameter(unsigned int index, double value);
00533
00535 bool isFeasible();
00536
00537 protected:
00539 Model* svm;
00540
00542 KernelFunction* kernel;
00543
00545 unsigned int hyperparameters;
00546 };
00547
00548
00600 class C_SVM : public MetaSVM
00601 {
00602 public:
00610 C_SVM(SVM* pSVM, double Cplus, double Cminus, bool norm2 = false, bool unconst = false);
00611
00613 ~C_SVM();
00614
00615
00621 void PrepareDerivative();
00622
00627 void modelDerivative(const Array<double>& input, Array<double>& derivative);
00628
00630 void setParameter(unsigned int index, double value);
00631
00633 inline double get_Cplus()
00634 {
00635 return C_plus;
00636 }
00637
00639 inline double get_Cminus()
00640 {
00641 return C_minus;
00642 }
00643
00645 inline bool is2norm()
00646 {
00647 return norm2penalty;
00648 }
00649
00651 inline bool isUnconstrained()
00652 {
00653 return exponential;
00654 }
00655
00657 inline double getCRatio()
00658 {
00659 return C_ratio;
00660 }
00661
00663 bool isFeasible();
00664
00665 protected:
00667 bool norm2penalty;
00668
00670 double C_plus;
00671
00673 double C_minus;
00674
00676 double C_ratio;
00677
00679 bool exponential;
00680
00681 Array<double> alpha_b_Derivative;
00682 };
00683
00684
00713 class Epsilon_SVM : public MetaSVM
00714 {
00715 public:
00723 Epsilon_SVM(SVM* pSVM, double C, double epsilon, bool unconst = false);
00724
00726 ~Epsilon_SVM();
00727
00728
00730 void setParameter(unsigned int index, double value);
00731
00733 inline double get_C()
00734 {
00735 return C;
00736 }
00737
00739 inline double get_epsilon()
00740 {
00741 return epsilon;
00742 }
00743
00745 bool isFeasible();
00746
00747 protected:
00749 double C;
00750
00752 double epsilon;
00753
00755 bool exponential;
00756 };
00757
00758
00759
00771 class OneClassSVM : public MetaSVM
00772 {
00773 public:
00774
00777 OneClassSVM(SVM* pSVM, double fractionNu);
00778
00780 ~OneClassSVM();
00781
00783 void setParameter(unsigned int index, double value);
00784
00786 inline double getNu(){
00787 return nu;
00788 }
00789
00791 bool isFeasible();
00792
00793 protected:
00795 double nu;
00796 };
00797
00798
00811 class RegularizationNetwork : public MetaSVM
00812 {
00813 public:
00815 RegularizationNetwork(SVM* pSVM, double gamma);
00816
00818 ~RegularizationNetwork();
00819
00820
00821 inline double get_gamma() { return parameter(0); }
00822 inline void set_gamma(double gamma) { setParameter(0, gamma); }
00823
00825 bool isFeasible();
00826 };
00827
00828
00839 class AllInOneMcSVM : public MetaSVM
00840 {
00841 public:
00843 AllInOneMcSVM(MultiClassSVM* pSVM, double C);
00844
00846 ~AllInOneMcSVM();
00847
00848
00849 inline double get_C() { return parameter(0); }
00850 inline void set_C(double C) { setParameter(0, C); }
00851
00853 bool isFeasible();
00854 };
00855
00856
00867 class CrammerSingerMcSVM : public MetaSVM
00868 {
00869 public:
00875 CrammerSingerMcSVM(MultiClassSVM* pSVM, double beta);
00876
00878 ~CrammerSingerMcSVM();
00879
00880
00881 inline double get_beta() { return parameter(0); }
00882 inline void set_beta(double beta) { setParameter(0, beta); }
00883
00885 bool isFeasible();
00886 };
00887
00888
00899 class OVAMcSVM : public MetaSVM
00900 {
00901 public:
00903 OVAMcSVM(MultiClassSVM* pSVM, double C);
00904
00906 ~OVAMcSVM();
00907
00908
00909 inline double get_C() { return parameter(0); }
00910 inline void set_C(double C) { setParameter(0, C); }
00911
00913 bool isFeasible();
00914 };
00915
00916
00931 class OCCMcSVM : public MetaSVM
00932 {
00933 public:
00935 OCCMcSVM(MultiClassSVM* pSVM, double C);
00936
00938 ~OCCMcSVM();
00939
00940
00941 inline double get_C() { return parameter(0); }
00942 inline void set_C(double C) { setParameter(0, C); }
00943
00945 bool isFeasible();
00946 };
00947
00948
00972 class SVM_Optimizer : public Optimizer
00973 {
00974 public:
00976 SVM_Optimizer();
00977
00979 ~SVM_Optimizer();
00980
00981
00990 void init(Model& model);
00991
01012 double optimize(Model& model, ErrorFunction& error, const Array<double>& input, const Array<double>& target);
01013
01028 double optimize(SVM& model, const Array<double>& input, const Array<double>& target, bool copy = false);
01029
01044 void optimize(MultiClassSVM& model, const Array<double>& input, const Array<double>& target, bool copy = false);
01045
01051 static ErrorFunction& dummyError;
01052
01058 inline QPSolver* get_Solver()
01059 {
01060 return solver;
01061 }
01062
01064 inline void setAccuracy(double accuracy = 0.001)
01065 {
01066 this->accuracy = accuracy;
01067 }
01068
01071 inline void setMaxIterations(SharkInt64 maxiter = -1)
01072 {
01073 this->maxIter = maxiter;
01074 }
01075
01078 inline void setMaxSeconds(int seconds = -1)
01079 {
01080 maxSeconds = seconds;
01081 }
01082
01084 inline void setVerbose(bool verbose = true)
01085 {
01086 printInfo = verbose;
01087 }
01088
01090 inline void setCacheSize(unsigned int cacheSize)
01091 {
01092 cacheMB = cacheSize;
01093 }
01094
01095 protected:
01096 enum eMode
01097 {
01098 eC1,
01099 eC2,
01100 eEpsilon,
01101 eNu,
01102 eRegularizationNetwork,
01103 e1Class,
01104 eAllInOne,
01105 eCrammerSinger,
01106 eOVA,
01107 eOCC,
01108 };
01109
01111 eMode mode;
01112
01114 QPMatrix* matrix;
01115
01117 CachedMatrix* cache;
01118
01120 QPSolver* solver;
01121
01123 double Cplus;
01124
01126 double Cminus;
01127
01129 double C;
01130
01132 double epsilon;
01133
01135 double gamma;
01136
01138 double beta;
01139
01141 double fractionOfOutliers;
01142
01144 double OneClassBoxUpper;
01145
01147 bool printInfo;
01148
01150 unsigned int cacheMB;
01151
01153 double accuracy;
01154
01156 SharkInt64 maxIter;
01157
01159 int maxSeconds;
01160 };
01161
01162
01163 typedef SVM KernelExpansion;
01164 typedef SVM LinearKernelModel;
01165
01166
01167 #endif
01168