import re from os.path import normpath, join, isabs import shutil from glob import glob from fnmatch import fnmatch import logging import boto3 import boto3.session import botocore.session import botocore.handlers from swiftclient.multithreading import OutputManager from swiftclient.service import SwiftError, SwiftService logger = logging.getLogger(__name__) class RegistrationError(Exception): pass class Source: def __init__(self, name: str=None): self.name = name def get_container_and_path(self, path): raise NotImplementedError def list_files(self, path, glob_pattern=None): raise NotImplementedError def get_file(self, path, target_path): raise NotImplementedError def get_vsi_env_and_path(self, path): raise NotImplementedError class SwiftSource(Source): def __init__(self, name=None, username=None, password=None, tenant_name=None, tenant_id=None, region_name=None, user_domain_id=None, user_domain_name=None, auth_url=None, auth_url_short=None, auth_version=None, container=None): super().__init__(name) self.username = username self.password = password self.tenant_name = tenant_name self.tenant_id = tenant_id self.region_name = region_name self.user_domain_id = user_domain_id self.user_domain_name = user_domain_name self.auth_url = auth_url self.auth_url_short = auth_url_short self.auth_version = auth_version # TODO: assume 3 self.container = container def get_service(self): return SwiftService(options={ "os_username": self.username, "os_password": self.password, "os_tenant_name": self.tenant_name, "os_tenant_id": self.tenant_id, "os_region_name": self.region_name, "os_auth_url": self.auth_url, "auth_version": self.auth_version, "os_user_domain_id": self.user_domain_id, "os_user_domain_name": self.user_domain_name, }) def get_container_and_path(self, path: str): container = self.container if container is None or container == '': parts = (path[1:] if path.startswith('/') else path).split('/') container, path = parts[0], '/'.join(parts[1:]) return container, path def list_files(self, path, glob_patterns=None): container, path = self.get_container_and_path(path) if glob_patterns and not isinstance(glob_patterns, list): glob_patterns = [glob_patterns] with self.get_service() as swift: pages = swift.list( container=container, options={"prefix": path}, ) filenames = [] for page in pages: if page["success"]: # at least two files present -> pass validation for item in page["listing"]: if glob_patterns is None or any( fnmatch(item['name'], join(path, glob_pattern)) for glob_pattern in glob_patterns): filenames.append( item['name'] if self.container else join(container, item['name']) ) else: raise page['error'] return filenames def get_file(self, path, target_path): container, path = self.get_container_and_path(path) with self.get_service() as swift: results = swift.download( container, [path], options={ 'out_file': target_path } ) for result in results: if not result["success"]: raise Exception('Failed to download %s' % path) def get_vsi_env_and_path(self, path): container, path = self.get_container_and_path(path) return { 'OS_IDENTITY_API_VERSION': self.auth_version, 'OS_AUTH_URL': self.auth_url, 'OS_USERNAME': self.username, 'OS_PASSWORD': self.password, 'OS_USER_DOMAIN_NAME': self.user_domain_name, # 'OS_PROJECT_NAME': self.tena, # 'OS_PROJECT_DOMAIN_NAME': , 'OS_REGION_NAME': self.region_name, }, f'/vsiswift/{container}/{path}' class S3Source(Source): def __init__(self, name=None, bucket_name=None, secret_access_key=None, access_key_id=None, endpoint_url=None, strip_bucket=True, validate_bucket_name=True, **client_kwargs): super().__init__(name) # see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html#boto3.session.Session.client # for client_kwargs self.bucket_name = bucket_name self.secret_access_key=secret_access_key self.access_key_id=access_key_id self.endpoint_url = endpoint_url self.strip_bucket = strip_bucket botocore_session = botocore.session.Session() if not validate_bucket_name: botocore_session.unregister('before-parameter-build.s3', botocore.handlers.validate_bucket_name) session = boto3.session.Session(botocore_session=botocore_session) self.client = session.client( 's3', aws_secret_access_key=secret_access_key, aws_access_key_id=access_key_id, endpoint_url=endpoint_url, **client_kwargs, ) def get_container_and_path(self, path: str): bucket = self.bucket_name if bucket is None: parts = (path[1:] if path.startswith('/') else path).split('/') bucket, path = parts[0], '/'.join(parts[1:]) elif self.strip_bucket: parts = (path[1:] if path.startswith('/') else path).split('/') if parts[0] == bucket: parts.pop(0) path = '/'.join(parts) return bucket, path def list_files(self, path, glob_patterns=None): if glob_patterns and not isinstance(glob_patterns, list): glob_patterns = [glob_patterns] bucket, key = self.get_container_and_path(path) logger.info(f'Listing S3 files for bucket {bucket} and prefix {key}') response = self.client.list_objects_v2( Bucket=bucket, Prefix=key, ) return [ f"{bucket}/{item['Key']}" for item in response['Contents'] if glob_patterns is None or any( fnmatch(item['Key'], join(key, glob_pattern)) for glob_pattern in glob_patterns ) ] def get_file(self, path, target_path): bucket, key = self.get_container_and_path(path) logger.info(f'Retrieving file from S3 {bucket}/{key} to be stored at {target_path}') self.client.download_file(bucket, key, target_path) def get_vsi_env_and_path(self, path: str, streaming: bool=False): bucket, key = self.get_container_and_path(path) return { 'AWS_SECRET_ACCESS_KEY': self.secret_access_key, 'AWS_ACCESS_KEY_ID': self.access_key_id, 'AWS_S3_ENDPOINT': self.endpoint_url, }, f'/{"vsis3" if not streaming else "vsis3_streaming"}/{bucket}/{key}' class LocalSource(Source): def __init__(self, name, root_directory): super().__init__(name) self.root_directory = root_directory def get_container_and_path(self, path): return (self.root_directory, path) def _join_path(self, path): path = normpath(path) if isabs(path): path = path[1:] return join(self.root_directory, path) def list_files(self, path, glob_patterns=None): if glob_patterns and not isinstance(glob_patterns, list): glob_patterns = [glob_patterns] if glob_patterns is not None: return glob(join(self._join_path(path), glob_patterns[0])) # TODO else: return glob(join(self._join_path(path), '*')) def get_file(self, path, target_path): shutil.copy(self._join_path(path), target_path) def get_vsi_env_and_path(self, path): return {}, self._join_path(path) SOURCE_TYPES = { 'swift': SwiftSource, 's3': S3Source, 'local': LocalSource, } def get_source(config: dict, path: str) -> Source: cfg_sources = config['sources'] for cfg_source in cfg_sources: if cfg_source.get('filter'): if re.match(cfg_source['filter'], path): break else: break else: # no source found raise RegistrationError(f'Could not find a suitable source for the path {path}') return SOURCE_TYPES[cfg_source['type']]( cfg_source['name'], *cfg_source.get('args', []), **cfg_source.get('kwargs', {}) )