#!/usr/bin/env python3
"""xcorr_py — vectorized Python re-implementation of gmtsar's xcorr (freq mode).

Drop-in alternative to the C `xcorr` binary for SAR image coregistration:
batches all N x M search-window FFTs into one call to scipy.fft.fft2 with
workers=-1, so the FFT lib threads across all CPU cores internally. No
explicit threading code on our side.

Usage matches the C binary's most common path:
    xcorr_py master.PRM aligned.PRM [-nx NX] [-ny NY] [-xsearch XS] [-ysearch YS]

Output: freq_xcorr.dat — same format as C xcorr:
    x_loc  y_loc  x_offset  y_offset  snr     (one row per location)

Currently supports: complex int16 SLC data (RS2, ERS, ENVI, TSX, CSK, ALOS, S1).
Not yet supported: -real mode, range pre-interpolation (-range_interp).
"""
import argparse, os, sys
import numpy as np
from scipy.fft import fft2, ifft2


# -------------------- PRM parsing --------------------

def parse_prm(path):
    """Parse gmtsar PRM file into a dict. Values are kept as strings; cast on use."""
    prm = {}
    with open(path) as f:
        for line in f:
            if '=' not in line: continue
            k, _, v = line.partition('=')
            prm[k.strip()] = v.strip()
    return prm

def _int(prm, k, default=0):    return int(prm.get(k, default))
def _flt(prm, k, default=0.0):  return float(prm.get(k, default))


# -------------------- SLC reader --------------------

def read_slc(prm, prm_path):
    """Memory-map the SLC binary referenced by a PRM file. Returns (re, im) as int16
    arrays of shape (num_lines, num_rng_bins). gmtsar SLC layout: int16 real + int16
    imag, interleaved, row-major (azimuth-line major)."""
    slc_name = prm['SLC_file']
    slc_path = slc_name if os.path.isabs(slc_name) else os.path.join(os.path.dirname(prm_path), slc_name)
    rng_bins  = _int(prm, 'num_rng_bins')
    num_lines = _int(prm, 'num_lines')          # may be larger than file; use file size
    bytes_per_line = _int(prm, 'bytes_per_line')
    if bytes_per_line == 0:
        bytes_per_line = rng_bins * 4           # int16 complex = 4 B/sample
    file_lines = os.path.getsize(slc_path) // bytes_per_line
    n = min(num_lines, file_lines) if num_lines else file_lines
    mm = np.memmap(slc_path, dtype=np.int16, mode='r', shape=(n, rng_bins, 2))
    return mm[..., 0], mm[..., 1], n, rng_bins


# -------------------- core xcorr --------------------

def grid_locations(nx, ny, npx, npy, xmax, ymax, xsearch, ysearch):
    """Build (n_loc, 2) grid of (x, y) center coords for search-window placement,
    matching the C xcorr's layout (uniform grid inside the safe region)."""
    x_margin = npx // 2 + xsearch
    y_margin = npy // 2 + ysearch
    xs = np.linspace(x_margin, xmax - x_margin, nx, dtype=int)
    ys = np.linspace(y_margin, ymax - y_margin, ny, dtype=int)
    yy, xx = np.meshgrid(ys, xs, indexing='ij')
    return np.stack([xx.ravel(), yy.ravel()], axis=1)   # (n_loc, 2)

def extract_patches(re, im, locs, npy, npx, dy=0, dx=0):
    """Extract (n_loc, npy, npx) complex patches centered at locs (+ optional shift).
    Uses fancy indexing — fast enough for the patch counts we hit (typically <=1000)."""
    n = len(locs)
    out = np.empty((n, npy, npx), dtype=np.complex64)
    half_y, half_x = npy // 2, npx // 2
    for k, (x, y) in enumerate(locs):
        ys, xs = y + dy - half_y, x + dx - half_x
        r = re[ys:ys+npy, xs:xs+npx].astype(np.float32)
        i = im[ys:ys+npy, xs:xs+npx].astype(np.float32)
        out[k] = r + 1j * i
    return out

def parabolic_subpixel(corr_neighbors):
    """3-point parabolic peak refinement along both axes.
    corr_neighbors: (n_loc, 3, 3) magnitude samples around the peak.
    Returns (n_loc, 2) sub-pixel offsets (dx, dy) in pixels, range [-0.5, 0.5]."""
    # along y: rows 0,1,2 at center column (col 1)
    a, b, c = corr_neighbors[:, 0, 1], corr_neighbors[:, 1, 1], corr_neighbors[:, 2, 1]
    denom_y = 2 * (a - 2*b + c)
    dy = np.where(np.abs(denom_y) > 1e-12, (a - c) / denom_y, 0.0)
    # along x: cols 0,1,2 at center row (row 1)
    a, b, c = corr_neighbors[:, 1, 0], corr_neighbors[:, 1, 1], corr_neighbors[:, 1, 2]
    denom_x = 2 * (a - 2*b + c)
    dx = np.where(np.abs(denom_x) > 1e-12, (a - c) / denom_x, 0.0)
    return np.stack([dx, dy], axis=1)

def cross_correlate_batch(master, aligned, locs, npx, npy, xsearch, ysearch,
                          rshift, ashift):
    """Vectorized batched FFT cross-correlation at all locations.
    Returns (offsets, snr) where offsets is (n_loc, 2) in pixels (range, azimuth)."""
    n = len(locs)
    re_m, im_m = master
    re_a, im_a = aligned

    # patch size for both: equal so we can batch.
    patch_y, patch_x = npy + 2 * ysearch, npx + 2 * xsearch

    # master: centered patches
    Cm = extract_patches(re_m, im_m, locs, patch_y, patch_x)
    # aligned: shifted by (rshift, ashift) per the PRM
    Ca = extract_patches(re_a, im_a, locs, patch_y, patch_x, dx=rshift, dy=ashift)

    # Detrend (subtract mean) — improves correlation peak conditioning
    Cm -= Cm.mean(axis=(-2, -1), keepdims=True)
    Ca -= Ca.mean(axis=(-2, -1), keepdims=True)

    # The big payoff: one batched FFT call across N patches, threaded internally.
    Fm = fft2(Cm, workers=-1)
    Fa = fft2(Ca, workers=-1)
    corr = np.abs(ifft2(Fm * Fa.conj(), workers=-1))    # (n, patch_y, patch_x)

    # Peak per location (vectorized argmax)
    flat_argmax = corr.reshape(n, -1).argmax(axis=1)
    py, px = np.unravel_index(flat_argmax, (patch_y, patch_x))

    # SNR proxy: peak magnitude / mean magnitude (vectorized)
    peak_vals = corr.reshape(n, -1).max(axis=1)
    mean_vals = corr.reshape(n, -1).mean(axis=1)
    snr = np.where(mean_vals > 0, peak_vals / mean_vals, 0.0)

    # Sub-pixel refinement: gather 3x3 neighborhood around each peak (clamp at edges)
    py_c = np.clip(py, 1, patch_y - 2)
    px_c = np.clip(px, 1, patch_x - 2)
    n_idx = np.arange(n)
    neighborhood = np.stack([
        corr[n_idx, py_c + dy, :][np.arange(n)[:, None], px_c[:, None] + np.array([-1, 0, 1])]
        for dy in (-1, 0, 1)
    ], axis=1)   # (n, 3, 3)
    sub = parabolic_subpixel(neighborhood)               # (n, 2) = (dx, dy)

    # FFT cross-corr peak location → spatial offset:
    # ifft2 of F_m * F_a.conj() puts the offset at (py, px); offsets wrap mod patch
    # size, so values > patch/2 represent negative offsets.
    off_x = ((px + sub[:, 0] + patch_x / 2) % patch_x) - patch_x / 2
    off_y = ((py + sub[:, 1] + patch_y / 2) % patch_y) - patch_y / 2
    return np.stack([off_x, off_y], axis=1), snr


# -------------------- driver --------------------

def main():
    ap = argparse.ArgumentParser(description=__doc__.splitlines()[0])
    ap.add_argument('master_prm')
    ap.add_argument('aligned_prm')
    ap.add_argument('-nx', type=int, default=20)
    ap.add_argument('-ny', type=int, default=50)
    ap.add_argument('-xsearch', type=int, default=64)
    ap.add_argument('-ysearch', type=int, default=64)
    ap.add_argument('-npx', type=int, default=64, help='master patch x size (pre-FFT)')
    ap.add_argument('-npy', type=int, default=64, help='master patch y size (pre-FFT)')
    ap.add_argument('-noshift', action='store_true', help='ignore ashift/rshift in PRM')
    ap.add_argument('-out', default='freq_xcorr.dat')
    args = ap.parse_args()

    m_prm = parse_prm(args.master_prm)
    a_prm = parse_prm(args.aligned_prm)

    re_m, im_m, nlines_m, nrng_m = read_slc(m_prm, args.master_prm)
    re_a, im_a, nlines_a, nrng_a = read_slc(a_prm, args.aligned_prm)

    rshift = 0 if args.noshift else _int(a_prm, 'rshift')
    ashift = 0 if args.noshift else _int(a_prm, 'ashift')

    locs = grid_locations(args.nx, args.ny, args.npx, args.npy,
                          min(nrng_m, nrng_a), min(nlines_m, nlines_a),
                          args.xsearch, args.ysearch)
    print(f'xcorr_py: {len(locs)} locations, patch=({args.npy},{args.npx}), '
          f'search=({args.ysearch},{args.xsearch}), rshift={rshift} ashift={ashift}',
          file=sys.stderr)

    offsets, snr = cross_correlate_batch(
        (re_m, im_m), (re_a, im_a), locs,
        args.npx, args.npy, args.xsearch, args.ysearch, rshift, ashift,
    )

    # Write in C xcorr's fitoffset-compatible format: x y xoff yoff snr
    with open(args.out, 'w') as f:
        for (x, y), (ox, oy), s in zip(locs, offsets, snr):
            f.write(f'{x} {y} {ox + rshift:.6f} {oy + ashift:.6f} {s:.4f}\n')
    print(f'wrote {len(locs)} rows to {args.out}', file=sys.stderr)


if __name__ == '__main__':
    main()
