EOX GitLab Instance

s3.py 5.98 KB
Newer Older
Nikola Jankovic's avatar
Nikola Jankovic committed
1
2
3
import re
import logging
from dateutil.parser import isoparse
4
from typing import TYPE_CHECKING, Iterator, Tuple
5
6
7
8
from functools import cached_property
from urllib.parse import urlparse
import json
from os.path import dirname, join
Nikola Jankovic's avatar
Nikola Jankovic committed
9
10
11
12

if TYPE_CHECKING:
    from mypy_boto3_s3.client import S3Client

13
import boto3
Nikola Jankovic's avatar
Nikola Jankovic committed
14
15
import boto3.session
import botocore.session
16
17
18
import botocore.handlers
from botocore import UNSIGNED
from botocore.config import Config
Nikola Jankovic's avatar
Nikola Jankovic committed
19
20
21
22
23
24
25
import pystac

from ._source import Source

logger = logging.getLogger(__name__)


26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class S3Base:
    def __init__(
        self,
        secret_access_key: str = None,
        access_key_id: str = None,
        endpoint_url: str = "",
        strip_bucket: bool = True,
        validate_bucket_name: bool = True,
        region_name: str = None,
        public: bool = False,
    ):
        self.secret_access_key = secret_access_key
        self.access_key_id = access_key_id
        self.endpoint_url = endpoint_url
        self.strip_bucket = strip_bucket
        self.region_name = region_name
        self.public = public
        self.validate_bucket_name = validate_bucket_name

    @cached_property
    def client(self) -> "S3Client":
        botocore_session = botocore.session.Session()
        session = boto3.session.Session(botocore_session=botocore_session)

        if not self.validate_bucket_name:
            botocore_session.unregister(
                "before-parameter-build.s3", botocore.handlers.validate_bucket_name
            )

        client = session.client(
            "s3",
            aws_access_key_id=self.access_key_id,
            aws_secret_access_key=self.secret_access_key,
            region_name=self.region_name,
            endpoint_url=self.endpoint_url,
            config=Config(signature_version=UNSIGNED) if self.public else None,
        )
        return client


Nikola Jankovic's avatar
Nikola Jankovic committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
class S3Source(Source):
    type = "S3"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @property
    def client(self) -> "S3Client":
        botocore_session = botocore.session.Session()
        session = boto3.session.Session(botocore_session=botocore_session)

        client = session.client(
            "s3",
            aws_access_key_id=self.parameters["access_key_id"],
            aws_secret_access_key=self.parameters["secret_access_key"],
            region_name=self.parameters["region"],
        )
        return client

    @property
    def bucket(self) -> str:
        bucket = self.parameters["url"].strip("https://").split(".")[0]
        return bucket

90
    def harvest(self) -> Iterator[dict]:
Nikola Jankovic's avatar
Nikola Jankovic committed
91
92
93
94
95
96
97
98
99
100
101
        logger.info("Starting S3 harvesting")

        paginator = self.client.get_paginator("list_objects_v2")
        pages = paginator.paginate(Bucket=self.bucket, Prefix=self.parameters["prefix"])

        time_regex: str = self.parameters["time_regex"]
        for page in pages:
            for file in page["Contents"]:
                if match := re.search(time_regex, file["Key"]):
                    dt = isoparse(match[0])
                    item = self._create_item(file, dt, self.parameters["url"])
102
                    yield item.to_dict()
Nikola Jankovic's avatar
Nikola Jankovic committed
103
104
105
106
107
108
109
110
111
112
113
114

    def _create_item(self, data, dt, url):
        identifier = dt.strftime("%Y%m%d_%H%M%S")
        properties = {
            "datetime": dt,
            "updated": data["LastModified"],
        }
        item = pystac.Item(
            id=identifier, geometry=None, bbox=None, datetime=dt, properties=properties
        )
        item.add_asset(identifier, pystac.Asset(f"{url}{data['Key']}"))
        return item
115
116


117
class S3CatalogSource(S3Base):
118
119
    type = "S3Catalog"

120
121
122
123
    def __init__(self, parameters: dict, **kwargs):
        self.root_href = parameters.pop("root_href")
        self.default_catalog_name = parameters.pop("default_catalog_name", None)
        super().__init__(**parameters)
124
125
126
127

    def harvest(self) -> Iterator[dict]:
        logger.info("Starting S3 Catalog harvesting")
        parsed = urlparse(self.root_href)
128
129
130
131
132
133
        path = parsed.path
        if path.startswith("/"):
            path = parsed.path[1:]
        if path.endswith("/") and self.default_catalog_name:
            path = join(path, self.default_catalog_name)
        yield from self.harvest_catalog(parsed.netloc, path)
134
135
136
137
138
139

    def fetch_json(self, bucket: str, key: str) -> dict:
        """
        Loads the given object identifier by bucket and key and loads it as
        JSON.
        """
140
141
        if key.startswith("/"):
            key = key[1:]
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        response = self.client.get_object(Bucket=bucket, Key=key)
        return json.load(response["Body"])

    def join_href(self, bucket: str, key: str, href: str) -> Tuple[str, str]:
        """
        Joins the given href with a previous bucket/key. When we have a fully
        qualified S3 URL, the included bucket/key pair is returned.
        If href is a relative path, it is joined with the previous key.
        """
        parsed = urlparse(href)
        if parsed.netloc:
            if parsed.scheme.lower() != "s3":
                # TODO: what if HTTP hrefs?
                raise ValueError("Can only join S3 URLs")
            return (parsed.netloc, parsed.path)
        else:
            return (
                bucket,
                join(dirname(key), parsed.path),
            )

    def harvest_catalog(self, bucket: str, key: str) -> Iterator[dict]:
        """
        Harvests a specified STAC catalog. Will recurse into child catalogs
        and yield all included STAC items.
        """
        logger.info(f"Harvesting from catalog {bucket}/{key}")
        catalog = self.fetch_json(bucket, key)
        for link in catalog["links"]:
            if link["rel"] == "item":
Fabian Schindler's avatar
Fabian Schindler committed
172
                item_bucket, item_key = self.join_href(bucket, key, link["href"])
173
174
175
176
177
                logger.info(f"Harvested item {item_bucket}/{item_key}")
                yield self.fetch_json(item_bucket, item_key)
            elif link["rel"] == "child":
                cat_bucket, cat_key = self.join_href(bucket, key, link["href"])
                yield from self.harvest_catalog(cat_bucket, cat_key)