#
# Author: Ben Stabley
# Date: 2023-08-26
#
# Requires pandas, opencv (cv2), and piexif packages; also requires Docker
# installed and in the system PATH.
#
# This script is for batch processing aerial photos ("single frames") acquired
# from the USGS Earth Explorer portal using OpenDroneMap. Each group of photos,
# ie dataset, is a "project". Each project is in its own folder, and in each
# project folder is a JSON configuration file which specifies paths for project
# inputs and outputs, as well as other parameters.
#
# An example `config.json` is:
#
#    {
#        "name": "1950_pdx_photos",
#        "project_root": "C:\projects\old_photos\1950",
#        "paths": {
#            "raw_img": "raw",
#            "processed_img": "processed",
#            "odm_workspace": "odm",
#            "products": "products",
#            "metadata": "metadata.csv",
#            "gcps": "gcp_list.txt",
#        },
#        "image_preprocess": {
#            "crop": [500, 400],
#            "resize": 0.5,
#            "jpeg_quality": 95,
#        },
#        "odm": {
#            "feature_quality": "high",
#            "dsm": false
#        },
#    }
#
# Which would correspond to the following directory structure before running
# the script:
#
#    C:\projects\old_photos
#        \1950
#            \raw
#                0001.tif
#                0002.tif
#                ...
#            metadata.csv
#            gcp_list.txt
#            config.json
#        \1960
#            <similar to above>
#        \1970
#            ...
#        \1980
#            ...
#
# The folders `processed`, `odm`, and `products` will be created by the script
# and contain output and intermediate files.
#
# Run the script as in this example to process multipled projects as a batch:
#
#   python process_images.py D:\projects\old_photos\1950\config.json D:\projects\old_photos\1960\config.json
#
# More command information using `python process_images.py --help`.

import argparse
import json
import pprint
import shutil
import subprocess
import time
from typing import Any
import piexif
import pandas as pd
import cv2
from pathlib import Path
import math


def expand_crop_amount(amount: Any) -> list[int]:
    """Expand crop amount into its full 4-integer form. See `crop()` for details."""
    if type(amount) == int:
        amount = [amount]
    elif type(amount) != list:
        raise ValueError("delta is not an int or list[int]")

    # expand < 4 ints to 4
    if len(amount) == 1:
        amount = amount*4
    elif len(amount) == 2:
        amount = [amount[0], amount[1], amount[0], amount[1]]
    elif len(amount) > 4:
        amount = amount[:4]
    elif len(amount) != 4:
        raise ValueError("crop amount as list[int] should be length 1, 2, or 4.")

    return amount


def crop(filepath: str, amount: Any = None) -> cv2.Mat:
    """
    Read the image at filepath, crop according to `amount`, and return a opencv
    image. Amount is the number of pixels **from the edge** to remove. It is given
    as a single integer or a list of integers with length 1, 2, or 4. A single
    value applies the crop to all sides. 2 values applies to vertical and
    horizontal sides `[top+botton, left+right]`. 4 values applies to each side
    `[top, right, bottom, left]`. This is the same ordering as CSS properties such
    as `margin`.
    """
    img = cv2.imread(filepath)
    if not amount:
        return img

    d = expand_crop_amount(amount)  # convert amount to full 4 number form

    h = img.shape[0]
    w = img.shape[1]
    c = [0+d[3], w-d[1], 0+d[0], h-d[2]]  # CSS order style is goofy
    crop = img[c[2]: c[3], c[0]: c[1]]  # y is first, x is second
    return crop


def resize(img: cv2.Mat, proportion: float = 1) -> cv2.Mat:
    """Resize/scale img by proprotion."""
    return cv2.resize(src=img, dsize=(0, 0), fx=proportion,
                      fy=proportion, interpolation=cv2.INTER_AREA)


def convert_to_jpg(img: cv2.Mat, dest_filepath: str, jpgquality: int = 95) -> str:
    """Write img to disk as a jpeg."""
    out = Path(dest_filepath)
    out = out.with_name(out.stem + ".jpg")
    cv2.imwrite(str(out), img, [cv2.IMWRITE_JPEG_QUALITY, jpgquality])
    return str(out)


def filename_from_photoid(id: str, ext: str = ".tif", prefix: str = "") -> str:
    return f"{prefix}{id}{ext}"


def rational(number: float, denominator: int = 1) -> tuple[int, int]:
    """Convert a float to a rational number with log10('denominator') digits of precision(?)."""
    return (round(number * denominator), denominator)


def dms(dd: float, unit: tuple[str, str]) -> tuple[tuple, tuple, tuple, str]:
    """Convert decimal degrees (DD) to degree min sec (DMS) with a NSEW reference."""
    ref = unit[0] if dd >= 0 else unit[1]
    dd = abs(dd)
    d = math.floor(dd)
    m = math.floor((dd - d) * 60)
    s = (((dd - d) * 60) - m) * 60
    return (rational(d), rational(m), rational(s, 100), ref)


def ft_m(ft: float) -> float:
    """Convert feet to meters."""
    return ft*0.3048


def focal_len_35mm(f: float, film: float) -> int:
    """Supposedly calculate the 35mm equivalent focal length of the lens."""
    return round(43.3*(f/math.hypot(film, film)))


def clean_film_size(raw: str, unit: str = "mm") -> float:
    """
    Parse 'human readable' film size to a single floating point.
    Assumes the film size is square."""
    # ex: 229mm x 229mm
    parts = raw.strip().split(maxsplit=1)
    clean = parts[0].lower().strip().removesuffix(unit).strip()
    return float(clean)


def clean_focal_len(raw: str, unit: str = "mm") -> float:
    """ Parse 'human readable' focal length to plain old floating point."""
    # ex: 152.42 mm
    clean = raw.strip().lower().removesuffix(unit).strip()
    return float(clean)


def load_metadata(filename: str) -> pd.DataFrame:
    """
    Read the metadata.csv at filename, convert it to a pandas dataframe,
    and perform some cleaning and conversions to the data.
    """
    df = pd.read_csv(filename, encoding="ISO-8859-1")  # government not using utf8...
    orig_cols = [
        "Photo ID",
        "Film Length and Width",
        "Focal Length",
        "Center Latitude dec",
        "Center Longitude dec",
        "Flying Height in Feet",
        "SW Corner Long dec",
        "SW Corner Lat dec",
        "NE Corner Long dec",
        "NE Corner Lat dec",
    ]
    new_cols = ["id", "film", "focal", "lat", "lon", "alt", "xmin", "ymin", "xmax", "ymax"]
    df = df[orig_cols]
    df.columns = new_cols
    df["film"] = df["film"].apply(clean_film_size)  # eg 120.1mm x 120.1mm -> 120.1
    df["focal"] = df["focal"].apply(clean_focal_len)  # eg 123.4mm -> 123.4
    df["alt"] = df["alt"].apply(ft_m)  # convert to meters
    df["file"] = df["id"].apply(filename_from_photoid)  # filename

    return df


def update_exif(filepath: str, data: dict):
    """
    Write the exif data to the image at filepath.
    """
    img_dims = cv2.imread(filepath).shape

    e = piexif.load(filepath)  # read any existing metadata
    lat = dms(data["lat"], ("N", "S"))
    lon = dms(data["lon"], ("E", "W"))
    # image dimensions
    e["0th"][piexif.ImageIFD.ImageWidth] = img_dims[1]
    e["0th"][piexif.ImageIFD.ImageLength] = img_dims[0]
    # gps
    e["GPS"][piexif.GPSIFD.GPSVersionID] = (2, 0, 0, 0)
    e["GPS"][piexif.GPSIFD.GPSLatitude] = lat[:3]
    e["GPS"][piexif.GPSIFD.GPSLatitudeRef] = lat[3]
    e["GPS"][piexif.GPSIFD.GPSLongitude] = lon[:3]
    e["GPS"][piexif.GPSIFD.GPSLongitudeRef] = lon[3]
    e["GPS"][piexif.GPSIFD.GPSAltitude] = rational(data["alt"], 100)
    e["GPS"][piexif.GPSIFD.GPSAltitudeRef] = 0
    # focal length
    e["Exif"][piexif.ExifIFD.FocalLength] = rational(data["focal"], 10)
    exif_bytes = piexif.dump(e)  # new metadata to bytes
    piexif.insert(exif_bytes, filepath)  # write to file


def fix_gcps(gcp_path: Path, crop: list[int], scale: float):
    """
    'Fix' the GCPs, which are assumed to have been created based on the raw
    uncroped unresized images. The GCPs reference the top-left corner as px(0,0)
    so to 'fix' the GCP pixel location it is:
        `fixed_pixel = (pixel - (top_crop, left_crop)) * resize_scale`
    The fixed GCP file will overwrite the file at gcp_path.
    """
    # odm gcp list format
    # https://docs.opendronemap.org/gcp/

    crs = gcp_path.read_text().splitlines()[0]
    crop = expand_crop_amount(crop)

    gcp = pd.read_csv(str(gcp_path), sep="\t", skiprows=1,
                      names=["lon", "lat", "elev", "x", "y", "img", "name"])
    print("original GCPs:")
    print(gcp)

    gcp["x"] = (gcp["x"] - crop[3]) * scale
    gcp["y"] = (gcp["y"] - crop[0]) * scale
    print("fixed GCPs:")
    print(gcp)

    fixed = gcp.to_csv(sep="\t", header=False, index=False)
    gcp_path.write_text(f"{crs}\n{fixed}")


def run_preprocess(proj: dict):
    """
    Perform the complete raw image preprocess task. This includes cropping and
    resizing the images, and reading and applying metadata to exif from a 
    separate file (eg metadata.csv). All processed images will be converted to
    jpeg and written to the 'processed_img' folder. The original raw images will
    not be modified.
    """
    paths = proj["paths"]
    print(f"using metadata at {paths['metadata']}")
    meta = load_metadata(paths["metadata"])  # assumes .tif extension

    for row in meta.to_dict(orient="records"):
        raw = Path(paths["raw_img"], row["file"])
        if not raw.exists():
            continue
        print(f"updating {raw}")
        pre = proj["image_preprocess"]
        # crop and resize
        modified = crop(str(raw), amount=pre["crop"])
        modified = resize(modified, proportion=pre["resize"])
        # convert to jpeg, save to new path
        dst_path = Path(paths["processed_img"], raw.name)
        new_name = convert_to_jpg(modified, dest_filepath=str(
            dst_path), jpgquality=pre["jpeg_quality"])
        # write metadata
        update_exif(str(new_name), row)
        print(f"   wrote {new_name}")


def run_odm(proj: dict):
    """
    Perform the complete OpenDroneMap task. Setup includes copying processed
    images to the ODM workspace, checking for GCPs (and fixing them if the
    original images have a crop/resize applied in the preprocess step), and 
    preparing all the ODM flags. Cleanup is copying the ortho, DEM, and report
    to the products folder, and deleting temp files.
    """

    ### setup ###
    paths = proj["paths"]
    dataset_dir = str(Path(paths["odm_workspace"]).parent)
    odm_dir_name = str(Path(paths["odm_workspace"]).name)
    odm_images = Path(paths["odm_workspace"], "images")
    shutil.copytree(src=paths["processed_img"], dst=odm_images, dirs_exist_ok=True)  # temp copy

    print(f"Dataset directory: {dataset_dir}")
    print(f"ODM directory:     {odm_dir_name}")

    orig_gcps = Path(paths["gcps"]) if paths["gcps"] else None
    odm_gcps = Path(paths["odm_workspace"], "gcp_list.txt")
    if orig_gcps and orig_gcps.exists():
        print(f"GCPs found:        {orig_gcps.exists()}")
        print(f"Using GCPs at:     {orig_gcps}")
        shutil.copyfile(src=orig_gcps, dst=odm_gcps)  # temp copy
        # correct odm's temp copy of gcps for crop/resize images
        pre = proj["image_preprocess"]
        if pre["crop"] or pre["resize"] != 1:
            print("Must fix GCPS.")
            fix_gcps(odm_gcps, crop=pre["crop"], scale=pre["resize"])
        else:
            print("Not fixing GCPs. No crop or resize modification.")
        gcp_flag = f"--gcp /datasets/{odm_dir_name}/gcp_list.txt "
    else:
        gcp_flag = ""

    #### run ODM ####
    # flags to consider using: (https://docs.opendronemap.org/arguments/)
    #   name - is this currently "odm"? (this does nothing. not a real argument.)
    #   optimize-disk-space - delete intermediate ODM outputs
    #   use-3dmesh - supposedly faster than 2.5D mesh for orthorectification
    #   dsm - create surface DEM
    #   feature-quality - default high, maybe change to medium?
    try:
        opts = proj["odm"]
        feature_quality = opts.get("feature_quality")
        if feature_quality not in ["ultra", "high", "medium", "low", "lowest"]:
            feature_quality = "high"

        flags = [
            # required. keep in this order.
            "-ti", "--rm",
            f"-v {dataset_dir}:/datasets",
            "--gpus all",
            "opendronemap/odm:gpu",
            "--project-path /datasets",
            odm_dir_name,
            # not required. probably order doesn't matter.
            gcp_flag,
            "--sfm-algorithm planar",
            "--feature-quality", feature_quality,
            "--dsm" if opts.get("dsm") else "",
            # "--use-3dmesh",
            "--skip-3dmodel",
            "--optimize-disk-space",
            "--rerun-all"
        ]
        cmd = f"docker run {' '.join(flags)}"

        print(f"running '{cmd}'")
        subprocess.run(cmd)  # OMG do it!

        # copy outputs to products folder
        report = Path(paths["odm_workspace"], "odm_report", "report.pdf")
        if report.exists():
            cp_report = Path(paths["products"], f"{proj['name']}_report.pdf")
            shutil.copyfile(src=report, dst=cp_report)
        ortho = Path(paths["odm_workspace"], "odm_orthophoto", "odm_orthophoto.tif")
        if ortho.exists():
            cp_ortho = Path(paths["products"], f"{proj['name']}_ortho.tif")
            shutil.copyfile(src=ortho, dst=cp_ortho)
        dem = Path(paths["odm_workspace"], "odm_dem", "dsm.tif")
        if dem.exists():
            cp_dem = Path(paths["products"], f"{proj['name']}_dsm.tif")
            shutil.copyfile(src=dem, dst=cp_dem)

    ### cleanup ###
    finally:
        odm_gcps.unlink(missing_ok=True)
        shutil.rmtree(odm_images)


def read_config(config_path: str) -> dict:
    """
    Read the project config json and convert any relative paths to absolute.
    Directories will be created if they do not exist.
    """
    with Path(config_path).open() as file:
        config = json.load(file)

    def make_abs(p, rel):
        if not p:
            return None  # TODO: this kinda sucks. gotta think of something else
        p = Path(p)
        if not p.is_absolute():
            p = Path(rel, p).resolve()
        if not p.exists():
            p.mkdir(parents=True, exist_ok=True)
            print("made", p)
        return str(p)

    config["project_root"] = make_abs(config["project_root"], Path.cwd())
    for p in ["raw_img", "processed_img", "odm_workspace", "products", "metadata", "gcps"]:
        config["paths"][p] = make_abs(config["paths"].get(p), config["project_root"])

    return config


def run_project(proj: dict, skip_preprocess: bool, skip_odm: bool):
    """
    Process each project. The first step is "preprocess" to crop/resize each
    raw image and apply metadata (GPS, lens info, etc) from a separate file. The
    second step is "odm" which sets up the configuration for OpenDroneMap to 
    do photogrammetry to the images. If either step fails, the project will
    stop processing and print an error message. 
    """
    print(f"Project configuration for '{proj['name']}'")
    pprint.pprint(proj)

    try:
        t = {"preprocess": 0.0, "odm": 0.0}
        if not skip_preprocess:
            start = time.monotonic()
            run_preprocess(proj)
            t["preprocess"] = (time.monotonic() - start)/60.0  # mins
        if not skip_odm:
            start = time.monotonic()
            run_odm(proj)
            t["odm"] = (time.monotonic() - start)/60.0  # mins
        print(
            f"{proj['name']} completed Preprocess in {t['preprocess']:.1f} mins and ODM in {t['odm']:.1f} mins.")
        proj["time"] = t
    except Exception as e:
        print(f"project {proj['name']} encountered a fatal error:")
        print(e)


if __name__ == "__main__":
    # process command line arguments
    parser = argparse.ArgumentParser(description="This script is mostly made for the purpose of batch processing historic aerial photography of the 'single frames' variety available on USGS Earth Explorer. It can processes multiple datasets. Each dataset must contain a `config.json` file, raw images, a metadata csv, and ground control points. Images are preprocessed to reshape them and apply metadata. Then an orthomosaic (etc) is created using OpenDroneMap.")
    parser.add_argument("config", type=str, nargs="+",
                        help="Give one config file per dataset to process.")
    parser.add_argument("--skip-preprocess", dest="skip_preprocess", type=bool, default=False,
                        action=argparse.BooleanOptionalAction,
                        help="Skip the image preprocessing step. Useful if the image parameters didn't change but ODM parameters did.")
    parser.add_argument("--skip-odm", dest="skip_odm", type=bool, default=False,
                        action=argparse.BooleanOptionalAction,
                        help="Skip the ODM step. Useful if you just want to process the images.")
    args = parser.parse_args()

    # read each project's configuration file
    config = [read_config(c) for c in args.config]

    # run each project
    for proj in config:
        run_project(proj, args.skip_preprocess, args.skip_odm)

    # final time summary (which was written back to the project config)
    print("\nFinal batch summary:")
    for proj in config:
        print(f"{proj['name']} completed Preprocess in {proj['time']['preprocess']:.1f} mins and ODM in {proj['time']['odm']:.1f} mins.")