File size: 5,269 Bytes
			
			| 61e0235 caa7010 61e0235 caa7010 307a330 caa7010 61e0235 caa7010 307a330 61e0235 307a330 61e0235 307a330 61e0235 307a330 61e0235 307a330 caa7010 61e0235 caa7010 61e0235 caa7010 61e0235 caa7010 61e0235 307a330 61e0235 307a330 61e0235 7c8cf1a 61e0235 7c8cf1a 61e0235 7c8cf1a 61e0235 caa7010 307a330 || import rasterio as rio
import pathlib
import opensr_test
import matplotlib.pyplot as plt
from typing import Callable, Union
def create_geotiff(
    model: Callable,
    fn: Callable,
    datasets: Union[str, list],
    output_path: str,
    force: bool = False,
    **kwargs
) -> None:
    """Create all the GeoTIFFs for a specific dataset snippet 
    Args:
        model (Callable): The model to use to run the fn function.
        fn (Callable): A function that return a dictionary with the following keys:
            - "lr": Low resolution image
            - "sr": Super resolution image
            - "hr": High resolution image
        datasets (list): A list of dataset snippets to use to run the fn function.
        output_path (str): The output path to save the GeoTIFFs.
        force (bool, optional): If True, the dataset is redownloaded. Defaults 
            to False.
    """
    
    if datasets == "all":
        datasets = opensr_test.datasets 
    for snippet in datasets:
        create_geotiff_batch(
            model=model,
            fn=fn,
            snippet=snippet,
            output_path=output_path,
            force=force,
            **kwargs
        )    
    return None
def create_geotiff_batch(
    model: Callable,
    fn: Callable,
    snippet: str,
    output_path: str,
    force: bool = False,
    **kwargs
) -> pathlib.Path:
    """Create all the GeoTIFFs for a specific dataset snippet 
    Args:
        model (Callable): The model to use to run the fn function.
        fn (Callable): A function that return a dictionary with the following keys:
            - "lr": Low resolution image
            - "sr": Super resolution image
            - "hr": High resolution image
        snippet (str): The dataset snippet to use to run the fn function.
        output_path (str): The output path to save the GeoTIFFs.
        force (bool, optional): If True, the dataset is redownloaded. Defaults 
            to False.
    Returns:
        pathlib.Path: The output path where the GeoTIFFs are saved.
    """
    
    # Create folders to save results
    output_path = pathlib.Path(output_path)  / "results" / "SR"
    output_path.mkdir(parents=True, exist_ok=True)
    output_path_dataset_geotiff = output_path / snippet / "geotiff"
    output_path_dataset_geotiff.mkdir(parents=True, exist_ok=True)
    output_path_dataset_png = output_path / snippet / "png"
    output_path_dataset_png.mkdir(parents=True, exist_ok=True)
    # Load the dataset 
    dataset = opensr_test.load(snippet, force=force)
    lr_dataset, hr_dataset, metadata = dataset["L2A"], dataset["HRharm"], dataset["metadata"]
    for index in range(len(lr_dataset)):
        print(f"Processing {index}/{len(lr_dataset)}")
        # Run the model    
        results = fn(
            model=model,
            lr=lr_dataset[index],
            hr=hr_dataset[index],
            **kwargs
        )
        # Get the image name
        image_name = metadata.iloc[index]["hr_file"]
        # Get the CRS and transform
        crs = metadata.iloc[index]["crs"]
        transform_str = metadata.iloc[index]["affine"]
        transform_list = [float(x) for x in transform_str.split(",")]
        transform_rio = rio.transform.from_origin(
            transform_list[2],
            transform_list[5],
            transform_list[0],
            transform_list[4] * -1
        )
        # Create rio dict
        meta_img = {
            "driver": "GTiff",
            "count": 3,
            "dtype": "uint16",
            "height": results["hr"].shape[1],
            "width": results["hr"].shape[2],
            "crs": crs,
            "transform": transform_rio,
            "compress": "deflate",
            "predictor": 2,
            "tiled": True
        }
        # Save the GeoTIFF
        with rio.open(output_path_dataset_geotiff / (image_name + ".tif"), "w", **meta_img) as dst:
            dst.write(results["sr"])
        # Save the PNG
        fig, ax = plt.subplots(1, 3, figsize=(15, 5))
        ax[0].imshow((results["lr"].transpose(1, 2, 0) / 3000).clip(0, 1))
        ax[0].set_title("LR")
        ax[0].axis("off")
        ax[1].imshow((results["sr"].transpose(1, 2, 0) / 3000).clip(0, 1))
        ax[1].set_title("SR")
        ax[1].axis("off")
        ax[2].imshow((results["hr"].transpose(1, 2, 0) / 3000).clip(0, 1))
        ax[2].set_title("HR")
        # remove whitespace around the image
        plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
        plt.axis("off")
        plt.savefig(output_path_dataset_png / (image_name + ".png"))
        plt.close()
        plt.clf()
    return output_path_dataset_geotiff
def run(
    model_path: str
) -> pathlib.Path:
    """Run the all metrics for a specific model.
    Args:
        model_path (str): The path to the model folder.
    
    Returns:
        pathlib.Path: The output path where the metrics are 
        saved as a pickle file.
    """
    pass
def plot(
    model_path: str
) -> pathlib.Path:
    """Generate the plots and tables for a specific model.
    Args:
        model_path (str): The path to the model folder.
    
    Returns:
        pathlib.Path: The output path where the plots and tables are 
        saved.
    """
    pass
 | 
