
# error below is uninterrpretable when network parameters haven't been defined
# network can't be simulated, but the resulting error is unrelated to that:
  # Error in sA[[1:NULL]] : 
  #   Network must be defined (with simcausal::network(...)) prior to using Var[[netidx]] node syntax
  # Error : error while evaluating node net.mean.sA formula: 
  # mean(sA[[1:Kmax]]).
  # Check syntax specification.
  # Error in set.DAG(D) : 
  # ...attempt to simulate data from DAG failed...


`%+%` <- function(a, b) paste0(a, b)
library(igraph)
library(simcausal)
library(tmlenet)

  # Kmax <- 10
  # trunc.c <- 10
  # shift <- 2


  generate.igraph.smallwld <- function(n, Kmax, ...) {
    g <- sample_pa(n)

    # # small world (Watts-Strogatz network) model:
    # g <- sample_smallworld(dim = 1, size = n, nei = 5, p = 0.05, loops = FALSE, multiple = FALSE)
    # g <- as.directed(g, mode = c("mutual"))

    degs <- degree(g, mode = "in") # number of edges going IN, for each vertex:
    high_v <- which(degs > Kmax) # vertex IDs that have too many friends (deg > Kmax)
    if (length(high_v)>0) {
      n_edges_to_del <- degs[high_v] - Kmax # number of edges going IN each high_v that need to be deleted (for each high_v):
      # g.trimmed <- g
      del_edges <- vector(mode="integer")
      del_idx_byvert <- vector(mode="list", length=length(high_v))
      for (i in seq_along(high_v)) {
        del_idx <- igraph::sample_seq(1, degs[high_v][i], n_edges_to_del[i])
        del_idx_byvert[[i]] <- del_idx
      }
      names(del_idx_byvert) <- as.character(high_v)
      edgelist_mat <- as_edgelist(g)
      colnames(edgelist_mat) <- c("from", "to")
      edgelist_mat <- cbind(id=(1:nrow(edgelist_mat)), edgelist_mat)
      edgelist_mat <- edgelist_mat[edgelist_mat[,"to"]%in%high_v,]
      # should be equal:
      # print(nrow(edgelist_mat)==length(E(g)[to(high_v)]))
      edgelist_df <- data.table(edgelist_mat)
      setkeyv(edgelist_df, "to")
      # use the list del_idx_byvert to index the row numbers for edge IDs that should be deleted
      # del_idx_byvert enumerates all edges that should be deleted indexed in 1:n_edges_to_del[v_i]
      delrows <- edgelist_df[, list(delrows=.I[del_idx_byvert[[.GRP]]]), by=to]
      # obtain IDs of the edges that need to be removed from the network:
      id <- edgelist_df[delrows[["delrows"]], ][["id"]]
      # print(length(id)==length(unlist(del_idx_byvert)))
      # will select an entire random sampled of edges going IN high_v that need to be deleted and delete them all at once:
      # g.trimmed <- g.trimmed - E(g.trimmed)[id]
      g <- g - E(g)[id]
    }

    # From igraph object to sparse adj. matrix:
    # sparse_AdjMat <- simcausal::igraph.to.sparseAdjMat(g.trimmed)
    sparse_AdjMat <- simcausal::igraph.to.sparseAdjMat(g)
    # From igraph object to simcausal/tmlenet input (NetInd_k, nF, Kmax):
    NetInd_out <- simcausal::sparseAdjMat.to.NetInd(sparse_AdjMat)
    if (Kmax < NetInd_out$Kmax) message("new network has larger Kmax value than requested, new Kmax = " %+% NetInd_out$Kmax)
    return(NetInd_out$NetInd_k)
  }

  D <- DAG.empty()
  # Adding the ER model network generator from igraph:
  D <- D + network("NetInd_k", Kmax = Kmax, netfun = "generate.igraph.smallwld")
  D <- D +
      node("W1", distr = "rbern", prob = 0.5) +
      node("W2", distr = "rbern", prob = 0.3) +
      node("W3", distr = "rbern", prob = 0.3) +
      node("sA.mu", distr = "rconst", const = (0.98 * W1 + 0.58 * W2 + 0.33 * W3)) +
      node("sA", distr = "rnorm", mean = sA.mu, sd = 1) +
      node("net.mean.sA", distr = "rconst", const = mean(sA[[1:Kmax]])) +
      node("r.obs.sA",  distr = "rconst", const = exp(shift * (sA - sA.mu - shift / 2))) +
      node("untrunc.sA.gstar",  distr = "rconst", const = sA + shift) +
      node("r.new.sA",  distr = "rconst", const = exp(shift * (untrunc.sA.gstar - sA.mu - shift / 2))) +
      node("tr.sA.gstar",  distr = "rconst", const = ifelse(r.new.sA > trunc.c, sA, untrunc.sA.gstar)) +
      node("probY", distr = "rconst", const = plogis(-0.35 * sA - 0.20 * mean(sA[[0:Kmax]]) - 0.5 * W1 - 0.58 * W2 - 0.33 * W3)) +
      # node("probY", distr = "rconst", const = plogis(-0.35 * sA - 0.20 * sum(sA[[1:Kmax]]) / nF - 0.5 * W1 - 0.58 * W2 - 0.33 * W3)) +
      node("Y", distr = "rbern", prob = probY) +
      node("probY.gstar", distr = "rconst", const = plogis(-0.35 * tr.sA.gstar - 0.20 * mean(tr.sA.gstar[[0:Kmax]]) - 0.5 * W1 - 0.58 * W2 - 0.33 * W3)) +
      # node("probY.gstar", distr = "rconst", const = plogis(-0.35 * tr.sA.gstar - 0.20 * sum(tr.sA.gstar[[1:Kmax]]) / nF - 0.5 * W1 - 0.58 * W2 - 0.33 * W3)) +
      node("Y.gstar", distr = "rbern", prob = probY.gstar)
  Dset <- set.DAG(D)
