#!/usr/bin/env sh

# 'local' isn't POSIX-compliant, but in practice is available in
# implementations of sh that we care about.
# shellcheck disable=SC3043

set -eu


SCRIPT_NAME=$(basename "$0")

usage ()
{
    cat <<EOF
${SCRIPT_NAME}: Find parallelism limit based on memory and CPUs available.
Requires some upper bound (i.e. reasonable guesstimate) of how much RAM each job needs.
Takes into account the physical memory available on the machine, and cgroup limits.
Reports but doesn't take into account how much memory appears to be currently
available.

The result is printed as a single integer to stdout. Additional information is
printed to stderr.

Usage:
  $SCRIPT_NAME MAX_MB_RAM_per_job : Calculate max parallelism for given max-MB-RAM-per-job

Options:
  -h: Print this message.
EOF
}

# echo to stderr
eecho () {
    >&2 echo "$*"
}

min () {
    local lhs="$1"
    local rhs="$2"
    echo $(( lhs < rhs ? lhs : rhs ))
}

max () {
    local lhs="$1"
    local rhs="$2"
    echo $(( lhs > rhs ? lhs : rhs ))
}

is_numeric () {
    case "$1" in
    ''|*[!0-9]*) return 1;;
    *) return 0;;
    esac
}

while getopts "h" opt ; do
    case "$opt" in
	h) usage
	   exit 0
	   ;;
	*) echo "Unknown option."
	   exit 1
	   ;;
    esac
done

# Remove the flags we parsed.
shift $((OPTIND-1))

if [ $# -ne 1 ] || ! is_numeric "$1";
then
    usage
    exit 1
fi
max_mb_per_job="$1"

# Follow the cgroup hierarchy and return the effective memory limit,
# which is the minimum of the limits at each level of the hierarchy.
get_cgroup_memmax_MB () {
    local cgroup_root
    cgroup_root=/sys/fs/cgroup
    local cgroup
    cgroup=$(cut -d : -f3 < /proc/self/cgroup)
    local next_cgroup_dir
    next_cgroup_dir="$cgroup_root$cgroup"
    local res="unlimited"
    while true
    do
        case "$next_cgroup_dir" in
        "$cgroup_root"*) : ;;
        *) break;;
        esac

	local cgroup_dir="$next_cgroup_dir"
	next_cgroup_dir=$(dirname "$next_cgroup_dir")
	local f="${cgroup_dir}memory.max"
	if [ ! -f "$f" ]
        then
            continue
	fi
	local this_max
	this_max=$(cat "$f")
	if ! is_numeric "$this_max"
        then
            continue
	fi
	eecho "found cgroup limit $cgroup_dir: $this_max"
	if ! is_numeric "$res"
        then
            res=$(( this_max / 1024 / 1024 ))
	fi
	res=$(min "$this_max" "$res")
    done
    echo "$res"
}

get_meminfo_available_MB () {
    local x
    x="$(grep ^MemAvailable /proc/meminfo | awk '{ print $2; }')"
    echo $(( x / 1024 ))
}

get_meminfo_total_MB () {
    local x
    x="$(grep ^MemTotal /proc/meminfo | awk '{ print $2; }')"
    echo $(( x / 1024 ))
}

get_max_mem_MB () {
    # Report available, but for now we don't use it since it may give
    # inconsistent results. e.g. we don't want to severely limit parallelism if
    # something happens to be briefly using a lot of memory at the moment.
    local meminfo_available
    meminfo_available=$(get_meminfo_available_MB)
    eecho "MemAvailable: $meminfo_available MB"

    local meminfo_total
    meminfo_total=$(get_meminfo_total_MB)
    eecho "MemTotal: $meminfo_total MB"

    local cgroup_memmax
    cgroup_memmax=$(get_cgroup_memmax_MB)
    eecho "cgroup limit: $cgroup_memmax MB"

    if ! is_numeric "$cgroup_memmax"
    then
        # No cgroup limit; fall back to total
        echo "$meminfo_total"
    else
        # Report smaller of total or cgroup limit
        min "${meminfo_total}" "${cgroup_memmax}"
    fi
}

get_parallelism_limit() {
    local mb_per_job="$1"
    local max_mem_MB
    max_mem_MB=$(get_max_mem_MB)
    local mem_based_limit=$(( max_mem_MB / mb_per_job ))
    eecho "memory based parallelism limit: $mem_based_limit"
    local cpu_based_limit
    cpu_based_limit=$(nproc)
    eecho "cpu based parallelism limit: $cpu_based_limit"
    max 1 "$(min "$mem_based_limit" "$cpu_based_limit")"
}

get_parallelism_limit "$max_mb_per_job"
