Am Ende des Kapitels ist somit der gesamte Quelltext für das Anwendungsproblem „Klassifizierung von Immobilien“ fertig:
/***************************************************************************
* Copyright (c) 2020 Konduit K.K.
* Copyright (c) 2015-2019 Skymind, Inc.
* SPDX-License-Identifier: Apache-2.0
* https://www.apache.org/licenses/LICENSE-2.0
**************************************************************************/
package org.deeplearning4j.examples.quickstart.modeling.feedforward.classification;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.learning.config.Sgd;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
/**
* @author Adam Gibson (IrisClassifier), changes by Dr. Albrecht Ehlert 2020
*/
@SuppressWarnings("DuplicatedCode")
public class PropertyClassifier {
public static void main(String[] args) throws Exception {
// Get the dataset
char delimiter = ',';
RecordReader recordReader = new CSVRecordReader(delimiter);
recordReader.initialize(new FileSplit(new File("WohnungsdatenAlsCSV.txt")));
// Classify the dataset
int labelIndex = 3; // 3 input features
int numberClasses = 3; // 3 classes (types of properties)
int batchSize = 1039; // 1039 examples total
DataSetIterator iterator = new RecordReaderDataSetIterator(
recordReader, batchSize, labelIndex, numberClasses);
DataSet allData = iterator.next();
allData.shuffle(); // Shuffle the order of the rows in the dataset
// Split the data into trainingdata (70%) and testdata (30%)
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.70);
DataSet trainingData = testAndTrain.getTrain();
DataSet testData = testAndTrain.getTest();
// Normalize the data
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData); // Collect the statistics
normalizer.transform(trainingData); // Apply normalization
normalizer.transform(testData); // Apply normalization
- 73 -