// This file contains code from external sources.
// Attributions: https://github.com/wasmerio/wasmer/blob/main/docs/ATTRIBUTIONS.md

//! WebAssembly module and function translation state.
//!
//! The `ModuleTranslationState` struct defined in this module is used to keep track of data about
//! the whole WebAssembly module, such as the decoded type signatures.
//!
//! The `FuncTranslationState` struct defined in this module is used to keep track of the WebAssembly
//! value and control stacks during the translation of a single function.

use super::func_environ::{FuncEnvironment, GlobalVariable};
use crate::heap::Heap;
use crate::translator::code_translator::CatchClause;
use crate::{HashMap, Occupied, Vacant};
use cranelift_codegen::ir::{self, Block, Inst, Value};
use cranelift_frontend::FunctionBuilder;
use itertools::Itertools;
use std::vec::Vec;
use wasmer_types::{
    CATCH_ALL_TAG_VALUE, FunctionIndex, GlobalIndex, MemoryIndex, SignatureIndex, WasmResult,
};

/// Information about the presence of an associated `else` for an `if`, or the
/// lack thereof.
#[derive(Debug)]
pub enum ElseData {
    /// The `if` does not already have an `else` block.
    ///
    /// This doesn't mean that it will never have an `else`, just that we
    /// haven't seen it yet.
    NoElse {
        /// If we discover that we need an `else` block, this is the jump
        /// instruction that needs to be fixed up to point to the new `else`
        /// block rather than the destination block after the `if...end`.
        branch_inst: Inst,

        /// The placeholder block we're replacing.
        placeholder: Block,
    },

    /// We have already allocated an `else` block.
    ///
    /// Usually we don't know whether we will hit an `if .. end` or an `if
    /// .. else .. end`, but sometimes we can tell based on the block's type
    /// signature that the signature is not valid if there isn't an `else`. In
    /// these cases, we pre-allocate the `else` block.
    WithElse {
        /// This is the `else` block.
        else_block: Block,
    },
}

/// A control stack frame can be an `if`, a `block` or a `loop`, each one having the following
/// fields:
///
/// - `destination`: reference to the `Block` that will hold the code after the control block;
/// - `num_return_values`: number of values returned by the control block;
/// - `original_stack_size`: size of the value stack at the beginning of the control block.
///
/// Moreover, the `if` frame has the `branch_inst` field that points to the `brz` instruction
/// separating the `true` and `false` branch. The `loop` frame has a `header` field that references
/// the `Block` that contains the beginning of the body of the loop.
#[derive(Debug)]
pub enum ControlStackFrame {
    If {
        destination: Block,
        else_data: ElseData,
        num_param_values: usize,
        num_return_values: usize,
        original_stack_size: usize,
        exit_is_branched_to: bool,
        blocktype: wasmer_compiler::wasmparser::BlockType,
        /// Was the head of the `if` reachable?
        head_is_reachable: bool,
        /// What was the reachability at the end of the consequent?
        ///
        /// This is `None` until we're finished translating the consequent, and
        /// is set to `Some` either by hitting an `else` when we will begin
        /// translating the alternative, or by hitting an `end` in which case
        /// there is no alternative.
        consequent_ends_reachable: Option<bool>,
        // Note: no need for `alternative_ends_reachable` because that is just
        // `state.reachable` when we hit the `end` in the `if .. else .. end`.
    },
    Block {
        destination: Block,
        num_param_values: usize,
        num_return_values: usize,
        original_stack_size: usize,
        exit_is_branched_to: bool,
        /// When this block corresponds to a try-table, keep the handler state
        /// checkpoint and the list of catch blocks to seal once the scope ends.
        try_table_info: Option<(HandlerStateCheckpoint, Vec<Block>)>,
    },
    Loop {
        destination: Block,
        header: Block,
        num_param_values: usize,
        num_return_values: usize,
        original_stack_size: usize,
    },
}

/// Helper methods for the control stack objects.
impl ControlStackFrame {
    pub fn num_return_values(&self) -> usize {
        match *self {
            Self::If {
                num_return_values, ..
            }
            | Self::Block {
                num_return_values, ..
            }
            | Self::Loop {
                num_return_values, ..
            } => num_return_values,
        }
    }
    pub fn num_param_values(&self) -> usize {
        match *self {
            Self::If {
                num_param_values, ..
            }
            | Self::Block {
                num_param_values, ..
            }
            | Self::Loop {
                num_param_values, ..
            } => num_param_values,
        }
    }
    pub fn following_code(&self) -> Block {
        match *self {
            Self::If { destination, .. }
            | Self::Block { destination, .. }
            | Self::Loop { destination, .. } => destination,
        }
    }
    pub fn br_destination(&self) -> Block {
        match *self {
            Self::If { destination, .. } | Self::Block { destination, .. } => destination,
            Self::Loop { header, .. } => header,
        }
    }
    /// Private helper. Use `truncate_value_stack_to_else_params()` or
    /// `truncate_value_stack_to_original_size()` to restore value-stack state.
    fn original_stack_size(&self) -> usize {
        match *self {
            Self::If {
                original_stack_size,
                ..
            }
            | Self::Block {
                original_stack_size,
                ..
            }
            | Self::Loop {
                original_stack_size,
                ..
            } => original_stack_size,
        }
    }
    pub fn is_loop(&self) -> bool {
        match *self {
            Self::If { .. } | Self::Block { .. } => false,
            Self::Loop { .. } => true,
        }
    }

    pub fn exit_is_branched_to(&self) -> bool {
        match *self {
            Self::If {
                exit_is_branched_to,
                ..
            }
            | Self::Block {
                exit_is_branched_to,
                ..
            } => exit_is_branched_to,
            Self::Loop { .. } => false,
        }
    }

    pub fn set_branched_to_exit(&mut self) {
        match *self {
            Self::If {
                ref mut exit_is_branched_to,
                ..
            }
            | Self::Block {
                ref mut exit_is_branched_to,
                ..
            } => *exit_is_branched_to = true,
            Self::Loop { .. } => {}
        }
    }

    /// Pop values from the value stack so that it is left at the
    /// input-parameters to an else-block.
    pub fn truncate_value_stack_to_else_params(&self, stack: &mut Vec<Value>) {
        debug_assert!(matches!(self, &Self::If { .. }));
        stack.truncate(self.original_stack_size());
    }

    /// Pop values from the value stack so that it is left at the state it was
    /// before this control-flow frame.
    pub fn truncate_value_stack_to_original_size(&self, stack: &mut Vec<Value>) {
        // The "If" frame pushes its parameters twice, so they're available to the else block
        // (see also `FuncTranslationState::push_if`).
        // Yet, the original_stack_size member accounts for them only once, so that the else
        // block can see the same number of parameters as the consequent block. As a matter of
        // fact, we need to substract an extra number of parameter values for if blocks.
        let num_duplicated_params = match self {
            &Self::If {
                num_param_values, ..
            } => {
                debug_assert!(num_param_values <= self.original_stack_size());
                num_param_values
            }
            _ => 0,
        };
        stack.truncate(self.original_stack_size() - num_duplicated_params);
    }

    /// Restore exception handler state and seal catch blocks when exiting a
    /// try-table scope.
    pub fn restore_catch_handlers(
        &self,
        handlers: &mut HandlerState,
        builder: &mut FunctionBuilder,
    ) {
        if let Self::Block {
            try_table_info: Some((checkpoint, catch_blocks)),
            ..
        } = self
        {
            handlers.restore_checkpoint(*checkpoint);
            for block in catch_blocks {
                builder.seal_block(*block);
            }
        }
    }
}

/// Contains information passed along during a function's translation and that records:
///
/// - The current value and control stacks.
/// - The depth of the two unreachable control blocks stacks, that are manipulated when translating
///   unreachable code;
pub struct FuncTranslationState {
    /// A stack of values corresponding to the active values in the input wasm function at this
    /// point.
    pub(crate) stack: Vec<Value>,
    /// A stack of active control flow operations at this point in the input wasm function.
    pub(crate) control_stack: Vec<ControlStackFrame>,
    /// Exception handler state used to attach catch blocks to try-calls.
    pub(crate) handlers: HandlerState,
    /// Is the current translation state still reachable? This is false when translating operators
    /// like End, Return, or Unreachable.
    pub(crate) reachable: bool,

    // Map of global variables that have already been created by `FuncEnvironment::make_global`.
    globals: HashMap<GlobalIndex, GlobalVariable>,

    // Map of heaps that have been created by `FuncEnvironment::make_heap`.
    heaps: HashMap<MemoryIndex, Heap>,

    // Map of indirect call signatures that have been created by
    // `FuncEnvironment::make_indirect_sig()`.
    // Stores both the signature reference and the number of WebAssembly arguments
    signatures: HashMap<SignatureIndex, (ir::SigRef, usize)>,

    // Imported and local functions that have been created by
    // `FuncEnvironment::make_direct_func()`.
    // Stores both the function reference and the number of WebAssembly arguments
    functions: HashMap<FunctionIndex, (ir::FuncRef, usize)>,
}

// Public methods that are exposed to non-`cranelift_wasm` API consumers.
impl FuncTranslationState {
    /// True if the current translation state expresses reachable code, false if it is unreachable.
    #[inline]
    #[allow(dead_code)]
    pub fn reachable(&self) -> bool {
        self.reachable
    }
}

#[derive(Clone, Copy, Debug)]
pub(crate) struct HandlerStateCheckpoint(usize, usize);

#[derive(Default)]
pub(crate) struct HandlerState {
    handlers: Vec<Block>,
    clauses: Vec<CatchClause>,
}

#[derive(Debug)]
pub(crate) struct LandingPad {
    pub(crate) block: Block,
    pub(crate) clauses: Vec<CatchClause>,
}

impl HandlerState {
    pub fn add_handler(&mut self, block: Block) {
        self.handlers.push(block);
    }

    pub fn add_clause(&mut self, clause: CatchClause) {
        self.clauses.push(clause);
    }

    pub fn take_checkpoint(&self) -> HandlerStateCheckpoint {
        HandlerStateCheckpoint(self.handlers.len(), self.clauses.len())
    }

    pub fn restore_checkpoint(&mut self, checkpoint: HandlerStateCheckpoint) {
        debug_assert!(checkpoint.0 <= self.handlers.len());
        debug_assert!(checkpoint.1 <= self.clauses.len());
        self.handlers.truncate(checkpoint.0);
        self.clauses.truncate(checkpoint.1);
    }

    /// Get the latest landing pad block including all the tags covered by it.
    pub fn landing_pad(&self) -> Option<LandingPad> {
        self.handlers.last().copied().map(|block| LandingPad {
            block,
            clauses: self.unique_clauses(),
        })
    }

    /// Returns an iterator over the catch clauses in reverse order, with duplicates removed.
    pub fn unique_clauses(&self) -> Vec<CatchClause> {
        self.clauses
            .iter()
            // Starting with the inner-most try_table catch clauses.
            .rev()
            .unique_by(|c| c.tag_value)
            // We can ignore every tag followed by the CatchAll.
            .take_while_inclusive(|c| c.tag_value != CATCH_ALL_TAG_VALUE)
            .cloned()
            .collect()
    }

    pub fn is_empty(&self) -> bool {
        self.handlers.is_empty()
    }

    pub fn clear(&mut self) {
        self.handlers.clear();
        self.clauses.clear();
    }
}

impl FuncTranslationState {
    /// Construct a new, empty, `FuncTranslationState`
    pub(crate) fn new() -> Self {
        Self {
            stack: Vec::new(),
            // TODO(reftypes):
            //metadata_stack: Vec::new(),
            control_stack: Vec::new(),
            handlers: HandlerState::default(),
            reachable: true,
            globals: HashMap::new(),
            heaps: HashMap::new(),
            signatures: HashMap::new(),
            functions: HashMap::new(),
        }
    }

    fn clear(&mut self) {
        debug_assert!(self.stack.is_empty());
        debug_assert!(self.control_stack.is_empty());
        debug_assert!(self.handlers.is_empty());
        self.reachable = true;
        self.handlers.clear();
        self.globals.clear();
        self.heaps.clear();
        self.signatures.clear();
        self.functions.clear();
    }

    /// Initialize the state for compiling a function with the given signature.
    ///
    /// This resets the state to containing only a single block representing the whole function.
    /// The exit block is the last block in the function which will contain the return instruction.
    pub(crate) fn initialize(&mut self, sig: &ir::Signature, exit_block: Block) {
        self.clear();
        self.push_block(
            exit_block,
            0,
            sig.returns
                .iter()
                .filter(|arg| arg.purpose == ir::ArgumentPurpose::Normal)
                .count(),
        );
    }

    /// Push a value.
    pub(crate) fn push1(&mut self, val: Value) {
        self.stack.push(val);
    }

    /// Push multiple values.
    pub(crate) fn pushn(&mut self, vals: &[Value]) {
        self.stack.extend_from_slice(vals);
    }

    /// Pop one value.
    pub(crate) fn pop1(&mut self) -> Value {
        self.stack
            .pop()
            .expect("attempted to pop a value from an empty stack")
    }

    /// Peek at the top of the stack without popping it.
    pub(crate) fn peek1(&self) -> Value {
        *self
            .stack
            .last()
            .expect("attempted to peek at a value on an empty stack")
    }

    /// Pop two values. Return them in the order they were pushed.
    pub(crate) fn pop2(&mut self) -> (Value, Value) {
        let v2 = self.pop1();
        let v1 = self.pop1();
        (v1, v2)
    }

    /// Pop three values. Return them in the order they were pushed.
    pub(crate) fn pop3(&mut self) -> (Value, Value, Value) {
        let v3 = self.pop1();
        let v2 = self.pop1();
        let v1 = self.pop1();
        (v1, v2, v3)
    }

    /// Helper to ensure the stack size is at least as big as `n`; note that due to
    /// `debug_assert` this will not execute in non-optimized builds.
    #[inline]
    fn ensure_length_is_at_least(&self, n: usize) {
        debug_assert!(
            n <= self.stack.len(),
            "attempted to access {} values but stack only has {} values",
            n,
            self.stack.len()
        );
    }

    /// Pop the top `n` values on the stack.
    ///
    /// The popped values are not returned. Use `peekn` to look at them before popping.
    pub(crate) fn popn(&mut self, n: usize) {
        self.ensure_length_is_at_least(n);
        let new_len = self.stack.len() - n;
        self.stack.truncate(new_len);
    }

    /// Peek at the top `n` values on the stack in the order they were pushed.
    pub(crate) fn peekn(&self, n: usize) -> &[Value] {
        self.ensure_length_is_at_least(n);
        &self.stack[self.stack.len() - n..]
    }

    /// Peek at the top `n` values on the stack in the order they were pushed.
    pub(crate) fn peekn_mut(&mut self, n: usize) -> &mut [Value] {
        self.ensure_length_is_at_least(n);
        let len = self.stack.len();
        &mut self.stack[len - n..]
    }

    fn push_block_impl(
        &mut self,
        following_code: Block,
        num_param_types: usize,
        num_result_types: usize,
        try_table_info: Option<(HandlerStateCheckpoint, Vec<Block>)>,
    ) {
        debug_assert!(num_param_types <= self.stack.len());
        self.control_stack.push(ControlStackFrame::Block {
            destination: following_code,
            original_stack_size: self.stack.len() - num_param_types,
            num_param_values: num_param_types,
            num_return_values: num_result_types,
            exit_is_branched_to: false,
            try_table_info,
        });
    }

    /// Push a block on the control stack.
    pub(crate) fn push_block(
        &mut self,
        following_code: Block,
        num_param_types: usize,
        num_result_types: usize,
    ) {
        self.push_block_impl(following_code, num_param_types, num_result_types, None);
    }

    /// Push a try-table block on the control stack.
    pub(crate) fn push_try_table_block(
        &mut self,
        following_code: Block,
        catch_blocks: Vec<Block>,
        num_param_types: usize,
        num_result_types: usize,
        checkpoint: HandlerStateCheckpoint,
    ) {
        self.push_block_impl(
            following_code,
            num_param_types,
            num_result_types,
            Some((checkpoint, catch_blocks)),
        );
    }

    /// Push a loop on the control stack.
    pub(crate) fn push_loop(
        &mut self,
        header: Block,
        following_code: Block,
        num_param_types: usize,
        num_result_types: usize,
    ) {
        debug_assert!(num_param_types <= self.stack.len());
        self.control_stack.push(ControlStackFrame::Loop {
            header,
            destination: following_code,
            original_stack_size: self.stack.len() - num_param_types,
            num_param_values: num_param_types,
            num_return_values: num_result_types,
        });
    }

    /// Push an if on the control stack.
    pub(crate) fn push_if(
        &mut self,
        destination: Block,
        else_data: ElseData,
        num_param_types: usize,
        num_result_types: usize,
        blocktype: wasmer_compiler::wasmparser::BlockType,
    ) {
        debug_assert!(num_param_types <= self.stack.len());

        // Push a second copy of our `if`'s parameters on the stack. This lets
        // us avoid saving them on the side in the `ControlStackFrame` for our
        // `else` block (if it exists), which would require a second heap
        // allocation. See also the comment in `translate_operator` for
        // `Operator::Else`.
        self.stack.reserve(num_param_types);
        for i in (self.stack.len() - num_param_types)..self.stack.len() {
            let val = self.stack[i];
            self.stack.push(val);
        }

        self.control_stack.push(ControlStackFrame::If {
            destination,
            else_data,
            original_stack_size: self.stack.len() - num_param_types,
            num_param_values: num_param_types,
            num_return_values: num_result_types,
            exit_is_branched_to: false,
            head_is_reachable: self.reachable,
            consequent_ends_reachable: None,
            blocktype,
        });
    }
}

/// Methods for handling entity references.
impl FuncTranslationState {
    /// Get the `GlobalVariable` reference that should be used to access the global variable
    /// `index`. Create the reference if necessary.
    /// Also return the WebAssembly type of the global.
    pub(crate) fn get_global<FE: FuncEnvironment + ?Sized>(
        &mut self,
        func: &mut ir::Function,
        index: u32,
        environ: &mut FE,
    ) -> WasmResult<GlobalVariable> {
        let index = GlobalIndex::from_u32(index);
        match self.globals.entry(index) {
            Occupied(entry) => Ok(*entry.get()),
            Vacant(entry) => Ok(*entry.insert(environ.make_global(func, index)?)),
        }
    }

    /// Get the `Heap` reference that should be used to access linear memory `index`.
    /// Create the reference if necessary.
    pub(crate) fn get_heap<FE: FuncEnvironment + ?Sized>(
        &mut self,
        func: &mut ir::Function,
        index: u32,
        environ: &mut FE,
    ) -> WasmResult<Heap> {
        let index = MemoryIndex::from_u32(index);
        match self.heaps.entry(index) {
            Occupied(entry) => Ok(*entry.get()),
            Vacant(entry) => Ok(*entry.insert(environ.make_heap(func, index)?)),
        }
    }

    /// Get the `SigRef` reference that should be used to make an indirect call with signature
    /// `index`. Also return the number of WebAssembly arguments in the signature.
    ///
    /// Create the signature if necessary.
    pub(crate) fn get_indirect_sig<FE: FuncEnvironment + ?Sized>(
        &mut self,
        func: &mut ir::Function,
        index: u32,
        environ: &mut FE,
    ) -> WasmResult<(ir::SigRef, usize)> {
        let index = SignatureIndex::from_u32(index);
        match self.signatures.entry(index) {
            Occupied(entry) => Ok(*entry.get()),
            Vacant(entry) => {
                let sig = environ.make_indirect_sig(func, index)?;
                Ok(*entry.insert((sig, num_wasm_parameters(environ, &func.dfg.signatures[sig]))))
            }
        }
    }

    /// Get the `FuncRef` reference that should be used to make a direct call to function
    /// `index`. Also return the number of WebAssembly arguments in the signature.
    ///
    /// Create the function reference if necessary.
    pub(crate) fn get_direct_func<FE: FuncEnvironment + ?Sized>(
        &mut self,
        func: &mut ir::Function,
        index: u32,
        environ: &mut FE,
    ) -> WasmResult<(ir::FuncRef, usize)> {
        let index = FunctionIndex::from_u32(index);
        match self.functions.entry(index) {
            Occupied(entry) => Ok(*entry.get()),
            Vacant(entry) => {
                let fref = environ.make_direct_func(func, index)?;
                let sig = func.dfg.ext_funcs[fref].signature;
                Ok(*entry.insert((
                    fref,
                    num_wasm_parameters(environ, &func.dfg.signatures[sig]),
                )))
            }
        }
    }
}

fn num_wasm_parameters<FE: FuncEnvironment + ?Sized>(
    environ: &FE,
    signature: &ir::Signature,
) -> usize {
    (0..signature.params.len())
        .filter(|index| environ.is_wasm_parameter(signature, *index))
        .count()
}
