#!/usr/bin/env python3
"""
merge_gpx.py - Align and merge two GPX files of the same activity.

Designed for the case where one device records full telemetry (Garmin: HR,
temperature in TrackPointExtension) and another records lat/lon/ele only
(e.g., DJI Avinox e-bike system), and you want one unified time series.

Alignment strategy
------------------
- Time is the join key. Both devices emit UTC timestamps from GPS, so they
  are aligned to within <1 s in practice.
- The 'canonical' file (default: garmin) provides the time spine and primary
  track geometry. Each canonical point is matched to the nearest point in
  the other file within --tolerance seconds via binary search.
- Unmatched canonical points keep their own data; the other-source channels
  are written as empty in that row.

Outputs
-------
- <prefix>.gpx - Well-formed GPX 1.1 with the canonical track and HR/temp
  (HR/temp are preferred from canonical, falling back to other if missing).
- <prefix>.csv - Wide-form CSV with every channel from both sources keyed
  by canonical time, plus the per-row time offset between the two sources.

Stated assumptions (these affect correctness)
---------------------------------------------
1. Both files contain UTC timestamps. Verified by trailing 'Z' on times.
2. Sample rates are ~1 Hz. Verified for both supplied files.
3. The two files are the same activity (overlapping time windows).
4. Device clocks are GPS-synced (<1 s drift). Standard behavior.
5. Some files (e.g., Avinox) declare XML prefixes without binding them.
   lxml's recover parser handles this; the script does not "fix" the file
   on disk - it just reads what's parseable.

Usage
-----
  python merge_gpx.py garmin.gpx avinox.gpx --out merged
  python merge_gpx.py garmin.gpx avinox.gpx --out merged --canonical avinox
  python merge_gpx.py a.gpx b.gpx --tolerance 2.0 --name "Your Name here"
"""

import argparse
import csv
from bisect import bisect_left
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional

from lxml import etree


# -----------------------------------------------------------------------------
# Data model
# -----------------------------------------------------------------------------

@dataclass
class TrackPoint:
    time: datetime
    lat: float
    lon: float
    ele: Optional[float] = None
    hr: Optional[int] = None
    temp_c: Optional[float] = None


# -----------------------------------------------------------------------------
# Parsing
# -----------------------------------------------------------------------------

def _strip_ns(elem) -> None:
    """Strip XML namespaces in-place so we can use simple tag names."""
    for e in elem.iter():
        if isinstance(e.tag, str) and "}" in e.tag:
            e.tag = e.tag.split("}", 1)[1]


def parse_gpx(path: Path) -> list[TrackPoint]:
    """Read a GPX file robustly (handles undeclared prefixes via recover=True)."""
    parser = etree.XMLParser(recover=True)
    tree = etree.parse(str(path), parser)
    root = tree.getroot()
    _strip_ns(root)

    points: list[TrackPoint] = []
    for trkpt in root.findall(".//trkpt"):
        try:
            lat = float(trkpt.attrib["lat"])
            lon = float(trkpt.attrib["lon"])
        except (KeyError, ValueError):
            continue

        time_el = trkpt.find("time")
        if time_el is None or not time_el.text:
            continue
        try:
            t = datetime.fromisoformat(time_el.text.replace("Z", "+00:00"))
        except ValueError:
            continue
        if t.tzinfo is None:
            t = t.replace(tzinfo=timezone.utc)

        ele = None
        ele_el = trkpt.find("ele")
        if ele_el is not None and ele_el.text:
            try:
                ele = float(ele_el.text)
            except ValueError:
                pass

        hr = None
        temp = None
        ext = trkpt.find("extensions")
        if ext is not None:
            for child in ext.iter():
                if not isinstance(child.tag, str) or not child.text:
                    continue
                tag = child.tag.lower()
                if tag.endswith("hr"):
                    try:
                        hr = int(float(child.text))
                    except ValueError:
                        pass
                elif tag.endswith("atemp") or tag.endswith("temp"):
                    try:
                        temp = float(child.text)
                    except ValueError:
                        pass

        points.append(TrackPoint(time=t, lat=lat, lon=lon, ele=ele, hr=hr, temp_c=temp))

    points.sort(key=lambda p: p.time)
    return points


# -----------------------------------------------------------------------------
# Alignment
# -----------------------------------------------------------------------------

def align(canonical: list[TrackPoint],
          other: list[TrackPoint],
          tolerance_s: float) -> list[tuple[TrackPoint, Optional[TrackPoint]]]:
    """Pair each canonical point with the nearest 'other' point within tolerance."""
    other_times = [p.time for p in other]
    pairs: list[tuple[TrackPoint, Optional[TrackPoint]]] = []
    for c in canonical:
        idx = bisect_left(other_times, c.time)
        candidates = []
        if idx > 0:
            candidates.append(other[idx - 1])
        if idx < len(other):
            candidates.append(other[idx])
        best = None
        best_dt = None
        for cand in candidates:
            dt = abs((cand.time - c.time).total_seconds())
            if dt <= tolerance_s and (best_dt is None or dt < best_dt):
                best = cand
                best_dt = dt
        pairs.append((c, best))
    return pairs


# -----------------------------------------------------------------------------
# Output
# -----------------------------------------------------------------------------

def _iso_z(t: datetime) -> str:
    return t.astimezone(timezone.utc).isoformat().replace("+00:00", "Z")


def write_csv(pairs: list[tuple[TrackPoint, Optional[TrackPoint]]],
              path: Path,
              canon_label: str,
              other_label: str) -> None:
    fields = [
        "time_utc",
        f"{canon_label}_lat", f"{canon_label}_lon", f"{canon_label}_ele",
        f"{canon_label}_hr", f"{canon_label}_temp_c",
        f"{other_label}_lat", f"{other_label}_lon", f"{other_label}_ele",
        f"{other_label}_hr", f"{other_label}_temp_c",
        "time_offset_s",
    ]
    with open(path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(fields)
        for c, o in pairs:
            row = [
                _iso_z(c.time),
                c.lat, c.lon, c.ele, c.hr, c.temp_c,
            ]
            if o is not None:
                row += [
                    o.lat, o.lon, o.ele, o.hr, o.temp_c,
                    round((o.time - c.time).total_seconds(), 3),
                ]
            else:
                row += [None, None, None, None, None, None]
            w.writerow(row)


def write_gpx(pairs: list[tuple[TrackPoint, Optional[TrackPoint]]],
              path: Path,
              name: str) -> None:
    """Write a well-formed GPX 1.1 with HR/temp where available."""
    GPX_NS = "http://www.topografix.com/GPX/1/1"
    TPX_NS = "http://www.garmin.com/xmlschemas/TrackPointExtension/v1"
    XSI_NS = "http://www.w3.org/2001/XMLSchema-instance"

    nsmap = {None: GPX_NS, "ns3": TPX_NS, "xsi": XSI_NS}
    gpx = etree.Element("gpx", nsmap=nsmap, version="1.1", creator="merge_gpx.py")
    gpx.set(
        "{%s}schemaLocation" % XSI_NS,
        "http://www.topografix.com/GPX/1/1 http://www.topografix.com/GPX/11.xsd",
    )

    meta = etree.SubElement(gpx, "metadata")
    etree.SubElement(meta, "time").text = _iso_z(datetime.now(timezone.utc))

    trk = etree.SubElement(gpx, "trk")
    etree.SubElement(trk, "name").text = name
    etree.SubElement(trk, "type").text = "e_bike_mountain"
    trkseg = etree.SubElement(trk, "trkseg")

    for c, o in pairs:
        pt = etree.SubElement(trkseg, "trkpt",
                              lat=f"{c.lat:.7f}", lon=f"{c.lon:.7f}")
        if c.ele is not None:
            etree.SubElement(pt, "ele").text = f"{c.ele:.2f}"
        etree.SubElement(pt, "time").text = _iso_z(c.time)

        # Prefer canonical channels; fall back to 'other' if canonical lacks them.
        hr = c.hr if c.hr is not None else (o.hr if o else None)
        temp = c.temp_c if c.temp_c is not None else (o.temp_c if o else None)
        if hr is not None or temp is not None:
            ext = etree.SubElement(pt, "extensions")
            tpx = etree.SubElement(ext, "{%s}TrackPointExtension" % TPX_NS)
            if temp is not None:
                etree.SubElement(tpx, "{%s}atemp" % TPX_NS).text = f"{temp:.1f}"
            if hr is not None:
                etree.SubElement(tpx, "{%s}hr" % TPX_NS).text = str(hr)

    etree.ElementTree(gpx).write(
        str(path), xml_declaration=True, encoding="UTF-8", pretty_print=True
    )


# -----------------------------------------------------------------------------
# CLI
# -----------------------------------------------------------------------------

def main() -> None:
    ap = argparse.ArgumentParser(
        description="Align and merge two GPX files of the same activity."
    )
    ap.add_argument("garmin", type=Path, help="GPX with full telemetry (HR/temp).")
    ap.add_argument("avinox", type=Path, help="Second GPX (lat/lon/ele only is fine).")
    ap.add_argument("--out", type=Path, default=Path("merged"),
                    help="Output prefix; produces <prefix>.gpx and <prefix>.csv")
    ap.add_argument("--canonical", choices=["garmin", "avinox"], default="garmin",
                    help="Which file provides the time spine and primary track geometry.")
    ap.add_argument("--tolerance", type=float, default=1.5,
                    help="Max time delta (s) for a point-to-point match.")
    ap.add_argument("--name", type=str, default="Merged activity",
                    help="Track name in the output GPX.")
    args = ap.parse_args()

    garmin_pts = parse_gpx(args.garmin)
    avinox_pts = parse_gpx(args.avinox)

    if not garmin_pts or not avinox_pts:
        raise SystemExit("One of the input files produced zero trackpoints.")

    print(f"Garmin: {len(garmin_pts):>5d} pts  "
          f"{_iso_z(garmin_pts[0].time)} -> {_iso_z(garmin_pts[-1].time)}")
    print(f"Avinox: {len(avinox_pts):>5d} pts  "
          f"{_iso_z(avinox_pts[0].time)} -> {_iso_z(avinox_pts[-1].time)}")

    if args.canonical == "garmin":
        canon, other = garmin_pts, avinox_pts
        canon_label, other_label = "garmin", "avinox"
    else:
        canon, other = avinox_pts, garmin_pts
        canon_label, other_label = "avinox", "garmin"

    pairs = align(canon, other, args.tolerance)
    matched = sum(1 for _, o in pairs if o is not None)
    pct = 100.0 * matched / len(canon)
    print(f"Canonical = {canon_label}: {len(canon)} pts, "
          f"{matched} matched ({pct:.1f}%) within +/-{args.tolerance}s")

    # If you wanted to compare GPS sources spatially, the CSV has both lat/lon
    # columns for every canonical timestamp.

    gpx_path = args.out.with_suffix(".gpx")
    csv_path = args.out.with_suffix(".csv")
    write_gpx(pairs, gpx_path, args.name)
    write_csv(pairs, csv_path, canon_label, other_label)
    print(f"Wrote {gpx_path}")
    print(f"Wrote {csv_path}")


if __name__ == "__main__":
    main()