/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.sphinx.decoder.search;

import edu.cmu.sphinx.decoder.pruner.Pruner;
import edu.cmu.sphinx.decoder.scorer.AcousticScorer;
import edu.cmu.sphinx.decoder.search.ActiveList;
import edu.cmu.sphinx.decoder.search.ActiveListFactory;
import edu.cmu.sphinx.decoder.search.ActiveListManager;
import edu.cmu.sphinx.decoder.search.Token;
import edu.cmu.sphinx.decoder.search.WordPruningBreadthFirstSearchManager;
import edu.cmu.sphinx.frontend.Data;
import edu.cmu.sphinx.linguist.Linguist;
import edu.cmu.sphinx.linguist.SearchState;
import edu.cmu.sphinx.linguist.SearchStateArc;
import edu.cmu.sphinx.linguist.WordSearchState;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.Loader;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.Sphinx3Loader;
import edu.cmu.sphinx.linguist.allphone.PhoneHmmSearchState;
import edu.cmu.sphinx.linguist.lextree.LexTreeLinguist;
import edu.cmu.sphinx.result.Result;
import edu.cmu.sphinx.util.props.PropertyException;
import edu.cmu.sphinx.util.props.PropertySheet;
import edu.cmu.sphinx.util.props.S4Component;
import edu.cmu.sphinx.util.props.S4Double;
import edu.cmu.sphinx.util.props.S4Integer;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;

public class WordPruningBreadthFirstLookaheadSearchManager
extends WordPruningBreadthFirstSearchManager {
    @S4Component(type=Loader.class)
    public static final String PROP_LOADER = "loader";
    @S4Component(type=Linguist.class)
    public static final String PROP_FASTMATCH_LINGUIST = "fastmatchLinguist";
    @S4Component(type=ActiveListFactory.class)
    public static final String PROP_FM_ACTIVE_LIST_FACTORY = "fastmatchActiveListFactory";
    @S4Double(defaultValue=1.0)
    public static final String PROP_LOOKAHEAD_PENALTY_WEIGHT = "lookaheadPenaltyWeight";
    @S4Integer(defaultValue=5)
    public static final String PROP_LOOKAHEAD_WINDOW = "lookaheadWindow";
    private Linguist fastmatchLinguist;
    private Loader loader;
    private ActiveListFactory fastmatchActiveListFactory;
    private int lookaheadWindow;
    private float lookaheadWeight;
    private HashMap<Integer, Float> penalties;
    private LinkedList<FrameCiScores> ciScores;
    private int currentFastMatchFrameNumber;
    protected ActiveList fastmatchActiveList;
    protected Map<SearchState, Token> fastMatchBestTokenMap;
    private boolean fastmatchStreamEnd;

    public WordPruningBreadthFirstLookaheadSearchManager(Linguist linguist, Linguist fastmatchLinguist, Loader loader, Pruner pruner, AcousticScorer scorer, ActiveListManager activeListManager, ActiveListFactory fastmatchActiveListFactory, boolean showTokenCount, double relativeWordBeamWidth, int growSkipInterval, boolean checkStateOrder, boolean buildWordLattice, int lookaheadWindow, float lookaheadWeight, int maxLatticeEdges, float acousticLookaheadFrames, boolean keepAllTokens) {
        super(linguist, pruner, scorer, activeListManager, showTokenCount, relativeWordBeamWidth, growSkipInterval, checkStateOrder, buildWordLattice, maxLatticeEdges, acousticLookaheadFrames, keepAllTokens);
        this.loader = loader;
        this.fastmatchLinguist = fastmatchLinguist;
        this.fastmatchActiveListFactory = fastmatchActiveListFactory;
        this.lookaheadWindow = lookaheadWindow;
        this.lookaheadWeight = lookaheadWeight;
        if (lookaheadWindow < 1 || lookaheadWindow > 10) {
            throw new IllegalArgumentException("Unsupported lookahead window size: " + lookaheadWindow + ". Value in range [1..10] is expected");
        }
        this.ciScores = new LinkedList();
        this.penalties = new HashMap();
        if (loader instanceof Sphinx3Loader && ((Sphinx3Loader)loader).hasTiedMixtures()) {
            ((Sphinx3Loader)loader).setGauScoresQueueLength(lookaheadWindow + 2);
        }
    }

    public WordPruningBreadthFirstLookaheadSearchManager() {
    }

    @Override
    public void newProperties(PropertySheet ps) throws PropertyException {
        super.newProperties(ps);
        this.fastmatchLinguist = (Linguist)ps.getComponent(PROP_FASTMATCH_LINGUIST);
        this.fastmatchActiveListFactory = (ActiveListFactory)ps.getComponent(PROP_FM_ACTIVE_LIST_FACTORY);
        this.loader = (Loader)ps.getComponent(PROP_LOADER);
        this.lookaheadWindow = ps.getInt(PROP_LOOKAHEAD_WINDOW);
        this.lookaheadWeight = ps.getFloat(PROP_LOOKAHEAD_PENALTY_WEIGHT);
        if (this.lookaheadWindow < 1 || this.lookaheadWindow > 10) {
            throw new PropertyException(WordPruningBreadthFirstLookaheadSearchManager.class.getName(), PROP_LOOKAHEAD_WINDOW, "Unsupported lookahead window size: " + this.lookaheadWindow + ". Value in range [1..10] is expected");
        }
        this.ciScores = new LinkedList();
        this.penalties = new HashMap();
        if (this.loader instanceof Sphinx3Loader && ((Sphinx3Loader)this.loader).hasTiedMixtures()) {
            ((Sphinx3Loader)this.loader).setGauScoresQueueLength(this.lookaheadWindow + 2);
        }
    }

    @Override
    public Result recognize(int nFrames) {
        boolean done = false;
        Result result = null;
        this.streamEnd = false;
        for (int i = 0; i < nFrames && !done; ++i) {
            if (!this.fastmatchStreamEnd) {
                this.fastMatchRecognize();
            }
            this.penalties.clear();
            this.ciScores.poll();
            done = this.recognize();
        }
        if (!this.streamEnd) {
            result = new Result(this.loserManager, this.activeList, this.resultList, this.currentCollectTime, done, this.linguist.getSearchGraph().getWordTokenFirst(), true);
        }
        if (this.showTokenCount) {
            this.showTokenCount();
        }
        return result;
    }

    private void fastMatchRecognize() {
        boolean more = this.scoreFastMatchTokens();
        if (more) {
            this.pruneFastMatchBranches();
            ++this.currentFastMatchFrameNumber;
            this.createFastMatchBestTokenMap();
            this.growFastmatchBranches();
        }
    }

    protected void createFastMatchBestTokenMap() {
        int mapSize = this.fastmatchActiveList.size() * 10;
        if (mapSize == 0) {
            mapSize = 1;
        }
        this.fastMatchBestTokenMap = new HashMap<SearchState, Token>(mapSize);
    }

    @Override
    protected void localStart() {
        this.currentFastMatchFrameNumber = 0;
        if (this.loader instanceof Sphinx3Loader && ((Sphinx3Loader)this.loader).hasTiedMixtures()) {
            ((Sphinx3Loader)this.loader).clearGauScores();
        }
        this.fastmatchActiveList = this.fastmatchActiveListFactory.newInstance();
        SearchState fmInitState = this.fastmatchLinguist.getSearchGraph().getInitialState();
        this.fastmatchActiveList.add(new Token(fmInitState, this.currentFastMatchFrameNumber));
        this.createFastMatchBestTokenMap();
        this.growFastmatchBranches();
        this.fastmatchStreamEnd = false;
        for (int i = 0; i < this.lookaheadWindow - 1 && !this.fastmatchStreamEnd; ++i) {
            this.fastMatchRecognize();
        }
        super.localStart();
    }

    protected void growFastmatchBranches() {
        this.growTimer.start();
        ActiveList oldActiveList = this.fastmatchActiveList;
        this.fastmatchActiveList = this.fastmatchActiveListFactory.newInstance();
        float fastmathThreshold = oldActiveList.getBeamThreshold();
        float[] frameCiScores = new float[100];
        Arrays.fill(frameCiScores, -3.4028235E38f);
        float frameMaxCiScore = -3.4028235E38f;
        for (Token token : oldActiveList) {
            float tokenScore = token.getScore();
            if (tokenScore < fastmathThreshold) continue;
            if (token.getSearchState() instanceof PhoneHmmSearchState) {
                int baseId = ((PhoneHmmSearchState)token.getSearchState()).getBaseId();
                if (frameCiScores[baseId] < tokenScore) {
                    frameCiScores[baseId] = tokenScore;
                }
                if (frameMaxCiScore < tokenScore) {
                    frameMaxCiScore = tokenScore;
                }
            }
            this.collectFastMatchSuccessorTokens(token);
        }
        this.ciScores.add(new FrameCiScores(frameCiScores, frameMaxCiScore));
        this.growTimer.stop();
    }

    protected boolean scoreFastMatchTokens() {
        this.scoreTimer.start();
        Data data = this.scorer.calculateScoresAndStoreData(this.fastmatchActiveList.getTokens());
        this.scoreTimer.stop();
        Token bestToken = null;
        if (data instanceof Token) {
            bestToken = (Token)data;
        } else {
            this.fastmatchStreamEnd = true;
        }
        boolean moreTokens = bestToken != null;
        this.fastmatchActiveList.setBestToken(bestToken);
        this.monitorStates(this.fastmatchActiveList);
        this.curTokensScored.value += (double)this.fastmatchActiveList.size();
        this.totalTokensScored.value += (double)this.fastmatchActiveList.size();
        return moreTokens;
    }

    protected void pruneFastMatchBranches() {
        this.pruneTimer.start();
        this.fastmatchActiveList = this.pruner.prune(this.fastmatchActiveList);
        this.pruneTimer.stop();
    }

    protected Token getFastMatchBestToken(SearchState state) {
        return this.fastMatchBestTokenMap.get(state);
    }

    protected void setFastMatchBestToken(Token token, SearchState state) {
        this.fastMatchBestTokenMap.put(state, token);
    }

    protected void collectFastMatchSuccessorTokens(Token token) {
        SearchStateArc[] arcs;
        SearchState state = token.getSearchState();
        for (SearchStateArc arc : arcs = state.getSuccessors()) {
            SearchState nextState = arc.getState();
            float logEntryScore = token.getScore() + arc.getProbability();
            Token predecessor = this.getResultListPredecessor(token);
            if (!nextState.isEmitting()) {
                Token newToken = new Token(predecessor, nextState, logEntryScore, arc.getInsertionProbability(), arc.getLanguageProbability(), this.currentFastMatchFrameNumber);
                this.tokensCreated.value += 1.0;
                if (this.isVisited(newToken)) continue;
                this.collectFastMatchSuccessorTokens(newToken);
                continue;
            }
            Token bestToken = this.getFastMatchBestToken(nextState);
            if (bestToken == null) {
                Token newToken = new Token(predecessor, nextState, logEntryScore, arc.getInsertionProbability(), arc.getLanguageProbability(), this.currentFastMatchFrameNumber);
                this.tokensCreated.value += 1.0;
                this.setFastMatchBestToken(newToken, nextState);
                this.fastmatchActiveList.add(newToken);
                continue;
            }
            if (!(bestToken.getScore() <= logEntryScore)) continue;
            bestToken.update(predecessor, nextState, logEntryScore, arc.getInsertionProbability(), arc.getLanguageProbability(), this.currentFastMatchFrameNumber);
        }
    }

    @Override
    protected void collectSuccessorTokens(Token token) {
        if (token.isFinal()) {
            this.resultList.add(this.getResultListPredecessor(token));
            return;
        }
        if (!token.isEmitting() && this.keepAllTokens && this.isVisited(token)) {
            return;
        }
        SearchState state = token.getSearchState();
        SearchStateArc[] arcs = state.getSuccessors();
        Token predecessor = this.getResultListPredecessor(token);
        float tokenScore = token.getScore();
        float beamThreshold = this.activeList.getBeamThreshold();
        boolean stateProducesPhoneHmms = state instanceof LexTreeLinguist.LexTreeNonEmittingHMMState || state instanceof LexTreeLinguist.LexTreeWordState || state instanceof LexTreeLinguist.LexTreeEndUnitState;
        for (SearchStateArc arc : arcs) {
            SearchState nextState = arc.getState();
            if (stateProducesPhoneHmms && nextState instanceof LexTreeLinguist.LexTreeHMMState) {
                int baseId = ((LexTreeLinguist.LexTreeHMMState)nextState).getHMMState().getHMM().getBaseUnit().getBaseID();
                Float penalty = this.penalties.get(baseId);
                if (penalty == null) {
                    penalty = this.updateLookaheadPenalty(baseId);
                }
                if (tokenScore + this.lookaheadWeight * penalty.floatValue() < beamThreshold) continue;
            }
            if (this.checkStateOrder) {
                this.checkStateOrder(state, nextState);
            }
            float logEntryScore = tokenScore + arc.getProbability();
            Token bestToken = this.getBestToken(nextState);
            if (bestToken == null) {
                Token newBestToken = new Token(predecessor, nextState, logEntryScore, arc.getInsertionProbability(), arc.getLanguageProbability(), this.currentCollectTime);
                this.tokensCreated.value += 1.0;
                this.setBestToken(newBestToken, nextState);
                this.activeListAdd(newBestToken);
                continue;
            }
            if (bestToken.getScore() < logEntryScore) {
                Token oldPredecessor = bestToken.getPredecessor();
                bestToken.update(predecessor, nextState, logEntryScore, arc.getInsertionProbability(), arc.getLanguageProbability(), this.currentCollectTime);
                if (!this.buildWordLattice || !(nextState instanceof WordSearchState)) continue;
                this.loserManager.addAlternatePredecessor(bestToken, oldPredecessor);
                continue;
            }
            if (!this.buildWordLattice || !(nextState instanceof WordSearchState) || predecessor == null) continue;
            this.loserManager.addAlternatePredecessor(bestToken, predecessor);
        }
    }

    private Float updateLookaheadPenalty(int baseId) {
        if (this.ciScores.isEmpty()) {
            return Float.valueOf(0.0f);
        }
        float penalty = -3.4028235E38f;
        for (FrameCiScores frameCiScores : this.ciScores) {
            float diff = frameCiScores.scores[baseId] - frameCiScores.maxScore;
            if (!(diff > penalty)) continue;
            penalty = diff;
        }
        this.penalties.put(baseId, Float.valueOf(penalty));
        return Float.valueOf(penalty);
    }

    private class FrameCiScores {
        public final float[] scores;
        public final float maxScore;

        public FrameCiScores(float[] scores, float maxScore) {
            this.scores = scores;
            this.maxScore = maxScore;
        }
    }
}

