import re from os.path import normpath, join, isabs import shutil from glob import glob from fnmatch import fnmatch import logging import boto3 from swiftclient.multithreading import OutputManager from swiftclient.service import SwiftError, SwiftService logger = logging.getLogger(__name__) class RegistrationError(Exception): pass class Source: def __init__(self, name: str=None): self.name = name def list_files(self, path, glob_pattern=None): raise NotImplementedError def get_file(self, path, target_path): raise NotImplementedError def get_vsi_env_and_path(self, path): raise NotImplementedError class SwiftSource(Source): def __init__(self, name=None, username=None, password=None, tenant_name=None, tenant_id=None, region_name=None, user_domain_id=None, user_domain_name=None, auth_url=None, auth_version=None, container=None): super().__init__(name) self.username = username self.password = password self.tenant_name = tenant_name self.tenant_id = tenant_id self.region_name = region_name self.user_domain_id = user_domain_id self.user_domain_name = user_domain_name self.auth_url = auth_url self.auth_version = auth_version # TODO: assume 3 self.container = container def get_service(self): return SwiftService(options={ "os_username": self.username, "os_password": self.password, "os_tenant_name": self.tenant_name, "os_tenant_id": self.tenant_id, "os_region_name": self.region_name, "os_auth_url": self.auth_url, "auth_version": self.auth_version, "os_user_domain_id": self.user_domain_id, "os_user_domain_name": self.user_domain_name, }) def get_container_and_path(self, path: str): container = self.container if container is None: parts = (path[1:] if path.startswith('/') else path).split('/') container, path = parts[0], '/'.join(parts[1:]) return container, path def list_files(self, path, glob_pattern=None): container, path = self.get_container_and_path(path) with self.get_service() as swift: pages = swift.list( container=container, options={"prefix": path}, ) filenames = [] for page in pages: if page["success"]: # at least two files present -> pass validation for item in page["listing"]: if glob_pattern is None or fnmatch(item['name'], glob_pattern): filenames.append(item['name']) else: raise page['error'] return filenames def get_file(self, path, target_path): container, path = self.get_container_and_path(path) with self.get_service() as swift: results = swift.download( container, [path], options={ 'out_file': target_path } ) for result in results: if not result["success"]: raise Exception('Failed to download %s' % path) def get_vsi_env_and_path(self, path): container, path = self.get_container_and_path(path) return { 'OS_IDENTITY_API_VERSION': self.auth_version, 'OS_AUTH_URL': self.auth_url, 'OS_USERNAME': self.username, 'OS_PASSWORD': self.password, 'OS_USER_DOMAIN_NAME': self.user_domain_name, # 'OS_PROJECT_NAME': self.tena, # 'OS_PROJECT_DOMAIN_NAME': , 'OS_REGION_NAME': self.region_name, }, f'/vsiswift/{container}/{path}' class S3Source(Source): def __init__(self, name=None, bucket_name=None, secret_access_key=None, access_key_id=None, endpoint_url=None, strip_bucket=True, **client_kwargs): super().__init__(name) # see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html#boto3.session.Session.client # for client_kwargs self.bucket_name = bucket_name self.secret_access_key=secret_access_key self.access_key_id=access_key_id self.endpoint_url = endpoint_url self.strip_bucket = strip_bucket self.client = boto3.client( 's3', aws_secret_access_key=secret_access_key, aws_access_key_id=access_key_id, endpoint_url=endpoint_url, **client_kwargs, ) def get_bucket_and_key(self, path: str): bucket = self.bucket_name if bucket is None: parts = (path[1:] if path.startswith('/') else path).split('/') bucket, path = parts[0], '/'.join(parts[1:]) elif self.strip_bucket: parts = (path[1:] if path.startswith('/') else path).split('/') if parts[0] == bucket: parts.pop(0) path = '/'.join(parts) return bucket, path def list_files(self, path, glob_pattern=None): bucket, key = self.get_bucket_and_key(path) logger.info(f'Listing S3 files for bucket {bucket} and prefix {key}') response = self.client.list_objects_v2( Bucket=bucket, Prefix=key, ) return [ f"{bucket}/{item['Key']}" for item in response['Contents'] if glob_pattern is None or fnmatch(item['Key'], glob_pattern) ] def get_file(self, path, target_path): bucket, key = self.get_bucket_and_key(path) logger.info(f'Retrieving file from S3 {bucket}/{key} to be stored at {target_path}') self.client.download_file(bucket, key, target_path) def get_vsi_env_and_path(self, path: str, streaming: bool=False): bucket, key = self.get_bucket_and_key(path) return { 'AWS_SECRET_ACCESS_KEY': self.secret_access_key, 'AWS_ACCESS_KEY_ID': self.access_key_id, 'AWS_S3_ENDPOINT': self.endpoint_url, }, f'/{"vsis3" if not streaming else "vsis3_streaming"}/{bucket}/{key}' class LocalSource(Source): def __init__(self, name, root_directory): super().__init__(name) self.root_directory = root_directory def _join_path(self, path): path = normpath(path) if isabs(path): path = path[1:] return join(self.root_directory, path) def list_files(self, path, glob_pattern=None): if glob_pattern is not None: return glob(join(self._join_path(path), glob_pattern)) else: return glob(join(self._join_path(path), '*')) def get_file(self, path, target_path): shutil.copy(self._join_path(path), target_path) def get_vsi_env_and_path(self, path): return {}, self._join_path(path) SOURCE_TYPES = { 'swift': SwiftSource, 's3': S3Source, 'local': LocalSource, } def get_source(config: dict, path: str) -> Source: cfg_sources = config['sources'] for cfg_source in cfg_sources: if cfg_source['filter']: if re.match(cfg_source['filter'], path): break else: break else: # no source found raise RegistrationError(f'Could not find a suitable source for the path {path}') return SOURCE_TYPES[cfg_source['type']]( cfg_source['name'], *cfg_source.get('args', []), **cfg_source.get('kwargs', {}) )