/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.sphinx.linguist.acoustic.tiedstate.trainer;

import edu.cmu.sphinx.frontend.FloatData;
import edu.cmu.sphinx.linguist.acoustic.HMM;
import edu.cmu.sphinx.linguist.acoustic.HMMState;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.GaussianMixture;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.GaussianWeights;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.HMMManager;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.Loader;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.MixtureComponent;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.Pool;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.Senone;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.SenoneHMM;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.SenoneHMMState;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.trainer.Buffer;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.trainer.TrainerScore;
import edu.cmu.sphinx.util.LogMath;
import java.io.IOException;
import java.util.HashMap;
import java.util.logging.Logger;

class HMMPoolManager {
    private HMMManager hmmManager;
    private HashMap<Object, Integer> indexMap;
    private Pool<float[]> meansPool;
    private Pool<float[]> variancePool;
    private Pool<float[][]> matrixPool;
    private GaussianWeights mixtureWeights;
    private Pool<Buffer> meansBufferPool;
    private Pool<Buffer> varianceBufferPool;
    private Pool<Buffer[]> matrixBufferPool;
    private Pool<Buffer> mixtureWeightsBufferPool;
    private Pool<Senone> senonePool;
    private LogMath logMath;
    private float logMixtureWeightFloor;
    private float logTransitionProbabilityFloor;
    private float varianceFloor;
    private float logLikelihood;
    private float currentLogLikelihood;
    private static Logger logger = Logger.getLogger("edu.cmu.sphinx.linguist.acoustic.HMMPoolManager");

    protected HMMPoolManager(Loader loader) throws IOException {
        loader.load();
        this.hmmManager = loader.getHMMManager();
        this.indexMap = new HashMap();
        this.meansPool = loader.getMeansPool();
        this.variancePool = loader.getVariancePool();
        this.mixtureWeights = loader.getMixtureWeights();
        this.matrixPool = loader.getTransitionMatrixPool();
        this.senonePool = loader.getSenonePool();
        this.createBuffers();
        this.logLikelihood = 0.0f;
        this.logMath = LogMath.getLogMath();
    }

    protected void resetBuffers() {
        this.createBuffers();
        this.logLikelihood = 0.0f;
    }

    protected void createBuffers() {
        this.meansBufferPool = this.create1DPoolBuffer(this.meansPool, false);
        this.varianceBufferPool = this.create1DPoolBuffer(this.variancePool, false);
        this.matrixBufferPool = this.create2DPoolBuffer(this.matrixPool, true);
        this.mixtureWeightsBufferPool = this.createWeightsPoolBuffer(this.mixtureWeights);
    }

    private Pool<Buffer> create1DPoolBuffer(Pool<float[]> pool, boolean isLog) {
        Pool<Buffer> bufferPool = new Pool<Buffer>(pool.getName());
        for (int i = 0; i < pool.size(); ++i) {
            float[] element = pool.get(i);
            this.indexMap.put(element, i);
            Buffer buffer = new Buffer(element.length, isLog, i);
            bufferPool.put(i, buffer);
        }
        return bufferPool;
    }

    private Pool<Buffer> createWeightsPoolBuffer(GaussianWeights mixtureWeights) {
        Pool<Buffer> bufferPool = new Pool<Buffer>(mixtureWeights.getName());
        int statesNum = mixtureWeights.getStatesNum();
        int streamsNum = mixtureWeights.getStreamsNum();
        int gauPerState = mixtureWeights.getGauPerState();
        for (int i = 0; i < streamsNum; ++i) {
            for (int j = 0; j < statesNum; ++j) {
                int id = i * statesNum + j;
                Buffer buffer = new Buffer(gauPerState, true, id);
                bufferPool.put(id, buffer);
            }
        }
        return bufferPool;
    }

    private Pool<Buffer[]> create2DPoolBuffer(Pool<float[][]> pool, boolean isLog) {
        Pool<Buffer[]> bufferPool = new Pool<Buffer[]>(pool.getName());
        for (int i = 0; i < pool.size(); ++i) {
            float[][] element = pool.get(i);
            this.indexMap.put(element, i);
            int poolSize = element.length;
            Buffer[] bufferArray = new Buffer[poolSize];
            for (int j = 0; j < poolSize; ++j) {
                bufferArray[j] = new Buffer(element[j].length, isLog, j);
            }
            bufferPool.put(i, bufferArray);
        }
        return bufferPool;
    }

    protected void accumulate(int index, TrainerScore[] score) {
        this.accumulate(index, score, null);
    }

    protected void accumulate(int index, TrainerScore[] score, TrainerScore[] nextScore) {
        TrainerScore thisScore = score[index];
        this.currentLogLikelihood = 0.0f;
        this.logLikelihood -= score[0].getScalingFactor();
        SenoneHMMState state = (SenoneHMMState)thisScore.getState();
        if (state == null) {
            int senoneID = thisScore.getSenoneID();
            if (senoneID == -1) {
                this.accumulateMean(senoneID, score[index]);
                this.accumulateVariance(senoneID, score[index]);
                this.accumulateMixture(senoneID, score[index]);
                this.accumulateTransition(senoneID, index, score, nextScore);
            }
        } else if (state.isEmitting()) {
            int senoneID = this.senonePool.indexOf(state.getSenone());
            this.accumulateMixture(senoneID, score[index]);
            this.accumulateTransition(senoneID, index, score, nextScore);
        }
    }

    private void accumulateMean(int senone, TrainerScore score) {
        if (senone == -1) {
            for (int i = 0; i < this.senonePool.size(); ++i) {
                this.accumulateMean(i, score);
            }
        } else {
            GaussianMixture gaussian = (GaussianMixture)this.senonePool.get(senone);
            MixtureComponent[] mix = gaussian.getMixtureComponents();
            for (int i = 0; i < mix.length; ++i) {
                float[] mean = mix[i].getMean();
                int indexMean = this.indexMap.get(mean);
                assert (indexMean >= 0);
                assert (indexMean == senone);
                Buffer buffer = this.meansBufferPool.get(indexMean);
                float[] feature = ((FloatData)score.getData()).getValues();
                double[] data = new double[feature.length];
                float prob = score.getComponentGamma()[i];
                double dprob = this.logMath.logToLinear(prob -= this.currentLogLikelihood);
                for (int j = 0; j < data.length; ++j) {
                    data[j] = (double)feature[j] * dprob;
                }
                buffer.accumulate(data, dprob);
            }
        }
    }

    private void accumulateVariance(int senone, TrainerScore score) {
        if (senone == -1) {
            for (int i = 0; i < this.senonePool.size(); ++i) {
                this.accumulateVariance(i, score);
            }
        } else {
            GaussianMixture gaussian = (GaussianMixture)this.senonePool.get(senone);
            MixtureComponent[] mix = gaussian.getMixtureComponents();
            for (int i = 0; i < mix.length; ++i) {
                float[] mean = mix[i].getMean();
                float[] variance = mix[i].getVariance();
                int indexVariance = this.indexMap.get(variance);
                Buffer buffer = this.varianceBufferPool.get(indexVariance);
                float[] feature = ((FloatData)score.getData()).getValues();
                double[] data = new double[feature.length];
                float prob = score.getComponentGamma()[i];
                double dprob = this.logMath.logToLinear(prob -= this.currentLogLikelihood);
                for (int j = 0; j < data.length; ++j) {
                    data[j] = feature[j] - mean[j];
                    int n = j;
                    data[n] = data[n] * (data[j] * dprob);
                }
                buffer.accumulate(data, dprob);
            }
        }
    }

    private void accumulateMixture(int senone, TrainerScore score) {
        if (senone == -1) {
            for (int i = 0; i < this.senonePool.size(); ++i) {
                this.accumulateMixture(i, score);
            }
        } else {
            Buffer buffer = this.mixtureWeightsBufferPool.get(senone);
            for (int i = 0; i < this.mixtureWeights.getGauPerState(); ++i) {
                float prob = score.getComponentGamma()[i];
                buffer.logAccumulate(prob -= this.currentLogLikelihood, i, this.logMath);
            }
        }
    }

    private void accumulateStateTransition(int indexScore, TrainerScore[] score, TrainerScore[] nextScore) {
        HMMState state = score[indexScore].getState();
        if (state == null) {
            return;
        }
        int indexState = state.getState();
        SenoneHMM hmm = (SenoneHMM)state.getHMM();
        float[][] matrix = hmm.getTransitionMatrix();
        int indexMatrix = this.indexMap.get(matrix);
        Buffer[] bufferArray = this.matrixBufferPool.get(indexMatrix);
        float[] vector = matrix[indexState];
        for (int i = 0; i < vector.length; ++i) {
            if (vector[i] == -3.4028235E38f) continue;
            int dist = i - indexState;
            int indexNextScore = indexScore + dist;
            assert (nextScore[indexNextScore].getState() == null || nextScore[indexNextScore].getState().getHMM() == hmm);
            float alpha = score[indexScore].getAlpha();
            float beta = nextScore[indexNextScore].getBeta();
            float transitionProb = vector[i];
            float outputProb = nextScore[indexNextScore].getScore();
            float prob = alpha + beta + transitionProb + outputProb;
            bufferArray[indexState].logAccumulate(prob -= this.currentLogLikelihood, i, this.logMath);
        }
    }

    private void accumulateStateTransition(int indexState, SenoneHMM hmm, float value) {
        float[][] matrix = hmm.getTransitionMatrix();
        float[] stateVector = matrix[indexState];
        int indexMatrix = this.indexMap.get(matrix);
        Buffer[] bufferArray = this.matrixBufferPool.get(indexMatrix);
        for (int i = 0; i < stateVector.length; ++i) {
            if (stateVector[i] == -3.4028235E38f) continue;
            bufferArray[indexState].logAccumulate(value, i, this.logMath);
        }
    }

    private void accumulateTransition(int indexHmm, int indexScore, TrainerScore[] score, TrainerScore[] nextScore) {
        if (indexHmm == -1) {
            for (HMM hmm : this.hmmManager) {
                for (int j = 0; j < hmm.getOrder(); ++j) {
                    this.accumulateStateTransition(j, (SenoneHMM)hmm, score[indexScore].getScore());
                }
            }
        } else if (nextScore != null) {
            this.accumulateStateTransition(indexScore, score, nextScore);
        }
    }

    protected void updateLogLikelihood() {
    }

    protected float normalize() {
        this.normalizePool(this.meansBufferPool);
        this.normalizePool(this.varianceBufferPool);
        this.logNormalizePool(this.mixtureWeightsBufferPool);
        this.logNormalize2DPool(this.matrixBufferPool, this.matrixPool);
        return this.logLikelihood;
    }

    private void normalizePool(Pool<Buffer> pool) {
        assert (pool != null);
        for (int i = 0; i < pool.size(); ++i) {
            Buffer buffer = pool.get(i);
            if (!buffer.wasUsed()) continue;
            buffer.normalize();
        }
    }

    private void logNormalizePool(Pool<Buffer> pool) {
        assert (pool != null);
        for (int i = 0; i < pool.size(); ++i) {
            Buffer buffer = pool.get(i);
            if (!buffer.wasUsed()) continue;
            buffer.logNormalize();
        }
    }

    private void logNormalize2DPool(Pool<Buffer[]> pool, Pool<float[][]> maskPool) {
        assert (pool != null);
        for (int i = 0; i < pool.size(); ++i) {
            Buffer[] bufferArray = pool.get(i);
            float[][] mask = maskPool.get(i);
            for (int j = 0; j < bufferArray.length; ++j) {
                if (!bufferArray[j].wasUsed()) continue;
                bufferArray[j].logNormalizeNonZero(mask[j]);
            }
        }
    }

    protected void update() {
        this.updateMeans();
        this.updateVariances();
        this.recomputeMixtureComponents();
        this.updateMixtureWeights();
        this.updateTransitionMatrices();
    }

    private void copyVector(float[] in, float[] out) {
        assert (in.length == out.length);
        System.arraycopy(in, 0, out, 0, in.length);
    }

    private void updateMeans() {
        assert (this.meansPool.size() == this.meansBufferPool.size());
        for (int i = 0; i < this.meansPool.size(); ++i) {
            float[] means = this.meansPool.get(i);
            Buffer buffer = this.meansBufferPool.get(i);
            if (buffer.wasUsed()) {
                float[] meansBuffer = buffer.getValues();
                this.copyVector(meansBuffer, means);
                continue;
            }
            logger.info("Senone " + i + " not used.");
        }
    }

    private void updateVariances() {
        assert (this.variancePool.size() == this.varianceBufferPool.size());
        for (int i = 0; i < this.variancePool.size(); ++i) {
            float[] means = this.meansPool.get(i);
            float[] variance = this.variancePool.get(i);
            Buffer buffer = this.varianceBufferPool.get(i);
            if (!buffer.wasUsed()) continue;
            float[] varianceBuffer = buffer.getValues();
            assert (means.length == varianceBuffer.length);
            for (int j = 0; j < means.length; ++j) {
                int n = j;
                varianceBuffer[n] = varianceBuffer[n] - means[j] * means[j];
                if (!(varianceBuffer[j] < this.varianceFloor)) continue;
                varianceBuffer[j] = this.varianceFloor;
            }
            this.copyVector(varianceBuffer, variance);
        }
    }

    private void recomputeMixtureComponents() {
        for (int i = 0; i < this.senonePool.size(); ++i) {
            MixtureComponent[] mixComponent;
            GaussianMixture gMix = (GaussianMixture)this.senonePool.get(i);
            for (MixtureComponent component : mixComponent = gMix.getMixtureComponents()) {
                component.precomputeDistance();
            }
        }
    }

    private void updateMixtureWeights() {
        int statesNum = this.mixtureWeights.getStatesNum();
        int streamsNum = this.mixtureWeights.getStreamsNum();
        assert (statesNum * streamsNum == this.mixtureWeightsBufferPool.size());
        for (int i = 0; i < streamsNum; ++i) {
            for (int j = 0; j < statesNum; ++j) {
                int id = i * statesNum + j;
                Buffer buffer = this.mixtureWeightsBufferPool.get(id);
                if (!buffer.wasUsed()) continue;
                if (buffer.logFloor(this.logMixtureWeightFloor)) {
                    buffer.logNormalizeToSum(this.logMath);
                }
                float[] mixtureWeightsBuffer = buffer.getValues();
                this.mixtureWeights.put(j, i, mixtureWeightsBuffer);
            }
        }
    }

    private void updateTransitionMatrices() {
        assert (this.matrixPool.size() == this.matrixBufferPool.size());
        for (int i = 0; i < this.matrixPool.size(); ++i) {
            float[][] matrix = this.matrixPool.get(i);
            Buffer[] bufferArray = this.matrixBufferPool.get(i);
            for (int j = 0; j < matrix.length; ++j) {
                Buffer buffer = bufferArray[j];
                if (!buffer.wasUsed()) continue;
                for (int k = 0; k < matrix[j].length; ++k) {
                    float bufferValue = buffer.getValue(k);
                    if (bufferValue == -3.4028235E38f) continue;
                    assert (matrix[j][k] != -3.4028235E38f);
                    if (!(bufferValue < this.logTransitionProbabilityFloor)) continue;
                    buffer.setValue(k, this.logTransitionProbabilityFloor);
                }
                buffer.logNormalizeToSum(this.logMath);
                this.copyVector(buffer.getValues(), matrix[j]);
            }
        }
    }
}

