#' @title Function to calculate stage-wise test statistics, variances, and correlation for two-sample generalized log-rank statistics.
#' @description Computes the stage-wise generalized log-rank statistics and their variance estimates at a set of interim analysis
#' calendar times. At each analysis, administrative censoring is applied at the specified calendar time, event times are converted from the
#' calendar-time scale to the event-time scale (time since enrollment), and the generalized log-rank statistics is
#' evaluated over \code{[0, tau]}. When multiple analysis are requested, the function also estimates the correlation
#' matrix of the stage-wise statistics.
#' @param data A data.frame generated by \code{TwoSample.generate.sequential()}.
#' @param tau Positive numeric value specifying the upper bound of event time
#'   (time since enrollment) for integration of the statistic. Default is
#'   \code{3}.
#' @param calendars  Numeric vector of interim analysis calendar times (in years)
#'   at which to compute stage-wise statistics and variance estimates.
#'
#' @returns A list containing stage-wise estimates. If \code{length(calendars) > 1},
#' the returned list includes:
#' \itemize{
#'   \item \code{Qs}: Numeric vector of stage-wise generalized log-rank statistics
#'     evaluated at each calendar time in \code{calendars}.
#'   \item \code{vars}: Numeric vector of estimated variances corresponding to
#'     \code{Qs}.
#'   \item \code{total.ns}: Numeric vector giving the total enrolled sample size
#'     contributing data at each calendar time.
#'   \item \code{corr.matrix}: Estimated correlation matrix of the stage-wise
#'     statistics.
#'   \item \code{nss}: List of length \code{length(calendars)} giving the
#'     group-specific sample sizes at each analysis.
#' }
#' If \code{length(calendars) == 1}, the list contains \code{Qs}, \code{vars},
#' and \code{total.ns}.
#' @export
#' @importFrom dplyr %>% group_by filter mutate count select slice ungroup
#' @importFrom tibble as_tibble
#' @importFrom bdsmatrix bdsBlock
#' @importFrom stats stepfun
#'
#' @examples
#' \donttest{
#' df <- TwoSample.generate.sequential(sizevec = c(200, 200), beta.trt = 0,
#' calendar = 5, recruitment = 3, random.censor.rate = 0.05, seed = 2026)
#' TwoSample.Q.Cov.Estimator.Sequential.LR(data = df, calendars = c(2.5, 3.5, 4.5))
#' }
TwoSample.Q.Cov.Estimator.Sequential.LR <- function(data, tau = 3, calendars){
  original.data <- data
  # output from this function:
  # 1. Test statistics at calendar times
  # 2. Estimated variances
  # 3. correlation matrix
  # 4. sample size at given calendar time
  Qs <- c()
  vars <- c()
  consts <- c()
  total.ns <- c()
  sigma.hats <- c()
  nss <- list()
  var.partIs <- list()
  var.partIIs <- list()

  corr.matrix <- diag(length(calendars))

  max.n <- length(unique(original.data$id))
  max.n1 <- length(unique(original.data[original.data$group == 1,]$id))
  max.n2 <- length(unique(original.data[original.data$group == 2,]$id))
  # Qinghua 04/02/2025 Update: patients may not have been fully enrolled yet, save the
  # save the id of enrolled patients (to address the mismatch of the Psi vector later in the covariance calculation)
  grp1.enrolled_id <- list()
  grp2.enrolled_id <- list()

  sigma.hats <- c()

  for (j in 1:length(calendars)){
    # j= 1
    # Step 1: Apply the calendar time as effective censoring time
    data.censored <- Apply.calendar.censoring.2(data = original.data, calendar = calendars[j])

    # Step 2: Run the estimator of the censored data
    # Keep the patients who are already in the study and convert the
    # event times from calendar scale to the event scale

    data.censored <- data.censored %>%
      dplyr::group_by(.data$id) %>%
      dplyr::filter(!is.na(.data$status)) %>%
      dplyr::mutate(true_event_time = .data$event_time_cal - .data$e)

    ns <- c(NA, NA) #group sizes
    # unsorted all times (recurrent, death and censoring)
    all.time <- data.censored$true_event_time
    sorted.all.time <- sort(all.time)
    Ybars <- dmuhats <- matrix(NA, 2, length(sorted.all.time))
    dPsihats <- vector(mode = "list", length = 2)
    # tau <- 3 # Upper bound of event time
    time.idx <- vector(mode = "list", length = 2)
    truncate.idxs <- c(NA, NA)

    for (a in 1:2){
      # a <- 1
      # sort all event times (recurrent, death, and censoring)
      data <- data.censored[data.censored$group == a,]
      ns[a] <- length(unique(data$id))
      # data_new <- data[order(data$time),]
      data_new <- data[order(data$true_event_time),]

      # All event times, including recurrent, death and censoring
      # sorted.time <- data_new$time
      sorted.time <- data_new$true_event_time
      sorted.event <- data_new$event
      n <- length(unique(data_new$id)) # sample size
      L <- length(sorted.event) # total number of all events (recurrent, death and censoring)

      # save the group times index of combined times
      time.idx[[a]] <- match(sorted.time, sorted.all.time)

      # last observation for each subject, death or censoring
      # last.time <- data_new$time[data_new$status == 1 | data_new$status == 0]
      # last.time.unsorted <- data$time[data$status == 1 | data$status == 0]
      last.time <- data_new$true_event_time[data_new$status == 1 | data_new$status == 0]
      last.time.unsorted <- data$true_event_time[data$status == 1 | data$status == 0]
      last.time.id <- match(last.time, last.time.unsorted)

      # At risk process for each event
      Y <- 1*(matrix(rep(last.time, L), n, L) >= matrix(rep(sorted.time, each = n), n, L))

      # Kaplan Meier estimates for death time points
      death <- data_new$death
      KMhat <- cumprod(1 - death/colSums(Y))

      # Nelson-Aalen estimates for all dN(t) = 1 (recurrent and death)
      dRhat <- sorted.event/colSums(Y)
      Rhat <- cumsum(dRhat)

      # Mean frequency estimator
      dmuhat <- KMhat*dRhat
      muhat <- cumsum(dmuhat)
      imuhat <- stepfun(sorted.time, c(0, muhat))
      muhat.extended <- imuhat(sorted.all.time)
      dmuhats[a,] <- diff(c(0,muhat.extended))

      ##### variance estimator #######
      ## The below calculation will use original id order, not the sorted last observation time order##
      Y <- Y[order(last.time.id),]
      Ybar <- colSums(Y)
      # Ybar.temp <- colSums(Y)
      # Ybar <- Ybar.temp + 1*(Ybar.temp == 0)
      # Qinghua 2/25/2025 update: The above lines are a 'safety net' to prevent the number at risk from being zero,
      # I don't think they have ever been activated.
      # This should be fine for the one sample estimator, however, for the two sample logrank statistics, since the
      # weight function is a function of Ybar, when both groups' Ybar are zero, the weight function should be zero, too.

      iYbar <- stepfun(sorted.time, c(n, Ybar))
      Ybars[a,] <- iYbar(sorted.all.time)

      # cumulative hazard for death
      delta <- data[data$status == 1|data$status == 0,]$death
      ND <- 1*(matrix(rep(last.time.unsorted, L), n, L) <= matrix(rep(sorted.time, each = n), n, L))*delta
      dND <- t(apply(ND, 1, function(x) diff(c(0,x))))
      dlambdaDhat <- colSums(t(t(dND)/Ybar))
      lambdaDhat <- cumsum(dlambdaDhat)
      # intensity process for death
      dADhat <- t(apply(Y, 1, function(x) x*dlambdaDhat))
      dMDhat <- dND - dADhat

      # number of events per subject, will be used to create a block diag matrix
      id.size <- data.frame(data %>% group_by(id) %>% count())$n
      grp <- bdsBlock(1:L, rep(1:n, id.size)) # block diag matrix
      grp <- as.matrix(grp)
      # original.time <- data$time
      original.time <- data$true_event_time
      t1 <- matrix(rep(original.time, each = L), L, L)*grp
      # put each subject's all event times on block diag, t1 is L x L
      t2 <- unique(t1) # keep only one row per subject, t2 is n x L
      original.event <- data$event # "event' is dN(t) , sum of 'recurrent' and 'death'

      t3 <- matrix(rep(original.event, each = L), L, L)*grp
      # put each subject's all event indicator on block diag, t3 is L x L
      t4 <- as.matrix(cbind(id = data$id, t3) %>% as_tibble() %>% group_by(id) %>%
                        slice(n()) %>% ungroup() %>% select(-id))
      # keep only one row per subject, t4 is n x L

      t5 <- t2*t4 # make censoring times become zero, since dN(t) = 0 when censored
      t6 <- unname(t(apply(t5, 1, function(x) x[order(original.time)])))
      # out each subjects' dN(t) = 1 (recurrent and death) times in the sorted order
      dN <- 1*(t6 == matrix(rep(sorted.time, each = n), n, L)) # at which time point did dN(t) jumped

      N <- t(apply(dN, 1, cumsum))
      dAhat <- t(apply(Y, 1, function(x) x*dRhat))
      dMhat <- dN - dAhat

      dpartI <- t(apply(dMhat, 1, function(x) x*KMhat/(Ybar/n)))

      dpartII <- t(apply(dMDhat, 1, function(x) x*muhat/(Ybar/n)))

      dpartIII.1 <- t(apply(dMDhat, 1, function(x) x/(Ybar/n)))
      dpartIII.2 <- t(apply(dpartIII.1, 1, cumsum))
      dpartIII <- t(apply(dpartIII.2, 1, function(x) x*dmuhat))

      dpartIV <- t(apply(dMDhat, 1, function(x) x*muhat/(Ybar/n)))

      dPsihat <- dpartI - dpartII -dpartIII + dpartIV

      # integrate from zero to tau
      truncate.idxa <- max(which(sorted.time <= tau))
      dPsihats[[a]] <- dPsihat[,1:truncate.idxa]
      truncate.idxs[a] <- truncate.idxa

      # Qinghua 04/20/2025 Update: Save the id of enrolled patients in each group
      if (a == 1){
        grp1.enrolled_id[[j]] <- unique(data$id)
      } else {
        grp2.enrolled_id[[j]] <- unique(data$id)
      }

      # Qinghua 05/07/2025 Update: Free up group-specific large objects
      rm(data, data_new, Y, KMhat, Rhat, dRhat, dmuhat, muhat, muhat.extended,
         Ybar, delta, ND, dND, dlambdaDhat, lambdaDhat, dADhat, dMDhat,
         id.size, grp, t1, t2, t3, t4, t5, t6, original.time, original.event,
         dN, N, dAhat, dMhat, dpartI, dpartII, dpartIII.1, dpartIII.2, dpartIII,
         dpartIV, dPsihat)
      gc(verbose = FALSE)


    } # end of the 'a' loop

    # weight function

    # Qinghua 2/25/2025 Update: set Khat to zero if both groups' Ybar are zero.
    if (all(Ybars[1,]==0) & all(Ybars[2,]==0)){
      Khat <- 0
    } else{
      Khat <- (sum(ns)/prod(ns))*Ybars[1,]*Ybars[2,]/(Ybars[1,] + Ybars[2,])
    }

    # integrate from 0 to tau
    truncate.idx <- max(which(sorted.all.time <= tau))
    Qs[j] <- sum(Khat[1:truncate.idx]*(dmuhats[1, 1:truncate.idx] - dmuhats[2, 1:truncate.idx]))

    # Asymptotic variance
    Khat.grp1 <- Khat[time.idx[[1]]][1:truncate.idxs[1]]
    var.partI.1 <- t(apply(dPsihats[[1]], 1, function(x) x*Khat.grp1))
    var.partI.2 <- t(apply(var.partI.1, 1, sum))
    var.partI <- sum(var.partI.2^2)*ns[2]/(sum(ns)*ns[1])

    Khat.grp2 <- Khat[time.idx[[2]]][1:truncate.idxs[2]]
    var.partII.1 <- t(apply(dPsihats[[2]], 1, function(x) x*Khat.grp2))
    var.partII.2 <- t(apply(var.partII.1, 1, sum))
    var.partII <- sum(var.partII.2^2)*ns[1]/(sum(ns)*ns[2])


    consts[j] <- 1/sqrt(ns[1]*ns[2]/sum(ns))
    nss[[j]] <- ns
    vars[j] <- (var.partI + var.partII)*consts[j]^2
    total.ns[j] <- sum(ns)
    sigma.hats[j] <- var.partI + var.partII
    # var.partIs[[j]] <- var.partI.2
    # var.partIIs[[j]] <- var.partII.2
    # enrolled_idx[[j]] <- unique(data.censored$id)

    # Interpolate varhats to the full length for each group
    # var.partIs_full <- numeric(max.n/2) #Simulation
    var.partIs_full <- numeric(max.n1) # Real Data example
    var.partIs_full[grp1.enrolled_id[[j]]] <- var.partI.2
    var.partIs[[j]] <- var.partIs_full

    # var.partIIs_full <- numeric(max.n/2) #Simulation
    var.partIIs_full <- numeric(max.n2) # Real Data Example
    # var.partIIs_full[grp2.enrolled_id[[j]] - max.n/2] <- var.partII.2 # Simulation
    var.partIIs_full[grp2.enrolled_id[[j]] - max.n1] <- var.partII.2
    var.partIIs[[j]] <- var.partIIs_full

    # Qinghua 05/07/2025 Update: Clean up objects reused in the next j loop
    rm(data.censored, Ybars, dmuhats, dPsihats, time.idx, truncate.idxs)
    gc(verbose = FALSE)

  } # end of the 'j' loop

  if (length(calendars) > 1){
    for (p in 1:(length(calendars)-1)){
      for (q in (p+1):length(calendars)){
        corr.matrix[p,q] = (nss[[q]][2]/(total.ns[q]*nss[[q]][1])*sum(var.partIs[[p]]*var.partIs[[q]]) +
                              nss[[q]][1]/(total.ns[q]*nss[[q]][2])*sum(var.partIIs[[p]]*var.partIIs[[q]]))/(sqrt(sigma.hats[p]*sigma.hats[q]))
        corr.matrix[q,p] = corr.matrix[p,q]
      }
    }
    return(list(Qs = Qs,
                vars = vars,
                total.ns = total.ns,
                corr.matrix = corr.matrix,
                nss = nss))
  } else {
    return(list(Qs = Qs,
                vars = vars,
                total.ns = total.ns))
  }
}
