/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.classification.document;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.SimpleNaiveBayesClassifier;
import org.apache.lucene.classification.document.DocumentClassifier;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.util.BytesRef;

public class SimpleNaiveBayesDocumentClassifier
extends SimpleNaiveBayesClassifier
implements DocumentClassifier<BytesRef> {
    protected Map<String, Analyzer> field2analyzer;

    public SimpleNaiveBayesDocumentClassifier(IndexReader indexReader, Query query, String classFieldName, Map<String, Analyzer> field2analyzer, String ... textFieldNames) {
        super(indexReader, null, query, classFieldName, textFieldNames);
        this.field2analyzer = field2analyzer;
    }

    @Override
    public ClassificationResult<BytesRef> assignClass(Document document) throws IOException {
        List<ClassificationResult<BytesRef>> assignedClasses = this.assignNormClasses(document);
        ClassificationResult<BytesRef> assignedClass = null;
        double maxscore = -1.7976931348623157E308;
        for (ClassificationResult<BytesRef> c : assignedClasses) {
            if (!(c.getScore() > maxscore)) continue;
            assignedClass = c;
            maxscore = c.getScore();
        }
        return assignedClass;
    }

    @Override
    public List<ClassificationResult<BytesRef>> getClasses(Document document) throws IOException {
        List<ClassificationResult<BytesRef>> assignedClasses = this.assignNormClasses(document);
        Collections.sort(assignedClasses);
        return assignedClasses;
    }

    @Override
    public List<ClassificationResult<BytesRef>> getClasses(Document document, int max) throws IOException {
        List<ClassificationResult<BytesRef>> assignedClasses = this.assignNormClasses(document);
        Collections.sort(assignedClasses);
        return assignedClasses.subList(0, max);
    }

    private List<ClassificationResult<BytesRef>> assignNormClasses(Document inputDocument) throws IOException {
        BytesRef c;
        ArrayList<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<ClassificationResult<BytesRef>>();
        LinkedHashMap<String, List<String[]>> fieldName2tokensArray = new LinkedHashMap<String, List<String[]>>();
        LinkedHashMap<String, Float> fieldName2boost = new LinkedHashMap<String, Float>();
        Terms classes = MultiFields.getTerms((IndexReader)this.indexReader, (String)this.classFieldName);
        TermsEnum classesEnum = classes.iterator();
        this.analyzeSeedDocument(inputDocument, fieldName2tokensArray, fieldName2boost);
        int docsWithClassSize = this.countDocsWithClass();
        while ((c = classesEnum.next()) != null) {
            double classScore = 0.0;
            Term term = new Term(this.classFieldName, c);
            for (String fieldName : this.textFieldNames) {
                List tokensArrays = (List)fieldName2tokensArray.get(fieldName);
                double fieldScore = 0.0;
                for (String[] fieldTokensArray : tokensArrays) {
                    fieldScore += this.calculateLogPrior(term, docsWithClassSize) + this.calculateLogLikelihood(fieldTokensArray, fieldName, term, docsWithClassSize) * (double)((Float)fieldName2boost.get(fieldName)).floatValue();
                }
                classScore += fieldScore;
            }
            assignedClasses.add(new ClassificationResult<BytesRef>(term.bytes(), classScore));
        }
        return this.normClassificationResults(assignedClasses);
    }

    private void analyzeSeedDocument(Document inputDocument, Map<String, List<String[]>> fieldName2tokensArray, Map<String, Float> fieldName2boost) throws IOException {
        for (int i = 0; i < this.textFieldNames.length; ++i) {
            IndexableField[] fieldValues;
            String fieldName = this.textFieldNames[i];
            float boost = 1.0f;
            LinkedList<String[]> tokenizedValues = new LinkedList<String[]>();
            if (fieldName.contains("^")) {
                String[] field2boost = fieldName.split("\\^");
                fieldName = field2boost[0];
                boost = Float.parseFloat(field2boost[1]);
            }
            for (IndexableField fieldValue : fieldValues = inputDocument.getFields(fieldName)) {
                TokenStream fieldTokens = fieldValue.tokenStream(this.field2analyzer.get(fieldName), null);
                String[] fieldTokensArray = this.getTokenArray(fieldTokens);
                tokenizedValues.add(fieldTokensArray);
            }
            fieldName2tokensArray.put(fieldName, tokenizedValues);
            fieldName2boost.put(fieldName, Float.valueOf(boost));
            this.textFieldNames[i] = fieldName;
        }
    }

    protected String[] getTokenArray(TokenStream tokenizedText) throws IOException {
        LinkedList<String> tokens = new LinkedList<String>();
        CharTermAttribute charTermAttribute = (CharTermAttribute)tokenizedText.addAttribute(CharTermAttribute.class);
        tokenizedText.reset();
        while (tokenizedText.incrementToken()) {
            tokens.add(charTermAttribute.toString());
        }
        tokenizedText.end();
        tokenizedText.close();
        return tokens.toArray(new String[tokens.size()]);
    }

    private double calculateLogLikelihood(String[] tokenizedText, String fieldName, Term term, int docsWithClass) throws IOException {
        double result = 0.0;
        for (String word : tokenizedText) {
            int hits = this.getWordFreqForClass(word, fieldName, term);
            double num = hits + 1;
            double den = this.getTextTermFreqForClass(term, fieldName) + (double)docsWithClass;
            double wordProbability = num / den;
            result += Math.log(wordProbability);
        }
        double normScore = result / (double)tokenizedText.length;
        return normScore;
    }

    private double getTextTermFreqForClass(Term term, String fieldName) throws IOException {
        Terms terms = MultiFields.getTerms((IndexReader)this.indexReader, (String)fieldName);
        long numPostings = terms.getSumDocFreq();
        double avgNumberOfUniqueTerms = (double)numPostings / (double)terms.getDocCount();
        int docsWithC = this.indexReader.docFreq(term);
        return avgNumberOfUniqueTerms * (double)docsWithC;
    }

    private int getWordFreqForClass(String word, String fieldName, Term term) throws IOException {
        BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder();
        BooleanQuery.Builder subQuery = new BooleanQuery.Builder();
        subQuery.add(new BooleanClause((Query)new TermQuery(new Term(fieldName, word)), BooleanClause.Occur.SHOULD));
        booleanQuery.add(new BooleanClause((Query)subQuery.build(), BooleanClause.Occur.MUST));
        booleanQuery.add(new BooleanClause((Query)new TermQuery(term), BooleanClause.Occur.MUST));
        if (this.query != null) {
            booleanQuery.add(this.query, BooleanClause.Occur.MUST);
        }
        TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
        this.indexSearcher.search((Query)booleanQuery.build(), (Collector)totalHitCountCollector);
        return totalHitCountCollector.getTotalHits();
    }

    private double calculateLogPrior(Term term, int docsWithClassSize) throws IOException {
        return Math.log(this.docCount(term)) - Math.log(docsWithClassSize);
    }

    private int docCount(Term term) throws IOException {
        return this.indexReader.docFreq(term);
    }
}

