/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.types.Alphabet;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.FeatureCounter;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.Labeling;
import cc.mallet.types.RankedFeatureVector;
import cc.mallet.util.Randoms;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.zip.GZIPOutputStream;

public class CodingTest
implements Serializable {
    protected ArrayList<TopicAssignment> data = new ArrayList();
    protected Alphabet alphabet;
    protected LabelAlphabet topicAlphabet;
    protected int numTopics;
    protected int topicMask;
    protected int topicBits;
    protected int numTypes;
    protected double[] alpha;
    protected double alphaSum;
    protected double beta;
    protected double betaSum;
    public static final double DEFAULT_BETA = 0.01;
    protected double smoothingOnlyMass = 0.0;
    protected double[] cachedCoefficients;
    int topicTermCount = 0;
    int betaTopicCount = 0;
    int smoothingOnlyCount = 0;
    protected InstanceList testing = null;
    protected int[] oneDocTopicCounts;
    protected int[][] typeTopicCounts;
    protected int[] tokensPerTopic;
    protected int[] docLengthCounts;
    protected int[][] topicDocCounts;
    protected int iterationsSoFar = 0;
    public int numIterations = 1000;
    public int burninPeriod = 20;
    public int saveSampleInterval = 5;
    public int optimizeInterval = 20;
    public int showTopicsInterval = 10;
    public int wordsPerTopic = 7;
    protected int outputModelInterval = 0;
    protected String outputModelFilename;
    protected int saveStateInterval = 0;
    protected String stateFilename = null;
    protected Randoms random;
    protected NumberFormat formatter;
    protected boolean printLogLikelihood = false;
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 0;
    private static final int NULL_INTEGER = -1;

    public CodingTest(int numberOfTopics) {
        this(numberOfTopics, numberOfTopics, 0.01);
    }

    public CodingTest(int numberOfTopics, double alphaSum, double beta) {
        this(numberOfTopics, alphaSum, beta, new Randoms());
    }

    private static LabelAlphabet newLabelAlphabet(int numTopics) {
        LabelAlphabet ret = new LabelAlphabet();
        for (int i = 0; i < numTopics; ++i) {
            ret.lookupIndex("topic" + i);
        }
        return ret;
    }

    public CodingTest(int numberOfTopics, double alphaSum, double beta, Randoms random) {
        this(CodingTest.newLabelAlphabet(numberOfTopics), alphaSum, beta, random);
    }

    public CodingTest(LabelAlphabet topicAlphabet, double alphaSum, double beta, Randoms random) {
        this.topicAlphabet = topicAlphabet;
        this.numTopics = topicAlphabet.size();
        if (Integer.bitCount(this.numTopics) == 1) {
            this.topicMask = this.numTopics - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        } else {
            this.topicMask = Integer.highestOneBit(this.numTopics) * 2 - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        }
        this.alphaSum = alphaSum;
        this.alpha = new double[this.numTopics];
        Arrays.fill(this.alpha, alphaSum / (double)this.numTopics);
        this.beta = beta;
        this.random = random;
        this.oneDocTopicCounts = new int[this.numTopics];
        this.tokensPerTopic = new int[this.numTopics];
        this.formatter = NumberFormat.getInstance();
        this.formatter.setMaximumFractionDigits(5);
        System.err.println("Coded LDA: " + this.numTopics + " topics, " + this.topicBits + " topic bits, " + Integer.toBinaryString(this.topicMask) + " topic mask");
    }

    public Alphabet getAlphabet() {
        return this.alphabet;
    }

    public LabelAlphabet getTopicAlphabet() {
        return this.topicAlphabet;
    }

    public int getNumTopics() {
        return this.numTopics;
    }

    public ArrayList<TopicAssignment> getData() {
        return this.data;
    }

    public void setTestingInstances(InstanceList testing) {
        this.testing = testing;
    }

    public void setNumIterations(int numIterations) {
        this.numIterations = numIterations;
    }

    public void setBurninPeriod(int burninPeriod) {
        this.burninPeriod = burninPeriod;
    }

    public void setTopicDisplay(int interval, int n) {
        this.showTopicsInterval = interval;
        this.wordsPerTopic = n;
    }

    public void setRandomSeed(int seed) {
        this.random = new Randoms(seed);
    }

    public void setOptimizeInterval(int interval) {
        this.optimizeInterval = interval;
    }

    public void setModelOutput(int interval, String filename) {
        this.outputModelInterval = interval;
        this.outputModelFilename = filename;
    }

    public void setSaveState(int interval, String filename) {
        this.saveStateInterval = interval;
        this.stateFilename = filename;
    }

    public int[][] getTypeTopicCounts() {
        return this.typeTopicCounts;
    }

    public int[] getTopicTotals() {
        return this.tokensPerTopic;
    }

    public void addInstances(InstanceList training) {
        FeatureSequence tokens;
        this.alphabet = training.getDataAlphabet();
        this.numTypes = this.alphabet.size();
        this.betaSum = this.beta * (double)this.numTypes;
        this.typeTopicCounts = new int[this.numTypes][];
        int[] typeTotals = new int[this.numTypes];
        int doc = 0;
        for (Instance instance : training) {
            ++doc;
            tokens = (FeatureSequence)instance.getData();
            for (int position = 0; position < tokens.getLength(); ++position) {
                int type;
                int n = type = tokens.getIndexAtPosition(position);
                typeTotals[n] = typeTotals[n] + 1;
            }
        }
        for (int type = 0; type < this.numTypes; ++type) {
            this.typeTopicCounts[type] = new int[Math.min(this.numTopics, typeTotals[type])];
        }
        doc = 0;
        for (Instance instance : training) {
            ++doc;
            tokens = (FeatureSequence)instance.getData();
            LabelSequence topicSequence = new LabelSequence(this.topicAlphabet, new int[tokens.size()]);
            int[] topics = topicSequence.getFeatures();
            for (int position = 0; position < tokens.size(); ++position) {
                int topic;
                topics[position] = topic = this.random.nextInt(this.numTopics);
                int n = topic;
                this.tokensPerTopic[n] = this.tokensPerTopic[n] + 1;
                int type = tokens.getIndexAtPosition(position);
                int[] currentTypeTopicCounts = this.typeTopicCounts[type];
                int index = 0;
                int currentTopic = currentTypeTopicCounts[index] & this.topicMask;
                while (currentTypeTopicCounts[index] > 0 && currentTopic != topic) {
                    if (++index >= currentTypeTopicCounts.length) {
                        for (int i = 0; i < currentTypeTopicCounts.length; ++i) {
                            System.out.println((currentTypeTopicCounts[i] & this.topicMask) + ":" + (currentTypeTopicCounts[i] >> this.topicBits) + " ");
                        }
                        System.out.println(type + " " + typeTotals[type]);
                    }
                    currentTopic = currentTypeTopicCounts[index] & this.topicMask;
                }
                int currentValue = currentTypeTopicCounts[index] >> this.topicBits;
                if (currentValue == 0) {
                    currentTypeTopicCounts[index] = (1 << this.topicBits) + topic;
                    continue;
                }
                currentTypeTopicCounts[index] = (currentValue + 1 << this.topicBits) + topic;
                while (index > 0 && currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) {
                    int temp = currentTypeTopicCounts[index];
                    currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1];
                    currentTypeTopicCounts[index - 1] = temp;
                    --index;
                }
            }
            TopicAssignment t = new TopicAssignment(instance, this, topicSequence);
            this.data.add(t);
        }
        this.initializeHistogramsAndCachedValues();
    }

    private void initializeHistogramsAndCachedValues() {
        int topic;
        int maxTokens = 0;
        int totalTokens = 0;
        for (int doc = 0; doc < this.data.size(); ++doc) {
            FeatureSequence fs = (FeatureSequence)this.data.get((int)doc).instance.getData();
            int seqLen = fs.getLength();
            if (seqLen > maxTokens) {
                maxTokens = seqLen;
            }
            totalTokens += seqLen;
        }
        this.smoothingOnlyMass = 0.0;
        for (topic = 0; topic < this.numTopics; ++topic) {
            this.smoothingOnlyMass += this.alpha[topic] * this.beta / ((double)this.tokensPerTopic[topic] + this.betaSum);
        }
        this.cachedCoefficients = new double[this.numTopics];
        for (topic = 0; topic < this.numTopics; ++topic) {
            this.cachedCoefficients[topic] = this.alpha[topic] / ((double)this.tokensPerTopic[topic] + this.betaSum);
        }
        System.err.println("max tokens: " + maxTokens);
        System.err.println("total tokens: " + totalTokens);
        this.docLengthCounts = new int[maxTokens + 1];
        this.topicDocCounts = new int[this.numTopics][maxTokens + 1];
    }

    public void estimate() throws IOException {
        this.estimate(this.numIterations);
    }

    public void estimate(int iterationsThisRound) throws IOException {
        long startTime = System.currentTimeMillis();
        int maxIteration = this.iterationsSoFar + iterationsThisRound;
        long totalTime = 0L;
        while (this.iterationsSoFar <= maxIteration) {
            long iterationStart = System.currentTimeMillis();
            if (this.showTopicsInterval == 0 || this.iterationsSoFar == 0 || this.iterationsSoFar % this.showTopicsInterval == 0) {
                // empty if block
            }
            if (this.saveStateInterval != 0 && this.iterationsSoFar % this.saveStateInterval == 0) {
                this.printState(new File(this.stateFilename + '.' + this.iterationsSoFar));
            }
            if (this.iterationsSoFar > this.burninPeriod && this.optimizeInterval != 0 && this.iterationsSoFar % this.optimizeInterval == 0) {
                this.alphaSum = Dirichlet.learnParameters(this.alpha, this.topicDocCounts, this.docLengthCounts);
                this.smoothingOnlyMass = 0.0;
                for (int topic = 0; topic < this.numTopics; ++topic) {
                    this.smoothingOnlyMass += this.alpha[topic] * this.beta / ((double)this.tokensPerTopic[topic] + this.betaSum);
                    this.cachedCoefficients[topic] = this.alpha[topic] / ((double)this.tokensPerTopic[topic] + this.betaSum);
                }
                this.clearHistograms();
            }
            this.smoothingOnlyCount = 0;
            this.betaTopicCount = 0;
            this.topicTermCount = 0;
            for (int di = 0; di < this.data.size(); ++di) {
                FeatureSequence tokenSequence = (FeatureSequence)this.data.get((int)di).instance.getData();
                LabelSequence topicSequence = this.data.get((int)di).topicSequence;
                this.sampleTopicsForOneDoc(tokenSequence, topicSequence, this.iterationsSoFar >= this.burninPeriod && this.iterationsSoFar % this.saveSampleInterval == 0, true);
            }
            long elapsedMillis = System.currentTimeMillis() - iterationStart;
            double ll = this.modelLogLikelihood();
            Runtime runtime = Runtime.getRuntime();
            long usedMemory = runtime.totalMemory() - runtime.freeMemory();
            System.out.println(elapsedMillis + "\t" + (totalTime += elapsedMillis) + "\t" + ll + "\t" + usedMemory);
            ++this.iterationsSoFar;
        }
    }

    private void clearHistograms() {
        Arrays.fill(this.docLengthCounts, 0);
        for (int topic = 0; topic < this.topicDocCounts.length; ++topic) {
            Arrays.fill(this.topicDocCounts[topic], 0);
        }
    }

    protected void sampleTopicsForOneDoc(FeatureSequence tokenSequence, FeatureSequence topicSequence, boolean shouldSaveState, boolean readjustTopicsAndStats) {
        int topic;
        int[] oneDocTopics = topicSequence.getFeatures();
        int docLength = tokenSequence.getLength();
        int[] localTopicCounts = new int[this.numTopics];
        int[] localTopicIndex = new int[this.numTopics];
        for (int position = 0; position < docLength; ++position) {
            int n = oneDocTopics[position];
            localTopicCounts[n] = localTopicCounts[n] + 1;
        }
        int denseIndex = 0;
        for (int topic2 = 0; topic2 < this.numTopics; ++topic2) {
            if (localTopicCounts[topic2] == 0) continue;
            localTopicIndex[denseIndex] = topic2;
            ++denseIndex;
        }
        int nonZeroTopics = denseIndex;
        double topicBetaMass = 0.0;
        for (denseIndex = 0; denseIndex < nonZeroTopics; ++denseIndex) {
            int topic3 = localTopicIndex[denseIndex];
            int n = localTopicCounts[topic3];
            topicBetaMass += this.beta * (double)n / ((double)this.tokensPerTopic[topic3] + this.betaSum);
            this.cachedCoefficients[topic3] = (this.alpha[topic3] + (double)n) / ((double)this.tokensPerTopic[topic3] + this.betaSum);
        }
        double topicTermMass = 0.0;
        double[] topicTermScores = new double[this.numTopics];
        for (int position = 0; position < docLength; ++position) {
            int temp;
            double sample;
            int currentValue;
            int type = tokenSequence.getIndexAtPosition(position);
            int oldTopic = oneDocTopics[position];
            int[] currentTypeTopicCounts = this.typeTopicCounts[type];
            this.smoothingOnlyMass -= this.alpha[oldTopic] * this.beta / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
            topicBetaMass -= this.beta * (double)localTopicCounts[oldTopic] / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
            int n = oldTopic;
            localTopicCounts[n] = localTopicCounts[n] - 1;
            if (localTopicCounts[oldTopic] == 0) {
                denseIndex = 0;
                while (localTopicIndex[denseIndex] != oldTopic) {
                    ++denseIndex;
                }
                while (denseIndex < nonZeroTopics) {
                    if (denseIndex < localTopicIndex.length - 1) {
                        localTopicIndex[denseIndex] = localTopicIndex[denseIndex + 1];
                    }
                    ++denseIndex;
                }
                --nonZeroTopics;
            }
            int n2 = oldTopic;
            this.tokensPerTopic[n2] = this.tokensPerTopic[n2] - 1;
            assert (this.tokensPerTopic[oldTopic] >= 0) : "old Topic " + oldTopic + " below 0";
            this.smoothingOnlyMass += this.alpha[oldTopic] * this.beta / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
            topicBetaMass += this.beta * (double)localTopicCounts[oldTopic] / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
            this.cachedCoefficients[oldTopic] = (this.alpha[oldTopic] + (double)localTopicCounts[oldTopic]) / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
            int index = 0;
            boolean alreadyDecremented = false;
            topicTermMass = 0.0;
            while (index < currentTypeTopicCounts.length && currentTypeTopicCounts[index] > 0) {
                int currentTopic = currentTypeTopicCounts[index] & this.topicMask;
                currentValue = currentTypeTopicCounts[index] >> this.topicBits;
                if (!alreadyDecremented && currentTopic == oldTopic) {
                    currentTypeTopicCounts[index] = --currentValue == 0 ? 0 : (currentValue << this.topicBits) + oldTopic;
                    for (int subIndex = index; subIndex < currentTypeTopicCounts.length - 1 && currentTypeTopicCounts[subIndex] < currentTypeTopicCounts[subIndex + 1]; ++subIndex) {
                        int temp2 = currentTypeTopicCounts[subIndex];
                        currentTypeTopicCounts[subIndex] = currentTypeTopicCounts[subIndex + 1];
                        currentTypeTopicCounts[subIndex + 1] = temp2;
                    }
                    alreadyDecremented = true;
                    continue;
                }
                double score = this.cachedCoefficients[currentTopic] * (double)currentValue;
                topicTermMass += score;
                topicTermScores[index] = score;
                ++index;
            }
            double origSample = sample = this.random.nextUniform() * (this.smoothingOnlyMass + topicBetaMass + topicTermMass);
            int newTopic = -1;
            if (sample < topicTermMass) {
                int i = -1;
                while (sample > 0.0) {
                    sample -= topicTermScores[++i];
                }
                newTopic = currentTypeTopicCounts[i] & this.topicMask;
                currentValue = currentTypeTopicCounts[i] >> this.topicBits;
                currentTypeTopicCounts[i] = (currentValue + 1 << this.topicBits) + newTopic;
                while (i > 0 && currentTypeTopicCounts[i] > currentTypeTopicCounts[i - 1]) {
                    temp = currentTypeTopicCounts[i];
                    currentTypeTopicCounts[i] = currentTypeTopicCounts[i - 1];
                    currentTypeTopicCounts[i - 1] = temp;
                    --i;
                }
            } else {
                if ((sample -= topicTermMass) < topicBetaMass) {
                    sample /= this.beta;
                    for (denseIndex = 0; denseIndex < nonZeroTopics; ++denseIndex) {
                        int topic4 = localTopicIndex[denseIndex];
                        if (!((sample -= (double)localTopicCounts[topic4] / ((double)this.tokensPerTopic[topic4] + this.betaSum)) <= 0.0)) continue;
                        newTopic = topic4;
                        break;
                    }
                } else {
                    sample -= topicBetaMass;
                    sample /= this.beta;
                    newTopic = 0;
                    sample -= this.alpha[newTopic] / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
                    while (sample > 0.0) {
                        sample -= this.alpha[++newTopic] / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
                    }
                }
                index = 0;
                while (currentTypeTopicCounts[index] > 0 && (currentTypeTopicCounts[index] & this.topicMask) != newTopic) {
                    ++index;
                }
                if (currentTypeTopicCounts[index] == 0) {
                    currentTypeTopicCounts[index] = (1 << this.topicBits) + newTopic;
                } else {
                    currentValue = currentTypeTopicCounts[index] >> this.topicBits;
                    currentTypeTopicCounts[index] = (currentValue + 1 << this.topicBits) + newTopic;
                    while (index > 0 && currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) {
                        temp = currentTypeTopicCounts[index];
                        currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1];
                        currentTypeTopicCounts[index - 1] = temp;
                        --index;
                    }
                }
            }
            if (newTopic == -1) {
                System.err.println("CodingTest sampling error: " + origSample + " " + sample + " " + this.smoothingOnlyMass + " " + topicBetaMass + " " + topicTermMass);
                newTopic = this.numTopics - 1;
            }
            oneDocTopics[position] = newTopic;
            this.smoothingOnlyMass -= this.alpha[newTopic] * this.beta / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
            topicBetaMass -= this.beta * (double)localTopicCounts[newTopic] / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
            int n3 = newTopic;
            localTopicCounts[n3] = localTopicCounts[n3] + 1;
            if (localTopicCounts[newTopic] == 1) {
                for (denseIndex = nonZeroTopics; denseIndex > 0 && localTopicIndex[denseIndex - 1] > newTopic; --denseIndex) {
                    localTopicIndex[denseIndex] = localTopicIndex[denseIndex - 1];
                }
                localTopicIndex[denseIndex] = newTopic;
                ++nonZeroTopics;
            }
            int n4 = newTopic;
            this.tokensPerTopic[n4] = this.tokensPerTopic[n4] + 1;
            this.cachedCoefficients[newTopic] = (this.alpha[newTopic] + (double)localTopicCounts[newTopic]) / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
            this.smoothingOnlyMass += this.alpha[newTopic] * this.beta / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
            topicBetaMass += this.beta * (double)localTopicCounts[newTopic] / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
        }
        if (shouldSaveState) {
            int n = docLength;
            this.docLengthCounts[n] = this.docLengthCounts[n] + 1;
            for (denseIndex = 0; denseIndex < nonZeroTopics; ++denseIndex) {
                topic = localTopicIndex[denseIndex];
                int[] nArray = this.topicDocCounts[topic];
                int n5 = localTopicCounts[topic];
                nArray[n5] = nArray[n5] + 1;
            }
        }
        for (denseIndex = 0; denseIndex < nonZeroTopics; ++denseIndex) {
            topic = localTopicIndex[denseIndex];
            this.cachedCoefficients[topic] = this.alpha[topic] / ((double)this.tokensPerTopic[topic] + this.betaSum);
        }
    }

    public void printTopWords(File file, int numWords, boolean useNewLines) throws IOException {
        PrintStream out = new PrintStream(file);
        this.printTopWords(out, numWords, useNewLines);
        out.close();
    }

    public void printTopWords(PrintStream out, int numWords, boolean usingNewLines) {
        int topic;
        FeatureCounter[] wordCountsPerTopic = new FeatureCounter[this.numTopics];
        for (topic = 0; topic < this.numTopics; ++topic) {
            wordCountsPerTopic[topic] = new FeatureCounter(this.alphabet);
        }
        for (int type = 0; type < this.numTypes; ++type) {
            int[] topicCounts = this.typeTopicCounts[type];
            for (int index = 0; index < topicCounts.length && topicCounts[index] > 0; ++index) {
                int topic2 = topicCounts[index] & this.topicMask;
                int count = topicCounts[index] >> this.topicBits;
                wordCountsPerTopic[topic2].increment(type, count);
            }
        }
        for (topic = 0; topic < this.numTopics; ++topic) {
            RankedFeatureVector rfv = wordCountsPerTopic[topic].toRankedFeatureVector();
            if (usingNewLines) {
                out.println("Topic " + topic);
                int max = rfv.numLocations();
                if (max > numWords) {
                    max = numWords;
                }
                for (int ri = 0; ri < max; ++ri) {
                    int type = rfv.getIndexAtRank(ri);
                    out.println(this.alphabet.lookupObject(type).toString() + "\t" + (int)rfv.getValueAtRank(ri));
                }
                continue;
            }
            out.print(topic + "\t" + this.formatter.format(this.alpha[topic]) + "\t");
            for (int ri = 0; ri < numWords; ++ri) {
                out.print(this.alphabet.lookupObject(rfv.getIndexAtRank(ri)).toString() + " ");
            }
            out.print("\n");
        }
    }

    public void printDocumentTopics(File f) throws IOException {
        this.printDocumentTopics(new PrintWriter(new FileWriter(f)));
    }

    public void printDocumentTopics(PrintWriter pw) {
        this.printDocumentTopics(pw, 0.0, -1);
    }

    public void printDocumentTopics(PrintWriter pw, double threshold, int max) {
        pw.print("#doc source topic proportion ...\n");
        int[] topicCounts = new int[this.numTopics];
        Object[] sortedTopics = new IDSorter[this.numTopics];
        for (int topic = 0; topic < this.numTopics; ++topic) {
            sortedTopics[topic] = new IDSorter(topic, topic);
        }
        if (max < 0 || max > this.numTopics) {
            max = this.numTopics;
        }
        for (int di = 0; di < this.data.size(); ++di) {
            LabelSequence topicSequence = this.data.get((int)di).topicSequence;
            int[] currentDocTopics = topicSequence.getFeatures();
            pw.print(di);
            pw.print(' ');
            if (this.data.get((int)di).instance.getSource() != null) {
                pw.print(this.data.get((int)di).instance.getSource());
            } else {
                pw.print("null-source");
            }
            pw.print(' ');
            int docLen = currentDocTopics.length;
            for (int token = 0; token < docLen; ++token) {
                int n = currentDocTopics[token];
                topicCounts[n] = topicCounts[n] + 1;
            }
            for (int topic = 0; topic < this.numTopics; ++topic) {
                ((IDSorter)sortedTopics[topic]).set(topic, (float)topicCounts[topic] / (float)docLen);
            }
            Arrays.sort(sortedTopics);
            for (int i = 0; i < max && !(((IDSorter)sortedTopics[i]).getWeight() < threshold); ++i) {
                pw.print(((IDSorter)sortedTopics[i]).getID() + " " + ((IDSorter)sortedTopics[i]).getWeight() + " ");
            }
            pw.print(" \n");
            Arrays.fill(topicCounts, 0);
        }
    }

    public void printState(File f) throws IOException {
        PrintStream out = new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f))));
        this.printState(out);
        out.close();
    }

    public void printState(PrintStream out) {
        out.println("#doc source pos typeindex type topic");
        for (int di = 0; di < this.data.size(); ++di) {
            FeatureSequence tokenSequence = (FeatureSequence)this.data.get((int)di).instance.getData();
            LabelSequence topicSequence = this.data.get((int)di).topicSequence;
            String source = "NA";
            if (this.data.get((int)di).instance.getSource() != null) {
                source = this.data.get((int)di).instance.getSource().toString();
            }
            for (int pi = 0; pi < topicSequence.getLength(); ++pi) {
                int type = tokenSequence.getIndexAtPosition(pi);
                int topic = topicSequence.getIndexAtPosition(pi);
                out.print(di);
                out.print(' ');
                out.print(source);
                out.print(' ');
                out.print(pi);
                out.print(' ');
                out.print(type);
                out.print(' ');
                out.print(this.alphabet.lookupObject(type));
                out.print(' ');
                out.print(topic);
                out.println();
            }
        }
    }

    public void write(File f) {
        try {
            ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f));
            oos.writeObject(this);
            oos.close();
        }
        catch (IOException e) {
            System.err.println("Exception writing file " + f + ": " + e);
        }
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.writeInt(0);
        out.writeObject(this.data);
        out.writeObject(this.alphabet);
        out.writeObject(this.topicAlphabet);
        out.writeInt(this.numTopics);
        out.writeObject(this.alpha);
        out.writeDouble(this.beta);
        out.writeDouble(this.betaSum);
        out.writeInt(this.topicMask);
        out.writeInt(this.topicBits);
        out.writeDouble(this.smoothingOnlyMass);
        out.writeObject(this.cachedCoefficients);
        out.writeInt(this.iterationsSoFar);
        out.writeInt(this.numIterations);
        out.writeInt(this.burninPeriod);
        out.writeInt(this.saveSampleInterval);
        out.writeInt(this.optimizeInterval);
        out.writeInt(this.showTopicsInterval);
        out.writeInt(this.wordsPerTopic);
        out.writeInt(this.outputModelInterval);
        out.writeObject(this.outputModelFilename);
        out.writeInt(this.saveStateInterval);
        out.writeObject(this.stateFilename);
        out.writeObject(this.random);
        out.writeObject(this.formatter);
        out.writeBoolean(this.printLogLikelihood);
        out.writeObject(this.docLengthCounts);
        out.writeObject(this.topicDocCounts);
        out.writeObject(this.typeTopicCounts);
        for (int ti = 0; ti < this.numTopics; ++ti) {
            out.writeInt(this.tokensPerTopic[ti]);
        }
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        int version = in.readInt();
        this.data = (ArrayList)in.readObject();
        this.alphabet = (Alphabet)in.readObject();
        this.topicAlphabet = (LabelAlphabet)in.readObject();
        this.numTopics = in.readInt();
        this.alpha = (double[])in.readObject();
        this.beta = in.readDouble();
        this.betaSum = in.readDouble();
        this.topicMask = in.readInt();
        this.topicBits = in.readInt();
        this.smoothingOnlyMass = in.readDouble();
        this.cachedCoefficients = (double[])in.readObject();
        this.iterationsSoFar = in.readInt();
        this.numIterations = in.readInt();
        this.burninPeriod = in.readInt();
        this.saveSampleInterval = in.readInt();
        this.optimizeInterval = in.readInt();
        this.showTopicsInterval = in.readInt();
        this.wordsPerTopic = in.readInt();
        this.outputModelInterval = in.readInt();
        this.outputModelFilename = (String)in.readObject();
        this.saveStateInterval = in.readInt();
        this.stateFilename = (String)in.readObject();
        this.random = (Randoms)in.readObject();
        this.formatter = (NumberFormat)in.readObject();
        this.printLogLikelihood = in.readBoolean();
        this.docLengthCounts = (int[])in.readObject();
        this.topicDocCounts = (int[][])in.readObject();
        int numDocs = this.data.size();
        this.numTypes = this.alphabet.size();
        this.typeTopicCounts = (int[][])in.readObject();
        this.tokensPerTopic = new int[this.numTopics];
        for (int ti = 0; ti < this.numTopics; ++ti) {
            this.tokensPerTopic[ti] = in.readInt();
        }
    }

    public double modelLogLikelihood() {
        double logLikelihood = 0.0;
        int[] topicCounts = new int[this.numTopics];
        double[] topicLogGammas = new double[this.numTopics];
        for (int topic = 0; topic < this.numTopics; ++topic) {
            topicLogGammas[topic] = Dirichlet.logGammaStirling(this.alpha[topic]);
        }
        for (int doc = 0; doc < this.data.size(); ++doc) {
            LabelSequence topicSequence = this.data.get((int)doc).topicSequence;
            int[] docTopics = topicSequence.getFeatures();
            for (int token = 0; token < docTopics.length; ++token) {
                int n = docTopics[token];
                topicCounts[n] = topicCounts[n] + 1;
            }
            for (int topic = 0; topic < this.numTopics; ++topic) {
                if (topicCounts[topic] <= 0) continue;
                logLikelihood += Dirichlet.logGammaStirling(this.alpha[topic] + (double)topicCounts[topic]) - topicLogGammas[topic];
            }
            logLikelihood -= Dirichlet.logGammaStirling(this.alphaSum + (double)docTopics.length);
            Arrays.fill(topicCounts, 0);
        }
        logLikelihood += (double)this.data.size() * Dirichlet.logGammaStirling(this.alphaSum);
        int nonZeroTypeTopics = 0;
        for (int type = 0; type < this.numTypes; ++type) {
            topicCounts = this.typeTopicCounts[type];
            for (int index = 0; index < topicCounts.length && topicCounts[index] > 0; ++index) {
                int topic = topicCounts[index] & this.topicMask;
                int count = topicCounts[index] >> this.topicBits;
                ++nonZeroTypeTopics;
                if (!Double.isNaN(logLikelihood += Dirichlet.logGammaStirling(this.beta + (double)count))) continue;
                System.out.println(count);
                System.exit(1);
            }
        }
        for (int topic = 0; topic < this.numTopics; ++topic) {
            if (!Double.isNaN(logLikelihood -= Dirichlet.logGammaStirling(this.beta * (double)this.numTopics + (double)this.tokensPerTopic[topic]))) continue;
            System.out.println("after topic " + topic + " " + this.tokensPerTopic[topic]);
            System.exit(1);
        }
        if (Double.isNaN(logLikelihood += Dirichlet.logGammaStirling(this.beta * (double)this.numTopics) - Dirichlet.logGammaStirling(this.beta) * (double)nonZeroTypeTopics)) {
            System.out.println("at the end");
            System.exit(1);
        }
        return logLikelihood;
    }

    public static void main(String[] args) throws IOException {
        InstanceList training = InstanceList.load(new File(args[0]));
        int numTopics = args.length > 1 ? Integer.parseInt(args[1]) : 200;
        CodingTest lda = new CodingTest(numTopics, 50.0, 0.01);
        lda.printLogLikelihood = true;
        lda.setTopicDisplay(0, 7);
        lda.addInstances(training);
        lda.estimate();
    }

    public class TopicAssignment
    implements Serializable {
        public Instance instance;
        public CodingTest model;
        public LabelSequence topicSequence;
        public Labeling topicDistribution;

        public TopicAssignment(Instance instance, CodingTest model, LabelSequence topicSequence) {
            this.instance = instance;
            this.model = model;
            this.topicSequence = topicSequence;
        }
    }
}

