EOX GitLab Instance

Skip to content
Snippets Groups Projects
source.py 8.77 KiB
Newer Older
import re
from os.path import normpath, join, isabs
import shutil
from glob import glob
from fnmatch import fnmatch
import logging
import boto3.session
import botocore.session
import botocore.handlers

from swiftclient.multithreading import OutputManager
from swiftclient.service import SwiftError, SwiftService


logger = logging.getLogger(__name__)

class RegistrationError(Exception):
    pass


class Source:
    def __init__(self, name: str=None):
        self.name = name

    def get_container_and_path(self, path):
        raise NotImplementedError

    def list_files(self, path, glob_pattern=None):
        raise NotImplementedError

    def get_file(self, path, target_path):
        raise NotImplementedError

    def get_vsi_env_and_path(self, path):
        raise NotImplementedError


class SwiftSource(Source):
    def __init__(self, name=None, username=None, password=None, tenant_name=None,
                 tenant_id=None, region_name=None, user_domain_id=None,
                 user_domain_name=None, auth_url=None, auth_url_short=None,
                 auth_version=None, container=None):
        super().__init__(name)

        self.username = username
        self.password = password
        self.tenant_name = tenant_name
        self.tenant_id = tenant_id
        self.region_name = region_name
        self.user_domain_id = user_domain_id
        self.user_domain_name = user_domain_name
        self.auth_url = auth_url
        self.auth_url_short = auth_url_short
        self.auth_version = auth_version  # TODO: assume 3
        self.container = container

    def get_service(self):
        return SwiftService(options={
            "os_username": self.username,
            "os_password": self.password,
            "os_tenant_name": self.tenant_name,
            "os_tenant_id": self.tenant_id,
            "os_region_name": self.region_name,
            "os_auth_url": self.auth_url,
            "auth_version": self.auth_version,
            "os_user_domain_id": self.user_domain_id,
            "os_user_domain_name": self.user_domain_name,
        })

    def get_container_and_path(self, path: str):
        container = self.container
            parts = (path[1:] if path.startswith('/') else path).split('/')
            container, path = parts[0], '/'.join(parts[1:])
    def list_files(self, path, glob_patterns=None):
        container, path = self.get_container_and_path(path)

        if glob_patterns and not isinstance(glob_patterns, list):
            glob_patterns = [glob_patterns]

        with self.get_service() as swift:
            pages = swift.list(
                container=container,
                options={"prefix": path},
            )

            filenames = []
            for page in pages:
                if page["success"]:
                    # at least two files present -> pass validation
                    for item in page["listing"]:
                        if glob_patterns is None or any(
                                fnmatch(item['name'], join(path, glob_pattern)) for glob_pattern in glob_patterns):

                            filenames.append(
                                item['name'] if self.container else join(container, item['name'])
                            )
                else:
                    raise page['error']

            return filenames

    def get_file(self, path, target_path):
        container, path = self.get_container_and_path(path)

        with self.get_service() as swift:
            results = swift.download(
                container,
                [path],
                options={
                    'out_file': target_path
                }
            )

            for result in results:
                if not result["success"]:
                    raise Exception('Failed to download %s' % path)

    def get_vsi_env_and_path(self, path):
        container, path = self.get_container_and_path(path)
        return {
            'OS_IDENTITY_API_VERSION': self.auth_version,
            'OS_AUTH_URL': self.auth_url,
            'OS_USERNAME': self.username,
            'OS_PASSWORD': self.password,
            'OS_USER_DOMAIN_NAME': self.user_domain_name,
            # 'OS_PROJECT_NAME': self.tena,
            # 'OS_PROJECT_DOMAIN_NAME': ,
            'OS_REGION_NAME': self.region_name,
        }, f'/vsiswift/{container}/{path}'


class S3Source(Source):
    def __init__(self, name=None, bucket_name=None, secret_access_key=None, access_key_id=None, endpoint_url=None,
                 strip_bucket=True, validate_bucket_name=True, **client_kwargs):
        super().__init__(name)

        # see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html#boto3.session.Session.client
        # for client_kwargs
        self.bucket_name = bucket_name
        self.secret_access_key=secret_access_key
        self.access_key_id=access_key_id
        self.endpoint_url = endpoint_url
        self.strip_bucket = strip_bucket
        botocore_session = botocore.session.Session()
        if not validate_bucket_name:
            botocore_session.unregister('before-parameter-build.s3', botocore.handlers.validate_bucket_name)

        session = boto3.session.Session(botocore_session=botocore_session)

        self.client = session.client(
            's3',
            aws_secret_access_key=secret_access_key,
            aws_access_key_id=access_key_id,
            endpoint_url=endpoint_url,
            **client_kwargs,
        )

    def get_container_and_path(self, path: str):
        bucket = self.bucket_name
        if bucket is None:
            parts = (path[1:] if path.startswith('/') else path).split('/')
            bucket, path = parts[0], '/'.join(parts[1:])
        elif self.strip_bucket:
            parts = (path[1:] if path.startswith('/') else path).split('/')
            if parts[0] == bucket:
                parts.pop(0)
            path = '/'.join(parts)
        return bucket, path
    def list_files(self, path, glob_patterns=None):
        if glob_patterns and not isinstance(glob_patterns, list):
            glob_patterns = [glob_patterns]

        bucket, key = self.get_container_and_path(path)
        logger.info(f'Listing S3 files for bucket {bucket} and prefix {key}')
        response = self.client.list_objects_v2(
            Bucket=bucket,
            Prefix=key,
        )

        return [
            f"{bucket}/{item['Key']}"
            for item in response['Contents']
            if glob_patterns is None or any(
                fnmatch(item['Key'], join(key, glob_pattern)) for glob_pattern in glob_patterns
        ]

    def get_file(self, path, target_path):
        bucket, key = self.get_container_and_path(path)
        logger.info(f'Retrieving file from S3 {bucket}/{key} to be stored at {target_path}')
        self.client.download_file(bucket, key, target_path)

    def get_vsi_env_and_path(self, path: str, streaming: bool=False):
        bucket, key = self.get_container_and_path(path)
        return {
            'AWS_SECRET_ACCESS_KEY': self.secret_access_key,
            'AWS_ACCESS_KEY_ID': self.access_key_id,
            'AWS_S3_ENDPOINT': self.endpoint_url,
        }, f'/{"vsis3" if not streaming else "vsis3_streaming"}/{bucket}/{key}'


class LocalSource(Source):
    def __init__(self, name, root_directory):
        super().__init__(name)

        self.root_directory = root_directory

    def get_container_and_path(self, path):
        return (self.root_directory, path)

    def _join_path(self, path):
        path = normpath(path)
        if isabs(path):
            path = path[1:]

        return join(self.root_directory, path)

    def list_files(self, path, glob_patterns=None):
        if glob_patterns and not isinstance(glob_patterns, list):
            glob_patterns = [glob_patterns]

        if glob_patterns is not None:
            return glob(join(self._join_path(path), glob_patterns[0])) # TODO
        else:
            return glob(join(self._join_path(path), '*'))

    def get_file(self, path, target_path):
        shutil.copy(self._join_path(path), target_path)

    def get_vsi_env_and_path(self, path):
        return {}, self._join_path(path)


SOURCE_TYPES = {
    'swift': SwiftSource,
    's3': S3Source,
    'local': LocalSource,
}


def get_source(config: dict, path: str) -> Source:
    cfg_sources = config['sources']

    for cfg_source in cfg_sources:
        if cfg_source.get('filter'):
            if re.match(cfg_source['filter'], path):
                break
        else:
            break
    else:
        # no source found
        raise RegistrationError(f'Could not find a suitable source for the path {path}')

    return SOURCE_TYPES[cfg_source['type']](
        cfg_source['name'],
        *cfg_source.get('args', []),
        **cfg_source.get('kwargs', {})
    )