Künstliche Intelligenz mit Java
Maschinelles Lernen mit Neuronalen Netzwerken

Am Ende des Kapitels ist somit der gesamte Quelltext für das Anwendungsproblem „Klassifizierung von Immobilien“ fertig:

/***************************************************************************
 * Copyright (c) 2020 Konduit K.K.
 * Copyright (c) 2015-2019 Skymind, Inc.
 * SPDX-License-Identifier: Apache-2.0
 * https://www.apache.org/licenses/LICENSE-2.0
 **************************************************************************/

package org.deeplearning4j.examples.quickstart.modeling.feedforward.classification;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.learning.config.Sgd;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;

/**
 * @author Adam Gibson (IrisClassifier), changes by Dr. Albrecht Ehlert 2020
 */

@SuppressWarnings("DuplicatedCode")

public class PropertyClassifier {

  public static void main(String[] args) throws  Exception {

    // Get the dataset
    char delimiter = ',';
    RecordReader recordReader = new CSVRecordReader(delimiter);
    recordReader.initialize(new FileSplit(new File("WohnungsdatenAlsCSV.txt")));

    // Classify the dataset
    int labelIndex = 3;    // 3 input features
    int numberClasses = 3; // 3 classes (types of properties)
    int batchSize = 1039;  // 1039 examples total
    DataSetIterator iterator = new RecordReaderDataSetIterator(
      recordReader, batchSize, labelIndex, numberClasses);
    DataSet allData = iterator.next();
    allData.shuffle();     // Shuffle the order of the rows in the dataset

    // Split the data into trainingdata (70%) and testdata (30%)
    SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.70);
    DataSet trainingData = testAndTrain.getTrain();
    DataSet testData = testAndTrain.getTest();

    // Normalize the data
    DataNormalization normalizer = new NormalizerStandardize();
    normalizer.fit(trainingData);       // Collect the statistics
    normalizer.transform(trainingData); // Apply normalization
    normalizer.transform(testData);     // Apply normalization 


- 73 -