#! /usr/bin/env python3
"""correct_insar_with_gnss — correct InSAR unwrapped phase using GNSS LOS.

Python port of csh correct_insar_with_gnss.csh (DTS 2021; K. Gualandi 2021).
Reads gnss_los.rad (range, azimuth, LOS_mm) and applies a smoothed
GNSS-vs-InSAR difference field as a correction to phase.grd.

Usage:  correct_insar_with_gnss master.PRM phase.grd gnsslos.rad filter_wavelength

Output: gnss_corrected_intf.grd (corrected phase in radians).
"""
import math
import os
import subprocess
import sys
from gmtsar_lib import run, grep_value


def _capture(cmd):
    return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE,
                          check=False).stdout.decode("utf-8").strip()


def correct_insar_with_gnss():
    if len(sys.argv) != 5:
        sys.exit(
            "Usage: correct_insar_with_gnss master.PRM phase.grd "
            "gnsslos.rad filter_wavelength\n"
            "  filter_wavelength: in meters (e.g. 40000 for 40 km)"
        )
    prm, insar, gnss = sys.argv[1], sys.argv[2], sys.argv[3]
    filterw = float(sys.argv[4])

    run("gmt set IO_NC4_CHUNK_SIZE classic")

    # Range / azimuth pixel sizes from PRM (40° look angle baked into the
    # range constant 1.556 = 1/sin(40°)).
    rng_samp_rate = float(grep_value(prm, "rng_samp_rate", 3))
    rng_px_size = 1.556 * 299792458.0 / rng_samp_rate / 2.0
    vel = float(grep_value(prm, "SC_vel", 3))
    PRF = float(grep_value(prm, "PRF", 3))
    azi_px_size = vel / PRF

    # Pixel decimation factors from the input grid.
    x_inc = int(float(_capture(f"gmt grdinfo {insar} | grep x_inc | awk 'NR == 1 {{print $7}}'")) + 0.5)
    y_inc = int(float(_capture(f"gmt grdinfo {insar} | grep y_inc | awk 'NR == 1 {{print $7}}'")) + 0.5)
    dx = x_inc * rng_px_size
    dy = y_inc * azi_px_size

    print("Preparing grid characteristics...")
    # Template: wavelength / 16 pixel size
    ndx = int(filterw / dx / 16)
    ndy = int(filterw / dy / 16)
    ndx2 = x_inc * ndx
    ndy2 = y_inc * ndy
    run(f"gmt grdsample {insar} -I{ndx2}/{ndy2} -Gtmp.grd")

    print("Computing the correction grid...")
    # 1 km low-pass filter, kernel side = (1000/dx) rounded to odd
    def _odd(n):
        n = int(n)
        return n + 1 if n % 2 == 0 else n
    fsx = _odd(1000 / dx)
    fsy = _odd(1000 / dy)
    run(f"gmt grdfilter {insar} -Dp -Fg{fsx}/{fsy} -Gtmp_filt.grd")

    wave = float(grep_value(prm, "wavelength", 3))
    # Convert GNSS LOS (mm) → phase (rad): φ = -4π·LOS / (λ·1000)
    run(f"awk '{{print $1, $2, $3*(4.0*3.141592653)/(-{wave}*1000.0) }}' < {gnss} | "
        f"gmt grdtrack -Gtmp_filt.grd | awk '{{print $1,$2,($4-$3)}}' | "
        f"grep -v 'nan' > ins-gps_diff.rad")

    run("gmt blockmedian ins-gps_diff.rad -Rtmp.grd > tmp.rad")
    run("gmt surface tmp.rad -Rtmp.grd -T0.1 -Gtmp1.grd")

    print("Applying Gaussian filter to correction grid (~1 min)...")
    # 17-pixel (= 16+1) Gaussian filter — wavelength in template-pixel units.
    run("gmt grdfilter tmp1.grd -Dp -Fg17 -Gtmp2.grd")

    print("...using grdsample to upsample the correction grid...")
    incre = _capture(f"gmt grdinfo -Cn {insar} | awk '{{printf \"%d+n/%d+n \\n\", $9, $10}}'")
    run(f"gmt grdsample tmp2.grd -I{incre} -Gcorrection.grd")

    print("Correcting the interferogram...")
    run(f"gmt grdmath {insar} correction.grd SUB = gnss_corrected_intf.grd")

    if os.path.isfile("gnss_corrected_intf.grd"):
        print("Interferogram corrected and stored as gnss_corrected_intf.grd!")
    else:
        sys.exit("Something went wrong — intf not corrected")
    run("rm -f tmp*")


if __name__ == "__main__":
    correct_insar_with_gnss()
