## -----------------------------------------------------------------------------
#| include: false
has_pkg <- requireNamespace("TemporalHazard", quietly = TRUE) &&
  requireNamespace("ggplot2", quietly = TRUE)
knitr::opts_chunk$set(
  collapse = TRUE,
  comment  = "#>",
  eval     = has_pkg
)


## -----------------------------------------------------------------------------
#| label: setup
library(TemporalHazard)
library(survival)
library(ggplot2)


## -----------------------------------------------------------------------------
#| label: km-baseline
data(avc)
avc <- na.omit(avc)
km <- survfit(Surv(int_dead, dead) ~ 1, data = avc)


## -----------------------------------------------------------------------------
#| label: fig-km
#| fig-cap: "Kaplan-Meier survival estimate: death after AVC repair"
#| fig-width: 7
#| fig-height: 4
km_df <- data.frame(time = km$time, survival = km$surv * 100)

ggplot(km_df, aes(time, survival)) +
  geom_step(linewidth = 0.6) +
  scale_y_continuous(limits = c(0, 100)) +
  labs(x = "Months after repair", y = "Freedom from death (%)") +
  theme_minimal()


## -----------------------------------------------------------------------------
#| label: fit-mv
fit <- hazard(
  Surv(int_dead, dead) ~ age + status + mal + com_iv,
  data  = avc,
  dist  = "weibull",
  theta = c(mu = 0.20, nu = 1.0, rep(0, 4)),
  fit   = TRUE,
  control = list(maxit = 500)
)


## -----------------------------------------------------------------------------
#| label: pred-types
t_grid <- seq(0.01, max(avc$int_dead) * 0.95, length.out = 200)
profile <- data.frame(
  time   = t_grid,
  age    = median(avc$age),
  status = 2,
  mal    = 0,
  com_iv = 0
)

surv   <- predict(fit, newdata = profile, type = "survival")
cumhaz <- predict(fit, newdata = profile, type = "cumulative_hazard")

profile$survival          <- surv
profile$cumulative_hazard <- cumhaz

head(profile[, c("time", "survival", "cumulative_hazard")])


## -----------------------------------------------------------------------------
#| label: fig-surv-overlay
#| fig-cap: "Weibull parametric survival vs. Kaplan-Meier (AVC death)"
#| fig-width: 7
#| fig-height: 4.5
ggplot() +
  geom_step(data = km_df, aes(time, survival, colour = "Kaplan-Meier"),
            linewidth = 0.5) +
  geom_line(data = profile,
            aes(time, survival * 100, colour = "Parametric (Weibull)"),
            linewidth = 1) +
  scale_colour_manual(
    values = c("Parametric (Weibull)" = "#0072B2",
               "Kaplan-Meier"         = "#D55E00")
  ) +
  scale_y_continuous(limits = c(0, 100)) +
  labs(x = "Months after repair", y = "Freedom from death (%)",
       colour = NULL) +
  theme_minimal() +
  theme(legend.position = "bottom")


## -----------------------------------------------------------------------------
#| label: pred-with-se
# Build a clean newdata frame (the earlier chunk appended result
# columns to `profile`, which would confuse predict()'s column count).
profile_ci <- data.frame(
  time   = t_grid,
  age    = median(avc$age),
  status = 2,
  mal    = 0,
  com_iv = 0
)
surv_ci <- predict(fit, newdata = profile_ci,
                   type = "survival", se.fit = TRUE, level = 0.95)
head(surv_ci)


## -----------------------------------------------------------------------------
#| label: fig-surv-ci
#| fig-cap: "Parametric survival with 95% delta-method confidence band"
#| fig-width: 7
#| fig-height: 4.5
ci_df <- data.frame(
  time     = profile_ci$time,
  survival = surv_ci$fit * 100,
  lower    = surv_ci$lower * 100,
  upper    = surv_ci$upper * 100
)

ggplot() +
  geom_step(data = km_df, aes(time, survival, colour = "Kaplan-Meier"),
            linewidth = 0.5) +
  geom_ribbon(data = ci_df,
              aes(time, ymin = lower, ymax = upper),
              fill = "#0072B2", alpha = 0.15) +
  geom_line(data = ci_df,
            aes(time, survival, colour = "Parametric + 95% CI"),
            linewidth = 1) +
  scale_colour_manual(
    values = c("Parametric + 95% CI" = "#0072B2",
               "Kaplan-Meier"        = "#D55E00")
  ) +
  scale_y_continuous(limits = c(0, 100)) +
  labs(x = "Months after repair", y = "Freedom from death (%)",
       colour = NULL) +
  theme_minimal() +
  theme(legend.position = "bottom")


## -----------------------------------------------------------------------------
#| label: fit-mp
data(cabgkul)

fit_mp <- hazard(
  Surv(int_dead, dead) ~ 1,
  data   = cabgkul,
  dist   = "multiphase",
  phases = list(
    early    = hzr_phase("cdf", t_half = 0.2, nu = 1, m = 1,
                          fixed = "shapes"),
    constant = hzr_phase("constant"),
    late     = hzr_phase("g3",  tau = 1, gamma = 3, alpha = 1, eta = 1,
                          fixed = "shapes")
  ),
  fit     = TRUE,
  control = list(n_starts = 5, maxit = 1000)
)


## -----------------------------------------------------------------------------
#| label: fig-decomposed-hazard
#| fig-cap: "Additive phase decomposition: total hazard rate (solid) = early + constant + late (dashed)"
#| fig-width: 7
#| fig-height: 4.5
t_mp <- seq(0.01, max(cabgkul$int_dead) * 0.95, length.out = 200)
nd   <- data.frame(time = t_mp)

decomp <- predict(fit_mp, newdata = nd, type = "cumulative_hazard",
                  decompose = TRUE)

# Numerical differentiation: h(t) ≈ ΔH(t) / Δt
num_hazard <- function(cumhaz, time) {
  dt <- diff(time)
  dH <- diff(cumhaz)
  c(dH[1] / dt[1], dH / dt)
}

h_long <- rbind(
  data.frame(time = t_mp, hazard = num_hazard(decomp$early, t_mp),
             Phase = "Early"),
  data.frame(time = t_mp, hazard = num_hazard(decomp$constant, t_mp),
             Phase = "Constant"),
  data.frame(time = t_mp, hazard = num_hazard(decomp$late, t_mp),
             Phase = "Late"),
  data.frame(time = t_mp, hazard = num_hazard(decomp$total, t_mp),
             Phase = "Total")
)
h_long$Phase <- factor(h_long$Phase,
                       levels = c("Total", "Early", "Constant", "Late"))

ggplot(h_long, aes(time, hazard, colour = Phase, linetype = Phase)) +
  geom_line(aes(linewidth = Phase)) +
  scale_colour_manual(values = c(Total = "#222222", Early = "#E69F00",
                                 Constant = "#56B4E9", Late = "#CC79A7")) +
  scale_linetype_manual(values = c(Total = "solid", Early = "dashed",
                                   Constant = "dashed", Late = "dashed")) +
  scale_linewidth_manual(values = c(Total = 1.3, Early = 0.7,
                                    Constant = 0.7, Late = 0.7)) +
  labs(x = "Months after CABG", y = "Hazard rate",
       colour = "Phase", linetype = "Phase", linewidth = "Phase") +
  theme_minimal() +
  theme(legend.position = "bottom")


## -----------------------------------------------------------------------------
#| label: fig-mp-surv
#| fig-cap: "Multiphase parametric survival vs. Kaplan-Meier"
#| fig-width: 7
#| fig-height: 4.5
surv_mp <- predict(fit_mp, newdata = nd, type = "survival") * 100

ggplot() +
  geom_step(data = km_df, aes(time, survival, colour = "Kaplan-Meier"),
            linewidth = 0.5) +
  geom_line(data = data.frame(time = t_grid, survival = surv_mp),
            aes(time, survival, colour = "Multiphase (3-phase)"),
            linewidth = 1) +
  scale_colour_manual(
    values = c("Multiphase (3-phase)" = "#0072B2",
               "Kaplan-Meier"         = "#D55E00")
  ) +
  scale_y_continuous(limits = c(0, 100)) +
  labs(x = "Months after AVC repair", y = "Freedom from death (%)",
       colour = NULL) +
  theme_minimal() +
  theme(legend.position = "bottom")


## -----------------------------------------------------------------------------
#| label: fig-risk-profiles
#| fig-cap: "Predicted survival by risk profile"
#| fig-width: 7
#| fig-height: 4.5
profiles <- list(
  "Low risk"  = data.frame(age = quantile(avc$age, 0.25),
                            status = 1, mal = 0, com_iv = 0),
  "Median"    = data.frame(age = median(avc$age),
                            status = 2, mal = 0, com_iv = 0),
  "High risk" = data.frame(age = quantile(avc$age, 0.90),
                            status = 4, mal = 1, com_iv = 1)
)

curves <- do.call(rbind, lapply(names(profiles), function(nm) {
  nd <- profiles[[nm]][rep(1, length(t_grid)), ]
  nd$time <- t_grid
  data.frame(time = t_grid,
             survival = predict(fit, newdata = nd, type = "survival") * 100,
             Profile = nm)
}))
curves$Profile <- factor(curves$Profile,
                         levels = c("Low risk", "Median", "High risk"))

ggplot(curves, aes(time, survival, colour = Profile)) +
  geom_line(linewidth = 0.9) +
  scale_colour_manual(values = c("Low risk" = "#009E73",
                                 "Median"   = "#0072B2",
                                 "High risk" = "#D55E00")) +
  scale_y_continuous(limits = c(0, 100)) +
  labs(x = "Months after AVC repair", y = "Freedom from death (%)",
       colour = NULL) +
  theme_minimal() +
  theme(legend.position = "bottom")


## -----------------------------------------------------------------------------
#| label: fig-valves-endpoints
#| fig-cap: "Freedom from death and PVE after valve replacement"
#| fig-width: 7
#| fig-height: 4.5
data(valves)
valves <- na.omit(valves)

km_death <- survfit(Surv(int_dead, dead) ~ 1, data = valves)
km_pve   <- survfit(Surv(int_pve, pve) ~ 1, data = valves)

ep_df <- rbind(
  data.frame(time = km_death$time, survival = km_death$surv * 100,
             Endpoint = "Death"),
  data.frame(time = km_pve$time, survival = km_pve$surv * 100,
             Endpoint = "PVE")
)

ggplot(ep_df, aes(time, survival, colour = Endpoint)) +
  geom_step(linewidth = 0.7) +
  scale_y_continuous(limits = c(0, 100)) +
  scale_colour_manual(values = c("Death" = "#D55E00", "PVE" = "#0072B2")) +
  labs(x = "Months after valve replacement",
       y = "Freedom from event (%)", colour = NULL) +
  theme_minimal() +
  theme(legend.position = "bottom")

