From 8207db8f4cb93819966a58ec632bb28fb5c8ddd0 Mon Sep 17 00:00:00 2001
From: Fabian Schindler <fabian.schindler.strauss@gmail.com>
Date: Tue, 25 Aug 2020 11:00:27 +0200
Subject: [PATCH] Fixing corner georef

---
 .../preprocessor/steps/georeference.py        | 107 ++++++++++++------
 1 file changed, 73 insertions(+), 34 deletions(-)

diff --git a/preprocessor/preprocessor/steps/georeference.py b/preprocessor/preprocessor/steps/georeference.py
index 0ed68761..e2612791 100644
--- a/preprocessor/preprocessor/steps/georeference.py
+++ b/preprocessor/preprocessor/steps/georeference.py
@@ -3,9 +3,9 @@ from os.path import join, basename, splitext
 import logging
 from glob import glob
 import shutil
-from typing import List
+from typing import List, Tuple
 
-from osgeo import gdal
+from osgeo import gdal, osr
 
 
 logger = logging.getLogger(__name__)
@@ -20,8 +20,12 @@ def georeference_step(source_dir: os.PathLike, target_dir: os.PathLike, type: st
         georef_func = rpc_georef
     elif type_name == 'world':
         georef_func = world_georef
+    elif type_name == 'corners':
+        georef_func = corner_georef
+    else:
+        raise Exception('Invalid georeference type %s' % type_name)
 
-    for filename in glob(join(source_dir, '*.TIF')):
+    for filename in glob(join(source_dir, '*.tif')):
         target_filename = join(target_dir, basename(filename))
         georef_func(filename, target_filename, **options)
 
@@ -77,25 +81,63 @@ def rpc_georef(input_filename: os.PathLike, target_filename: os.PathLike, rpc_fi
         **(warp_options or {})
     )
 
-def corner_georef(input_filename: os.PathLike, target_filename: os.PathLike, corner_names: List[str]=None):
+
+def corner_georef(input_filename: os.PathLike, target_filename: os.PathLike, corner_names: List[str]=None,
+                  orbit_direction_name: str=None, force_north_up: bool=False, gcp_srid: int=4326, warp: bool=False):
     corner_names = corner_names or ["bottom_left", "bottom_right", "top_left", "top_right"]
     ds = gdal.Open(input_filename, gdal.GA_Update)
-    for corner_name in corner_names:
-        ds.GetMetaData()
+
+    orbit_direction = ds.GetMetadata()[orbit_direction_name].lower()
+    metadata = ds.GetRasterBand(1).GetMetadata()
+
+    # from pprint import pprint
+
+    # pprint (metadata)
+    # pprint(ds.GetMetadata())
+    bl, br, tl, tr = [
+        [float(num) for num in metadata[corner_name].split()]
+        for corner_name in corner_names
+    ]
+
+    gcps = gcps_from_borders(
+        (ds.RasterXSize, ds.RasterYSize),
+        (bl, br, tl, tr),
+        orbit_direction,
+        force_north_up
+    )
+
+    sr = osr.SpatialReference()
+    sr.ImportFromEPSG(gcp_srid)
+
+    ds.SetGCPs(gcps, sr.ExportToWkt())
+
+    if warp:
+        gdal.Warp(
+            target_filename,
+            ds,
+        )
+        del ds
+    else:
+        ds.SetGeoTransform(gdal.GCPsToGeoTransform(ds.GetGCPs()))
+        driver = ds.GetDriver()
+        del ds
+        driver.Rename(target_filename, input_filename)
 
 
 def world_georef():
+    # TODO: implement
     pass
 
 
 
 
-def gcps_from_borders(dst, coords, orbit_direction, force_north_up=False):
+def gcps_from_borders(size: Tuple[float, float], coords: List[Tuple[float, float]], orbit_direction: str, force_north_up: bool=False):
+    x_size, y_size = size
     # expects coordinates in dict(.*border_left.*:[lat,lon],...)
     gcps = []
     if force_north_up and len(coords) == 4:
         # compute gcps assuming north-up, east-right image no matter, what is claimed by metadata
-        sorted_by_lats = sorted(coords.values(), key=lambda x: x[0], reverse=True)
+        sorted_by_lats = sorted(coords, key=lambda x: x[0], reverse=True)
         # compare longitudes
         if sorted_by_lats[0][1] > sorted_by_lats[1][1]:
             #                                                            /\
@@ -126,31 +168,28 @@ def gcps_from_borders(dst, coords, orbit_direction, force_north_up=False):
                 bottom_left = sorted_by_lats[0]
                 bottom_right = sorted_by_lats[1]
         gcps.append(gdal.GCP(bottom_left[1], bottom_left[0], 0, 0.5, 0.5))
-        gcps.append(gdal.GCP(bottom_right[1], bottom_right[0], 0, dst.RasterXSize - 0.5, 0.5))
-        gcps.append(gdal.GCP(top_left[1], top_left[0], 0, 0.5, dst.RasterYSize - 0.5))
-        gcps.append(gdal.GCP(top_right[1], top_right[0], 0, dst.RasterXSize - 0.5, dst.RasterYSize - 0.5))
+        gcps.append(gdal.GCP(bottom_right[1], bottom_right[0], 0, x_size - 0.5, 0.5))
+        gcps.append(gdal.GCP(top_left[1], top_left[0], 0, 0.5, y_size - 0.5))
+        gcps.append(gdal.GCP(top_right[1], top_right[0], 0, x_size - 0.5, y_size - 0.5))
+
     else:
-        for key, value in coords.items():
-            # assume points are labeled correctly and image not necessarily north up, access border points as mid-pixels and use them as GCPs
-            if "bottom_left" in key.lower() or "bottom_right" in key.lower():
-                if orbit_direction != "descending":
-                    y = 0.5
-                else:
-                    y = dst.RasterYSize - 0.5
-            else:
-                if orbit_direction != "descending":
-                    y = dst.RasterYSize - 0.5
-                else:
-                    y = 0.5
-            if "bottom_left" in key.lower() or "top_left" in key.lower():
-                if orbit_direction != "descending":
-                    x = 0.5
-                else:
-                    x = dst.RasterXSize - 0.5
-            else:
-                if orbit_direction != "descending":
-                    x = dst.RasterXSize - 0.5
-                else:
-                    x = 0.5
-            gcps.append(gdal.GCP(value[1], value[0], 0, x, y))
+        bl, br, tl, tr = coords
+
+        x_left = x_size - 0.5
+        x_right = 0.5
+
+        y_bottom = 0.5
+        y_top = y_size - 0.5
+
+        if orbit_direction == 'descending':
+            x_left, x_right = x_right, x_left
+            y_bottom, y_top = y_top, y_bottom
+
+        gcps.extend([
+            gdal.GCP(bl[1], bl[0], 0, x_left, y_bottom),
+            gdal.GCP(br[1], br[0], 0, x_right, y_bottom),
+            gdal.GCP(tl[1], tl[0], 0, x_left, y_top),
+            gdal.GCP(tr[1], tr[0], 0, x_right, y_top),
+        ])
+
     return gcps
-- 
GitLab