#! /usr/bin/env python3
"""pre_proc_batch — preprocess a stack of images against a common master.

Python port of csh pre_proc_batch.csh (X. Tong 2010 + D. Sandwell 2010/2011,
A. Hogrelius ENVI_SLC 2017). Reads data.in (one master + N aligned image
names) and produces synchronized near_range / earth_radius / fd1 /
num_patches across all PRMs, plus a baseline_table.dat + baseline plot.

Usage:  pre_proc_batch SAT data.in batch.config

Supported SATs: ALOS, ENVI, ENVI_SLC, ERS, TSX
"""
import os
import subprocess
import sys

from gmtsar_lib import run, grep_value, check_file_report


_SUPPORTED_SATS = {"ALOS", "ENVI", "ENVI_SLC", "ERS", "TSX"}


def _get_config(path, key, default=""):
    """grep <key> <path> | awk '{print $3}'."""
    with open(path) as f:
        for line in f:
            if key in line:
                parts = line.split()
                if len(parts) >= 3:
                    return parts[2]
    return default


def _ers_envi_commandline(near_range, earth_radius, num_patches, fd):
    """Build positional commandline for ERS/ENVI _pre_process binaries."""
    parts = [near_range or "0", earth_radius or "0", num_patches or "0"]
    if fd:
        parts.append(fd)
    return " ".join(parts)


def _alos_commandline(near_range, earth_radius, num_patches, fd):
    """Build flag-style commandline for ALOS_pre_process."""
    parts = []
    if earth_radius:
        parts += ["-radius", earth_radius]
    if near_range:
        parts += ["-near", near_range]
    if num_patches:
        parts += ["-npatch", num_patches]
    if fd:
        parts += ["-fd1", fd]
    return " ".join(parts)


def _baseline_pair(master_prm, aligned_prm, dat_file, gmt_file, append):
    """Run baseline_table twice (default + GMT mode) and concatenate output."""
    redir = ">>" if append else ">"
    run(f"baseline_table {master_prm} {aligned_prm} {redir} {dat_file}")
    run(f"baseline_table {master_prm} {aligned_prm} GMT {redir} {gmt_file}")


def _preprocess_master(SAT, master, commandline, earth_radius):
    """Per-SAT master preprocessing. Returns (NEAR, RAD, FD1, npatch,
    rng_samp_rate_m, master_prm_path) — the canonical values to feed
    aligned preprocessing."""
    if SAT in ("ERS", "ENVI"):
        if not (check_file_report(f"{master}.raw") and check_file_report(f"{master}.LED")):
            run(f"{SAT}_pre_process {master} {commandline}")
        prm = f"{master}.PRM"
        _baseline_pair(prm, prm, "baseline_table.dat", "table.gmt", append=False)
        rsr = ""
    elif SAT == "ENVI_SLC":
        er = earth_radius or "0"
        if not (check_file_report(f"{master}.SLC") and check_file_report(f"{master}.LED")):
            run(f"ENVI_SLC_pre_process {master} {er}")
        prm = f"{master}.PRM"
        _baseline_pair(prm, prm, "baseline_table.dat", "table.gmt", append=False)
        rsr = ""
    elif SAT == "ALOS":
        prm = f"IMG-HH-{master}.PRM"
        if not (check_file_report(f"IMG-HH-{master}.raw") and check_file_report(prm)):
            run(f"ALOS_pre_process IMG-HH-{master} LED-{master} {commandline}")
        rsr = grep_value(prm, "rng_samp_rate", 3)
        _baseline_pair(prm, prm, "baseline_table.dat", "table.gmt", append=False)
    elif SAT == "TSX":
        prm = f"{master}.PRM"
        if not (check_file_report(f"{master}.raw") and check_file_report(prm)):
            run(f"make_slc_tsx {master}.xml {master}.cos {master}")
            run(f"cp {master}.PRM {master}.PRM0")
            run(f"calc_dop_orb {master}.PRM0 {master}.log 0 0")
            run(f"cat {master}.PRM0 {master}.log > {master}.PRM")
            with open(f"{master}.PRM", "a") as f:
                f.write("fdd1                    = 0\n")
                f.write("fddd1                   = 0\n")
        rsr = grep_value(prm, "rng_samp_rate", 3)
        _baseline_pair(prm, prm, "baseline_table.dat", "table.gmt", append=False)
    else:
        sys.exit(f"unsupported SAT: {SAT}")

    NEAR = grep_value(prm, "near_range", 3)
    RAD  = grep_value(prm, "earth_radius", 3)
    FD1  = grep_value(prm, "fd1", 3)
    npatch = grep_value(prm, "num_patch", 3)
    return NEAR, RAD, FD1, npatch, rsr, prm


def _preprocess_aligned(SAT, master, aligned_token, NEAR, RAD, FD1, npatch, rsr_m):
    """Per-SAT aligned preprocessing. aligned_token is the raw line from data.in."""
    if SAT in ("ERS", "ENVI"):
        aligned = aligned_token
        if not (check_file_report(f"{aligned}.raw") and check_file_report(f"{aligned}.LED")):
            run(f"{SAT}_pre_process {aligned} {NEAR} {RAD} {npatch} {FD1}")
        _baseline_pair(f"{master}.PRM", f"{aligned}.PRM",
                       "baseline_table.dat", "table.gmt", append=True)
    elif SAT == "ENVI_SLC":
        aligned = aligned_token
        if not (check_file_report(f"{aligned}.SLC") and check_file_report(f"{aligned}.LED")):
            run(f"ENVI_SLC_pre_process {aligned} {RAD}")
        _baseline_pair(f"{master}.PRM", f"{aligned}.PRM",
                       "baseline_table.dat", "table.gmt", append=True)
    elif SAT == "TSX":
        aligned = aligned_token
        if not (check_file_report(f"{aligned}.SLC") and check_file_report(f"{aligned}.LED")):
            run(f"make_slc_tsx {aligned}.xml {aligned}.cos {aligned}")
        rad_now = grep_value(f"{master}.PRM", "earth_radius", 3)
        run(f"cp {aligned}.PRM {aligned}.PRM0")
        run(f"calc_dop_orb {aligned}.PRM0 {aligned}.log {rad_now} 0")
        run(f"cat {aligned}.PRM0 {aligned}.log > {aligned}.PRM")
        with open(f"{aligned}.PRM", "a") as f:
            f.write("fdd1                    = 0\n")
            f.write("fddd1                   = 0\n")
        _baseline_pair(f"{master}.PRM", f"{aligned}.PRM",
                       "baseline_table.dat", "table.gmt", append=True)
    elif SAT == "ALOS":
        # ALOS data.in row has the full IMG-HH-... name; pull the bare master tag.
        aligned = aligned_token
        if aligned.startswith("IMG-HH-"):
            aligned = aligned[7:]
        aligned_prm = f"IMG-HH-{aligned}.PRM"
        if not (check_file_report(f"IMG-HH-{aligned}.raw") and check_file_report(aligned_prm)):
            run(f"ALOS_pre_process IMG-HH-{aligned} LED-{aligned} "
                f"-fd1 {FD1} -near {NEAR} -radius {RAD} -npatch {npatch}")
        rsr_s = grep_value(aligned_prm, "rng_samp_rate", 3)
        try:
            ratio = float(rsr_m) / float(rsr_s)
        except (ValueError, ZeroDivisionError):
            sys.exit(f"ALOS rng_samp_rate undefined for {aligned}")
        if abs(ratio - 1.0) < 0.05:
            print(f"Same range sampling rate ({rsr_m}) for master and aligned")
            _baseline_pair(f"IMG-HH-{master}.PRM", aligned_prm,
                           "baseline_table.dat", "table.gmt", append=True)
        elif abs(ratio - 2.0) < 0.05:
            print("Convert the aligned image from FBD to FBS mode")
            run(f"ALOS_fbd2fbs IMG-HH-{aligned}.PRM IMG-HH-{aligned}_FBS.PRM")
            _baseline_pair(f"IMG-HH-{master}.PRM", f"IMG-HH-{aligned}_FBS.PRM",
                           "baseline_table.dat", "table.gmt", append=True)
            run(f"mv IMG-HH-{aligned}_FBS.PRM IMG-HH-{aligned}.PRM")
            run(f"update_PRM IMG-HH-{aligned}.PRM input_file IMG-HH-{aligned}.raw")
            run(f"mv IMG-HH-{aligned}_FBS.raw IMG-HH-{aligned}.raw")
        elif abs(ratio - 0.5) < 0.05:
            sys.exit("Use FBS mode image as master")
        else:
            sys.exit(f"Incompatible rng_samp_rate ratio: master={rsr_m} aligned={rsr_s}")


def _plot_baseline(SAT):
    """Plot baseline_table.dat → stacktable_all.ps (then psconvert → pdf)."""
    epoch_expr = {
        "ERS":      "1992 + $1/365.25",
        "ENVI":     "1992 + $1/365.25",
        "ENVI_SLC": "1992 + $1/365.25",
        "TSX":      "2007 + $1/365.25",
        "ALOS":     "2006.5 + ($1-181)/365.25",
    }[SAT]
    run(f"awk '{{print {epoch_expr}, $2, $7}}' < table.gmt > text")
    region_raw = subprocess.run(
        ["gmt", "gmtinfo", "text", "-C"], check=False, stdout=subprocess.PIPE
    ).stdout.decode("utf-8").split()
    if len(region_raw) >= 4:
        x0, x1, y0, y1 = (
            float(region_raw[0]) - 0.5, float(region_raw[1]) + 0.5,
            float(region_raw[2]) - 500, float(region_raw[3]) + 500
        )
        R = f"-R{x0}/{x1}/{y0}/{y1}"
    else:
        R = "-R0/1/-1000/1000"
    run(f"gmt pstext text -JX8.8i/6.8i {R} -D0.2/0.2 -X1.5i -Y1i "
        f"-K -N -F+f8,Helvetica+j5 > stacktable_all.ps")
    run("awk '{print $1,$2}' < text > text2")
    run(f"gmt psxy text2 -Sp0.2c -G0 -R -JX "
        f"-Ba1:\"year\":/a200g00f100:\"baseline (m)\":WSen -O >> stacktable_all.ps")
    run("rm -f text text2 *.PRM0")


def pre_proc_batch():
    if len(sys.argv) != 4:
        sys.exit(
            "Usage: pre_proc_batch SAT data.in batch.config\n"
            "  SAT: ALOS / ENVI / ENVI_SLC / ERS / TSX\n"
            "  data.in: line 1 = master, line 2+ = aligned names."
        )
    SAT, data_in, config = sys.argv[1], sys.argv[2], sys.argv[3]
    if SAT not in _SUPPORTED_SATS:
        sys.exit(f"SAT must be one of: {' '.join(sorted(_SUPPORTED_SATS))}")

    num_patches  = _get_config(config, "num_patches",  "")
    near_range   = _get_config(config, "near_range",   "")
    earth_radius = _get_config(config, "earth_radius", "")
    fd           = _get_config(config, "fd1",          "")

    if SAT in ("ERS", "ENVI"):
        commandline = _ers_envi_commandline(near_range, earth_radius, num_patches, fd)
    elif SAT == "ALOS":
        commandline = _alos_commandline(near_range, earth_radius, num_patches, fd)
    else:
        commandline = ""

    print(f"commandline: {commandline}\nSTART PREPROCESS A STACK OF IMAGES\n")
    print("preprocess master image")
    with open(data_in) as f:
        names = [ln.strip() for ln in f if ln.strip()]
    if len(names) < 1:
        sys.exit("pre_proc_batch: data.in is empty")

    line1 = names[0].split()[0]
    if SAT == "ALOS" and line1.startswith("IMG-HH-"):
        master = line1[7:]  # strip leading IMG-HH-
    else:
        master = line1

    NEAR, RAD, FD1, npatch, rsr_m, _master_prm = _preprocess_master(
        SAT, master, commandline, earth_radius
    )

    for line2 in names[1:]:
        print("preprocess aligned image")
        _preprocess_aligned(SAT, master, line2.split()[0],
                            NEAR, RAD, FD1, npatch, rsr_m)

    _plot_baseline(SAT)
    print("\nEND PREPROCESS A STACK OF IMAGES")


if __name__ == "__main__":
    pre_proc_batch()
