#' Weighted confusion matrix
#'
#' This function calculates the weighted confusion matrix from a caret
#' ConfusionMatrix object or a simple matrix, according to one of several
#' weighting schemas and optionally prints the weighted accuracy score.
#'
#' @param m the caret confusion matrix object or simple matrix.
#'
#' @param weight.type the weighting schema to be used. Can be one of:
#' "arithmetic" - a decreasing arithmetic progression weighting scheme,
#' "geometric" - a decreasing geometric progression weighting scheme,
#' "normal" - weights drawn from the right tail of a normal distribution,
#' "interval" - weights contained on a user-defined interval,
#' "custom" - custom weight vector defined by the user.
#'
#' @param weight.penalty determines whether the weights associated with
#' non-diagonal elements generated by the "normal", "arithmetic" and "geometric"
#' weight types are positive or negative values. By default, the value is set to
#' FALSE, which means that generated weights will be positive values.
#'
#' @param standard.deviation standard deviation of the normal distribution, if
#' the normal distribution weighting schema is used.
#'
#' @param geometric.multiplier the multiplier used to construct the geometric
#' progression series, if the geometric progression weighting scheme is used.
#'
#' @param interval.high the upper bound of the weight interval, if the interval
#' weighting scheme is used.
#'
#' @param interval.low the lower bound of the weight interval, if the interval
#' weighting scheme is used.
#'
#' @param custom.weights the vector of custom weight sto be applied, is the
#' custom weighting scheme wasd selected. The vector should be equal to "n", but
#' can be larger, with excess values being ignored.
#'
#' @param print.weighted.accuracy print the weighted accuracy metric, which
#' represents the sum of all weighted confusion matrix cells divided by the
#' total number of observations.
#'
#' @return an nxn weighted confusion matrix
#'
#' @details The number of categories "n" should be greater or equal to 2.
#'
#' @usage wconfusionmatrix(m, weight.type = "arithmetic",
#'                         weight.penalty = FALSE,
#'                         standard.deviation = 2,
#'                         geometric.multiplier = 2,
#'                         interval.high=1, interval.low = -1,
#'                         custom.weights = NA,
#'                         print.weighted.accuracy = FALSE)
#'
#' @keywords weighted confusion matrix accuracy score
#'
#' @seealso [weightmatrix()]
#'
#' @author Alexandru Monahov, <https://www.alexandrumonahov.eu.org/>
#'
#' @examples
#' m = matrix(c(70,0,0,10,10,0,5,3,2), ncol = 3, nrow=3)
#' wconfusionmatrix(m, weight.type="arithmetic", print.weighted.accuracy = TRUE)
#' wconfusionmatrix(m, weight.type="geometric", print.weighted.accuracy = TRUE)
#' wconfusionmatrix(m, weight.type="interval", print.weighted.accuracy = TRUE)
#' wconfusionmatrix(m, weight.type="normal", print.weighted.accuracy = TRUE)
#' wconfusionmatrix(m, weight.type= "custom", custom.weights = c(1,0.1,0),
#'                  print.weighted.accuracy = TRUE)
#'
#' @export

wconfusionmatrix <- function(m, weight.type = "arithmetic", weight.penalty = FALSE, standard.deviation = 2, geometric.multiplier = 2, interval.high=1, interval.low = -1, custom.weights = NA, print.weighted.accuracy = FALSE) {

  if (is.matrix(m) == FALSE) {m = as.matrix(m)}
  n = length(m[,1])

  if (weight.type == "normal") {
  # Normal distribution
  a <- seq(from = 1, to = n, by = 1)
  fmean <- seq(from = 1, to = n, by = 1)
  mat <- t(mapply(function(mean,sd) dnorm(a,mean,sd)/max(dnorm(a,mean,sd)), mean=fmean, sd=standard.deviation))
  if (weight.penalty == TRUE) {
    mat = -(1-mat)
    diag(mat) = 1
  }
  if (print.weighted.accuracy == TRUE) {
    waccuracy = sum(m*mat)/sum(m)
    cat("Weighted accuracy = ", sum(m*mat)/sum(m), "\n", "\n")
  }
  return(m*mat)
  }

  else if (weight.type == "arithmetic") {
  # Arithmetic progression
  mat = ((n-1)-abs(outer(seq(0, (n-1), 1), seq(0, (n-1), 1), `-`)))/(n-1)
  if (weight.penalty == TRUE) {
    mat = -(1-mat)
    diag(mat) = 1
  }
  if (print.weighted.accuracy == TRUE) {
    waccuracy = sum(m*mat)/sum(m)
    cat("Weighted accuracy = ", sum(m*mat)/sum(m), "\n", "\n")
  }
  return(m*mat)
  }

  else if (weight.type == "geometric") {
  # Geometric progression
  mult = geometric.multiplier
  mat = (abs(outer(seq(0, (n-1), 1), seq(0, (n-1), 1), `-`)))+1
  x=mult^seq(0,(n-1),by=1)
  x_n = (x-min(x))/(max(x)-min(x))
  if (mult > 1){
    x_dict = 1-x_n
  } else if (mult > 0 && mult < 1) {
    x_dict = x_n
  } else if (mult == 1) {
    x_dict = 1-seq(0, (n-1), 1)/(n-1)
  } else if (mult <= 0) {
    stop("Please enter a multiplier value greater than zero.")
  }
  for (i in 1:n) {
    mat[mat==i] = x_dict[i]
  }
  if (weight.penalty == TRUE) {
    mat = -(1-mat)
    diag(mat) = 1
  }
  if (print.weighted.accuracy == TRUE) {
    waccuracy = sum(m*mat)/sum(m)
    cat("Weighted accuracy = ", sum(m*mat)/sum(m), "\n", "\n")
  }
  return(m*mat)
  }

  else if (weight.type == "interval") {
  # Interval weight
  hi = interval.high
  lo = interval.low
  mat = (abs(outer(seq(0, (n-1), 1), seq(0, (n-1), 1), `-`)))+1
  x=seq(hi, lo, length.out = n)
  for (i in 1:n) {
    mat[mat==i] = x[i]
  }
  if (print.weighted.accuracy == TRUE) {
    waccuracy = sum(m*mat)/sum(m)
    cat("Weighted accuracy = ", sum(m*mat)/sum(m), "\n", "\n")
  }
  return(m*mat)
  }

  else if (weight.type == "custom") {
  # Custom weights
  wt = custom.weights
  mat = (abs(outer(seq(0, (n-1), 1), seq(0, (n-1), 1), `-`)))+1
  for (i in 1:n) {
    mat[mat==i] = wt[i]
  }
  if (print.weighted.accuracy == TRUE) {
    waccuracy = sum(m*mat)/sum(m)
    cat("Weighted accuracy = ", sum(m*mat)/sum(m), "\n", "\n")
  }
  return(m*mat)
  }

}
