Monday, January 18, 2016

Naive Bayes Classifier for classifying Text Documents.

Naive Bayes Classifier can be used for efficiently classifying Text Documents into various categories. 

Input Data:


The program has to classify new files in one of these 20 categories :

talk.religion.misc
misc.forsale
alt.atheism
sci.electronics
talk.politics.misc
comp.os.ms-windows.misc
sci.med 
comp.graphics
comp.sys.mac.hardware
comp.sys.ibm.pc.hardware
sci.crypt 
sci.space 
rec.sport.baseball 
rec.motorcycles 
comp.windows.x 
rec.autos 
talk.politics.guns 
rec.sport.hockey 
talk.politics.mideast 
soc.religion.christian 

Source (Java): 

/*
    Takes as input a directory.
    The subdirectory names in that directory are taken as attributes names for classification.
    The files within those subdirectories are the training examples for those attributes.

    Only "trainUpto" fraction of the input is used for training.
    Rest is used for testing.

    Input applied contained 20 attributes and almost 20000 files.
    Input source : http://www.cs.cmu.edu/afs/cs.cmu.edu/user/mitchell/ftp/mlbook.html
    Testing used almost 6000 files.
    Accuracy : 83.4% (The tokens need to be selected more carefully for higher accuracy).
*/

import java.io.*;
import java.util.*;
import java.util.Map.Entry;
 
public class NaiveBayesClassifier {  
    
    static HashMap<String, Integer> globalVocab = new HashMap<>();
    static int nExamples;
    ArrayList<Attribute> attributes = new ArrayList<>();
    static ArrayList<ArrayList<File>> input;
    static double trainUpto = 0.7;
    static HashSet<String> unnecessary = new HashSet<>();  
    
    static class Attribute{
        String name; 
        File[] files;
        int n;
        double p_vj;
        
        ArrayList<String> words = new ArrayList<>();
        HashMap<String, Double> hm = new HashMap<>();
        
        Attribute(String name){
            System.out.println("Attribute : "+name);
            this.name = name;
        } 
        
        public void processInput() throws Exception{
            BufferedReader br;
            for(int i=0;i<files.length*trainUpto;++i){
                File file = files[i];
                br = new BufferedReader(new FileReader(file));
                String s;
                while((s=br.readLine())!=null){
                    String[] d = s.split(" ");
                    for(String ss:d){ 
                        
                        if(unnecessary.contains(ss.toLowerCase()))continue;
                        
                        //Match english words
                        if(!ss.matches("[-, (\"]*\\w+[).,!\"]*")){
                            //Match email addresses
                            if(!ss.matches("(<)?.*@.*\\.((com)|(org)|(net)|(edu))(>)?")){
                                //Match the category name
                                if(!ss.matches(name)){
                                    //Match Dollar amounts (bills)
                                    if(!ss.matches("^\\$[0-9 ]*\\.?[0-9]*$")){ 
                                        continue;
                                    } 
                                }
                            } 
                        }
//                        else System.out.println("Added : "+ss);
                        
                        Double freq = hm.get(ss);
                        if(freq==null)freq=0d;
                        hm.put(ss, ++freq);
                        
                        globalVocab.put(ss, freq.intValue()); 
                    } 
                }
            }
             
            n = hm.size(); 
        } 
        
        public void calculateProbabilities(){ 
            System.out.println("No of words with attribute '"+name+"' : "+n+" GlobalVocab.size() = "+globalVocab.size());
            p_vj = (double)files.length/nExamples;   
            for(String wk:globalVocab.keySet()){
                Double nk = hm.get(wk);
                if(nk==null)nk=0d;
                hm.put(wk, (double)(nk+1)/(n+globalVocab.size()));
            }
        }
    } 
    
    //Constructor
    NaiveBayesClassifier(String directory) throws Exception{
        BufferedReader br = new BufferedReader(new FileReader("d:\\common.txt"));
        String s;
        while((s=br.readLine())!=null){
            String ss[] = s.split(" ");
            for(String sss:ss)unnecessary.add(sss);
        } 
        readInput(directory); 
    } 
    
    //Read input from directory
    private void readInput(String directory) throws Exception{
        File dir = new File(directory); 
        File[] files = dir.listFiles(); 
        int nAttrib = files.length; 
        input = new ArrayList<>();
        for(int i=0;i<nAttrib;++i)input.add(new ArrayList<>());
        
        int i=0;
        for(File file:files){
            if(file.isDirectory()){
                String name = file.getName();
                Attribute att = new Attribute(name);
                attributes.add(att);
                att.files = file.listFiles();
                input.get(i++).addAll(Arrays.asList(att.files));
                nExamples += att.files.length;
                att.processInput();
            }
        } 
        filterVocabulary();
        for(Attribute att:attributes){
            att.calculateProbabilities();
        }
    } 
    
    //Clean vocabulary. (Removing 100 most frequently occuring words like "the", "of" etc.)
    //Since we've used a custom made list of words like these, this code does no useful work.
    //HashSet "unnecessary" takes care of this work.
    private void filterVocabulary(){
        Set<Entry<String, Integer>> es = globalVocab.entrySet();
            Entry<String, Integer>[] entries = new Entry[es.size()];
            entries = es.toArray(entries);
            Arrays.sort(entries, new Comparator(){

                @Override
                public int compare(Object o1, Object o2) { 
                    Entry<String, Integer> e1 = (Entry<String, Integer>)o1;
                    Entry<String, Integer> e2 = (Entry<String, Integer>)o2;
                    if(e1.getValue()>e2.getValue()) return -1;
                    else if(e1.getValue().equals(e2.getValue()))return 0;
                    else return 1; 
                }
            } );
            
            /*
                The code below removes upto "up" most common elements,
                but we've included words like "I", "the", "for", "of" in
                a file that we read during initialization. The Set "unnecessary"
                takes care of these words. Including the code below can remove
                important classification words if "unnecessary" set is used as
                a filter.
            */
            
//            int up = 10;
//            for(int i=0;i<up;++i){
//                boolean f = false;
//                for(Attribute att:attributes){
//                    for(String s:att.name.split(".")){
//                        if(entries[i].getKey().matches(s)){
//                            up++;
//                            f = true;
//                            break;
//                        }
//                    }
//                    if(f)break;
//                }
//                if(f)continue;
//                System.out.println("----[1]Removed : "+entries[i].getKey());
//                globalVocab.remove(entries[i].getKey()); 
//            } 
            
            /*
                For removing words that appear less than 3 times
            */
//            for(int i =entries.length-1;i>=0;--i){
//                if(entries[i].getValue()<3){
//                    globalVocab.remove(entries[i].getKey());
//                    System.out.println("-------[2]Removed : "+entries[i].getKey());
//                }
//            }
    }
    
    
    //Test on remaining training examples from input. Margin decided by the variable
    //trainUpto.
    public void classifyRest() throws Exception{
        int correct=0;
        int total=0;
        
        HashMap<String, Integer> misclassifications = new HashMap<>();
         
        for(int i=0;i<input.size();++i){
            int length = input.get(i).size(); 
            total += (length-(int)Math.ceil(length*trainUpto)); 
            for(int j=(int)Math.ceil(length*trainUpto);j<length;++j){ 
                File file = input.get(i).get(j);
                
                double[] attr_scores = new double[attributes.size()];
                for(int k=0;k<attr_scores.length;++k)attr_scores[k] = Math.log(attributes.get(k).p_vj); 
                BufferedReader br = new BufferedReader(new FileReader(file));
                String s;
            
                while((s=br.readLine())!=null){
                    String ss[] = s.split(" ");
                    for(String sss:ss){
                        if(globalVocab.containsKey(sss)){  
                            for(int k=0;k<attributes.size();++k){ 
                                attr_scores[k] += Math.log(attributes.get(k).hm.get(sss)); 
                            }
                        } 
                    }
                } 
            
                System.out.println();
                double max = Double.NEGATIVE_INFINITY;
                int maxAttr=-1;
                for(int k=0;k<attr_scores.length;++k){
                    if(attr_scores[k]>max){
                        max=attr_scores[k];
                        maxAttr = k;
                    } 
                }
                
                System.out.println("Verdict : "+attributes.get(maxAttr).name);
                String cls = attributes.get(i).name;
                System.out.println("Actual classification : "+cls);
                if(cls.matches(attributes.get(maxAttr).name))correct++; 
                else {
                    Integer ret = misclassifications.get(cls);
                    if(ret==null)ret = 0;
                    misclassifications.put(cls, ++ret);
                }
            }
        }
        
        Set<Entry<String, Integer>> es = misclassifications.entrySet();
        Entry<String, Integer>[] entries = new Entry[es.size()];
        entries = es.toArray(entries);
        
        Arrays.sort(entries, new Comparator(){

            @Override
            public int compare(Object o1, Object o2) { 
                Entry<String, Integer> e1 = (Entry<String, Integer>)o1;
                Entry<String, Integer> e2 = (Entry<String, Integer>)o2;
                if(e1.getValue()>e2.getValue()) return -1;
                else if(e1.getValue().equals(e2.getValue()))return 0;
                else return 1; 
            }
        } );
        
        System.out.println("\n\n*** Failed classifications -> Highest to lowest: \n");
        for(int i=0;i<entries.length;++i){
            System.out.println(entries[i].getKey()+" : "+entries[i].getValue());
        }
        System.out.println("\nTest Evaluation finished...\n\nAccuracy : "+((double)correct*100/total)+"%. Correct/Total = "+correct+"/"+total);
    }
    
    public static void main(String[] args) throws Exception{  
        NaiveBayesClassifier nb = new NaiveBayesClassifier("D:\\20_newsgroups\\"); 
        nb.classifyRest();
    }
}

Output:

Attribute : alt.atheism
Attribute : comp.graphics
Attribute : comp.os.ms-windows.misc
Attribute : comp.sys.ibm.pc.hardware
Attribute : comp.sys.mac.hardware
Attribute : comp.windows.x
Attribute : misc.forsale
Attribute : rec.autos
Attribute : rec.motorcycles
Attribute : rec.sport.baseball
Attribute : rec.sport.hockey
Attribute : sci.crypt
Attribute : sci.electronics
Attribute : sci.med
Attribute : sci.space
Attribute : soc.religion.christian
Attribute : talk.politics.guns
Attribute : talk.politics.mideast
Attribute : talk.politics.misc
Attribute : talk.religion.misc
No of words with attribute 'alt.atheism' : 21813 GlobalVocab.size() = 186600
No of words with attribute 'comp.graphics' : 21223 GlobalVocab.size() = 186600
No of words with attribute 'comp.os.ms-windows.misc' : 16947 GlobalVocab.size() = 186600
No of words with attribute 'comp.sys.ibm.pc.hardware' : 17437 GlobalVocab.size() = 186600
No of words with attribute 'comp.sys.mac.hardware' : 16351 GlobalVocab.size() = 186600
No of words with attribute 'comp.windows.x' : 20340 GlobalVocab.size() = 186600
No of words with attribute 'misc.forsale' : 17809 GlobalVocab.size() = 186600
No of words with attribute 'rec.autos' : 18653 GlobalVocab.size() = 186600
No of words with attribute 'rec.motorcycles' : 18013 GlobalVocab.size() = 186600
No of words with attribute 'rec.sport.baseball' : 17296 GlobalVocab.size() = 186600
No of words with attribute 'rec.sport.hockey' : 19310 GlobalVocab.size() = 186600
No of words with attribute 'sci.crypt' : 17595 GlobalVocab.size() = 186600
No of words with attribute 'sci.electronics' : 18347 GlobalVocab.size() = 186600
No of words with attribute 'sci.med' : 24897 GlobalVocab.size() = 186600
No of words with attribute 'sci.space' : 23250 GlobalVocab.size() = 186600
No of words with attribute 'soc.religion.christian' : 24635 GlobalVocab.size() = 186600
No of words with attribute 'talk.politics.guns' : 24724 GlobalVocab.size() = 186600
No of words with attribute 'talk.politics.mideast' : 28232 GlobalVocab.size() = 186600
No of words with attribute 'talk.politics.misc' : 27019 GlobalVocab.size() = 186600
No of words with attribute 'talk.religion.misc' : 22620 GlobalVocab.size() = 186600

Verdict : talk.religion.misc
Actual classification : alt.atheism

Verdict : alt.atheism
Actual classification : alt.atheism

Verdict : soc.religion.christian
Actual classification : alt.atheism

Verdict : talk.religion.misc
Actual classification : alt.atheism

Verdict : comp.windows.x
Actual classification : alt.atheism

Verdict : talk.politics.mideast
Actual classification : alt.atheism

Verdict : talk.religion.misc
Actual classification : alt.atheism

Verdict : alt.atheism
Actual classification : alt.atheism

Verdict : talk.politics.misc
Actual classification : alt.atheism

Verdict : alt.atheism
Actual classification : alt.atheism

Verdict : alt.atheism
Actual classification : alt.atheism

Verdict : talk.politics.mideast
Actual classification : alt.atheism

Verdict : talk.politics.mideast
Actual classification : alt.atheism

Verdict : alt.atheism
Actual classification : alt.atheism

Verdict : talk.religion.misc
Actual classification : alt.atheism

Verdict : alt.atheism
Actual classification : alt.atheism

Verdict : alt.atheism
Actual classification : alt.atheism

Verdict : alt.atheism
Actual classification : alt.atheism

(............................................................)

Verdict : talk.politics.misc
Actual classification : talk.religion.misc

Verdict : talk.religion.misc
Actual classification : talk.religion.misc

Verdict : talk.religion.misc
Actual classification : talk.religion.misc

Verdict : talk.religion.misc
Actual classification : talk.religion.misc

Verdict : talk.politics.misc
Actual classification : talk.religion.misc

Verdict : talk.religion.misc
Actual classification : talk.religion.misc


*** Failed classifications -> Highest to lowest: 

talk.religion.misc : 129
misc.forsale : 104
alt.atheism : 98
sci.electronics : 86
talk.politics.misc : 72
comp.os.ms-windows.misc : 68
sci.med : 52
comp.graphics : 48
comp.sys.mac.hardware : 45
comp.sys.ibm.pc.hardware : 39
sci.crypt : 38
sci.space : 37
rec.sport.baseball : 29
rec.motorcycles : 29
comp.windows.x : 28
rec.autos : 26
talk.politics.guns : 26
rec.sport.hockey : 8
talk.politics.mideast : 7
soc.religion.christian : 2

Test Evaluation finished...

Accuracy : 83.41305090536386%. Correct/Total = 4883/5854

No comments:

Post a Comment