Author: Lao Jiu – technology big millet

Socializing: Zhihu

Public account: Lao Jiu School (Surprise for newcomers)

Special statement: the original is not easy, without authorization shall not be reproduced or copied, if you need to reproduce can contact the author authorized

preface

There was a problem implementing the decision tree: Java runtime error after using python modified text data

Start with the code — implement the decision tree

1. Tree nodes


public class treeNode{
    private String sname;/ / the node name
    public treeNode(String str) {
        sname=str;
    }
    public String getsname(a) {
        return sname;
    }
    ArrayList<String> label=new ArrayList<String>();// The edge label between the node and its children
    ArrayList<treeNode> node=new ArrayList<treeNode>();// Corresponds to the child node
}
Copy the code

2. Implement the decision tree


public class ID3 {
    private ArrayList<String> label = new ArrayList<String>();// Feature tag
    private ArrayList<ArrayList<String>> date = new ArrayList<ArrayList<String>>();/ / data set
    private ArrayList<ArrayList<String>> test = new ArrayList<ArrayList<String>>();// Test the data set
    private ArrayList<String> sum = new ArrayList<String>();// The number of categories
    private String kind;

    public static ArrayList<Output> outputs = new ArrayList<>();// An array to store output

    public ID3(String path, String path0) throws FileNotFoundException {     
        getDate(path); // Initialize the training data and get the classification number
        gettestDate(path0);// Get the test data set
        init(date);
    }

    public void init(ArrayList<ArrayList<String>> date) {    
        sum.add(date.get(0).get(date.get(0).size() - 1));// Get the number of types
        for (int i = 0; i < date.size(); i++) {
            if (sum.contains(date.get(i).get(date.get(0).size() - 1)) = =false) {
                sum.add(date.get(i).get(date.get(0).size() - 1)); }}}/* Get test data set */
    public void gettestDate(String path) throws FileNotFoundException {
        String str;
        int i = 0;
        try {
            //BufferedReader in=new BufferedReader(new FileReader(path));
            FileInputStream fis = new FileInputStream(path);
            InputStreamReader isr = new InputStreamReader(fis, "UTF-8");
            BufferedReader in = new BufferedReader(isr);
            while((str = in.readLine()) ! =null) {
                String[] strs = str.split(",");
                ArrayList<String> line = new ArrayList<String>();
                boolean isFinished=true;// Check whether the output array is added
                for (int j = 0; j < strs.length; j++) {
                    line.add(strs[j]);
                    if(! isFinished){ Output output=new Output(strs[j/(label.size()-1)]."null");
                        outputs.add(output);
                        isFinished=true;
                    }
                    if(j%(label.size()-1) = =0){
                        isFinished=false;
                    }
                }
                test.add(line);
                i++;
            }

            in.close();
        } catch(Exception e) { e.printStackTrace(); }}// Get the training data set
    public void getDate(String path) throws FileNotFoundException {
        String str;
        int i = 0;
        try {
            FileInputStream fis = new FileInputStream(path);
            InputStreamReader isr = new InputStreamReader(fis, "UTF-8");
            BufferedReader in = new BufferedReader(isr);
            while((str = in.readLine()) ! =null) {
                if (i == 0) {
                    String[] strs = str.split(",");
                    for (int j = 0; j < strs.length; j++) {
                        label.add(strs[j]);
                    }
                    i++;
                    continue;
                }
                String[] strs = str.split(",");
                ArrayList<String> line = new ArrayList<String>();
                for (int j = 0; j < strs.length; j++) {
                    line.add(strs[j]);
                }
                date.add(line);
                i++;
            }
            in.close();
        } catch(Exception e) { e.printStackTrace(); }}public double Ent(ArrayList<ArrayList<String>> dat) {
        // Calculate the total information entropy
        int all = 0;
        double amount = 0.0;
        for (int i = 0; i < sum.size(); i++) {
            for (int j = 0; j < dat.size(); j++) {
                if (sum.get(i).equals(dat.get(j).get(dat.get(0).size() - 1))) { all++; }}if ((double) all / dat.size() == 0.0) {
                continue;
            }
            amount += ((double) all / dat.size()) * (Math.log(((double) all / dat.size())) / Math.log(2.0));
            all = 0;
        }
        if (amount == 0.0) {
            return 0.0;
        }
        return -amount;// Calculate information entropy
    }

    /* Calculates the conditional entropy and returns the information gain value */
    public double condtion(int a, ArrayList<ArrayList<String>> dat) {
        ArrayList<String> all = new ArrayList<String>();
        double c = 0.0;
        all.add(dat.get(0).get(a));
        // Get the attribute type
        for (int i = 0; i < dat.size(); i++) {
            if (all.contains(dat.get(i).get(a)) == false) {
                all.add(dat.get(i).get(a));
            }
        }
        ArrayList<ArrayList<String>> plus = new ArrayList<ArrayList<String>>();
        // Partial grouping
        ArrayList<ArrayList<ArrayList<String>>> count = new ArrayList<ArrayList<ArrayList<String>>>();
        // Grouping total
        for (int i = 0; i < all.size(); i++) {
            for (int j = 0; j < dat.size(); j++) {
                if (true == all.get(i).equals(dat.get(j).get(a))) {
                    plus.add(dat.get(j));
                }
            }
            count.add(plus);
            c += ((double) count.get(i).size() / dat.size()) * Ent(count.get(i));
            plus.removeAll(plus);
        }
        return (Ent(dat) - c);
        // Return conditional entropy
    }

    /* Calculate the maximum attribute of information gain */
    public int Gain(ArrayList<ArrayList<String>> dat) {
        ArrayList<Double> num = new ArrayList<Double>();
        // Save each information gain value
        for (int i = 0; i < dat.get(0).size() - 1; i++) {
            num.add(condtion(i, dat));
        }
        int index = 0;
        double max = num.get(0);
        for (int i = 1; i < num.size(); i++) {
            if(max < num.get(i)) { max = num.get(i); index = i; }}return index;
    }

    // Build the decision tree
    public treeNode creattree(ArrayList<ArrayList<String>> dat) {
        int index = Gain(dat);
        treeNode node = new treeNode(label.get(index));
        ArrayList<String> s = new ArrayList<String>();// Attribute type
        s.add(dat.get(0).get(index));
        for (int i = 1; i < dat.size(); i++) {
            if (s.contains(dat.get(i).get(index)) == false) {
                s.add(dat.get(i).get(index));
            }
        }
        ArrayList<ArrayList<String>> plus = new ArrayList<ArrayList<String>>();
        // Partial grouping
        ArrayList<ArrayList<ArrayList<String>>> count = new ArrayList<ArrayList<ArrayList<String>>>();
        // Grouping total
        // Get the edge labels under the nodes and group them
        for (int i = 0; i < s.size(); i++) {
            node.label.add(s.get(i));// Add edge labels
            for (int j = 0; j < dat.size(); j++) {
                if (true == s.get(i).equals(dat.get(j).get(index))) {
                    plus.add(dat.get(j));
                }
            }
            count.add(plus);

            // add nodes below
            int k;
            String str = count.get(i).get(0).get(count.get(i).get(0).size() - 1);
            for (k = 1; k < count.get(i).size(); k++) {
                if (false == str.equals(count.get(i).get(k).get(count.get(i).get(k).size() - 1))) {
                    break; }}if (k == count.get(i).size()) {
                treeNode dd = new treeNode(str);
                node.node.add(dd);
            } else {
                node.node.add(creattree(count.get(i)));
            }
            plus.removeAll(plus);
        }
        return node;
    }

    // Outputs the decision tree
    public void print(ArrayList<ArrayList<String>> dat) {
        System.out.println("The constructed decision tree is as follows:");
        treeNode node = null;
        node = creattree(dat);/ / class
        put(node);// recursive call
    }

    // A recursive function
    public void put(treeNode node) {
        System.out.println("Node:" + node.getsname() + "\n");
        for (int i = 0; i < node.label.size(); i++) {
            System.out.println(node.getsname() + Tag attribute of ":" + node.label.get(i));
            if (node.node.get(i).node.isEmpty() == true) {
                System.out.println("Leaf node:" + node.node.get(i).getsname());
            } else{ put(node.node.get(i)); }}}/* Is used to make predictions on decision data and save the results in the specified path */
    public void testdate(ArrayList<ArrayList<String>> test, String path) throws IOException {
        treeNode node = null;
        int count = 0;
        node = creattree(this.date);/ / class
        try {
            BufferedWriter out = new BufferedWriter(new FileWriter(path));
            for (int i = 0; i < test.size(); i++) {
                testput(node, test.get(i));// recursive call
                for (int j = 0; j < test.get(i).size(); j++) {
                    out.write(test.get(i).get(j) + ",");
                }
                if (kind.equals(date.get(i).get(date.get(i).size() - 1)) = =true) {
                    count++;
                }
                out.write(kind);
                outputs.get(i).kind=kind;
                out.newLine();
            }
            out.flush();
            out.close();
        } catch(IOException e) { e.printStackTrace(); }}// Recursive calls for testing
    public void testput(treeNode node, ArrayList<String> t) {
        int index = 0;
        for (int i = 0; i < this.label.size(); i++) {
            if (this.label.get(i).equals(node.getsname()) == true) {
                index = i;
                break; }}for (int i = 0; i < node.label.size(); i++) {
            if (t.get(index).equals(node.label.get(i)) == false) {
                continue;
            }	
            if (node.node.get(i).node.isEmpty() == true) {
                this.kind = node.node.get(i).getsname();// retrieve the classification result
            } else{ testput(node.node.get(i), t); }}}public static void main(String[] args) throws IOException {
        String data = "src\\com\\xuetang9\\data.txt";// Train the data set
        String test = "src\\com\\xuetang9\\test.txt";// Test the data set
        String result = "src\\com\\xuetang9\\result.txt";// Predict the result set
        ID3 id = new ID3(data, test);// Initialize the data
        id.print(id.date);// Build and output the decision tree
        id.testdate(id.test,result);// Predict the data and print the results
        System.out.println("Please enter the location you wish to inquire.");
        Scanner scanner= new Scanner(System.in);
        String input=scanner.next();
        for (Output output: outputs) {
            if(output.countryName.equals(input)){
                System.out.println("Query result:"+output.kind); }}}}Copy the code

3. Utility classes for storing the output

public class Output{
    public String countryName;
    public String kind;
    public Output(String countryName,String kind){
        this.countryName=countryName;
        this.kind=kind;
    }

    @Override
    public String toString(a) {
        return "Output{" +
                "countryName='" + countryName + '\' ' +
                ", kind='" + kind + '\' ' +
                '} '; }}Copy the code

4. Data text file

Location, sunshine, number of people, traffic, wind, air, rain, suitable for travel in The United Kingdom, bright, crowded, crowded, breeze, slightly cloudy, little, suitable for travel in the United States, slightly dark, crowded, crowded, breeze, slightly cloudy, little, suitable for travel in Japan, gray, crowded, slightly crowded, breeze, fresh, little, suitable for travel South Korea, bright, crowded, not hug, breeze, fresh, small, suitable for travel in Britain, dark, crowded, slightly hug, breeze, fresh, small, suitable for travel in the United States, bright, crowded, slightly hug, breeze, slightly muddy, medium, suitable for travel in Japan, dark, slightly hug, slightly muddy, stroke, slightly muddy, small, suitable for travel South Korea, dark, slightly warm, slightly warm, breeze, slightly muddy, little, not suitable for travel In Britain, dark, slightly warm, not warm, stroke, slightly muddy, little, not suitable for travel in the United States, bright, not warm, crowded, breeze, cloudy, medium, not suitable for travel in Japan, dark, not warm, crowded, stroke, cloudy, a lot, not suitable for travel South Korea, dark, crowded, slightly cloudy, stroke, cloudy, a lot, not suitable for traveling England, bright, slightly cloudy, stroke, fresh, a little, not suitable for traveling America, dark, slightly cloudy, stroke, fresh, a little, not suitable for traveling Japan, dark, slightly cloudy, breeze, slightly cloudy, a lot, not suitable for traveling South Korea, dark, crowded, slightly cloudy, stroke, cloudy, little, not suitable for travel England, bright, crowded, stroke, slightly cloudy, little, not suitable for travel America, bright, not cloudy, crowded, breeze, slightly cloudy, little, suitable for travelCopy the code

5. Test text file

Britain, gray, slightly crowded, slightly crowded, breeze, slightly Muddy, lots of America, slightly dark, crowded, crowded, breeze, slightly muddy, little Japan, bright, not crowded, breeze, cloudy, medium Korea, bright, crowded, not crowded, breeze, fresh, littleCopy the code

Runs perfectly in Java:

But here’s the problem

Change the test data to

England, bright, crowded, crowded, breezy, slightly muddy, lightCopy the code

In Java, the result is

Debugging found

// The kind member variable in the loop is null
if (kind.equals(date.get(i).get(date.get(i).size() - 1)) = =true) {
	count++;
}
Copy the code

Trace the kind initialization location

for (int i = 0; i < node.label.size(); i++) {
    if (t.get(index).equals(node.label.get(i)) == false) {
        continue;
    }	
    if (node.node.get(i).node.isEmpty() == true) {
        this.kind = node.node.get(i).getsname();// retrieve the classification result
    } else{ testput(node.node.get(i), t); }}Copy the code

It is found that the modified data does not have the same data in the decision tree

Print log

System.out.println("T:"+t.get(index)+",L:"+node.label.get(i));
Copy the code

The problem is caused by garbled characters in the read data

Check the encoding format of the text. Before using Python, test was UTF-8. After using python, test is GBK. Then modify the text encoding format, problem solved.

conclusion

Text formatting is often the most common problem we see in programming.

The last

Remember to give dashu ❤️ attention + like + collect + comment + forward ❤️

Author: Lao Jiu School – technology big millet

Copyright belongs to the author. Commercial reprint please contact the author for authorization, non-commercial reprint please indicate the source.