Generating cross-validation folds (Java approach)


This article describes how to generate train/test splits for cross-validation using the Weka API directly.

The following variables are given:

Instances data =  ...;   // contains the full dataset we wann create train/test sets from

int seed = ...;          // the seed for randomizing the data

int folds = ...;         // the number of folds to generate, >=2

 Randomize the data

First, randomize your data:

Random rand = new Random(seed);   // create seeded number generator

randData = new Instances(data);   // create copy of original data

randData.randomize(rand);         // randomize data with number generator

In case your data has a nominal class and you wanna perform stratified cross-validation:


 Generate the folds

 Single run

Next thing that we have to do is creating the train and the test set:

for (int n = 0; n < folds; n++) {

Instances train = randData.trainCV(folds, n);

Instances test = randData.testCV(folds, n);

// further processing, classification, etc.




  • the above code is used by the weka.filters.supervised.instance.StratifiedRemoveFolds filter
  • the weka.classifiers.Evaluation class and the Explorer/Experimenter would use this method for obtaining the train set:

Instances train = randData.trainCV(folds, n, rand);

 Multiple runs

The example above only performs one run of a cross-validation. In case you want to run 10 runs of 10-fold cross-validation, use the following loop:

Instances data = ...;  // our dataset again, obtained from somewhere

int runs = 10;

for (int i = 0; i < runs; i++) {

seed = i+1;  // every run gets a new, but defined seed value

// see: randomize the data


// see: generate the folds





package assignment2;import weka.core.Instances;import weka.core.converters.ConverterUtils.DataSource;import weka.core.Utils;import weka.classifiers.Classifier;import weka.classifiers.Evaluation;import weka.classifiers.trees.J48;import weka.filters.Filter;import weka.filters.unsupervised.attribute.Remove;import;import java.util.Random;public class cv_rw {public static Instances getFileInstances(String filename) throws Exception{FileReader frData =new FileReader(filename);Instances data = new Instances(frData);int length= data.numAttributes();String[] options = new String[2];options[0]="-R";options[1]=Integer.toString(length);Remove remove =new Remove();remove.setOptions(options);remove.setInputFormat(data);Instances newData= Filter.useFilter(data, remove);return newData;}public static void main(String[] args) throws Exception {// loads data and set class index
Instances data = getFileInstances("D://Weka_tutorial//WineQuality//RedWhiteWine.arff");//     System.out.println(instances);
data.setClassIndex(data.numAttributes()-1);// classifier//      String[] tmpOptions;//      String classname;//      tmpOptions     = Utils.splitOptions(Utils.getOption("W", args));//      classname      = tmpOptions[0];//      tmpOptions[0]  = "";//      Classifier cls = (Classifier) Utils.forName(Classifier.class, classname, tmpOptions);////      // other options//      int runs  = Integer.parseInt(Utils.getOption("r", args));//重复试验//      int folds = Integer.parseInt(Utils.getOption("x", args));int runs=1;int folds=10;J48 j48= new J48();//     j48.buildClassifier(instances);// perform cross-validationfor (int i = 0; i < runs; i++) {// randomize dataint seed = i + 1;Random rand = new Random(seed);Instances randData = new Instances(data);randData.randomize(rand);//        if (randData.classAttribute().isNominal())    //没看懂这里什么意思,往高手回复,万分感谢//          randData.stratify(folds);
Evaluation eval = new Evaluation(randData);for (int n = 0; n < folds; n++) {Instances train = randData.trainCV(folds, n);Instances test = randData.testCV(folds, n);// the above code is used by the StratifiedRemoveFolds filter, the// code below by the Explorer/Experimenter:// Instances train = randData.trainCV(folds, n, rand);// build and evaluate classifier
Classifier j48Copy = Classifier.makeCopy(j48);j48Copy.buildClassifier(train);eval.evaluateModel(j48Copy, test);}// output evaluation
System.out.println();System.out.println("=== Setup run " + (i+1) + " ===");System.out.println("Classifier: " + j48.getClass().getName());System.out.println("Dataset: " + data.relationName());System.out.println("Folds: " + folds);System.out.println("Seed: " + seed);System.out.println();System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation run " + (i+1) + "===", false));}}}


=== Setup run 1 ===

Classifier: weka.classifiers.trees.J48

Dataset: RedWhiteWine-weka.filters.unsupervised.instance.Randomize-S42-weka.filters.unsupervised.instance.Randomize-S42-weka.filters.unsupervised.attribute.Remove-R13

Folds: 10

Seed: 1

=== 10-fold Cross-validation run 1===

Correctly Classified Instances        6415               98.7379 %

Incorrectly Classified Instances        82                1.2621 %

Kappa statistic                          0.9658

Mean absolute error                      0.0159

Root mean squared error                  0.1109

Relative absolute error                  4.2898 %

Root relative squared error             25.7448 %

Total Number of Instances             6497



