LinearClassifierTest.cpp
#include <Rng/GlobalRng.h>
#include <ReClaM/Dataset.h>
#include <ReClaM/LinearModel.h>
#include <ReClaM/LDA.h>
#include <ReClaM/WTA.h>
class MultiClassProblem : public DataSource
{
public:
MultiClassProblem()
{
dataDim = 2;
targetDim = 3;
}
bool GetData(Array<double>& data, Array<double>& target, int count)
{
data.resize(count, 2, false);
target.resize(count, 3, false);
target = 0.0;
int i, c;
for (i=0; i<count; i++)
{
c = Rng::discrete(0, 2);
data(i, 1) = 0.7 * Rng::gauss();
data(i, 0) = 0.7 * Rng::gauss() - 1.0 * data(i, 1);
if (c == 0) data(i, 0) -= 1.0;
else if (c == 1) data(i, 1) += 2.0;
else if (c == 2) data(i, 0) += 1.0;
target(i, c) = 1.0;
}
return true;
}
};
int main(int argc, char** argv)
{
MultiClassProblem problem;
Dataset dataset(problem, 100, 1000);
LinearClassifier model(2, 3);
LDA optimizer;
optimizer.init(model);
std::cout << "LDA training ..." << std::flush;
optimizer.optimize(model, dataset.getTrainingData(), dataset.getTrainingTarget());
std::cout << " done." << std::endl;
WTA err;
double e = err.error(model, dataset.getTestData(), dataset.getTestTarget());
std::cout << "fraction of errors: " << e << std::endl;
if (e <= 0.081) exit(EXIT_SUCCESS);
else exit(EXIT_FAILURE);
}