00001
00042
00043
00044 #ifndef RBFNET_H
00045 #define RBFNET_H
00046
00047 #include <fstream>
00048 #include <iostream>
00049 #include <sstream>
00050 #include <SharkDefs.h>
00051 #include <Array/ArrayOp.h>
00052 #include <Array/Array.h>
00053 #include <Array/ArrayIo.h>
00054 #include <ReClaM/Model.h>
00055 #include <Mixture/RBFN.h>
00056
00057
00058
00096 class RBFNet : protected RBFN, public Model
00097 {
00098 public:
00099 void read(std::istream& is);
00100 void write(std::ostream& os) const;
00101
00102
00130 RBFNet(unsigned numInput, unsigned numOutput, unsigned numHidden)
00131 : RBFN(numInput, numOutput, numHidden)
00132 {
00133 Model::inputDimension = numInput;
00134 Model::outputDimension = numOutput;
00135 Model::parameter.resize(b.nelem() + A.nelem() + m.nelem() + v.nelem());
00136 getParams(Model::parameter);
00137 }
00138
00139
00140
00165 RBFNet(const std::string &filename)
00166 {
00167 std::ifstream input(filename.c_str());
00168 if (!input)
00169 {
00170 std::stringstream s;
00171 s << "cannot open net file " << filename << std::endl;
00172 throw SHARKEXCEPTION(s.str().c_str());
00173 }
00174 input >> *this;
00175 input.close();
00176 }
00177
00178
00179
00180
00215 RBFNet(unsigned numInput, unsigned numOutput, unsigned numHidden,
00216 const Array<double> &_m,
00217 const Array<double> &_A,
00218 const Array<double> &_b,
00219 const Array<double> &_v)
00220 : RBFN(numInput, numOutput, numHidden)
00221 {
00222 outputDimension = numOutput;
00223 inputDimension = numInput;
00224
00225 m = _m; v = _v; A = _A; b = _b;
00226
00227 Model::parameter.resize(b.nelem() +
00228 A.nelem() +
00229 m.nelem() +
00230 v.nelem());
00231
00232 getParams(Model::parameter);
00233 }
00234
00235
00236
00260 void initRBFNet(const Array< double >& input, const Array< double >& target)
00261 {
00262 RBFN::initialize(input, target);
00263 getParams(Model::parameter);
00264 }
00265
00266
00292 void model(const Array< double >& input, Array< double >& target)
00293 {
00294 setParams(Model::parameter);
00295 recall(input, target);
00296 }
00297
00299 void setParameter(unsigned int index, double value) {
00300 Model::setParameter(index, value);
00301 setParams(Model::parameter);
00302 }
00303
00304
00305
00323 void modelDerivative(const Array<double>& input, Array<double>& derivative)
00324 {
00325 setParams(Model::parameter);
00326 gradientOut(input, derivative);
00327 }
00328
00329
00349 void modelDerivative(const Array<double>& input, Array<double>& output, Array<double>& derivative)
00350 {
00351 setParams(Model::parameter);
00352 recall(input, output);
00353 gradientOut(input, derivative);
00354 }
00355
00356
00378 const Array<double>& getWeights() const
00379 {
00380
00381 return A;
00382 }
00383
00384
00404 const Array<double>& getBias() const
00405 {
00406
00407 return b;
00408 }
00409
00410
00430 const Array<double>& getCenter() const
00431 {
00432
00433 return m;
00434 }
00435
00436
00456 const Array<double>& getVariance() const
00457 {
00458
00459 return v;
00460 }
00461
00462
00478 const unsigned getNHidden() const
00479 {
00480
00481 return v.dim(0);
00482 }
00483
00484
00507 void setWeights(const Array<double>& _A)
00508 {
00509 if ((A.ndim() == _A.ndim()) && (A.dim(0) == _A.dim(0)) && (A.dim(1) == _A.dim(1)))
00510 A = _A;
00511 getParams(Model::parameter);
00512 }
00513
00514
00535 void setBias(const Array<double>& _b)
00536 {
00537 if ((b.ndim() == _b.ndim()) && (b.dim(0) == _b.dim(0)))
00538 b = _b;
00539 getParams(Model::parameter);
00540 }
00541
00542
00564 void setCenter(const Array<double>& _m)
00565 {
00566 if ((m.ndim() == _m.ndim()) && (m.dim(0) == _m.dim(0)) && (m.dim(1) == _m.dim(1)))
00567 m = _m;
00568 getParams(Model::parameter);
00569 }
00570
00571
00593 void setVariance(const Array<double>& _v)
00594 {
00595 if ((v.ndim() == _v.ndim()) && (v.dim(0) == _v.dim(0)) && (v.dim(1) == _v.dim(1)))
00596 v = _v;
00597 getParams(Model::parameter);
00598 }
00599
00600
00601
00602 }
00603 ;
00604
00605
00606
00628 void RBFNet::write(std::ostream& os) const
00629 {
00630 Array<double> tmp;
00631 tmp = getCenter();
00632 os << tmp.dim(1) << " " ;
00633 tmp = getBias();
00634 os << tmp.dim(0) << " " ;
00635 os << getNHidden() << "\n\n";
00636 writeArray(getWeights(), os, "", "\n", ' ');
00637 writeArray(getBias(), os, "", "\n", ' ');
00638 os << "\n";
00639 writeArray(getCenter(), os, "", "\n", ' ');
00640 writeArray(getVariance(), os, "", "\n", ' ');
00641 }
00642
00643
00644
00668 void RBFNet::read(std::istream& is)
00669 {
00670 unsigned i, j;
00671 unsigned NumInput, NumOutput, NumHidden;
00672 is >> NumInput >> NumOutput >> NumHidden;
00673
00674 a.resize(NumHidden);
00675 for (i = 0; i < a.dim(0); i++)
00676 a(i) = 1.0 / NumHidden;
00677
00678 Array<double> Weights(NumOutput, NumHidden);
00679 Array<double> Bias(NumOutput);
00680 Array<double> Center(NumHidden, NumInput);
00681 Array<double> Variance(NumHidden, NumInput);
00682
00683
00684 for (i = 0; i < Weights.dim(0); i++)
00685 for (j = 0; j < Weights.dim(1); j++)
00686 is >> Weights(i, j);
00687
00688 for (i = 0; i < Bias.dim(0); i++)
00689 is >> Bias(i);
00690
00691 for (i = 0; i < Center.dim(0); i++)
00692 for (j = 0; j < Center.dim(1); j++)
00693 is >> Center(i, j);
00694
00695 for (i = 0; i < Variance.dim(0); i++)
00696 for (j = 0; j < Variance.dim(1); j++)
00697 is >> Variance(i, j);
00698
00699
00700
00701
00702 A.resize(NumOutput, NumHidden);
00703 b.resize(NumOutput);
00704 m.resize(NumHidden, NumInput);
00705 v.resize(NumHidden, NumInput);
00706
00707
00708 setWeights(Weights);
00709 setBias(Bias);
00710 setCenter(Center);
00711 setVariance(Variance);
00712
00713 Model::inputDimension = NumInput;
00714 Model::outputDimension = NumOutput;
00715 Model::parameter.resize(b.nelem() + A.nelem() + m.nelem() + v.nelem());
00716
00717 getParams(Model::parameter);
00718 }
00719
00720 #endif
00721