//
//  LinearExtensionGenerator.h
//  POSet
//
//  Created by Alessandro Avellone on 14/05/2019.
//  Copyright © 2019 Alessandro Avellone. All rights reserved.
//

#ifndef linearExtensionGenerator_hpp
#define linearExtensionGenerator_hpp


#include <map>
#include <deque>
#include <chrono>
#include <sstream>
#include <fstream>
#include <cstdint>
#include <cmath>
#include <bitset>

#include "paramType.h"
#include "utilita.h"
#include "linearExtension.h"
#include "poset.h"
#include "treeOfIdeals.h"
#include "latticeOfIdeals.h"

// ***********************************************
// ***********************************************
// ***********************************************

class LinearExtensionGenerator {
protected:
    std::uint_fast64_t current_number_le;
    std::shared_ptr<LinearExtension> currentLinearExtension;
    bool started;
public:
    std::shared_ptr<std::vector<std::shared_ptr<POSet>>> posets;
    LinearExtensionGenerator(std::shared_ptr<std::vector<std::shared_ptr<POSet>>> posets) : posets(posets) {
        this->current_number_le = 0;
    }
    virtual ~LinearExtensionGenerator() {}
    virtual void start(std::uint_fast64_t) = 0;
    virtual void next() = 0;
    virtual bool hasNext() = 0;
    virtual std::shared_ptr<std::vector<std::string>> keys() = 0;
    virtual std::uint_fast64_t numberOfLE() const = 0;
    virtual void to_file(std::fstream&, char DELIMETER = ';') = 0;

    virtual std::string to_string() const {return "";};
    std::uint_fast64_t currentNumberOfLE() const {
        return this->current_number_le;
    }
    std::shared_ptr<LinearExtension> get() {
        if (this->started == false) {
            std::string err_str = "LEG error: not started yet!";
            throw_line(err_str);
        }
        return this->currentLinearExtension;
    }
    
    std::uint_fast64_t LESize() const {
        return currentLinearExtension->size();
    }
    
};

// ***********************************************
// ***********************************************
// ***********************************************

class LEGBubleyDyer : public LinearExtensionGenerator {
private:
    std::shared_ptr<Random> rnd = nullptr;
    bool toUpdate;
    bool isSwitched;
    std::uint_fast64_t positionToUpdate;
    std::uint_fast64_t max_number_le;
    std::shared_ptr<POSet> poset;
public:
    LEGBubleyDyer(std::shared_ptr<std::vector<std::shared_ptr<POSet>>>, std::shared_ptr<Random>);
    virtual ~LEGBubleyDyer() {}
    virtual std::string to_string() const;
    virtual void start(std::uint_fast64_t);
    virtual void next();
    virtual bool hasNext() {
        return current_number_le < max_number_le;
    }
    virtual void to_file(std::fstream&, char DELIMETER = ';');
    
    virtual std::shared_ptr<std::vector<std::string>> keys() {
        return this->poset->Elements();
    }
    virtual std::uint_fast64_t numberOfLE() const {
        return max_number_le;
    }
    
    bool UpdateCounters(std::uint_fast64_t quante, std::shared_ptr<double> errore) {
        if (quante != 0 && quante != std::numeric_limits<std::uint_fast64_t>::max()) {
            max_number_le += quante;
            return true;
        } else if (quante == std::numeric_limits<std::uint_fast64_t>::max()) {
            auto n = this->evaluateNumberOfIteration(*errore);
            if (max_number_le < n) {
                max_number_le = n;
                return true;
            }
        }
        return false;
    }
    
    std::uint_fast64_t evaluateNumberOfIteration(double error) {
        std::uint_fast64_t risultato;
        size_t nelementi = this->poset->size();
        double add1 = std::pow(nelementi, 4) * std::pow(std::log(nelementi), 2);
        double add2 = std::pow(nelementi, 3) * std::log(nelementi) * std::log(1/error);
        risultato = ((std::uint_fast64_t)(std::max((add1 + add2), 1.0)));
        return risultato;
    }

private:
    std::uint_fast64_t getSetOneElement(std::set<std::uint_fast64_t>& setOne);
};

// ***********************************************
// ***********************************************
// ***********************************************

class LEGTreeOfIdeals : public LinearExtensionGenerator {
private:
    std::uint_fast64_t extension_stack_size;
    bool more_extensions;
    std::shared_ptr<LatticeOfIdeals>  latticeOfIdeals;
    std::shared_ptr<std::vector<std::uint_fast64_t>> latticeOfIdealsCrossing;
    std::shared_ptr<std::vector<bool>> moreCrossing;
    std::shared_ptr<POSet> poset;
public:
    LEGTreeOfIdeals(std::shared_ptr<std::vector<std::shared_ptr<POSet>>>);
    virtual ~LEGTreeOfIdeals() {}
    virtual std::string to_string() const;
    virtual void start(std::uint_fast64_t);
    virtual void next();
    virtual bool hasNext();
    virtual std::shared_ptr<std::vector<std::string>> keys() {
        return this->poset->Elements();
    }
    
    virtual void to_file(std::fstream&, char DELIMETER = ';');
    virtual std::uint_fast64_t numberOfLE() const {
        return 0;
    }
    
private:
    bool IsPossibleToSwitch(std::uint_fast64_t p1, std::uint_fast64_t p2);
};


// ***********************************************
// ***********************************************
// ***********************************************

class LEGFromLinearPosets : public LinearExtensionGenerator {
private:
    std::shared_ptr<std::vector<std::string>> product_keys;
public:
    LEGFromLinearPosets(std::shared_ptr<std::vector<std::shared_ptr<POSet>>> posets) : LinearExtensionGenerator(posets) {
        auto le_size = 1;
        for (auto p : *posets) {
            le_size *= p->size();
        }
        currentLinearExtension = std::make_shared<LinearExtension>(le_size);
        product_keys = std::make_shared<std::vector<std::string>>(le_size, "");
    }

    virtual ~LEGFromLinearPosets() {}
    
    virtual std::string to_string() const {
        std::string base_string = LinearExtensionGenerator::to_string();
        std::string risultato = "FromLinearPosets:";
        if (base_string != "")
            risultato += "\n\t" + FindAndReplaceAll(base_string, "\n", "\n\t");
        return risultato;
    }
    
    virtual void start(std::uint_fast64_t) {
        std::vector<std::uint_fast64_t> prod_val(posets->size(), 0);
        
        auto tostring = [this](std::vector<std::uint_fast64_t>& e)
        {
            std::string result = "";
            for (std::uint_fast64_t k = 0; k < e.size(); ++k) {
                result += this->posets->at(k)->Elements()->at(e.at(k));
                if (k < e.size() - 1) {
                    result += "_";
                }
            }
            return result;
        };
        
        auto modify = [this](std::vector<std::uint_fast64_t>& e)
        {
            std::uint_fast64_t k = e.size() - 1;
            while (true) {
                if (e.at(k) < this->posets->at(k)->Elements()->size() - 1) {
                    ++e.at(k);
                    return;
                }
                e.at(k) = 0;
                if (k == 0) {
                    break;
                }
                --k;
            }
        };
        
        for (std::uint_fast64_t k = 0; k < currentLinearExtension->size(); ++k) {
            currentLinearExtension->set(k, k);
            product_keys->at(k) = tostring(prod_val);
            modify(prod_val);
        }
        
        started = true;
        current_number_le = 1;
    }
    
    virtual void next() {
        std::string err_str = "LEGFromLinearPosets error: max number of linear extention reached!";
        throw_line(err_str);
    }
    
    virtual bool hasNext() {
        return false;
    }
    
    virtual std::shared_ptr<std::vector<std::string>> keys() {
        return product_keys;
    }
    
    
    void to_file(std::fstream& file_le, char DELIMETER) {
        if (file_le.is_open()) {
            std::string str_le = "";
            bool first = true;
            for (std::uint_fast64_t k = 0; k < this->currentLinearExtension->size(); ++k) {
                std::string nome_etichetta = product_keys->at(this->currentLinearExtension->getVal(k));
                if (first) {
                    str_le = "" + nome_etichetta;
                    first = false;
                }
                else {
                    str_le += DELIMETER + nome_etichetta;
                }
            }
            file_le  << str_le;
            file_le  << std::endl;
        }
    }
    virtual std::uint_fast64_t numberOfLE() const {
        return 1;
    }
private:
    bool IsPossibleToSwitch(std::uint_fast64_t p1, std::uint_fast64_t p2);
};

// ***********************************************
// ***********************************************
// ***********************************************

class LEGBinaryVariable : public LinearExtensionGenerator {
private:
    std::shared_ptr<BinaryVariablePOSet> binary_poset;
    std::shared_ptr<LinearExtension> le;
    std::shared_ptr<LinearExtension> le_inv;
    std::vector<std::uint_fast64_t> permutazione;
    std::vector<std::uint_fast64_t> permutazione_inv;
    std::vector<std::uint_fast64_t> permutazione_interna;
    std::uint_fast64_t numero_variabili;
    std::uint_fast64_t numero_profili;
    std::uint_fast64_t perm_var_1;
    std::uint_fast64_t perm_var_2;
    std::uint_fast64_t max_number_le;
    
public:
    std::shared_ptr<std::vector<std::shared_ptr<POSet>>> posets;
    LEGBinaryVariable(std::shared_ptr<std::vector<std::shared_ptr<POSet>>> posets) : LinearExtensionGenerator(posets) {
        if (posets->size() != 1) {
            std::string err_str = "LEGBinaryVariable error: poset";
            throw_line(err_str);
        }
        if (dynamic_pointer_cast<BinaryVariablePOSet>(posets->at(0)) == nullptr) {
            std::string err_str = "LEGBinaryVariable error: wrong poset";
            throw_line(err_str);
        }
        binary_poset = dynamic_pointer_cast<BinaryVariablePOSet>(posets->at(0));
        numero_variabili = binary_poset->NumberOfVariables();
        numero_profili = binary_poset->size();
        max_number_le = std::tgamma(numero_variabili + 1);

        currentLinearExtension = std::make_shared<LinearExtension>(numero_profili);
        
        
        permutazione.assign(numero_variabili, 0);
        permutazione_inv.assign(numero_variabili, 0);
        permutazione_interna.assign(numero_variabili - 2, 0);
        
        le = std::make_shared<LinearExtension>(numero_profili);
        le_inv = std::make_shared<LinearExtension>(numero_profili);
        
    }
    virtual ~LEGBinaryVariable() {}
    
    virtual std::string to_string() const {
        std::string base_string = LinearExtensionGenerator::to_string();
        std::string risultato = "LEGBinaryVariable:";
        if (base_string != "")
            risultato += "\n\t" + FindAndReplaceAll(base_string, "\n", "\n\t");
        return risultato;
    }
    
    virtual void start(std::uint_fast64_t) {
        perm_var_1 = 0;
        perm_var_2 = 1;
        
        for (std::uint_fast64_t i = 0, j = 0; i < numero_variabili; ++i) {
            if (i != perm_var_1 && i != perm_var_2) {
                permutazione_interna.at(j++) = i;
            }
        }
        // costruzione  due permutazioni - start
        permutazione.at(0) = perm_var_1;
        permutazione_inv.at(0) = perm_var_2;

        for (std::uint_fast64_t p = 0; p < permutazione_interna.size(); p++) {
            permutazione.at(p + 1) = permutazione_interna.at(p);
            permutazione_inv.at(p + 1) = permutazione_interna.at(permutazione_interna.size() - p - 1);
        }
        permutazione.at(numero_variabili - 1) = perm_var_2;
        permutazione_inv.at(numero_variabili - 1) = perm_var_1;
        
        // costruzione  due permutazioni - end
        BuildLEFromPermutation();
        
        for (std::uint_fast64_t k = 0; k < currentLinearExtension->size(); ++k) {
            auto val = le->getVal(k);
            currentLinearExtension->set(k, val);
        }
        
        started = true;
        current_number_le = 1;
        
        //std::cout << '\t' << vector_to_string(*permutazione) << std::endl;
        //std::cout << '\t' << vector_to_string(*permutazione_inv) << std::endl;
    }
    
    
    virtual void next() {
        if (current_number_le % 2 == 1) {
            for (std::uint_fast64_t k = 0; k < currentLinearExtension->size(); ++k) {
                auto val = le_inv->getVal(k);
                currentLinearExtension->set(k, val);
            }
        } else {
            if (!std::next_permutation(std::begin(permutazione_interna), std::end(permutazione_interna))) {
                if (perm_var_2 >= numero_variabili - 1) {
                    if (perm_var_1 >= numero_variabili - 1) {
                        std::string err_str = "LEGBinaryVariable error: next()";
                        throw_line(err_str);
                    }
                    ++perm_var_1;
                    perm_var_2 = perm_var_1;
                }
                ++perm_var_2;
                for (std::uint_fast64_t i = 0, j = 0; i < numero_variabili; ++i) {
                    if (i != perm_var_1 && i != perm_var_2) {
                        permutazione_interna.at(j++) = i;
                    }
                }
            }
            
            permutazione.at(0) = perm_var_1;
            permutazione_inv.at(0) = perm_var_2;

            for (std::uint_fast64_t p = 0; p < permutazione_interna.size(); p++) {
                permutazione.at(p + 1) = permutazione_interna.at(p);
                permutazione_inv.at(p + 1) = permutazione_interna.at(permutazione_interna.size() - p - 1);
            }
            permutazione.at(numero_variabili - 1) = perm_var_2;
            permutazione_inv.at(numero_variabili - 1) = perm_var_1;
            
            //std::cout << '\t' << vector_to_string(*permutazione) << std::endl;
            //std::cout << '\t' << vector_to_string(*permutazione_inv) << std::endl;

            
            // costruzione  due permutazioni - end
            BuildLEFromPermutation();
            
            for (std::uint_fast64_t k = 0; k < currentLinearExtension->size(); ++k) {
                auto val = le->getVal(k);
                currentLinearExtension->set(k, val);
            }
        }
        ++current_number_le;
    }
    
    virtual bool hasNext() {
        return current_number_le < max_number_le;
    }
    
    virtual std::shared_ptr<std::vector<std::string>> keys() {
        auto result = std::make_shared<std::vector<std::string>>(numero_profili);
        for (std::uint_fast64_t k = 0; k < numero_profili; ++k) {
            std::string s = std::bitset<64>(k).to_string();
            result->at(k) = s.substr(s.size() - numero_variabili);
        }
        return result;
    }

    void to_file(std::fstream& file_le, char DELIMETER) {
        if (file_le.is_open()) {
            std::string str_le = "";
            bool first = true;
            for (std::uint_fast64_t k = 0; k < this->currentLinearExtension->size(); ++k) {
                std::string nome_etichetta = binary_poset->GetEName(this->currentLinearExtension->getVal(k));
                if (first) {
                    str_le = "" + nome_etichetta;
                    first = false;
                }
                else {
                    str_le += DELIMETER + nome_etichetta;
                }
            }
            file_le  << str_le;
            file_le  << std::endl;
        }
    }
    virtual std::uint_fast64_t numberOfLE() const {
        return this->max_number_le;
    }
private:
    void BuildLEFromPermutation() {
        for (std::uint_fast64_t n = 0; n < le->size(); ++n) {
            std::uint_fast64_t value_lex = 0;
            std::uint_fast64_t value_lex_inv = 0;
            for (std::uint_fast64_t i = 0; i < numero_variabili; ++i) {
                std::uint_fast64_t old_val = (n >> (numero_variabili - i - 1) & ((std::uint_fast64_t) 1));
                
                std::uint_fast64_t new_val = old_val << (numero_variabili - permutazione.at(i) - 1);
                value_lex = value_lex | new_val;
                
                std::uint_fast64_t new_val_inv = old_val << (numero_variabili - permutazione_inv.at(i) - 1);
                value_lex_inv = value_lex_inv | new_val_inv;
            }
            le->set(n, value_lex);
            le_inv->set(n, value_lex_inv);
        }
    }
};


#endif /* linearExtensionGenerator_hpp */
