"""
Cross-implementation CRAM validation pipeline.

Converts input files to CRAM using both htsjdk and samtools, decodes with both tools,
and verifies all outputs match the original. Run with:

    pixi run snakemake --cores 4

Supports two sample sheet formats:
  - Local mode (samples.tsv): 3 columns (sample, input, reference) with local file paths.
  - Remote mode (test_datasets.tsv): 9 columns with URLs; files are downloaded automatically.

Set `dataset_file` in config.yaml to choose which sample sheet to use.
See config.yaml for pipeline configuration and README.md for full documentation.
"""

import csv
from pathlib import Path
from urllib.parse import urlparse

# Directory containing this Snakefile. All config, sample sheet, and JAR paths
# are resolved relative to this directory so that --directory can be used to
# place outputs (downloads, intermediates, results) on a separate volume.
PIPELINE_DIR = Path(workflow.basedir)

configfile: str(PIPELINE_DIR / "config.yaml")


# ---------------------------------------------------------------------------
# Resolve the fat JAR, supporting glob patterns in the config
# ---------------------------------------------------------------------------

def resolve_jar(pattern):
    import glob
    # Resolve the pattern relative to the pipeline directory
    resolved = str(PIPELINE_DIR / pattern)
    matches = sorted(glob.glob(resolved))
    if matches:
        return matches[-1]  # latest version if multiple
    if Path(resolved).exists():
        return resolved
    raise FileNotFoundError(f"htsjdk JAR not found: {resolved}. Run: cd .. && ./gradlew shadowJar")


# ---------------------------------------------------------------------------
# Load sample sheet (auto-detect local vs. remote format)
# ---------------------------------------------------------------------------

def ref_key_from_url(url):
    """Derive a short, human-readable reference key from a single URL."""
    known = {
        "GRCh38_full_analysis_set_plus_decoy_hla.fa":
            "GRCh38_full_decoy_hla",
        "GRCh38_GIABv3_no_alt_analysis_set_maskedGRC_decoys_MAP2K3_KMT2C_KCNJ18.fasta.gz":
            "GRCh38_GIABv3",
        "Homo_sapiens_assembly38.fasta":
            "Homo_sapiens_assembly38",
        "hs37d5.fa.gz":
            "hs37d5",
        "Homo_sapiens_assembly38_noALT_noHLA_noDecoy_ERCC92.fasta":
            "GRCh38_noALT_ERCC92",
        "GCA_000001405.15_GRCh38_no_alt_analysis_set.fna.gz":
            "GRCh38_no_alt",
    }
    if "eutils" in url:
        return "MN908947_3"
    filename = Path(urlparse(url).path).name
    return known.get(filename, filename.replace(".fasta", "").replace(".fa", "")
                                       .replace(".fna", "").replace(".gz", ""))


def parse_reference_urls(url_string):
    """Parse a (possibly comma-separated) reference URL string.

    Each entry may carry a ``#contig_name`` suffix that instructs the pipeline
    to rename the first FASTA header to ``>contig_name`` after download.  This
    is useful when the upstream source uses a different contig name (e.g. NCBI
    accession) than the BAM header expects.

    Returns a list of dicts: ``[{"url": str, "rename": str | None}, ...]``.
    """
    parts = []
    for raw in url_string.split(","):
        raw = raw.strip()
        if "#" in raw and not raw.startswith("s3://"):
            url, rename = raw.rsplit("#", 1)
            parts.append({"url": url, "rename": rename})
        else:
            parts.append({"url": raw, "rename": None})
    return parts


def ref_key_from_urls(url_parts):
    """Derive a reference key from a parsed URL list.

    Uses the primary (first) URL's key, then appends the rename label (or
    derived key) for each additional URL.  This ensures that references with
    extra contigs get a distinct key from the base reference alone.
    """
    key = ref_key_from_url(url_parts[0]["url"])
    for part in url_parts[1:]:
        if part["rename"]:
            key += "_" + part["rename"]
        else:
            key += "_" + ref_key_from_url(part["url"])
    return key


def input_extension(url):
    """Determine BAM/CRAM extension from URL path."""
    return "cram" if urlparse(url).path.endswith(".cram") else "bam"


DATASET_FILE = config.get("dataset_file", "samples.tsv")

SAMPLES = {}        # sample -> {"input": path, "reference": path}
REFERENCES = {}     # ref_key -> {"urls": [{"url": str, "rename": str|None}, ...]}
REMOTE_INPUTS = {}  # sample -> {"url": str, "ext": str} (remote mode only)
REMOTE_MODE = False

with open(PIPELINE_DIR / DATASET_FILE) as f:
    reader = csv.DictReader(f, delimiter="\t")

    if "input_url" in reader.fieldnames:
        # Remote mode: test_datasets.tsv format.
        # Download paths are relative to the working directory (--directory).
        REMOTE_MODE = True
        for row in reader:
            sample = row["sample"]
            input_url = row["input_url"]
            ref_url = row["reference_url"]
            url_parts = parse_reference_urls(ref_url)
            ref_key = ref_key_from_urls(url_parts)
            ext = input_extension(input_url)

            SAMPLES[sample] = {
                "input": f"downloads/inputs/{sample}.{ext}",
                "reference": f"downloads/references/{ref_key}/ref.fasta",
            }
            REFERENCES[ref_key] = {"urls": url_parts}
            REMOTE_INPUTS[sample] = {"url": input_url, "ext": ext}
    else:
        # Local mode: samples.tsv format (backward compatible).
        # Paths in the TSV are relative to the pipeline directory, so resolve
        # them to absolute paths to work correctly with --directory.
        for row in reader:
            SAMPLES[row["sample"]] = {
                "input": str(PIPELINE_DIR / row["input"]),
                "reference": str(PIPELINE_DIR / row["reference"]),
            }

# Apply optional sample filter from config
selected = config.get("samples", None)
if selected:
    SAMPLES = {k: v for k, v in SAMPLES.items() if k in selected}
    REMOTE_INPUTS = {k: v for k, v in REMOTE_INPUTS.items() if k in selected}


# ---------------------------------------------------------------------------
# Pipeline configuration
# ---------------------------------------------------------------------------

PROFILES = config["profiles"]
HTSJDK_JAR = resolve_jar(config["htsjdk_jar"])
COMPARISON_MODE = config.get("comparison_mode", "strict")
MAX_DIFFS = config.get("max_diffs", 10)


# Comparisons to run per sample x profile:
#   1. htsjdk roundtrip:  original vs htsjdk_encode -> htsjdk_decode
#   2. samtools reads htsjdk:  original vs htsjdk_encode -> samtools_decode
#   3. htsjdk reads samtools:  original vs samtools_encode -> htsjdk_decode
#   4. samtools roundtrip:  original vs samtools_encode -> samtools_decode
COMPARISONS = [
    ("htsjdk_via_htsjdk",   "original"),
    ("htsjdk_via_samtools", "original"),
    ("samtools_via_htsjdk", "original"),
    ("samtools_via_samtools", "original"),
]


# ---------------------------------------------------------------------------
# Helper: reference inputs for rules that need a FASTA + .fai
# ---------------------------------------------------------------------------

def ref_path(wc):
    return SAMPLES[wc.sample]["reference"]

def ref_fai(wc):
    return SAMPLES[wc.sample]["reference"] + ".fai"

def ref_inputs(wc):
    """Return dict with ref and (in remote mode) fai paths."""
    result = {"ref": ref_path(wc)}
    if REMOTE_MODE:
        result["fai"] = ref_fai(wc)
    return result


# ---------------------------------------------------------------------------
# Target rule
# ---------------------------------------------------------------------------

rule all:
    input:
        expand("output/{sample}/{profile}/{cmp_name}_vs_{cmp_base}.result",
               sample=SAMPLES.keys(),
               profile=PROFILES,
               cmp_name=[c[0] for c in COMPARISONS],
               cmp_base=[c[1] for c in COMPARISONS]),
        "output/summary.txt"


# ---------------------------------------------------------------------------
# Download rules (remote mode only)
# ---------------------------------------------------------------------------

if REMOTE_MODE:

    rule download_input:
        """Download a remote input BAM/CRAM file."""
        output:
            "downloads/inputs/{sample}.{ext}"
        log:
            "logs/download_input.{sample}.{ext}.log"
        params:
            url=lambda wc: REMOTE_INPUTS[wc.sample]["url"],
            is_s3=lambda wc: REMOTE_INPUTS[wc.sample]["url"].startswith("s3://"),
        shell:
            """
            (
                mkdir -p $(dirname {output})
                if [[ "{params.is_s3}" == "True" ]]; then
                    aws s3 cp --no-sign-request '{params.url}' {output}
                else
                    curl -fSL --retry 3 --retry-delay 5 -o {output}.tmp '{params.url}'
                    mv {output}.tmp {output}
                fi
            ) > {log} 2>&1
            """

    rule download_reference:
        """Download, decompress, and concatenate remote reference FASTA(s).

        Supports multiple URLs per reference (comma-separated in TSV).
        Each URL may carry a #contig_name suffix to rename the FASTA header.
        """
        output:
            "downloads/references/{ref_key}/ref.fasta"
        log:
            "logs/download_reference.{ref_key}.log"
        params:
            url_parts=lambda wc: REFERENCES[wc.ref_key]["urls"],
        run:
            import os
            os.makedirs(os.path.dirname(output[0]), exist_ok=True)
            tmp = output[0] + ".tmp"
            logfile = log[0]

            open(logfile, "w").close()
            if os.path.exists(tmp):
                os.remove(tmp)

            for part in params.url_parts:
                url = part["url"]
                rename = part.get("rename")

                if url.startswith("s3://"):
                    dl_cmd = f"aws s3 cp --no-sign-request '{url}' -"
                else:
                    dl_cmd = f"curl -fSL --retry 3 --retry-delay 5 '{url}'"

                if url.endswith(".gz"):
                    dl_cmd += " | gunzip"

                if rename:
                    dl_cmd += " | sed '1s/^>.*/>%s/'" % rename

                shell(f"({dl_cmd}) >> '{tmp}' 2>> '{logfile}'")

            os.rename(tmp, output[0])

    rule index_reference:
        """Create .fai index for a downloaded reference FASTA."""
        input:
            "downloads/references/{ref_key}/ref.fasta"
        output:
            "downloads/references/{ref_key}/ref.fasta.fai"
        log:
            "logs/index_reference.{ref_key}.log"
        shell:
            "samtools faidx {input} > {log} 2>&1"


# ---------------------------------------------------------------------------
# Core pipeline rules
# ---------------------------------------------------------------------------

rule convert_to_bam:
    """Convert the original input to BAM as a common baseline for comparison."""
    input:
        unpack(ref_inputs),
        file=lambda wc: SAMPLES[wc.sample]["input"],
    output:
        temp("output/{sample}/original.bam")
    log:
        "logs/convert_to_bam.{sample}.log"
    priority: 10
    shell:
        "samtools view -b -T {input.ref} -o {output} {input.file} > {log} 2>&1"


rule htsjdk_encode:
    """Encode to CRAM using htsjdk CramConverter."""
    input:
        unpack(ref_inputs),
        bam="output/{sample}/original.bam",
    output:
        temp("output/{sample}/{profile}/htsjdk.cram")
    log:
        "logs/htsjdk_encode.{sample}.{profile}.log"
    priority: 20
    retries: 3
    resources:
        mem_mb=lambda wc, attempt: 4096 + (attempt - 1) * 2048,
    params:
        jar=HTSJDK_JAR,
    shell:
        "java -Xmx{resources.mem_mb}m -cp {params.jar} htsjdk.samtools.cram.CramConverter "
        "{input.bam} {output} --reference {input.ref} --profile {wildcards.profile} "
        "> {log} 2>&1"


rule samtools_encode:
    """Encode to CRAM using samtools."""
    input:
        unpack(ref_inputs),
        bam="output/{sample}/original.bam",
    output:
        temp("output/{sample}/{profile}/samtools.cram")
    log:
        "logs/samtools_encode.{sample}.{profile}.log"
    priority: 20
    shell:
        "samtools view -C -T {input.ref} "
        "--output-fmt-option version=3.1 --output-fmt-option {wildcards.profile} "
        "-o {output} {input.bam} > {log} 2>&1"


rule htsjdk_decode:
    """Decode CRAM to BAM using htsjdk CramConverter."""
    input:
        unpack(ref_inputs),
        cram="output/{sample}/{profile}/{encoder}.cram",
    output:
        temp("output/{sample}/{profile}/{encoder}_via_htsjdk.bam")
    log:
        "logs/htsjdk_decode.{sample}.{profile}.{encoder}.log"
    priority: 30
    retries: 3
    resources:
        mem_mb=lambda wc, attempt: 4096 + (attempt - 1) * 2048,
    params:
        jar=HTSJDK_JAR,
    shell:
        "java -Xmx{resources.mem_mb}m -cp {params.jar} htsjdk.samtools.cram.CramConverter "
        "{input.cram} {output} --reference {input.ref} "
        "> {log} 2>&1"


rule samtools_decode:
    """Decode CRAM to BAM using samtools."""
    input:
        unpack(ref_inputs),
        cram="output/{sample}/{profile}/{encoder}.cram",
    output:
        temp("output/{sample}/{profile}/{encoder}_via_samtools.bam")
    log:
        "logs/samtools_decode.{sample}.{profile}.{encoder}.log"
    priority: 30
    shell:
        "samtools view -b -T {input.ref} -o {output} {input.cram} > {log} 2>&1"


rule compare:
    """Compare a decoded BAM against the original using CramComparison."""
    input:
        unpack(ref_inputs),
        test="output/{sample}/{profile}/{cmp_name}.bam",
        original="output/{sample}/original.bam",
    output:
        "output/{sample}/{profile}/{cmp_name}_vs_{cmp_base}.result"
    log:
        "logs/compare.{sample}.{profile}.{cmp_name}.{cmp_base}.log"
    priority: 40
    retries: 3
    resources:
        mem_mb=lambda wc, attempt: 4096 + (attempt - 1) * 2048,
    params:
        jar=HTSJDK_JAR,
        mode="--lenient" if COMPARISON_MODE == "lenient" else "",
        max_diffs=MAX_DIFFS,
    shell:
        "java -Xmx{resources.mem_mb}m -cp {params.jar} htsjdk.samtools.cram.CramComparison "
        "{input.test} {input.original} "
        "--reference {input.ref} --output {output} {params.mode} --max-diffs {params.max_diffs} "
        "> {log} 2>&1"


rule summary:
    """Aggregate all comparison results into a summary."""
    input:
        expand("output/{sample}/{profile}/{cmp_name}_vs_{cmp_base}.result",
               sample=SAMPLES.keys(),
               profile=PROFILES,
               cmp_name=[c[0] for c in COMPARISONS],
               cmp_base=[c[1] for c in COMPARISONS]),
    output:
        "output/summary.txt"
    log:
        "logs/summary.log"
    priority: 50
    run:
        import shutil
        passed = 0
        failed = 0
        with open(output[0], "w") as out:
            for f in sorted(input):
                path = Path(f)
                with open(f) as result:
                    content = result.read().strip()
                    status = "PASS" if "OK:" in content else "FAIL"
                    if status == "PASS":
                        passed += 1
                    else:
                        failed += 1
                    line = f"{status}  {path.relative_to('output')}"
                    out.write(line + "\n")
                    print(line)
            summary_line = f"\n{passed} passed, {failed} failed out of {passed + failed} comparisons"
            out.write(summary_line + "\n")
            print(summary_line)
        shutil.copy(output[0], log[0])
