#'
#' H2O Grid Support
#'
#' Provides a set of functions to launch a grid search and get
#' its results.

#-------------------------------------
# Grid-related functions start here :)
#-------------------------------------

#'
#' Launch grid search with given algorithm and parameters.
#'
#' @param algorithm  name of algorithm to use in grid search (gbm, randomForest, kmeans, glm, deeplearning, naivebayes, pca)
#' @param grid_id  optional id for resulting grid search, if it is not specified then it is autogenerated
#' @param ...  arguments describing parameters to use with algorithm (i.e., x, y, training_frame).
#'        Look at the specific algorithm - h2o.gbm, h2o.glm, h2o.kmeans, h2o.deepLearning
#' @param hyper_params  list of hyper parameters (i.e., \code{list(ntrees=c(1,2), max_depth=c(5,7))})
#' @param is_supervised  if specified then override default heuristing which decide if given algorithm
#'        name and parameters specify super/unsupervised algorithm.
#' @param do_hyper_params_check  perform client check for specified hyper parameters. It can be time expensive for
#'        large hyper space
#' @param conn connection to H2O cluster
#' @importFrom jsonlite toJSON
#' @examples
#' library(h2o)
#' library(jsonlite)
#' localH2O <- h2o.init()
#' iris.hex <- as.h2o(iris)
#' grid <- h2o.grid("gbm", x = c(1:4), y = 5, training_frame = iris.hex,
#'                  hyper_params = list(ntrees = c(1,2,3)))
#' # Get grid summary
#' summary(grid)
#' # Fetch grid models
#' model_ids <- grid@@model_ids
#' models <- lapply(model_ids, function(id) { h2o.getModel(id)})
#' @export
h2o.grid <- function(algorithm,
                     grid_id,
                     ...,
                     hyper_params = list(),
                     is_supervised = NULL,
                     do_hyper_params_check = FALSE,
                     conn = h2o.getConnection())
{
  # Extract parameters
  dots <- list(...)
  algorithm <- .h2o.unifyAlgoName(algorithm)
  model_param_names <- names(dots)
  hyper_param_names <- names(hyper_params)
  # Reject overlapping definition of parameters
  if (any(model_param_names %in% hyper_param_names)) {
    overlapping_params <- intersect(model_param_names, hyper_param_names)
    stop(paste0("The following parameters are defined as common model parameters and also as hyper parameters: ",
                .collapse(overlapping_params), "! Please choose only one way!"))
  }
  # Get model builder parameters for this model
  all_params <- .h2o.getModelParameters(algo = algorithm)

  # Prepare model parameters
  params <- .h2o.prepareModelParameters(algo = algorithm, params = dots, is_supervised = is_supervised)
  # Validation of input key
  .key.validate(params$key_value)
  # Validate all hyper parameters against REST API end-point
  if (do_hyper_params_check) {
    lparams <- params
    # Generate all combination of hyper parameters
    expanded_grid <- expand.grid(lapply(hyper_params, function(o) { 1:length(o) }))
    # Verify each defined point in hyper space against REST API
    apply(expanded_grid,
          MARGIN = 1,
          FUN = function(permutation) {
      # Fill hyper parameters for this permutation
      hparams <- lapply(hyper_param_names, function(name) { hyper_params[[name]][[permutation[[name]]]] })
      names(hparams) <- hyper_param_names
      params_for_validation <- lapply(append(lparams, hparams), function(x) { if(is.integer(x)) x <- as.numeric(x); x })
      # We have to repeat part of work used by model builders
      params_for_validation <- .h2o.checkAndUnifyModelParameters(algo = algorithm, allParams = all_params, params = params_for_validation)
      .h2o.validateModelParameters(conn, algorithm, params_for_validation)
    })
  }

  # Verify and unify the parameters
  params <- .h2o.checkAndUnifyModelParameters(algo = algorithm, allParams = all_params,
                                                  params = params, hyper_params = hyper_params)
  # Validate and unify hyper parameters
  hyper_values <- .h2o.checkAndUnifyHyperParameters(algo = algorithm,
                                                        allParams = all_params, hyper_params = hyper_params,
                                                        do_hyper_params_check = do_hyper_params_check)
  # Append grid parameters in JSON form
  params$hyper_parameters <- toJSON(hyper_values, digits=99)

  # Append grid_id if it is specified
  if (!missing(grid_id)) params$grid_id <- grid_id

  # Trigger grid search job
  res <- .h2o.__remoteSend(conn, .h2o.__GRID(algorithm), h2oRestApiVersion = 99, .params = params, method = "POST")
  grid_id <- res$job$dest$name
  job_key <- res$job$key$name
  # Wait for grid job to finish
  .h2o.__waitOnJob(conn, job_key)

  h2o.getGrid(grid_id = grid_id, conn = conn)
}

#' Get a grid object from H2O distributed K/V store.
#'
#' @param grid_id  ID of existing grid object to fetch
#' @param conn  H2O connection
#' @examples
#' library(h2o)
#' library(jsonlite)
#' localH2O <- h2o.init()
#' iris.hex <- as.h2o(iris)
#' h2o.grid("gbm", grid_id = "gbm_grid", x = c(1:4), y = 5,
#'          training_frame = iris.hex, hyper_params = list(ntrees = c(1,2,3)))
#' grid <- h2o.getGrid("gbm_grid")
#' # Get grid summary
#' summary(grid)
#' # Fetch grid models
#' model_ids <- grid@@model_ids
#' models <- lapply(model_ids, function(id) { h2o.getModel(id)})
#' @export
h2o.getGrid <- function(grid_id, conn = h2o.getConnection()) {
  json <- .h2o.__remoteSend(conn, method = "GET", h2oRestApiVersion = 99, .h2o.__GRIDS(grid_id))
  class <- "H2OGrid"
  grid_id <- json$grid_id$name
  model_ids <- lapply(json$model_ids, function(model_id) { model_id$name })
  hyper_names <- lapply(json$hyper_names, function(name) { name })
  failed_params <- lapply(json$failed_params, function(param) {
                          x <- if (is.null(param) || is.na(param)) NULL else param
                          x
                        })
  failure_details <- lapply(json$failure_details, function(msg) { msg })
  failure_stack_traces <- lapply(json$failure_stack_traces, function(msg) { msg })
  failed_raw_params <- if (is.list(json$failed_raw_params)) matrix(nrow=0, ncol=0) else json$failed_raw_params

  new(class,
      grid_id = grid_id,
      model_ids = model_ids,
      hyper_names = hyper_names,
      failed_params = failed_params,
      failure_details = failure_details,
      failure_stack_traces = failure_stack_traces,
      failed_raw_params = failed_raw_params)
}


