#' @title Function to calculate stage-wise test statistics, variances, and correlation for two-sample generalized-t statistics.
#' @description Computes the stage-wise generalized-t statistics and their variance estimates at a set of interim analysis
#' calendar times. At each analysis, administrative censoring is applied at the specified calendar time, event times are converted from the
#' calendar-time scale to the event-time scale (time since enrollment), and the generalized-t statistics is
#' evaluated over \code{[0, tau]}. When multiple analysis are requested, the function also estimates the correlation
#' matrix of the stage-wise statistics.
#' @param data A data.frame generated by \code{TwoSample.generate.sequential()}.
#' @param tau Positive numeric value specifying the upper bound of event time
#'   (time since enrollment) for integration of the statistic. Default is
#'   \code{3}.
#' @param calendars  Numeric vector of interim analysis calendar times (in years)
#'   at which to compute stage-wise statistics and variance estimates.
#'
#' @returns A list containing stage-wise estimates. If \code{length(calendars) > 1},
#' the returned list includes:
#' \itemize{
#'   \item \code{Qs}: Numeric vector of stage-wise generalized-t statistics
#'     evaluated at each calendar time in \code{calendars}.
#'   \item \code{vars}: Numeric vector of estimated variances corresponding to
#'     \code{Qs}.
#'   \item \code{total.ns}: Numeric vector giving the total enrolled sample size
#'     contributing data at each calendar time.
#'   \item \code{corr.matrix}: Estimated correlation matrix of the stage-wise
#'     statistics.
#'   \item \code{nss}: List of length \code{length(calendars)} giving the
#'     group-specific sample sizes at each analysis.
#' }
#' If \code{length(calendars) == 1}, the list contains \code{Qs}, \code{vars},
#' and \code{total.ns}.
#' @export
#' @importFrom dplyr %>% group_by filter mutate count select slice ungroup
#' @importFrom tibble as_tibble
#' @importFrom bdsmatrix bdsBlock
#' @importFrom stats stepfun
#'
#' @examples
#' \donttest{
#' df <- TwoSample.generate.sequential(sizevec = c(200, 200), beta.trt = 0,
#' calendar = 5, recruitment = 3, random.censor.rate = 0.05, seed = 2026)
#' TwoSample.Q.Cov.Estimator.Sequential.GT(data = df, calendars = c(2.5, 3.5, 4.5))
#' }
TwoSample.Q.Cov.Estimator.Sequential.GT <- function(data, tau = 3, calendars){
  original.data <- data
  # output from this function:
  # 1. Test statistics at calendar times
  # 2. Estimated variances
  # 3. correlation matrix
  # 4. sample size at given calendar time
  Qs <- c()
  vars <- c()
  consts <- c()
  total.ns <- c()
  sigma.hats <- c()
  nss <- list()
  var.partIs <- list()
  var.partIIs <- list()

  corr.matrix <- diag(length(calendars))

  max.n <- length(unique(original.data$id))
  max.n1 <- length(unique(original.data[original.data$group == 1,]$id))
  max.n2 <- length(unique(original.data[original.data$group == 2,]$id))
  # Qinghua 04/02/2025 Update: patients may not have been fully enrolled yet, save the
  # save the id of enrolled patients (to address the mismatch of the Psi vector later in the covariance calculation)
  grp1.enrolled_id <- list()
  grp2.enrolled_id <- list()

  sigma.hats <- c()

  for (j in 1:length(calendars)){
    # j= 1
    # Step 1: Apply the calendar time as effective censoring time
    data.censored <- Apply.calendar.censoring.2(data = original.data, calendar = calendars[j])

    # Step 2: Run the estimator of the censored data
    # Keep the patients who are already in the study and convert the
    # event times from calendar scale to the event scale

    data.censored <- data.censored %>%
      dplyr::group_by(.data$id) %>%
      dplyr::filter(!is.na(.data$status)) %>%
      dplyr::mutate(true_event_time = .data$event_time_cal - .data$e)

    ns <- c(NA, NA)
    all.time <- data.censored$true_event_time
    sorted.all.time <- sort(all.time)
    Psihats <- vector(mode = "list", length = 2)

    # Hhats <- matrix(NA, 2, length(sorted.all.time))
    Hhats <- muhats <- matrix(NA, 2, length(sorted.all.time))
    # tau <- 6 # 80th quantile of event times
    # tau <- 3 # Upper limit of the event times of interest
    time.idx <- vector(mode = "list", length = 2)
    truncate.idxs <- c(NA, NA)
    dts <- vector(mode = "list", length = 2)

    for (a in 1:2){
      # a <- 1
      # sort all event times (recurrent, death, and censoring)
      data <- data.censored[data.censored$group == a,] # 'data' is unsorted
      ns[a] <- length(unique(data$id))
      # data_new <- data[order(data$time),]
      data_new <- data[order(data$true_event_time),] # 'data_new' is sorted

      # All event times, including recurrent, death and censoring
      # sorted.time <- data_new$time
      sorted.time <- data_new$true_event_time
      sorted.event <- data_new$event
      n <- length(unique(data_new$id)) # sample size
      L <- length(sorted.event) # total number of all events (recurrent, death and censoring)


      # save the group times index of combined times
      time.idx[[a]] <- match(sorted.time, sorted.all.time)

      # last observation for each subject, death or censoring
      # last.time <- data_new$time[data_new$status == 1 | data_new$status == 0]
      # last.time.unsorted <- data$time[data$status == 1 | data$status == 0]
      last.time <- data_new$true_event_time[data_new$status == 1 | data_new$status == 0]
      last.time.unsorted <- data$true_event_time[data$status == 1 | data$status == 0]

      last.time.id <- match(last.time, last.time.unsorted)
      last.delta <- data_new$death[data_new$status == 1 | data_new$status == 0]

      # At risk process for each event
      Y <- 1*(matrix(rep(last.time, L), n, L) >= matrix(rep(sorted.time, each = n), n, L))

      # Kaplan Meier estimates for censoring time points
      # censor <- 1 - data_new$death
      censor <- 1 - data_new$event
      # Qinghua 2/26/25 Update: 1- 'death' would treat recurrent event as a censoring
      Hhat <- cumprod(1 - censor/colSums(Y))
      iHhat <- stepfun(sorted.time, c(1, Hhat)) # 'sorted.time' is unique for each group
      Hhats[a,] <- iHhat(sorted.all.time) # 'sorted.all.time' has all event times from both groups

      # Kaplan Meier estimates for death time points
      death <- data_new$death
      KMhat <- cumprod(1 - death/colSums(Y))

      # Nelson-Aalen estimates for all dN(t) = 1 (recurrent and death)
      dRhat <- sorted.event/colSums(Y)
      Rhat <- cumsum(dRhat)

      # Ghosh-Lin estimator
      dmuhat <- KMhat*dRhat
      muhat <- cumsum(dmuhat)
      imuhat <- stepfun(sorted.time, c(0, muhat))
      muhats[a,] <- imuhat(sorted.all.time)

      ##### variance estimator #######
      ## The below calculation will use original id order, not the sorted last observation time order##
      Y <- Y[order(last.time.id),]
      Ybar.temp <- colSums(Y)
      Ybar <- Ybar.temp + 1*(Ybar.temp == 0)

      # cumulative hazard for death
      delta <- data[data$status == 1|data$status == 0,]$death
      ND <- 1*(matrix(rep(last.time.unsorted, L), n, L) <= matrix(rep(sorted.time, each = n), n, L))*delta
      dND <- t(apply(ND, 1, function(x) diff(c(0,x))))
      dlambdaDhat <- colSums(t(t(dND)/Ybar))
      lambdaDhat <- cumsum(dlambdaDhat)
      # intensity process for death
      dADhat <- t(apply(Y, 1, function(x) x*dlambdaDhat))
      dMDhat <- dND - dADhat

      # number of events per subject, will be used to create a block diag matrix
      id.size <- data.frame(data %>% group_by(id) %>% count())$n
      grp <- bdsBlock(1:L, rep(1:n, id.size)) # block diag matrix
      grp <- as.matrix(grp)
      # original.time <- data$time
      original.time <- data$true_event_time
      t1 <- matrix(rep(original.time, each = L), L, L)*grp
      # put each subject's all event times on block diag, t1 is L x L
      t2 <- unique(t1) # keep only one row per subject, t2 is n x L
      original.event <- data$event # "event' is dN(t) , sum of 'recurrent' and 'death'

      t3 <- matrix(rep(original.event, each = L), L, L)*grp
      # put each subject's all event indicator on block diag, t3 is L x L
      t4 <- as.matrix(cbind(id = data$id, t3) %>% as_tibble() %>% group_by(id) %>%
                        slice(n()) %>% ungroup() %>% select(-id))
      # keep only one row per subject, t4 is n x L

      t5 <- t2*t4 # make censoring times become zero, since dN(t) = 0 when censored
      t6 <- unname(t(apply(t5, 1, function(x) x[order(original.time)])))
      # out each subjects' dN(t) = 1 (recurrent and death) times in the sorted order
      dN <- 1*(t6 == matrix(rep(sorted.time, each = n), n, L)) # at which time point did dN(t) jump

      N <- t(apply(dN, 1, cumsum))
      dAhat <- t(apply(Y, 1, function(x) x*dRhat))
      dMhat <- dN - dAhat

      dpartI <- t(apply(dMhat, 1, function(x) x*KMhat/(Ybar/n)))
      partI <- t(apply(dpartI, 1, cumsum))

      dpartII <- t(apply(dMDhat, 1, function(x) x/(Ybar/n)))
      partII.1 <- t(apply(dpartII, 1, cumsum))
      partII <- t(apply(partII.1, 1, function(x) x*muhat))

      dpartIII <- t(apply(dMDhat, 1, function(x) x*muhat/(Ybar/n)))
      partIII <- t(apply(dpartIII, 1, cumsum))

      Psihat <- partI - partII + partIII
      # Psihats[[a]] <- Psihat

      # integrate from zero to tau
      truncate.idxa <- max(which(sorted.time <= tau))

      # Qinghua 2/17/25 Update: added the lines below to avoid selecting the last value of Khat,
      # since the last value might be NaN
      # Qinghua 2/26/25 Update: Using the new weight function, commented out the line below.
      # if (truncate.idxa == length(sorted.time)){
      #   truncate.idxa = truncate.idxa - 1
      # }

      Psihats[[a]] <- Psihat[,1:truncate.idxa]
      truncate.idxs[a] <- truncate.idxa
      dts[[a]] <- diff(c(0, sorted.time))[1:truncate.idxa]


      # Qinghua 05/03/2025 Update: Save the id of enrolled patients in each group
      if (a == 1){
        grp1.enrolled_id[[j]] <- unique(data$id)
      } else {
        grp2.enrolled_id[[j]] <- unique(data$id)
      }

      # Qinghua 05/07/2025 Update: Clean up large objects not needed after each group
      rm(data, data_new, Y, KMhat, Rhat, dRhat, dMhat, dAhat,
         Psihat, dND, dMDhat, lambdaDhat, dlambdaDhat, ND, Ybar,
         id.size, grp, t1, t2, t3, t4, t5, t6, dpartI, dpartII, partII.1,
         dpartIII, partI, partII, partIII)
      gc(verbose = FALSE)


    } # End of the 'a' loop


    # weight function
    Khat <- sum(ns)*Hhats[1,]*Hhats[2,]/(ns[1]*Hhats[1,] + ns[2]*Hhats[2,])
    # integrate from 0 to tau
    truncate.idx <- max(which(sorted.all.time <= tau))

    # Qinghua 02/26/25 Update: using the modified weight function
    Khat <- c(1, Khat[-length(Khat)])

    dt <- diff(c(0, sorted.all.time))[1:truncate.idx]
    Qs[j] <- sum(Khat[1:truncate.idx]*(muhats[1, 1:truncate.idx] - muhats[2, 1:truncate.idx])*dt)

    # Asymptotic variance
    Khat.grp1 <- Khat[time.idx[[1]]][1:truncate.idxs[1]]
    var.partI.1 <- t(apply(Psihats[[1]], 1, function(x) x*Khat.grp1*dts[[1]]))
    var.partI.2 <- t(apply(var.partI.1, 1, sum))
    var.partI <- sum(var.partI.2^2)*ns[2]/(sum(ns)*ns[1])

    Khat.grp2 <- Khat[time.idx[[2]]][1:truncate.idxs[2]]
    var.partII.1 <- t(apply(Psihats[[2]], 1, function(x) x*Khat.grp2*dts[[2]]))
    var.partII.2 <- t(apply(var.partII.1, 1, sum))
    var.partII <- sum(var.partII.2^2)*ns[1]/(sum(ns)*ns[2])

    # const <- 1/sqrt(ns[1]*ns[2]/sum(ns))
    # var <- (var.partI + var.partII)*const^2

    # gsDesign simulation new:
    consts[j] <- 1/sqrt(ns[1]*ns[2]/sum(ns))
    nss[[j]] <- ns
    vars[j] <- (var.partI + var.partII)*consts[j]^2
    total.ns[j] <- sum(ns)
    sigma.hats[j] <- var.partI + var.partII
    # var.partIs[[j]] <- var.partI.2
    # var.partIIs[[j]] <- var.partII.2
    # enrolled_idx[[j]] <- unique(data.censored$id)

    # Interpolate varhats to the full length for each group
    # var.partIs_full <- numeric(max.n/2)
    var.partIs_full <-  numeric(max.n1) # Qinghua 05/14/2025 Update: for the real data, two groups may not have the same size
    var.partIs_full[grp1.enrolled_id[[j]]] <- var.partI.2
    var.partIs[[j]] <- var.partIs_full

    # var.partIIs_full <- numeric(max.n/2)
    var.partIIs_full <- numeric(max.n2) # Qinghua 05/14/2025 Update: for the real data, two groups may not have the same size
    # var.partIIs_full[grp2.enrolled_id[[j]] - max.n/2] <- var.partII.2
    var.partIIs_full[grp2.enrolled_id[[j]] - max.n1] <- var.partII.2
    var.partIIs[[j]] <- var.partIIs_full

    #  Qinghua 05/07/2025 Update: Clean up intermediate outputs per calendar
    rm(data.censored, Hhats, muhats, Psihats, time.idx, truncate.idxs, dts)
    gc(verbose = FALSE)

  } # end of the 'j' loop

  if (length(calendars) > 1){
    for (p in 1:(length(calendars)-1)){
      for (q in (p+1):length(calendars)){
        corr.matrix[p,q] = (nss[[q]][2]/(total.ns[q]*nss[[q]][1])*sum(var.partIs[[p]]*var.partIs[[q]]) +
                              nss[[q]][1]/(total.ns[q]*nss[[q]][2])*sum(var.partIIs[[p]]*var.partIIs[[q]]))/(sqrt(sigma.hats[p]*sigma.hats[q]))
        corr.matrix[q,p] = corr.matrix[p,q]
      }
    }
    return(list(Qs = Qs,
                vars = vars,
                total.ns = total.ns,
                corr.matrix = corr.matrix,
                nss = nss))
  } else {
    return(list(Qs = Qs,
                vars = vars,
                total.ns = total.ns))
  }
}
