File size: 5,269 Bytes
			
			| 61e0235 caa7010 61e0235 caa7010 307a330 caa7010 61e0235 caa7010 307a330 61e0235 307a330 61e0235 307a330 61e0235 307a330 61e0235 307a330 caa7010 61e0235 caa7010 61e0235 caa7010 61e0235 caa7010 61e0235 307a330 61e0235 307a330 61e0235 7c8cf1a 61e0235 7c8cf1a 61e0235 7c8cf1a 61e0235 caa7010 307a330 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | import rasterio as rio
import pathlib
import opensr_test
import matplotlib.pyplot as plt
from typing import Callable, Union
def create_geotiff(
    model: Callable,
    fn: Callable,
    datasets: Union[str, list],
    output_path: str,
    force: bool = False,
    **kwargs
) -> None:
    """Create all the GeoTIFFs for a specific dataset snippet 
    Args:
        model (Callable): The model to use to run the fn function.
        fn (Callable): A function that return a dictionary with the following keys:
            - "lr": Low resolution image
            - "sr": Super resolution image
            - "hr": High resolution image
        datasets (list): A list of dataset snippets to use to run the fn function.
        output_path (str): The output path to save the GeoTIFFs.
        force (bool, optional): If True, the dataset is redownloaded. Defaults 
            to False.
    """
    
    if datasets == "all":
        datasets = opensr_test.datasets 
    for snippet in datasets:
        create_geotiff_batch(
            model=model,
            fn=fn,
            snippet=snippet,
            output_path=output_path,
            force=force,
            **kwargs
        )    
    return None
def create_geotiff_batch(
    model: Callable,
    fn: Callable,
    snippet: str,
    output_path: str,
    force: bool = False,
    **kwargs
) -> pathlib.Path:
    """Create all the GeoTIFFs for a specific dataset snippet 
    Args:
        model (Callable): The model to use to run the fn function.
        fn (Callable): A function that return a dictionary with the following keys:
            - "lr": Low resolution image
            - "sr": Super resolution image
            - "hr": High resolution image
        snippet (str): The dataset snippet to use to run the fn function.
        output_path (str): The output path to save the GeoTIFFs.
        force (bool, optional): If True, the dataset is redownloaded. Defaults 
            to False.
    Returns:
        pathlib.Path: The output path where the GeoTIFFs are saved.
    """
    
    # Create folders to save results
    output_path = pathlib.Path(output_path)  / "results" / "SR"
    output_path.mkdir(parents=True, exist_ok=True)
    output_path_dataset_geotiff = output_path / snippet / "geotiff"
    output_path_dataset_geotiff.mkdir(parents=True, exist_ok=True)
    output_path_dataset_png = output_path / snippet / "png"
    output_path_dataset_png.mkdir(parents=True, exist_ok=True)
    # Load the dataset 
    dataset = opensr_test.load(snippet, force=force)
    lr_dataset, hr_dataset, metadata = dataset["L2A"], dataset["HRharm"], dataset["metadata"]
    for index in range(len(lr_dataset)):
        print(f"Processing {index}/{len(lr_dataset)}")
        # Run the model    
        results = fn(
            model=model,
            lr=lr_dataset[index],
            hr=hr_dataset[index],
            **kwargs
        )
        # Get the image name
        image_name = metadata.iloc[index]["hr_file"]
        # Get the CRS and transform
        crs = metadata.iloc[index]["crs"]
        transform_str = metadata.iloc[index]["affine"]
        transform_list = [float(x) for x in transform_str.split(",")]
        transform_rio = rio.transform.from_origin(
            transform_list[2],
            transform_list[5],
            transform_list[0],
            transform_list[4] * -1
        )
        # Create rio dict
        meta_img = {
            "driver": "GTiff",
            "count": 3,
            "dtype": "uint16",
            "height": results["hr"].shape[1],
            "width": results["hr"].shape[2],
            "crs": crs,
            "transform": transform_rio,
            "compress": "deflate",
            "predictor": 2,
            "tiled": True
        }
        # Save the GeoTIFF
        with rio.open(output_path_dataset_geotiff / (image_name + ".tif"), "w", **meta_img) as dst:
            dst.write(results["sr"])
        # Save the PNG
        fig, ax = plt.subplots(1, 3, figsize=(15, 5))
        ax[0].imshow((results["lr"].transpose(1, 2, 0) / 3000).clip(0, 1))
        ax[0].set_title("LR")
        ax[0].axis("off")
        ax[1].imshow((results["sr"].transpose(1, 2, 0) / 3000).clip(0, 1))
        ax[1].set_title("SR")
        ax[1].axis("off")
        ax[2].imshow((results["hr"].transpose(1, 2, 0) / 3000).clip(0, 1))
        ax[2].set_title("HR")
        # remove whitespace around the image
        plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
        plt.axis("off")
        plt.savefig(output_path_dataset_png / (image_name + ".png"))
        plt.close()
        plt.clf()
    return output_path_dataset_geotiff
def run(
    model_path: str
) -> pathlib.Path:
    """Run the all metrics for a specific model.
    Args:
        model_path (str): The path to the model folder.
    
    Returns:
        pathlib.Path: The output path where the metrics are 
        saved as a pickle file.
    """
    pass
def plot(
    model_path: str
) -> pathlib.Path:
    """Generate the plots and tables for a specific model.
    Args:
        model_path (str): The path to the model folder.
    
    Returns:
        pathlib.Path: The output path where the plots and tables are 
        saved.
    """
    pass
 | 
