Shark Machine Learning Library
  • About Shark
  • Sourceforge
    • Project Summary
    • Downloads
    • Subversion Repository
  • Getting Started
  • Tutorials
  • FAQ
  • Main Modules
    • ReClaM
    • EALib
    • MOO-EALib
    • Fuzzy
  • Tools
    • Mixture
    • Array
    • Rng
    • LinAlg
    • FileUtil
  • Main Page
  • Related Pages
  • Classes

RBFNet.h

Go to the documentation of this file.
00001 //===========================================================================
00042 //===========================================================================
00043 
00044 #ifndef RBFNET_H
00045 #define RBFNET_H
00046 
00047 #include <fstream>
00048 #include <iostream>
00049 #include <sstream>
00050 #include <SharkDefs.h>
00051 #include <Array/ArrayOp.h>
00052 #include <Array/Array.h>
00053 #include <Array/ArrayIo.h>
00054 #include <ReClaM/Model.h>
00055 #include <Mixture/RBFN.h>
00056 
00057 
00058 //===========================================================================
00096 class RBFNet : protected RBFN, public Model
00097 {
00098 public:
00099     void read(std::istream& is);
00100     void write(std::ostream& os) const;
00101 
00102 //===========================================================================
00130     RBFNet(unsigned numInput, unsigned numOutput, unsigned numHidden)
00131             : RBFN(numInput, numOutput, numHidden)
00132     {
00133         Model::inputDimension = numInput;
00134         Model::outputDimension = numOutput;
00135         Model::parameter.resize(b.nelem() + A.nelem() + m.nelem() + v.nelem());
00136         getParams(Model::parameter);
00137     }
00138 
00139 
00140 //===========================================================================
00165     RBFNet(const std::string &filename)
00166     {
00167         std::ifstream input(filename.c_str());
00168         if (!input)
00169         {
00170             std::stringstream s;
00171             s << "cannot open net file " << filename << std::endl;
00172             throw SHARKEXCEPTION(s.str().c_str());
00173         }
00174         input >> *this;
00175         input.close();
00176     }
00177 
00178 
00179 
00180 //===========================================================================
00215     RBFNet(unsigned numInput, unsigned numOutput, unsigned numHidden,
00216            const Array<double> &_m,
00217            const Array<double> &_A,
00218            const Array<double> &_b,
00219            const Array<double> &_v)
00220             : RBFN(numInput, numOutput, numHidden)
00221     {
00222         outputDimension = numOutput;
00223         inputDimension  = numInput;
00224 
00225         m = _m; v = _v; A = _A; b = _b;
00226 
00227         Model::parameter.resize(b.nelem() +    // bias
00228                                 A.nelem() +   // weights
00229                                 m.nelem() +   // centers
00230                                 v.nelem());   // variances
00231 
00232         getParams(Model::parameter);
00233     }
00234 
00235 
00236 //===========================================================================
00260     void initRBFNet(const Array< double >& input, const Array< double >& target)
00261     {
00262         RBFN::initialize(input, target);
00263         getParams(Model::parameter);
00264     }
00265 
00266 //===========================================================================
00292     void model(const Array< double >& input, Array< double >& target)
00293     {
00294         setParams(Model::parameter);
00295         recall(input, target);
00296     }
00297 
00299     void setParameter(unsigned int index, double value) {
00300         Model::setParameter(index,  value);
00301         setParams(Model::parameter);
00302     }
00303 
00304 
00305 //===========================================================================
00323     void modelDerivative(const Array<double>& input, Array<double>& derivative)
00324     {
00325         setParams(Model::parameter);
00326         gradientOut(input, derivative);
00327     }
00328 
00329 //===========================================================================
00349     void modelDerivative(const Array<double>& input, Array<double>& output, Array<double>& derivative)
00350     {
00351         setParams(Model::parameter);
00352         recall(input, output);
00353         gradientOut(input, derivative);
00354     }
00355 
00356 //===========================================================================
00378     const Array<double>& getWeights() const
00379     {
00380         //setParams(Model::parameter);
00381         return A;
00382     }
00383 
00384 //===========================================================================
00404     const Array<double>& getBias() const
00405     {
00406         //setParams(Model::parameter);
00407         return b;
00408     }
00409 
00410 //===========================================================================
00430     const Array<double>& getCenter() const
00431     {
00432         //setParams(Model::parameter);
00433         return m;
00434     }
00435 
00436 //===========================================================================
00456     const Array<double>& getVariance() const
00457     {
00458         //setParams(Model::parameter);
00459         return v;
00460     }
00461 
00462 //===========================================================================
00478     const unsigned getNHidden() const
00479     {
00480         //setParams(Model::parameter);
00481         return v.dim(0);
00482     }
00483 
00484 //===========================================================================
00507     void setWeights(const Array<double>& _A)
00508     {
00509         if ((A.ndim() == _A.ndim()) && (A.dim(0) == _A.dim(0)) && (A.dim(1) == _A.dim(1)))
00510             A = _A;
00511         getParams(Model::parameter);
00512     }
00513 
00514 //===========================================================================
00535     void setBias(const Array<double>& _b)
00536     {
00537         if ((b.ndim() == _b.ndim()) && (b.dim(0) == _b.dim(0)))
00538             b = _b;
00539         getParams(Model::parameter);
00540     }
00541 
00542 //===========================================================================
00564     void setCenter(const Array<double>& _m)
00565     {
00566         if ((m.ndim() == _m.ndim()) && (m.dim(0) == _m.dim(0)) && (m.dim(1) == _m.dim(1)))
00567             m = _m;
00568         getParams(Model::parameter);
00569     }
00570 
00571 //===========================================================================
00593     void setVariance(const Array<double>& _v)
00594     {
00595         if ((v.ndim() == _v.ndim()) && (v.dim(0) == _v.dim(0)) && (v.dim(1) == _v.dim(1)))
00596             v = _v;
00597         getParams(Model::parameter);
00598     }
00599 
00600 
00601 
00602 }
00603 ; // End of class RBFNet
00604 
00605 
00606 //===========================================================================
00628 void RBFNet::write(std::ostream& os) const
00629 {
00630     Array<double> tmp;
00631     tmp = getCenter();
00632     os << tmp.dim(1) << " " ;
00633     tmp = getBias();
00634     os << tmp.dim(0) << " " ;
00635     os << getNHidden() << "\n\n";
00636     writeArray(getWeights(), os, "", "\n", ' ');
00637     writeArray(getBias(), os, "", "\n", ' ');
00638     os << "\n";
00639     writeArray(getCenter(), os, "", "\n", ' ');
00640     writeArray(getVariance(), os, "", "\n", ' ');
00641 }
00642 
00643 
00644 //===========================================================================
00668 void RBFNet::read(std::istream& is)
00669 {
00670     unsigned i, j;
00671     unsigned NumInput, NumOutput, NumHidden;
00672     is >> NumInput >> NumOutput >> NumHidden;
00673 
00674     a.resize(NumHidden);
00675     for (i = 0; i < a.dim(0); i++)
00676         a(i) = 1.0 / NumHidden;
00677 
00678     Array<double> Weights(NumOutput, NumHidden);
00679     Array<double> Bias(NumOutput);
00680     Array<double> Center(NumHidden, NumInput);
00681     Array<double> Variance(NumHidden, NumInput);
00682 
00683 
00684     for (i = 0; i < Weights.dim(0); i++)
00685         for (j = 0; j < Weights.dim(1); j++)
00686             is >> Weights(i, j);
00687 
00688     for (i = 0; i < Bias.dim(0); i++)
00689         is >> Bias(i);
00690 
00691     for (i = 0; i < Center.dim(0); i++)
00692         for (j = 0; j < Center.dim(1); j++)
00693             is >> Center(i, j);
00694 
00695     for (i = 0; i < Variance.dim(0); i++)
00696         for (j = 0; j < Variance.dim(1); j++)
00697             is >> Variance(i, j);
00698 
00699 
00700 
00701 
00702     A.resize(NumOutput, NumHidden);
00703     b.resize(NumOutput);
00704     m.resize(NumHidden, NumInput);
00705     v.resize(NumHidden, NumInput);
00706 
00707 
00708     setWeights(Weights);
00709     setBias(Bias);
00710     setCenter(Center);
00711     setVariance(Variance);
00712 
00713     Model::inputDimension = NumInput;
00714     Model::outputDimension = NumOutput;
00715     Model::parameter.resize(b.nelem() + A.nelem() + m.nelem() + v.nelem());
00716 
00717     getParams(Model::parameter);
00718 }
00719 
00720 #endif
00721