LinearClassifierTest.cpp

//===========================================================================
#include <Rng/GlobalRng.h>
#include <ReClaM/Dataset.h>
#include <ReClaM/LinearModel.h>
#include <ReClaM/LDA.h>
#include <ReClaM/WTA.h>


// Simple data generating distrubution: Gaussians around standard basis vectors
class MultiClassProblem : public DataSource
{
public:
    MultiClassProblem()
    {
        dataDim = 2;
        targetDim = 3;
    }

    bool GetData(Array<double>& data, Array<double>& target, int count)
    {
        data.resize(count, 2, false);
        target.resize(count, 3, false);
        target = 0.0;
        int i, c;
        for (i=0; i<count; i++)
        {
            c = Rng::discrete(0, 2);
            data(i, 1) = 0.7 * Rng::gauss();
            data(i, 0) = 0.7 * Rng::gauss() - 1.0 * data(i, 1);
            if (c == 0) data(i, 0) -= 1.0;
            else if (c == 1) data(i, 1) += 2.0;
            else if (c == 2) data(i, 0) += 1.0;
            target(i, c) = 1.0;
        }
        return true;
    }
};


int main(int argc, char** argv)
{
    // generate multi class dataset with 100 training and 1000 test examples
    MultiClassProblem problem;
    Dataset dataset(problem, 100, 1000);

    // construct model and optimizer for LDA with 2 dimensions and 3 classes
    LinearClassifier model(2, 3);
    LDA optimizer;
    optimizer.init(model);

    // train the model
    std::cout << "LDA training ..." << std::flush;
    optimizer.optimize(model, dataset.getTrainingData(), dataset.getTrainingTarget());
    std::cout << " done." << std::endl;

    // count the errors on the test set
    WTA err;
    double e = err.error(model, dataset.getTestData(), dataset.getTestTarget());
    std::cout << "fraction of errors: " << e << std::endl;

    // lines below are for self-testing this example, please ignore
    if (e <= 0.081) exit(EXIT_SUCCESS);
    else exit(EXIT_FAILURE);
}