00001
00044
00045
00046
00047 #ifndef SQUARED_ERROR_H
00048 #define SQUARED_ERROR_H
00049
00050
00051
00052
00053 #include <ReClaM/ErrorFunction.h>
00054
00055
00056
00072 class SquaredError : public ErrorFunction
00073 {
00074 public:
00075
00076
00102 double error(Model& model, const Array<double>& in, const Array<double>& out)
00103 {
00104 double se = 0;
00105 if (in.ndim() == 1)
00106 {
00107 Array<double> output(out.dim(0));
00108 model.model(in, output);
00109 for (unsigned c = 0; c < out.dim(0); c++)
00110 {
00111 se += (out(c) - output(c)) * (out(c) - output(c));
00112 }
00113 }
00114 else
00115 {
00116 Array<double> output(out.dim(1));
00117 for (unsigned pattern = 0; pattern < in.dim(0); ++pattern)
00118 {
00119 model.model(in[pattern], output);
00120 for (unsigned c = 0; c < out.dim(1); c++)
00121 {
00122 se += (out(pattern, c) - output(c)) * (out(pattern, c) - output(c));
00123 }
00124 }
00125 }
00126 return se;
00127 }
00128
00129
00166 double errorDerivative(Model& model, const Array<double>& in, const Array<double>& out, Array<double>& derivative)
00167 {
00168 double se = 0;
00169 Array<double> dmdw;
00170 derivative.resize(model.getParameterDimension(), false);
00171 derivative = 0;
00172
00173 if (in.ndim() == 1)
00174 {
00175 Array<double> output(out.dim(0));
00176 model.modelDerivative(in, output, dmdw);
00177 for (unsigned c = 0; c < output.nelem(); c++)
00178 {
00179 se += (out(c) - output(c)) * (out(c) - output(c));
00180 for (unsigned i = 0; i < derivative.nelem(); i++)
00181 derivative(i) -= (out(c) - output(c)) * dmdw(c, i);
00182 }
00183 for (unsigned i = 0; i < derivative.nelem(); i++) derivative(i) *= 2;
00184 }
00185 else
00186 {
00187 Array<double> output(out.dim(1));
00188 for (unsigned pattern = 0; pattern < in.dim(0); ++pattern)
00189 {
00190 model.modelDerivative(in[pattern], output, dmdw);
00191
00192 for (unsigned c = 0; c < output.nelem(); c++)
00193 {
00194 se += (out(pattern, c) - output(c)) * (out(pattern, c) - output(c));
00195 for (unsigned i = 0; i < derivative.nelem(); i++)
00196 derivative(i) -= (out(pattern, c) - output(c)) * dmdw(c, i);
00197 }
00198 }
00199 for (unsigned i = 0; i < derivative.nelem(); i++) derivative(i) *= 2;
00200 }
00201 return se;
00202 }
00203 };
00204
00205
00206 #endif
00207