import re import logging from dateutil.parser import isoparse from typing import TYPE_CHECKING, Iterator, Tuple from functools import cached_property from urllib.parse import urlparse import json from os.path import dirname, join if TYPE_CHECKING: from mypy_boto3_s3.client import S3Client import boto3 import boto3.session import botocore.session import botocore.handlers from botocore import UNSIGNED from botocore.config import Config import pystac from ._source import Source logger = logging.getLogger(__name__) class S3Base: def __init__( self, secret_access_key: str = None, access_key_id: str = None, endpoint_url: str = "", strip_bucket: bool = True, validate_bucket_name: bool = True, region_name: str = None, public: bool = False, ): self.secret_access_key = secret_access_key self.access_key_id = access_key_id self.endpoint_url = endpoint_url self.strip_bucket = strip_bucket self.region_name = region_name self.public = public self.validate_bucket_name = validate_bucket_name @cached_property def client(self) -> "S3Client": botocore_session = botocore.session.Session() session = boto3.session.Session(botocore_session=botocore_session) if not self.validate_bucket_name: botocore_session.unregister( "before-parameter-build.s3", botocore.handlers.validate_bucket_name ) client = session.client( "s3", aws_access_key_id=self.access_key_id, aws_secret_access_key=self.secret_access_key, region_name=self.region_name, endpoint_url=self.endpoint_url, config=Config(signature_version=UNSIGNED) if self.public else None, ) return client class S3Source(Source): type = "S3" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @property def client(self) -> "S3Client": botocore_session = botocore.session.Session() session = boto3.session.Session(botocore_session=botocore_session) client = session.client( "s3", aws_access_key_id=self.parameters["access_key_id"], aws_secret_access_key=self.parameters["secret_access_key"], region_name=self.parameters["region"], ) return client @property def bucket(self) -> str: bucket = self.parameters["url"].strip("https://").split(".")[0] return bucket def harvest(self) -> Iterator[dict]: logger.info("Starting S3 harvesting") paginator = self.client.get_paginator("list_objects_v2") pages = paginator.paginate(Bucket=self.bucket, Prefix=self.parameters["prefix"]) time_regex: str = self.parameters["time_regex"] for page in pages: for file in page["Contents"]: if match := re.search(time_regex, file["Key"]): dt = isoparse(match[0]) item = self._create_item(file, dt, self.parameters["url"]) yield item.to_dict() def _create_item(self, data, dt, url): identifier = dt.strftime("%Y%m%d_%H%M%S") properties = { "datetime": dt, "updated": data["LastModified"], } item = pystac.Item( id=identifier, geometry=None, bbox=None, datetime=dt, properties=properties ) item.add_asset(identifier, pystac.Asset(f"{url}{data['Key']}")) return item class S3CatalogSource(S3Base): type = "S3Catalog" def __init__(self, parameters: dict, **kwargs): self.root_href = parameters.pop("root_href") self.default_catalog_name = parameters.pop("default_catalog_name", None) super().__init__(**parameters) def harvest(self) -> Iterator[dict]: logger.info("Starting S3 Catalog harvesting") parsed = urlparse(self.root_href) path = parsed.path if path.startswith("/"): path = parsed.path[1:] if path.endswith("/") and self.default_catalog_name: path = join(path, self.default_catalog_name) yield from self.harvest_catalog(parsed.netloc, path) def fetch_json(self, bucket: str, key: str) -> dict: """ Loads the given object identifier by bucket and key and loads it as JSON. """ if key.startswith("/"): key = key[1:] response = self.client.get_object(Bucket=bucket, Key=key) return json.load(response["Body"]) def join_href(self, bucket: str, key: str, href: str) -> Tuple[str, str]: """ Joins the given href with a previous bucket/key. When we have a fully qualified S3 URL, the included bucket/key pair is returned. If href is a relative path, it is joined with the previous key. """ parsed = urlparse(href) if parsed.netloc: if parsed.scheme.lower() != "s3": # TODO: what if HTTP hrefs? raise ValueError("Can only join S3 URLs") return (parsed.netloc, parsed.path) else: return ( bucket, join(dirname(key), parsed.path), ) def harvest_catalog(self, bucket: str, key: str) -> Iterator[dict]: """ Harvests a specified STAC catalog. Will recurse into child catalogs and yield all included STAC items. """ logger.info(f"Harvesting from catalog {bucket}/{key}") catalog = self.fetch_json(bucket, key) for link in catalog["links"]: if link["rel"] == "item": item_bucket, item_key = self.join_href(bucket, key, link["href"]) logger.info(f"Harvested item {item_bucket}/{item_key}") yield self.fetch_json(item_bucket, item_key) elif link["rel"] == "child": cat_bucket, cat_key = self.join_href(bucket, key, link["href"]) yield from self.harvest_catalog(cat_bucket, cat_key)