Author: Lao Jiu – technology big millet
Socializing: Zhihu
Public account: Lao Jiu School (Surprise for newcomers)
Special statement: the original is not easy, without authorization shall not be reproduced or copied, if you need to reproduce can contact the author authorized
preface
There was a problem implementing the decision tree: Java runtime error after using python modified text data
Start with the code — implement the decision tree
1. Tree nodes
public class treeNode{
private String sname;/ / the node name
public treeNode(String str) {
sname=str;
}
public String getsname(a) {
return sname;
}
ArrayList<String> label=new ArrayList<String>();// The edge label between the node and its children
ArrayList<treeNode> node=new ArrayList<treeNode>();// Corresponds to the child node
}
Copy the code
2. Implement the decision tree
public class ID3 {
private ArrayList<String> label = new ArrayList<String>();// Feature tag
private ArrayList<ArrayList<String>> date = new ArrayList<ArrayList<String>>();/ / data set
private ArrayList<ArrayList<String>> test = new ArrayList<ArrayList<String>>();// Test the data set
private ArrayList<String> sum = new ArrayList<String>();// The number of categories
private String kind;
public static ArrayList<Output> outputs = new ArrayList<>();// An array to store output
public ID3(String path, String path0) throws FileNotFoundException {
getDate(path); // Initialize the training data and get the classification number
gettestDate(path0);// Get the test data set
init(date);
}
public void init(ArrayList<ArrayList<String>> date) {
sum.add(date.get(0).get(date.get(0).size() - 1));// Get the number of types
for (int i = 0; i < date.size(); i++) {
if (sum.contains(date.get(i).get(date.get(0).size() - 1)) = =false) {
sum.add(date.get(i).get(date.get(0).size() - 1)); }}}/* Get test data set */
public void gettestDate(String path) throws FileNotFoundException {
String str;
int i = 0;
try {
//BufferedReader in=new BufferedReader(new FileReader(path));
FileInputStream fis = new FileInputStream(path);
InputStreamReader isr = new InputStreamReader(fis, "UTF-8");
BufferedReader in = new BufferedReader(isr);
while((str = in.readLine()) ! =null) {
String[] strs = str.split(",");
ArrayList<String> line = new ArrayList<String>();
boolean isFinished=true;// Check whether the output array is added
for (int j = 0; j < strs.length; j++) {
line.add(strs[j]);
if(! isFinished){ Output output=new Output(strs[j/(label.size()-1)]."null");
outputs.add(output);
isFinished=true;
}
if(j%(label.size()-1) = =0){
isFinished=false;
}
}
test.add(line);
i++;
}
in.close();
} catch(Exception e) { e.printStackTrace(); }}// Get the training data set
public void getDate(String path) throws FileNotFoundException {
String str;
int i = 0;
try {
FileInputStream fis = new FileInputStream(path);
InputStreamReader isr = new InputStreamReader(fis, "UTF-8");
BufferedReader in = new BufferedReader(isr);
while((str = in.readLine()) ! =null) {
if (i == 0) {
String[] strs = str.split(",");
for (int j = 0; j < strs.length; j++) {
label.add(strs[j]);
}
i++;
continue;
}
String[] strs = str.split(",");
ArrayList<String> line = new ArrayList<String>();
for (int j = 0; j < strs.length; j++) {
line.add(strs[j]);
}
date.add(line);
i++;
}
in.close();
} catch(Exception e) { e.printStackTrace(); }}public double Ent(ArrayList<ArrayList<String>> dat) {
// Calculate the total information entropy
int all = 0;
double amount = 0.0;
for (int i = 0; i < sum.size(); i++) {
for (int j = 0; j < dat.size(); j++) {
if (sum.get(i).equals(dat.get(j).get(dat.get(0).size() - 1))) { all++; }}if ((double) all / dat.size() == 0.0) {
continue;
}
amount += ((double) all / dat.size()) * (Math.log(((double) all / dat.size())) / Math.log(2.0));
all = 0;
}
if (amount == 0.0) {
return 0.0;
}
return -amount;// Calculate information entropy
}
/* Calculates the conditional entropy and returns the information gain value */
public double condtion(int a, ArrayList<ArrayList<String>> dat) {
ArrayList<String> all = new ArrayList<String>();
double c = 0.0;
all.add(dat.get(0).get(a));
// Get the attribute type
for (int i = 0; i < dat.size(); i++) {
if (all.contains(dat.get(i).get(a)) == false) {
all.add(dat.get(i).get(a));
}
}
ArrayList<ArrayList<String>> plus = new ArrayList<ArrayList<String>>();
// Partial grouping
ArrayList<ArrayList<ArrayList<String>>> count = new ArrayList<ArrayList<ArrayList<String>>>();
// Grouping total
for (int i = 0; i < all.size(); i++) {
for (int j = 0; j < dat.size(); j++) {
if (true == all.get(i).equals(dat.get(j).get(a))) {
plus.add(dat.get(j));
}
}
count.add(plus);
c += ((double) count.get(i).size() / dat.size()) * Ent(count.get(i));
plus.removeAll(plus);
}
return (Ent(dat) - c);
// Return conditional entropy
}
/* Calculate the maximum attribute of information gain */
public int Gain(ArrayList<ArrayList<String>> dat) {
ArrayList<Double> num = new ArrayList<Double>();
// Save each information gain value
for (int i = 0; i < dat.get(0).size() - 1; i++) {
num.add(condtion(i, dat));
}
int index = 0;
double max = num.get(0);
for (int i = 1; i < num.size(); i++) {
if(max < num.get(i)) { max = num.get(i); index = i; }}return index;
}
// Build the decision tree
public treeNode creattree(ArrayList<ArrayList<String>> dat) {
int index = Gain(dat);
treeNode node = new treeNode(label.get(index));
ArrayList<String> s = new ArrayList<String>();// Attribute type
s.add(dat.get(0).get(index));
for (int i = 1; i < dat.size(); i++) {
if (s.contains(dat.get(i).get(index)) == false) {
s.add(dat.get(i).get(index));
}
}
ArrayList<ArrayList<String>> plus = new ArrayList<ArrayList<String>>();
// Partial grouping
ArrayList<ArrayList<ArrayList<String>>> count = new ArrayList<ArrayList<ArrayList<String>>>();
// Grouping total
// Get the edge labels under the nodes and group them
for (int i = 0; i < s.size(); i++) {
node.label.add(s.get(i));// Add edge labels
for (int j = 0; j < dat.size(); j++) {
if (true == s.get(i).equals(dat.get(j).get(index))) {
plus.add(dat.get(j));
}
}
count.add(plus);
// add nodes below
int k;
String str = count.get(i).get(0).get(count.get(i).get(0).size() - 1);
for (k = 1; k < count.get(i).size(); k++) {
if (false == str.equals(count.get(i).get(k).get(count.get(i).get(k).size() - 1))) {
break; }}if (k == count.get(i).size()) {
treeNode dd = new treeNode(str);
node.node.add(dd);
} else {
node.node.add(creattree(count.get(i)));
}
plus.removeAll(plus);
}
return node;
}
// Outputs the decision tree
public void print(ArrayList<ArrayList<String>> dat) {
System.out.println("The constructed decision tree is as follows:");
treeNode node = null;
node = creattree(dat);/ / class
put(node);// recursive call
}
// A recursive function
public void put(treeNode node) {
System.out.println("Node:" + node.getsname() + "\n");
for (int i = 0; i < node.label.size(); i++) {
System.out.println(node.getsname() + Tag attribute of ":" + node.label.get(i));
if (node.node.get(i).node.isEmpty() == true) {
System.out.println("Leaf node:" + node.node.get(i).getsname());
} else{ put(node.node.get(i)); }}}/* Is used to make predictions on decision data and save the results in the specified path */
public void testdate(ArrayList<ArrayList<String>> test, String path) throws IOException {
treeNode node = null;
int count = 0;
node = creattree(this.date);/ / class
try {
BufferedWriter out = new BufferedWriter(new FileWriter(path));
for (int i = 0; i < test.size(); i++) {
testput(node, test.get(i));// recursive call
for (int j = 0; j < test.get(i).size(); j++) {
out.write(test.get(i).get(j) + ",");
}
if (kind.equals(date.get(i).get(date.get(i).size() - 1)) = =true) {
count++;
}
out.write(kind);
outputs.get(i).kind=kind;
out.newLine();
}
out.flush();
out.close();
} catch(IOException e) { e.printStackTrace(); }}// Recursive calls for testing
public void testput(treeNode node, ArrayList<String> t) {
int index = 0;
for (int i = 0; i < this.label.size(); i++) {
if (this.label.get(i).equals(node.getsname()) == true) {
index = i;
break; }}for (int i = 0; i < node.label.size(); i++) {
if (t.get(index).equals(node.label.get(i)) == false) {
continue;
}
if (node.node.get(i).node.isEmpty() == true) {
this.kind = node.node.get(i).getsname();// retrieve the classification result
} else{ testput(node.node.get(i), t); }}}public static void main(String[] args) throws IOException {
String data = "src\\com\\xuetang9\\data.txt";// Train the data set
String test = "src\\com\\xuetang9\\test.txt";// Test the data set
String result = "src\\com\\xuetang9\\result.txt";// Predict the result set
ID3 id = new ID3(data, test);// Initialize the data
id.print(id.date);// Build and output the decision tree
id.testdate(id.test,result);// Predict the data and print the results
System.out.println("Please enter the location you wish to inquire.");
Scanner scanner= new Scanner(System.in);
String input=scanner.next();
for (Output output: outputs) {
if(output.countryName.equals(input)){
System.out.println("Query result:"+output.kind); }}}}Copy the code
3. Utility classes for storing the output
public class Output{
public String countryName;
public String kind;
public Output(String countryName,String kind){
this.countryName=countryName;
this.kind=kind;
}
@Override
public String toString(a) {
return "Output{" +
"countryName='" + countryName + '\' ' +
", kind='" + kind + '\' ' +
'} '; }}Copy the code
4. Data text file
Location, sunshine, number of people, traffic, wind, air, rain, suitable for travel in The United Kingdom, bright, crowded, crowded, breeze, slightly cloudy, little, suitable for travel in the United States, slightly dark, crowded, crowded, breeze, slightly cloudy, little, suitable for travel in Japan, gray, crowded, slightly crowded, breeze, fresh, little, suitable for travel South Korea, bright, crowded, not hug, breeze, fresh, small, suitable for travel in Britain, dark, crowded, slightly hug, breeze, fresh, small, suitable for travel in the United States, bright, crowded, slightly hug, breeze, slightly muddy, medium, suitable for travel in Japan, dark, slightly hug, slightly muddy, stroke, slightly muddy, small, suitable for travel South Korea, dark, slightly warm, slightly warm, breeze, slightly muddy, little, not suitable for travel In Britain, dark, slightly warm, not warm, stroke, slightly muddy, little, not suitable for travel in the United States, bright, not warm, crowded, breeze, cloudy, medium, not suitable for travel in Japan, dark, not warm, crowded, stroke, cloudy, a lot, not suitable for travel South Korea, dark, crowded, slightly cloudy, stroke, cloudy, a lot, not suitable for traveling England, bright, slightly cloudy, stroke, fresh, a little, not suitable for traveling America, dark, slightly cloudy, stroke, fresh, a little, not suitable for traveling Japan, dark, slightly cloudy, breeze, slightly cloudy, a lot, not suitable for traveling South Korea, dark, crowded, slightly cloudy, stroke, cloudy, little, not suitable for travel England, bright, crowded, stroke, slightly cloudy, little, not suitable for travel America, bright, not cloudy, crowded, breeze, slightly cloudy, little, suitable for travelCopy the code
5. Test text file
Britain, gray, slightly crowded, slightly crowded, breeze, slightly Muddy, lots of America, slightly dark, crowded, crowded, breeze, slightly muddy, little Japan, bright, not crowded, breeze, cloudy, medium Korea, bright, crowded, not crowded, breeze, fresh, littleCopy the code
Runs perfectly in Java:
But here’s the problem
Change the test data to
England, bright, crowded, crowded, breezy, slightly muddy, lightCopy the code
In Java, the result is
Debugging found
// The kind member variable in the loop is null
if (kind.equals(date.get(i).get(date.get(i).size() - 1)) = =true) {
count++;
}
Copy the code
Trace the kind initialization location
for (int i = 0; i < node.label.size(); i++) {
if (t.get(index).equals(node.label.get(i)) == false) {
continue;
}
if (node.node.get(i).node.isEmpty() == true) {
this.kind = node.node.get(i).getsname();// retrieve the classification result
} else{ testput(node.node.get(i), t); }}Copy the code
It is found that the modified data does not have the same data in the decision tree
Print log
System.out.println("T:"+t.get(index)+",L:"+node.label.get(i));
Copy the code
The problem is caused by garbled characters in the read data
Check the encoding format of the text. Before using Python, test was UTF-8. After using python, test is GBK. Then modify the text encoding format, problem solved.
conclusion
Text formatting is often the most common problem we see in programming.
The last
Remember to give dashu ❤️ attention + like + collect + comment + forward ❤️
Author: Lao Jiu School – technology big millet
Copyright belongs to the author. Commercial reprint please contact the author for authorization, non-commercial reprint please indicate the source.