simpleMSERNNet.cpp
#include<ReClaM/MSERNNet.h>
#include<ReClaM/Rprop.h>
#include <ReClaM/createConnectionMatrix.h>
#include<Array/ArrayIo.h>
using namespace std;
class myNet: public MSERNNet, public RpropPlus
{
public:
myNet(): MSERNNet()
{
Array <int> con;
createConnectionMatrixRNN(con,1,8,1,2);
setStructure(con);
initWeights(-0.5, 0.5);
init(*this);
};
};
int main(int argc, char *argv[])
{
unsigned iterations = 1000;
const char* datafile = "timeseries";
unsigned episode = 1000;
int forecast = 5;
bool checkGradient = false;
unsigned long t = 0, i;
ifstream infile;
ofstream outfile;
Array<double> data;
infile.open(datafile);
readArray(data, infile);
infile.close();
data /= 100.; data += .5;
unsigned dims[2] = {episode, 1};
ArrayReference<double> trainIn(dims, &data.elem(0) , 2, episode);
ArrayReference<double> trainTarget(dims, &data.elem(0) + forecast, 2, episode);
ArrayReference<double> evalIn(dims, &data.elem(1000) , 2, episode);
ArrayReference<double> evalTarget(dims, &data.elem(1000) + forecast, 2, episode);
Array<double> trainOut(episode, 1), evalOut(episode, 1);
outfile.open("input");
writeArray(evalIn, outfile);
outfile.close();
outfile.open("target");
writeArray(evalTarget, outfile);
outfile.close();
myNet net;
Array<double> exact, estimated;
double z;
while (t++ < iterations)
{
cout << t << "\t";
net.includeWarmUp(100);
if (checkGradient)
{
net.errorDerivative(net, trainIn, trainTarget, exact);
cout << "derror - gradient:" << exact;
net.ErrorFunction::errorDerivative(net, trainIn, trainTarget, estimated);
cout << "deltagrad - gradient:" << estimated();
for (z = 0, i = 0;i < exact.nelem();i++)
{
exact(i) = exact(i) - estimated(i); z += exact(i) * exact(i);
}
cout << "difference:" << exact;
cout << "square-sum of difference: " << z << "\n";
}
net.optimize(net, net, trainIn, trainTarget);
cout << net.error(net, evalIn, evalTarget) << endl;
if (!(t % 100))
{
net.model(evalIn, evalOut);
outfile.open("output");
writeArray(evalOut, outfile);
outfile.close();
}
}
if (net.error(net, evalIn, evalTarget) < 0.00135) exit(EXIT_SUCCESS);
else exit(EXIT_FAILURE);
}