Home » Machine Learning/Artificial Intelligence

Training and Testing Sets in Java | Machine Learning

In this article, we would be looking at how can we perform cross-validation the given data set and then split out data into training and testing sets?
Submitted by Raunak Goswami, on September 11, 2018


Well, those who haven’t yet read my previous articles should note that for machine learning in java I am using a weka.jar file to import the required machine learning classes into my eclipse IDE. I will suggest you guys have a look at my article on data splitting using Python programming language.

Let’s have a look at the basic definition of training and test sets before we proceed further.

Training Set

The purpose of using the training set is as the name suggests is to train our model by feeding in the attributes and the corresponding target value into using the values in the training our model can identify a pattern which will be used by our model to predict the test set values.

Test Set

This set is used to check the accuracy of our model and as the name suggest we use this dataset to perform the testing of our result. This data set usually contains the independent attributes using which our model predicts the dependent value or the target value. Using the predicted target values we further compare those values with the predefined set of the target values in our test set in order to determine the various evaluating parameters like RMSE,percentage accuracy, percentage error, area under the curve to determine the efficiency of our model in predicting the dependent values which in turn determines the usefulness of our model.

For detailed information about training and test set, you can refer to my article about data splitting.

Another important feature that we are going to talk about is the cross-validation. Well, in order to increase the accuracy of our model we use cross-validation. Suppose if we split our data in such a way that we have 100 set of values and we split first 20 as testing sets and rest as the training sets, well since we need more data for training the splitting ratio we used here is completely fine but then there arise many uncertainties like what if the first 20 sets of data have completely opposite values from the rest of data one way to sort this issue is to use a random function which will randomly select the testing and training set values so now we have reduced chances of getting biased set of values into our training and test sets but still we have not fully sorted the problem there are still chances that maybe the randomized testing data set has the values which aren’t at all related to the training set values or it might be that the values in the test set are exactly the same as that of training set which will result in overfitting of our model ,you can refer to this article if you want to know more about overfitting and underfitting of the data.

Well, then how do we solve this issue? One way is to split the data n times into training and testing sets and then find the average of those splitting datasets to create the best possible set for training and testing. But everything comes with a cost since we are repeatedly splitting out data into training and testing the process of cross-validation consumes some time. But then it is worth waiting if we can get a more accurate result.

Training and Testing Sets in Java | Machine Learning

Image source: https://upload.wikimedia.org/wikipedia/commons/1/1c/K-fold_cross_validation_EN.jpg

While writing the code I would be using a variable named as fold or K as shown in the above figure which signifies the no of times to perform the cross-validation.

Below is the java code is written for generating testing and training sets in the ratio of 1:4(approx.) which is an optimal ratio of splitting the data sets.

The data set I have used can be downloaded from here: headbraina.arff


import weka.core.Instances;

import java.io.File;
import java.util.Random;

import weka.core.converters.ArffSaver;
import weka.core.converters.ConverterUtils.DataSource;
import weka.classifiers.Evaluation;
import weka.classifiers.bayes.NaiveBayes;

public class testtrainjaava{
	public static void main(String args[]) throws Exception{
		//load dataset
		DataSource source = new DataSource("headbraina.arff");
		Instances dataset = source.getDataSet();	
		//set class index to the last attribute

		int seed = 1;
		int folds = 15;
		//randomize data
		Random rand = new Random(seed);
		//create random dataset
		Instances randData = new Instances(dataset);
		if (randData.classAttribute().isNominal())

		// perform cross-validation	    	    
		for (int n = 0; n < folds; n++) {
			//Evaluation eval = new Evaluation(randData);
			//get the folds	      
			Instances train = randData.trainCV(folds, n);
			Instances test = randData.testCV(folds, n);	      

			ArffSaver saver = new ArffSaver();
			System.out.println("No of folds done = " + (n+1));

			saver.setFile(new File("trainheadbraina.arff"));
			//{System.out.println("Training set generated after the final fold is");

			ArffSaver saver1 = new ArffSaver();
			saver1.setFile(new File("testheadbraina1.arff"));


Training and Testing Sets in Java Output 1

After getting this output just go to the destination folder in which you have to save the training and testing data sets and you should see the following results.

Dataset generated for training the model

Training and Testing Sets in Java Output 2

Dataset generated for testing the model

Training and Testing Sets in Java Output 3

This was all for today guys hope you liked this, feel free to ask your queries and have a great day ahead.

Comments and Discussions

Ad: Are you a blogger? Join our Blogging forum.

Languages: » C » C++ » C++ STL » Java » Data Structure » C#.Net » Android » Kotlin » SQL
Web Technologies: » PHP » Python » JavaScript » CSS » Ajax » Node.js » Web programming/HTML
Solved programs: » C » C++ » DS » Java » C#
Aptitude que. & ans.: » C » C++ » Java » DBMS
Interview que. & ans.: » C » Embedded C » Java » SEO » HR
CS Subjects: » CS Basics » O.S. » Networks » DBMS » Embedded Systems » Cloud Computing » Machine learning » CS Organizations » Linux » DOS
More: » Articles » Puzzles » News/Updates

© https://www.includehelp.com some rights reserved.