#!/usr/bin/env python3
#
#Copyright (C) 2022 The University of Notre Dame
#This software is distributed under the GNU General Public License.
#See the file COPYING for details.
#
# This program implements a way to organize and manage a large number of
# concurrently running GATK instances

# Author: Nick Hazekamp
# Date: 09/03/2013


import optparse, os, sys, re, string, logging as log, stat

class PassThroughParser(optparse.OptionParser):
    def _process_args(self, largs, rargs, values):
        while rargs:
            try:
                optparse.OptionParser._process_args(self,largs,rargs,values)
            except (optparse.BadOptionError,optparse.AmbiguousOptionError) as e:
                largs.append(e.opt_str)

def parse_req(parser):
        group = optparse.OptionGroup(parser, "Required Parameters")
        group.add_option('-T','--analysis_type',dest='type',type="string",help="Type of analysis to run")
        group.add_option('--query',dest='input',type="string",help="SAM or BAM file(s)")
        group.add_option('--ref',dest='ref',type="string",help="Reference sequence file")
        group.add_option('--out',dest='output',type="string",help="Output name")
        group.add_option('--Xmx', dest='xmx', help='Specify Java Heap size', default="4G")
        parser.add_option_group(group)

def parse_makeflow(parser):
        group = optparse.OptionGroup(parser, "Makeflow Preparation Options")
        group.add_option('--makeflow', dest='makeflow', help='Makeflow destination file [stdout]')
        group.add_option('--makeflow-config', dest='config', help='Makeflow configurations File')
        group.add_option('--num-seq', dest='num_seq', help='Determines number of sequences per partition [10000]', default=10000,type=int)
        group.add_option('--verbose', dest='verbose', help='Show verbose level output', action='store_true',default="False")
        parser.add_option_group(group)


#Helper function for finding executables in path
def search_file(filename, search_path, pathsep=os.pathsep):
        """ Given a search path, find file with requested name """
        for path in string.split(search_path, pathsep):
            candidate = os.path.join(path, filename)
            if os.path.exists(candidate): return os.path.abspath(candidate)
        return None

def find_file(filename, path):
	log.info("Searching PATH for : {0}".format(filename))
	loc = search_file(filename, path)
	if os.path.exists("./{0}".format(filename)):
		log.info("{0} found locally".format(filename))
	elif loc and not os.path.exists("./{0}".format(filename)):
		os.symlink(loc, filename)
		log.info("{0} linked from : {1}".format(filename, loc))
	else:
	    log.error("Unable to find {0}".format(filename))
	    sys.exit(3)

def write_progeny_split(destination):
	prog_split = open(destination, 'w')
	log.info("Writing Partitioning Script to : {0}".format(destination))
	try:
		prog_split.write('''#!/usr/bin/perl
#Programmer: Nicholas Hazekamp
#Date: 6/10/2014

use integer;
use Symbol;
use Data::Dumper;

my $numargs = $#ARGV + 1;

if ($numargs != 3) {
	print STDERR "Usage: perl gatk_split_align <number of reads> <reference file> <sam file>
";
	exit 1;
}

my $num_reads = shift;
my $ref_file = shift;
my $sam_file = shift;
my $ref = substr $ref_file, 0, -3;
my $sam = substr $sam_file, 0, -4;

my %ref_defs = ();
my @loc_defs = ();
my @outputs = ();
my @files;

my $i = 0;
my $position = 0;
my $read_count = 0;
my $num_outputs = 0;

#Open input file
open(INPUT, $ref_file);
my $loc_def = "";
my $last_loc = "";

my $file = gensym;
open(OUTPUT,">$ref.$num_outputs.fa");
open($file, ">$sam.$num_outputs.sam");
push(@outputs, $file);

while (my $line = <INPUT>) {
	chomp $line;
#FASTQ files begin sequence with '@' character
#If line begins with '@' then it is a new sequence
	if ($line =~ /^>/){
		($loc_def, my $loc) = $line =~ m/>(\D+)(\d+)/;
		if($loc_def ne $last_loc){
			push (@loc_defs, $loc_def);
			$last_loc = $loc_def;
		}
#Check if the new sequence should be placed in a new file, otherwise place it in same file
		if ($read_count == $num_reads){
			close(OUTPUT);
			$num_outputs++;
			$read_count = 1;

			$file = gensym;
			open(OUTPUT,">$ref.$num_outputs.fa");
			open($file, ">$sam.$num_outputs.sam");
			push(@outputs, $file);
			$ref_defs{$loc_def+$loc} = $file;

			print OUTPUT "$line\\n";
		}
		else{
			$ref_defs{$loc_def+$loc} = $file;
			print OUTPUT "$line\\n";
			$read_count++;
		}
	}
#place all other lines in FASTQ file under same sequence
	else {
		print OUTPUT "$line\\n";
	}
}
close(OUTPUT);
close(INPUT);

my $output;

#Open input file
open(INPUT, $sam_file);

my $in_header = 1;

my $loc_def_str = join("|", (@loc_defs));
$loc_def_str = "".$loc_def_str."";

while (my $line = <INPUT>) {
	chomp $line;

	($loc_def, my $loc) = $line =~ /($loc_def_str)(\d+)/;
	$output = $ref_defs{$loc_def+$loc};

	if ($line =~ /^\@SQ/ and tell($output) != -1){
		if($in_header eq 0){
			print "\@SQ found after end of header.\\n";
			exit;
		}
		print $output "$line\\n";
	}
	elsif($line =~ /^\@/){
		foreach $output (@outputs) {
			print $output "$line\\n";
		}
	}
	#place all other lines in FASTQ file under same sequence
	elsif (tell($output) != -1){
		if($in_header == 1){
			$in_header == 0;
		}
		print $output "$line\\n";
	}
}
close(INPUT);

foreach $output (@outputs) {
	close($output);
}''')
	finally:
		if destination:
			prog_split.close()

def write_java(destination):
	ref_split = open(destination, 'w')
	log.info("Writing Java Environment Script to : {0}".format(destination))
	try:
		ref_split.write('''
# Author: Nicholas Hazekamp
# Date: 04/28/2014


import optparse, os, sys, tempfile, shutil, stat

#Parse Command Line
parser = optparse.OptionParser()
parser.add_option('', '--inputs', dest=\"inputs\", type=\"string\")
parser.add_option('', '--output', dest=\"output\", type=\"string\")
(options, args) = parser.parse_args()

output = \"\"
if options.output:
	output = \" > \" + options.output + \" 2> \"+ options.output

path = os.environ.get('PATH', 'NULL')
java_home = os.environ.get('JAVA_HOME', 'NULL')
ld_lib = os.environ.get('LD_LIBRARY_PATH', 'NULL')
java_opt = os.environ.get('_JAVA_OPTIONS', 'NULL')

if not java_opt or java_opt == \"NULL\":
	java_opt=\"\"

cwd = os.getcwd()

os.environ['PATH'] = cwd+\"/jre/bin/:\"+path
os.environ['JAVA_HOME'] = cwd+\"/jre/\"
os.environ['LD_LIBRARY_PATH'] = cwd+\"/jre/lib\"
os.environ['_JAVA_OPTIONS'] = java_opt+\"-Djava.io.tmpdir=\"+cwd

os.system('java ' + options.inputs + output)

if path == \"NULL\":
	os.environ['PATH'] = path
if java_home == \"NULL\":
	os.environ['JAVA_HOME'] = java_home
if ld_lib == \"NULL\":
	os.environ['LD_LIBRARY_PATH'] = ld_lib
if java_opt == \"\":
	os.environ['_JAVA_OPTIONS'] = \"NULL\"
''')
	finally:
		if destination:
			ref_split.close()


def count_splits( fastq , num_reads, log):
	log.info("Counting Number of Partitions")
	num_outputs = 1
	num_reads = int(num_reads)
	FILE = open(fastq, "r")

	read_count = 0
	for line in FILE:
		if re.match("^>", line) is not None:
			if (read_count == num_reads):
				num_outputs += 1
				read_count = 0
				if ( num_outputs % 10 == 0):
					log.info("Current Number of Partitions : {0}".format(num_outputs))
			else:
				read_count += 1
		#place all other lines in FASTQ file under same sequence
	FILE.close()
	log.info("Total Number of Partitions : {0}".format(num_outputs))
	return num_outputs


def write_makeflow(destination, configuration, arguments, ref, input, num_reads, log):
	if destination:
		makeflow = open(destination,'w')
		log.info("Writing Makeflow to : {0}".format(destination))
	else:
		makeflow = sys.stdout
		log.info("Writing Makeflow to : STDOUT")

	try:
		if options.config:
			log.info("Reading Configuration from : {0}".format(configuration))
			config = open(configuration, 'r')
			makeflow.write(config.read())

		xmx = ""
		if options.xmx:
			xmx = " -Xmx" + options.xmx + " "

		(ref_base, sep, file_ext) = ref.rpartition(".")
		if not re.search("f[ast]?a", file_ext, re.IGNORECASE):
			ref_base = "{0}.{1}".format(ref_base, file_ext)

		comp = ""
		(input, sep, file_ext) = input.rpartition(".")
		if not re.search("[bs]am", file_ext, re.IGNORECASE):
			input = "{0}.{1}".format(input,file_ext)
		elif re.search("bam", file_ext, re.IGNORECASE):
			comp = "bam"
		else:
			comp = "sam"

		input_original = "{0}.{1}".format(input,file_ext)

		input_full_bam = "{0}.bam".format(input)
		input_full_sam = "{0}.sam".format(input)

		log.info("Writing Partitioning Rules")

		if num_reads > 0:
			makeflow.write("\n\n {0} : samtools {1}".format(input_full_sam, input_full_bam))
			makeflow.write("\n\tLOCAL ./samtools view -h {0} > {1}".format(input_full_bam, input_full_sam))

			merge_files = []
			merge_inputs = []

			count = count_splits(ref, num_reads, log);

			inputs = []
			refs = []
			for i in range(0, count):
				refs.append("{0}.{1}.fa".format(ref_base, i))
				inputs.append("{0}.{1}.sam".format(input,i))

			makeflow.write("\n\n{0} {1} : gatk_split_align {2} {3}".format(' '.join(refs), ' '.join(inputs), ref, input_full_sam))
			makeflow.write("\n\tLOCAL perl gatk_split_align {0} {1} {2}".format(num_reads, ref, input_full_sam))

			log.info("Writing Prep and GATK Rules")

			for i in range(0, count):

				ref_sp = "{0}.{1}.fa".format(ref_base, i)
				index = "{0}.fai".format(ref_sp)
				dict  = "{0}.{1}.dict".format(ref_base,i)

				makeflow.write("\n\n{0} : samtools {1}".format(index, ref_sp))
				makeflow.write("\n\t./samtools faidx {0}".format(ref_sp))

				makeflow.write("\n\n{0} : java_run jre/ picard.jar {1}".format(dict,ref_sp))
				makeflow.write("\n\tpython java_run --inputs \"-jar picard.jar CreateSequenceDictionary R={0} O={1}\"".format(ref_sp, dict))

				ref_index = "{0} {1} {2}".format(ref_sp, index, dict)
				input_sam = "{0}.{1}.sam".format(input, i)
				input_bam = "{0}.{1}.bam".format(input, i)
				input_bai = "{0}.bai".format(input_bam)
				input_vcf = "{0}.{1}.vcf".format(input, i)
				input_idx = "{0}.idx".format(input_vcf)

				makeflow.write("\n\n{0} {1} : {2} samtools".format(input_bam, input_bai, input_sam))
				makeflow.write("\n\t./samtools view -bS {0} > {1}".format(input_sam, input_bam))
				makeflow.write(";\t./samtools index {0}".format(input_bam))

				query = "{0} {1}".format(input_bai, input_bam)

				merge_files.append("{0} {1}".format(input_vcf, input_idx))
				merge_inputs.append("--variant {0}".format(input_vcf))

				makeflow.write("\n\n{0} {1} : java_run jre {2} {3} GenomeAnalysisTK.jar".format(input_vcf, input_idx, ref_index, query))
				makeflow.write("\n\tpython java_run --inputs \"{0} -jar GenomeAnalysisTK.jar -T {1} -R {2} -I {3} -o {4} {5}\"".format(xmx, options.type, ref_sp, input_bam, input_vcf, ' '.join(arguments)))

			log.info("Writing Joining Rules")

			makeflow.write("\n\n{0} : GenomeAnalysisTK.jar {1} {2}".format(options.output, ref, ' '.join(merge_files)))
			makeflow.write("\n\tLOCAL java -jar GenomeAnalysisTK.jar -T CombineVariants -R {0} {1} -o {2}".format(ref, ' '.join(merge_inputs), options.output))

		else:
			ref_sp = "{0}.fa".format(ref_base)
			index = "{0}.fai".format(ref_sp)
			dict  = "{0}.dict".format(ref_base)

			makeflow.write("\n\n{0} : samtools {1}".format(index, ref_sp))
			makeflow.write("\n\t./samtools faidx {0}".format(ref_sp))

			makeflow.write("\n\n{0} : java_run jre/ picard.jar {1}".format(dict,ref_sp))
			makeflow.write("\n\tpython java_run --inputs \"-jar picard.jar CreateSequenceDictionary R={0} O={1}\"".format(ref_sp, dict))

			ref_index = "{0} {1} {2}".format(ref_sp, index, dict)
			input_bai = "{0}.bai".format(input_original)

			if not comp == "bam":
				makeflow.write("\n\n{0} {1} : {2} samtools".format(input_bai, input_full_bam, input_original))
				makeflow.write("\n\t./samtools view -bS {0} > {1};".format(input_original, input_full_bam))
			else:
				makeflow.write("\n\n{0} : {1} samtools\n".format(input_bai, input_full_bam))
			makeflow.write("\t./samtools index {0}".format(input_full_bam))

			query = "{0} {1}".format(input_bai, input_full_bam)
			input_vcf = options.output
			input_idx = "{0}.idx".format(input_vcf)

			makeflow.write("\n\n{0} {1} : java_run jre {2} {3} GenomeAnalysisTK.jar".format(input_vcf, input_idx, ref_index, query))
			makeflow.write("\n\tpython java_run --inputs \"{0} -jar GenomeAnalysisTK.jar -T {1} -R {2} -I {3} -o {4} {5}\"".format(xmx, options.type, ref_sp, input_full_bam, input_vcf, ' '.join(arguments)))

		makeflow.write("\n")

	finally:
		if options.makeflow is not None:
			makeflow.close()


#MAIN

    #Parse Command Line
parser = PassThroughParser()
parser.usage="usage: %prog  [options]"

parse_req(parser)
parse_makeflow(parser)

(options, args) = parser.parse_args()

if options.verbose == True:
        log.basicConfig(format="%(levelname)s: %(message)s", level=log.DEBUG)
        log.info("Verbose output.")
else:
        log.basicConfig(format="%(levelname)s: %(message)s")

if not options.ref:
	log.error("Reference File required : --reference_sequence")
	sys.exit(4)
else:
	log.info("Reference File : {0}".format(options.ref))

if not options.input:
	log.error("Input SAM required : --input")
	sys.exit(4)
else:
	log.info("Input SAM : {0}".format(options.input))

if not options.type:
	log.error("Analysis Type required : --analysis_type")
	sys.exit(4)
else:
	log.info("Analysis Type : {0}".format(options.type))

if options.num_seq:
	num_reads = options.num_seq
else:
	num_reads = "10000"

log.info("Partition Size : {0}".format(num_reads))

path = os.getenv("PATH")
path = "{0}{1}{1}.".format(path, os.pathsep)

find_file("GenomeAnalysisTK.jar", path)

find_file("samtools", path)

find_file("picard.jar", path)

java_home = os.environ.get('JAVA_HOME', 'NULL')
if not java_home:
	log.error("JAVA_HOME not initialized")
	sys.exit(6)

if os.path.exists("./jre"):
	log.info("Using local jre directory")
else:
	os.symlink(java_home, "jre")
	log.info("JAVA_HOME linked to : jre")

write_progeny_split("gatk_split_align")
write_java("java_run")


write_makeflow(options.makeflow, options.config, args, options.ref, options.input, num_reads, log)
