EOX GitLab Instance

Skip to content
Snippets Groups Projects
source.py 7.50 KiB
import re
from os.path import normpath, join, isabs
import shutil
from glob import glob
from fnmatch import fnmatch
import logging

import boto3
from swiftclient.multithreading import OutputManager
from swiftclient.service import SwiftError, SwiftService


logger = logging.getLogger(__name__)

class RegistrationError(Exception):
    pass


class Source:
    def __init__(self, name: str=None):
        self.name = name

    def list_files(self, path, glob_pattern=None):
        raise NotImplementedError

    def get_file(self, path, target_path):
        raise NotImplementedError

    def get_vsi_env_and_path(self, path):
        raise NotImplementedError


class SwiftSource(Source):
    def __init__(self, name=None, username=None, password=None, tenant_name=None,
                 tenant_id=None, region_name=None, user_domain_id=None,
                 user_domain_name=None, auth_url=None, auth_version=None,
                 container=None):
        super().__init__(name)

        self.username = username
        self.password = password
        self.tenant_name = tenant_name
        self.tenant_id = tenant_id
        self.region_name = region_name
        self.user_domain_id = user_domain_id
        self.user_domain_name = user_domain_name
        self.auth_url = auth_url
        self.auth_version = auth_version  # TODO: assume 3
        self.container = container

    def get_service(self):
        return SwiftService(options={
            "os_username": self.username,
            "os_password": self.password,
            "os_tenant_name": self.tenant_name,
            "os_tenant_id": self.tenant_id,
            "os_region_name": self.region_name,
            "os_auth_url": self.auth_url,
            "auth_version": self.auth_version,
            "os_user_domain_id": self.user_domain_id,
            "os_user_domain_name": self.user_domain_name,
        })

    def get_container_and_path(self, path: str):
        container = self.container
        if container is None:
            parts = (path[1:] if path.startswith('/') else path).split('/')
            container, path = parts[0], '/'.join(parts[1:])

        return container, path


    def list_files(self, path, glob_pattern=None):
        container, path = self.get_container_and_path(path)

        with self.get_service() as swift:
            pages = swift.list(
                container=container,
                options={"prefix": path},
            )

            filenames = []
            for page in pages:
                if page["success"]:
                    # at least two files present -> pass validation
                    for item in page["listing"]:
                        if glob_pattern is None or fnmatch(item['name'], glob_pattern):
                            filenames.append(item['name'])
                else:
                    raise page['error']

            return filenames

    def get_file(self, path, target_path):
        container, path = self.get_container_and_path(path)

        with self.get_service() as swift:
            results = swift.download(
                container,
                [path],
                options={
                    'out_file': target_path
                }
            )

            for result in results:
                if not result["success"]:
                    raise Exception('Failed to download %s' % path)

    def get_vsi_env_and_path(self, path):
        container, path = self.get_container_and_path(path)
        return {
            'OS_IDENTITY_API_VERSION': self.auth_version,
            'OS_AUTH_URL': self.auth_url,
            'OS_USERNAME': self.username,
            'OS_PASSWORD': self.password,
            'OS_USER_DOMAIN_NAME': self.user_domain_name,
            # 'OS_PROJECT_NAME': self.tena,
            # 'OS_PROJECT_DOMAIN_NAME': ,
            'OS_REGION_NAME': self.region_name,
        }, f'/vsiswift/{container}/{path}'


class S3Source(Source):
    def __init__(self, name=None, bucket_name=None, secret_access_key=None, access_key_id=None, endpoint_url=None, strip_bucket=True, **client_kwargs):
        super().__init__(name)

        # see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html#boto3.session.Session.client
        # for client_kwargs
        self.bucket_name = bucket_name
        self.secret_access_key=secret_access_key
        self.access_key_id=access_key_id
        self.endpoint_url = endpoint_url
        self.strip_bucket = strip_bucket

        self.client = boto3.client(
            's3',
            aws_secret_access_key=secret_access_key,
            aws_access_key_id=access_key_id,
            endpoint_url=endpoint_url,
            **client_kwargs,
        )

    def get_bucket_and_key(self, path: str):
        bucket = self.bucket_name
        if bucket is None:
            parts = (path[1:] if path.startswith('/') else path).split('/')
            bucket, path = parts[0], '/'.join(parts[1:])
        elif self.strip_bucket:
            parts = (path[1:] if path.startswith('/') else path).split('/')
            if parts[0] == bucket:
                parts.pop(0)
            path = '/'.join(parts)

        return bucket, path

    def list_files(self, path, glob_pattern=None):
        bucket, key = self.get_bucket_and_key(path)
        logger.info(f'Listing S3 files for bucket {bucket} and prefix {key}')
        response = self.client.list_objects_v2(
            Bucket=bucket,
            Prefix=key,
        )

        return [
            f"{bucket}/{item['Key']}"
            for item in response['Contents']
            if glob_pattern is None or fnmatch(item['Key'], glob_pattern)
        ]

    def get_file(self, path, target_path):
        bucket, key = self.get_bucket_and_key(path)
        logger.info(f'Retrieving file from S3 {bucket}/{key} to be stored at {target_path}')
        self.client.download_file(bucket, key, target_path)

    def get_vsi_env_and_path(self, path: str, streaming: bool=False):
        bucket, key = self.get_bucket_and_key(path)
        return {
            'AWS_SECRET_ACCESS_KEY': self.secret_access_key,
            'AWS_ACCESS_KEY_ID': self.access_key_id,
            'AWS_S3_ENDPOINT': self.endpoint_url,
        }, f'/{"vsis3" if not streaming else "vsis3_streaming"}/{bucket}/{key}'


class LocalSource(Source):
    def __init__(self, name, root_directory):
        super().__init__(name)

        self.root_directory = root_directory

    def _join_path(self, path):
        path = normpath(path)
        if isabs(path):
            path = path[1:]

        return join(self.root_directory, path)

    def list_files(self, path, glob_pattern=None):
        if glob_pattern is not None:
            return glob(join(self._join_path(path), glob_pattern))
        else:
            return glob(join(self._join_path(path), '*'))

    def get_file(self, path, target_path):
        shutil.copy(self._join_path(path), target_path)

    def get_vsi_env_and_path(self, path):
        return {}, self._join_path(path)


SOURCE_TYPES = {
    'swift': SwiftSource,
    's3': S3Source,
    'local': LocalSource,
}


def get_source(config: dict, path: str) -> Source:
    cfg_sources = config['sources']

    for cfg_source in cfg_sources:
        if cfg_source['filter']:
            if re.match(cfg_source['filter'], path):
                break
        else:
            break
    else:
        # no source found
        raise RegistrationError(f'Could not find a suitable source for the path {path}')

    return SOURCE_TYPES[cfg_source['type']](
        cfg_source['name'],
        *cfg_source.get('args', []),
        **cfg_source.get('kwargs', {})
    )