#include <Rcpp.h>
#include <cmath>
using namespace Rcpp;

// [[Rcpp::export]]
Rcpp::NumericVector asymp_mean_rcpp(const Rcpp::List& R_list,
                                    double m,
                                    double n) {
  int n_matrices = R_list.size();
  Rcpp::NumericVector result(2 * n_matrices);

  double N = m + n;
  double denom = N * (N - 1);

  if (denom == 0) {
    Rcpp::stop("Division by zero: N * (N - 1) cannot be zero.");
  }

  double num1 = m * (m - 1);
  double num2 = n * (n - 1);

  for (int i = 0; i < n_matrices; ++i) {
    Rcpp::NumericMatrix current_matrix = Rcpp::as<Rcpp::NumericMatrix>(R_list[i]);

    double current_sum = Rcpp::sum(current_matrix);

    result[2 * i]     = num1 * current_sum / denom;
    result[2 * i + 1] = num2 * current_sum / denom;
  }

  return result;
}

// sum of an element-wise product of two matrices.
double accu_prod(const Rcpp::NumericMatrix& A, const Rcpp::NumericMatrix& B) {
  if (A.nrow() != B.nrow() || A.ncol() != B.ncol()) {
    Rcpp::stop("Matrices must have the same dimensions for element-wise product.");
  }
  double total = 0;
  for (int j = 0; j < A.ncol(); ++j) {
    for (int i = 0; i < A.nrow(); ++i) {
      total += A(i, j) * B(i, j);
    }
  }
  return total;
}

// sum of element-wise product of a matrix and a broadcasted vector sum(matrix * vector_by_row)
double accu_prod_broadcast(const Rcpp::NumericMatrix& M, const Rcpp::NumericVector& V) {
  if (M.nrow() != V.size()) {
    Rcpp::stop("Matrix nrow must match vector size for broadcasting.");
  }
  double total = 0;
  for (int j = 0; j < M.ncol(); ++j) {
    for (int i = 0; i < M.nrow(); ++i) {
      total += M(i, j) * V[i];
    }
  }
  return total;
}

// [[Rcpp::export]]
Rcpp::NumericMatrix asymp_cov_rcpp(const Rcpp::List& R_list, double m, double n) {
  int k = R_list.size();
  double N = m + n;
  if (N <= 3.0) Rcpp::stop("N must be > 3 for denominator N-3.");

  std::vector<Rcpp::NumericMatrix> R(k);
  for(int i = 0; i < k; ++i) {
    R[i] = Rcpp::as<Rcpp::NumericMatrix>(R_list[i]);
  }

  double Q1_x = m * (m - 1.0);
  double Q1_y = n * (n - 1.0);
  double Q2_xx = Q1_x;
  double Q2_yy = Q1_y;
  double Q2_xy = 0.0;
  double Q3_xx = m * (m - 1.0) * (m - 1.0);
  double Q3_yy = n * (n - 1.0) * (n - 1.0);
  double Q3_xy = 0.0;

  Rcpp::NumericVector R1_sums(k);
  Rcpp::NumericMatrix R2_sums(k, k);
  Rcpp::NumericMatrix R3_sums(k, k);
  std::vector<Rcpp::NumericVector> R_row_sums(k);

  for(int i = 0; i < k; ++i) {
    R1_sums[i] = Rcpp::sum(R[i]);
    R_row_sums[i] = Rcpp::rowSums(R[i]);
  }

  for(int i = 0; i < k; ++i) {
    for(int j = 0; j < k; ++j) {
      R2_sums(i, j) = accu_prod(R[i], R[j]);
      R3_sums(i, j) = accu_prod_broadcast(R[i], R_row_sums[j]);
    }
  }

  Rcpp::NumericMatrix tR2_norm(k, k);
  Rcpp::NumericMatrix tR3_norm(k, k);
  double denom_r2 = N * (N - 1.0);
  double denom_r3 = N * (N - 2.0);

  for(int i = 0; i < k; ++i) {
    for(int j = 0; j < k; ++j) {
      tR2_norm(i, j) = (R2_sums(i, j) - R1_sums[i] * R1_sums[j] / denom_r2) / N;
      tR3_norm(i, j) = (R3_sums(i, j) - R1_sums[i] * R1_sums[j] / N) / denom_r3;
    }
  }

  double tQ2_xx_norm = (Q2_xx - Q1_x * Q1_x / denom_r2) / (N - 3.0);
  double tQ2_yy_norm = (Q2_yy - Q1_y * Q1_y / denom_r2) / (N - 3.0);
  double tQ2_xy_norm = (Q2_xy - Q1_x * Q1_y / denom_r2) / (N - 3.0);

  double tQ3_xx_norm = (Q3_xx - Q1_x * Q1_x / N) / (N - 3.0);
  double tQ3_yy_norm = (Q3_yy - Q1_y * Q1_y / N) / (N - 3.0);
  double tQ3_xy_norm = (Q3_xy - Q1_x * Q1_y / N) / (N - 3.0);

  Rcpp::NumericMatrix cov_matrix(2 * k, 2 * k);
  double c1 = 2.0;
  double c2 = -4.0 / (N - 2.0);
  double c3 = -4.0;
  double c4 = 4.0 / (N - 1.0) * (N + 1.0);

  for (int i = 0; i < k; ++i) {
    for (int j = i; j < k; ++j) {
      double r2_ij = tR2_norm(i, j);
      double r3_ij = tR3_norm(i, j);

      double cov_xx = c1*r2_ij*tQ2_xx_norm + c2*r2_ij*tQ3_xx_norm + c3*r3_ij*tQ2_xx_norm + c4*r3_ij*tQ3_xx_norm;
      double cov_yy = c1*r2_ij*tQ2_yy_norm + c2*r2_ij*tQ3_yy_norm + c3*r3_ij*tQ2_yy_norm + c4*r3_ij*tQ3_yy_norm;
      double cov_xy = c1*r2_ij*tQ2_xy_norm + c2*r2_ij*tQ3_xy_norm + c3*r3_ij*tQ2_xy_norm + c4*r3_ij*tQ3_xy_norm;

      cov_matrix(2*i, 2*j)     = cov_xx;
      cov_matrix(2*i, 2*j+1)   = cov_xy;
      cov_matrix(2*i+1, 2*j)   = cov_xy;
      cov_matrix(2*i+1, 2*j+1) = cov_yy;

      if (i != j) {
        cov_matrix(2*j, 2*i)     = cov_xx;
        cov_matrix(2*j, 2*i+1)   = cov_xy;
        cov_matrix(2*j+1, 2*i)   = cov_xy;
        cov_matrix(2*j+1, 2*i+1) = cov_yy;
      }
    }
  }

  return cov_matrix;
}

// Compute pairwise Minkowski distances between rows of A (n x d).
// Returns an n x n symmetric matrix (diagonal zeros).
// p > 0; p==1 Manhattan, p==2 Euclidean, p==INFINITY Chebyshev.
// [[Rcpp::export]]
NumericMatrix minkowski_dist(const NumericMatrix& A, double p = 2.0) {
  int n = A.nrow();
  int d = A.ncol();
  if (n == 0) return NumericMatrix(0,0);
  if (p <= 0) stop("p must be > 0");

  NumericMatrix out(n, n);
  bool is_inf = std::isinf(p);

  for (int i = 0; i < n; ++i) {
    out(i,i) = 0.0;
    for (int j = i+1; j < n; ++j) {
      double acc = 0.0;
      double maxabs = 0.0;
      if (is_inf) {
        // Chebyshev: max_k |A[i,k] - A[j,k]|
        for (int k = 0; k < d; ++k) {
          const double* colk = &A(0, k);
          double ad = std::abs(colk[i] - colk[j]);
          if (ad > maxabs) maxabs = ad;
        }
        out(i,j) = out(j,i) = maxabs;
      } else if (p == 1.0) {
        // L1
        for (int k = 0; k < d; ++k) {
          const double* colk = &A(0, k);
          acc += std::abs(colk[i] - colk[j]);
        }
        out(i,j) = out(j,i) = acc;
      } else if (p == 2.0) {
        // L2
        for (int k = 0; k < d; ++k) {
          const double* colk = &A(0, k);
          double diff = colk[i] - colk[j];
          acc += diff * diff;
        }
        double val = std::sqrt(acc);
        out(i,j) = out(j,i) = val;
      } else if (p == 3.0) {
        // L3: ad^3
        for (int k = 0; k < d; ++k) {
          const double* colk = &A(0, k);
          double ad = std::abs(colk[i] - colk[j]);
          acc += ad * ad * ad;
        }
        out(i,j) = out(j,i) = std::pow(acc, 1.0 / 3.0);
      } else if (p == 4.0) {
        // L4: ad^4 via (ad^2)^2
        for (int k = 0; k < d; ++k) {
          const double* colk = &A(0, k);
          double ad = std::abs(colk[i] - colk[j]);
          double ad2 = ad * ad;
          acc += ad2 * ad2;
        }
        out(i,j) = out(j,i) = std::pow(acc, 1.0 / 4.0);
      } else {
        // general p > 0
        for (int k = 0; k < d; ++k) {
          const double* colk = &A(0, k);
          double ad = std::abs(colk[i] - colk[j]);
          acc += std::pow(ad, p);
        }
        out(i,j) = out(j,i) = std::pow(acc, 1.0 / p);
      }
    }
  }
  return out;
}

// Compute pairwise Manhattan (L1) distances between rows of A (n x d).
// Returns an n x n symmetric matrix (diagonal zeros).
// [[Rcpp::export]]
NumericMatrix manhattan_dist(const NumericMatrix& A) {
  int n = A.nrow();
  int d = A.ncol();
  if (n == 0) return NumericMatrix(0,0);

  NumericMatrix out(n, n);
  for (int i = 0; i < n; ++i) {
    out(i,i) = 0.0;
    for (int j = i + 1; j < n; ++j) {
      double acc = 0.0;
      // iterate over columns (R is column-major)
      for (int k = 0; k < d; ++k) {
        const double* colk = &A(0, k);
        acc += std::abs(colk[i] - colk[j]);
      }
      out(i,j) = out(j,i) = acc;
    }
  }
  return out;
}
