00001
00042 #include <ReClaM/InverseClassSeparability.h>
00043
00044
00045 InverseClassSeparability::InverseClassSeparability()
00046 {
00047 }
00048
00049 InverseClassSeparability::~InverseClassSeparability()
00050 {
00051 }
00052
00053
00054 double InverseClassSeparability::error(Model& model, const Array<double>& input, const Array<double>& target)
00055 {
00056
00057 KernelFunction* pKernel = dynamic_cast<KernelFunction*>(&model);
00058 if (pKernel == NULL) throw SHARKEXCEPTION("[InverseClassSeparability::error] model is not a valid KernelFunction.");
00059
00060 int lPlus = 0;
00061 int lMinus = 0;
00062 int i, j, t, T = input.dim(0);
00063 for (t = 0; t < T; t++) if (target(t, 0) > 0.0) lPlus++; else lMinus++;
00064 int l = lPlus + lMinus;
00065 double lPlusInverse = 1.0 / lPlus;
00066 double lMinusInverse = 1.0 / lMinus;
00067 double lInverse = 1.0 / l;
00068
00069 double B = 0.0;
00070 double W = 0.0;
00071 double k;
00072 double b, w;
00073 for (i = 0; i < T; i++)
00074 {
00075 for (j = 0; j < i; j++)
00076 {
00077 k = pKernel->eval(input[i], input[j]);
00078 if (target(i, 0) > 0.0)
00079 {
00080 if (target(j, 0) > 0.0)
00081 {
00082 b = lPlusInverse - lInverse;
00083 w = -lPlusInverse;
00084 }
00085 else
00086 {
00087 b = -lInverse;
00088 w = 0.0;
00089 }
00090 }
00091 else
00092 {
00093 if (target(j, 0) > 0.0)
00094 {
00095 b = -lInverse;
00096 w = 0.0;
00097 }
00098 else
00099 {
00100 b = lMinusInverse - lInverse;
00101 w = -lMinusInverse;
00102 }
00103 }
00104 B += 2.0 * b * k;
00105 W += 2.0 * w * k;
00106 }
00107
00108 k = pKernel->eval(input[i], input[i]);
00109 if (target(i, 0) > 0.0)
00110 {
00111 b = lPlusInverse - lInverse;
00112 w = 1.0 - lPlusInverse;
00113 }
00114 else
00115 {
00116 b = lMinusInverse - lInverse;
00117 w = 1.0 - lMinusInverse;
00118 }
00119 B += b * k;
00120 W += w * k;
00121 }
00122
00123 return W / B;
00124 }
00125
00126 double InverseClassSeparability::errorDerivative(Model& model, const Array<double>& input, const Array<double>& target, Array<double>& derivative)
00127 {
00128
00129 KernelFunction* pKernel = dynamic_cast<KernelFunction*>(&model);
00130 if (pKernel == NULL) throw SHARKEXCEPTION("[InverseClassSeparability::errorDerivative] model is not a valid KernelFunction.");
00131
00132 int p, pc = pKernel->getParameterDimension();
00133 derivative.resize(pc, false);
00134 Array<double> der(pc);
00135
00136 int lPlus = 0;
00137 int lMinus = 0;
00138 int i, j, t, T = input.dim(0);
00139 for (t = 0; t < T; t++) if (target(t, 0) > 0.0) lPlus++; else lMinus++;
00140 int l = lPlus + lMinus;
00141 double lPlusInverse = 1.0 / lPlus;
00142 double lMinusInverse = 1.0 / lMinus;
00143 double lInverse = 1.0 / l;
00144
00145 double B = 0.0;
00146 double W = 0.0;
00147 Array<double> gB(pc);
00148 Array<double> gW(pc);
00149 gB = 0.0;
00150 gW = 0.0;
00151 double k;
00152 double b, w;
00153 for (i = 0; i < T; i++)
00154 {
00155 for (j = 0; j < i; j++)
00156 {
00157 k = pKernel->evalDerivative(input[i], input[j], der);
00158 if (target(i, 0) > 0.0)
00159 {
00160 if (target(j, 0) > 0.0)
00161 {
00162 b = lPlusInverse - lInverse;
00163 w = -lPlusInverse;
00164 }
00165 else
00166 {
00167 b = -lInverse;
00168 w = 0.0;
00169 }
00170 }
00171 else
00172 {
00173 if (target(j, 0) > 0.0)
00174 {
00175 b = -lInverse;
00176 w = 0.0;
00177 }
00178 else
00179 {
00180 b = lMinusInverse - lInverse;
00181 w = -lMinusInverse;
00182 }
00183 }
00184 B += 2.0 * b * k;
00185 W += 2.0 * w * k;
00186 for (p = 0; p < pc; p++)
00187 {
00188 gB(p) += 2.0 * b * der(p);
00189 gW(p) += 2.0 * w * der(p);
00190 }
00191 }
00192
00193 k = pKernel->evalDerivative(input[i], input[i], der);
00194 if (target(i, 0) > 0.0)
00195 {
00196 b = lPlusInverse - lInverse;
00197 w = 1.0 - lPlusInverse;
00198 }
00199 else
00200 {
00201 b = lMinusInverse - lInverse;
00202 w = 1.0 - lMinusInverse;
00203 }
00204 B += b * k;
00205 W += w * k;
00206 for (p = 0; p < pc; p++)
00207 {
00208 gB(p) += b * der(p);
00209 gW(p) += w * der(p);
00210 }
00211 }
00212
00213 double ret = W / B;
00214 for (p = 0; p < pc; p++) derivative(p) = (gW(p) - ret * gB(p)) / B;
00215 return ret;
00216 }
00217