# ========================================
# final_prop_svyglm: Propensity-weighted survey GLM
# ========================================
#' Propensity-weighted survey GLM
#'
#' Calculates IPTW weights and fits survey-weighted GLM.
#' Supports binary, multinomial, or continuous exposures.
#'
#' @param data Data frame
#' @param dep_var Character; binary outcome
#' @param exposure Character; treatment/exposure variable
#' @param covariates Character vector; adjustment variables
#' @param id_var Character; PSU
#' @param strata_var Character; strata
#' @param weight_var Character; base weight
#' @param exposure_type Character; "binary", "multinomial", "continuous"
#' @param outcome_covariates Character vector of additional covariates to include in the final outcome model after applying propensity weights (default = NULL)
#' @param level Numeric; confidence interval level
#' @param ... Additional args to svyglm
#' @examples
#' set.seed(123)
#' n <- 1500
#' dat <- data.frame(
#'   psu = sample(1:10, n, replace = TRUE),
#'   strata = sample(1:5, n, replace = TRUE),
#'   weight = runif(n, 0.5, 2),
#'   age = rnorm(n, 50, 10),
#'   sex = factor(sample(c("Male", "Female"), n, replace = TRUE)),
#'   exposure_bin = rbinom(n, 1, 0.5)
#' )
#' dat$outcome <- rbinom(n, 1, plogis(-2 + 0.03*dat$age + 0.5*dat$exposure_bin))
#' ## ---- Example 1: Binary exposure ----
#' fit_bin_exp<-final_prop_svyglm(dat, dep_var="outcome",
#'                   covariates=c("age","sex"),
#'                   exposure="exposure_bin",
#'                   id_var="psu", strata_var="strata",
#'                   weight_var="weight", outcome_covariates = NULL)
#'fit_bin_exp$OR_table
#'## ---- Example 2: Continuous exposure ----
#'fit_cont_exp <- final_prop_svyglm(
#'   dat,
#'   dep_var     = "outcome",
#'   covariates  = c("sex"),
#'   exposure    = "age",
#'   id_var      = "psu",
#'   strata_var  = "strata",
#'   weight_var  = "weight",
#'  exposure_type = "continuous",
#'  outcome_covariates = NULL)
#' fit_cont_exp$OR_table
#' #### ---- Example 1: Multinomial exposure ----
#'dat$exposure_3cat <- cut(dat$age,
#' breaks = quantile(dat$age, probs = c(0, 1/3, 2/3, 1)),  # tertiles
#'  labels = c("Low", "Medium", "High"),
#' include.lowest = TRUE)
#' # Numeric coding for exposure effect
#' exp_eff <- ifelse(dat$exposure_3cat == "Low", 0,
#'                 ifelse(dat$exposure_3cat == "Medium", 0.6, 1.2))
#'dat$outcome <- rbinom(n,1,plogis(-3 +0.02 * dat$age +0.4  * (dat$sex == "Male") +exp_eff))
#'fit_multi_cat <- final_prop_svyglm(dat, dep_var     = "outcome",
#'covariates  = c("age", "sex"), exposure    = "exposure_3cat",
#'id_var      = "psu", strata_var  = "strata", weight_var  = "weight",
#'exposure_type = "multinomial",
#'outcome_covariates = NULL)
#'fit_multi_cat$OR_table
#' @return A list with:
#' \itemize{
#'   \item \code{ps_model}: Propensity score svyglm model.
#'   \item \code{final_model}: Weighted outcome svyglm model.
#'   \item \code{OR_table}: Odds ratios with CI and p-values.
#'   \item \code{AUC}: Weighted AUC.
#'   \item \code{data}: Data with IPTW and predictions.
#' }
#' @importFrom survey svydesign svyglm
#' @importFrom stats glm predict coef confint residuals binomial fitted quasibinomial as.formula
#' @importFrom nnet multinom
#' @export
# ---
final_prop_svyglm <- function(
    data,
    dep_var,
    covariates,
    exposure,
    id_var,
    strata_var,
    weight_var,
    exposure_type = "binary",
    outcome_covariates = NULL,   # <-- new optional argument
    level = 0.95,
    ...
) {
  # ---- Safety Check: Remove rows with missing critical variables ----
  required_vars <- c(dep_var, exposure, covariates, outcome_covariates)
  required_vars <- required_vars[!sapply(required_vars, is.null)]  # exclude NULL
  missing_rows <- !complete.cases(data[, required_vars, drop = FALSE])

  if (any(missing_rows)) {
    warning(sprintf("Removing %d rows with missing values in outcome, exposure, or covariates", sum(missing_rows)))
    data <- data[!missing_rows, , drop = FALSE]
  }
  if(exposure_type == "continuous") {
    # Gaussian GPS
    lm_fit <- stats::lm(stats::as.formula(paste(exposure, "~", paste(covariates, collapse="+"))), data=data)
    data$residual <- stats::residuals(lm_fit)
    data$muhat <- stats::fitted(lm_fit)
    sdhat <- stats::sd(data$residual, na.rm=TRUE)
    data$gps <- (1/(sdhat*sqrt(2*pi))) * exp(-(data$residual^2)/(2*sdhat^2))

    lm_fit0 <- stats::lm(stats::as.formula(paste(exposure, "~1")), data=data)
    muhat0 <- stats::coef(lm_fit0)
    sdhat0 <- stats::sd(data[[exposure]], na.rm=TRUE)
    data$gps0 <- (1/(sdhat0*sqrt(2*pi))) * exp(-((data[[exposure]] - muhat0)^2)/(2*sdhat0^2))

    eps <- 1e-6
    data$gps  <- pmax(data$gps, eps)
    data$gps0 <- pmax(data$gps0, eps)
    data$stgps <- data$gps0 / data$gps
    data$weight_sw <- data[[weight_var]] * data$stgps

  } else if(exposure_type == "binary") {
    if (is.factor(data[[exposure]]) && length(levels(data[[exposure]])) == 2) {
      lvl <- levels(data[[exposure]])
      message(sprintf(
        "Binary exposure '%s' is factor. Converting '%s' -> 1, '%s' -> 0",
        exposure, lvl[2], lvl[1]
      ))
      data[[exposure]] <- ifelse(data[[exposure]] == lvl[2], 1, 0)
    }
    # ---- Fit propensity score model ----
    ps_model <- stats::glm(stats::as.formula(paste(exposure, "~", paste(covariates, collapse="+"))),
                           data=data, family=stats::binomial())
    eps <- 1e-6
    data$p1 <- stats::predict(ps_model, type="response")
    # Identify which probabilities were adjusted
    adjusted_idx <- which(data$p1 < eps | data$p1 > 1 - eps)
    # Bound the probabilities
    data$p1 <- pmin(pmax(data$p1, eps), 1 - eps)

    # Warn if any adjustments happened
    if (length(adjusted_idx) > 0) {
      warning(sprintf(
        "%d propensity scores were adjusted to be within [%g, %g] to avoid Inf/NaN weights",
        length(adjusted_idx), eps, 1 - eps
      ))
    }


    data$iptwt <- rep(NA_real_, nrow(data))
    idx1 <- data[[exposure]] == 1
    idx0 <- data[[exposure]] == 0
    data$iptwt[idx1] <- 1 / data$p1[idx1]
    data$iptwt[idx0] <- 1 / (1 - data$p1[idx0])
    data$weight_sw <- data[[weight_var]] * data$iptwt

  } else if(exposure_type == "multinomial") {
    data[[exposure]] <- factor(data[[exposure]])
    # ---- Safety Check: Ensure each level has enough observations ----
    tab <- table(data[[exposure]])
    rare_levels <- names(tab)[tab < 5]  # threshold can be adjusted
    if (length(rare_levels) > 0) {
      warning(sprintf("Levels %s of exposure have very few observations (<5). Results may be unstable.",
                      paste(rare_levels, collapse = ", ")))
    }
    # Fit multinomial logistic regression for propensity scores
    mlogit <- nnet::multinom(stats::as.formula(paste(exposure, "~", paste(covariates, collapse="+"))), data=data)
    pred_probs <- stats::predict(mlogit, type="probs")
    colnames(pred_probs) <- paste0("p", colnames(pred_probs))
    data <- cbind(data, pred_probs)
    data$iptwt <- NA
    for(i in levels(data[[exposure]])) {
      idx <- data[[exposure]] == i
      data$iptwt[idx] <- 1 / data[idx, paste0("p", i)]
    }
    data$norm_w <- data$iptwt / mean(data$iptwt, na.rm = TRUE)
    data$weight_sw <- data[[weight_var]] * data$norm_w
  }
  # ---- Safety Check: Ensure weights are not zero or extremely large ----
  if (any(data$weight_sw <= 0)) {
    warning("Some computed survey weights are zero or negative. Adjusting to small positive value (1e-6).")
    data$weight_sw[data$weight_sw <= 0] <- 1e-6
  }

  if (max(data$weight_sw) > 1e5) {
    warning("Some weights are very large (>1e5). Consider truncating to improve stability.")
    data$weight_sw <- pmin(data$weight_sw, 1e5)
  }
  ###########################################################
  # Survey design
  if (is.null(id_var) || id_var == "") stop("id_var is empty")

  # 1. Create formulas as strings
  design_sw <- survey::svydesign(
    id      = stats::as.formula(paste0("~", id_var)),
    strata  = stats::as.formula(paste0("~", strata_var)),
    weights = ~weight_sw,
    data    = data,
    nest    = TRUE
  )
  # Final GLM
  # ---- Safety Check: Ensure outcome has both 0 and 1 ----
  outcome_table <- table(data[[dep_var]])
  if (length(outcome_table) < 2) {
    stop(sprintf("Outcome '%s' does not contain both 0 and 1. Cannot fit logistic model.", dep_var))
  }

  if (any(outcome_table < 5)) {
    warning(sprintf("Outcome '%s' has very few events (%d) or non-events (%d). Estimates may be unstable.",
                    dep_var, outcome_table[1], outcome_table[2]))
  }
  ####
  final_vars <- c(exposure, outcome_covariates)  # outcome_covariates is NULL by default
  fit_formula <- stats::as.formula(
    paste(dep_var, "~", paste(final_vars, collapse = "+"))
  )

  fit <- survey::svyglm(fit_formula,
                        design = design_sw,
                        family = stats::quasibinomial(),
                        ...)

  # OR table
  coef_fit <- stats::coef(fit)
  ci <- stats::confint(fit, level = level)
  OR_table <- data.frame(
    Variable = names(coef_fit),
    OR = exp(coef_fit),
    CI_low = exp(ci[,1]),
    CI_high = exp(ci[,2]),
    p_value = summary(fit)$coefficients[, "Pr(>|t|)"],
    row.names = NULL
  )

  # AUC
  pred <- as.numeric(stats::predict(fit, type = "response"))
  pred_class <-ifelse(pred >= 0.5, 1L, 0L)
  # Use model.response to safely extract the outcome from the fitted model
  final_outcome <- as.numeric(stats::model.response(stats::model.frame(fit)))
  # Extract the IPTW-adjusted weights directly from the design used in the fit
  # fit$survey.design contains the actual weights (weight_sw) used for the GLM
  final_wts <- as.numeric(stats::weights(fit$survey.design))
##########################################################################################
  out <- list(
    model        = fit,
    OR_table     = OR_table,
    outcome = final_outcome,
    predictions=as.numeric(pred),
    final_weights= final_wts
  )
  class(out) <- "svyCausal"
  return(invisible(out))
}
#'@exportS3Method
# Add this small helper function OUTSIDE your main function
# This tells R: "If someone looks at this object, only show the OR Table"
print.svyCausal <- function(x, ...) {
  cat("\n--- svyCausalGLM Results ---\n")
  print(x$OR_table)
}



