Something went wrong on our end
-
Fabian Schindler authoredFabian Schindler authored
source.py 7.50 KiB
import re
from os.path import normpath, join, isabs
import shutil
from glob import glob
from fnmatch import fnmatch
import logging
import boto3
from swiftclient.multithreading import OutputManager
from swiftclient.service import SwiftError, SwiftService
logger = logging.getLogger(__name__)
class RegistrationError(Exception):
pass
class Source:
def __init__(self, name: str=None):
self.name = name
def list_files(self, path, glob_pattern=None):
raise NotImplementedError
def get_file(self, path, target_path):
raise NotImplementedError
def get_vsi_env_and_path(self, path):
raise NotImplementedError
class SwiftSource(Source):
def __init__(self, name=None, username=None, password=None, tenant_name=None,
tenant_id=None, region_name=None, user_domain_id=None,
user_domain_name=None, auth_url=None, auth_version=None,
container=None):
super().__init__(name)
self.username = username
self.password = password
self.tenant_name = tenant_name
self.tenant_id = tenant_id
self.region_name = region_name
self.user_domain_id = user_domain_id
self.user_domain_name = user_domain_name
self.auth_url = auth_url
self.auth_version = auth_version # TODO: assume 3
self.container = container
def get_service(self):
return SwiftService(options={
"os_username": self.username,
"os_password": self.password,
"os_tenant_name": self.tenant_name,
"os_tenant_id": self.tenant_id,
"os_region_name": self.region_name,
"os_auth_url": self.auth_url,
"auth_version": self.auth_version,
"os_user_domain_id": self.user_domain_id,
"os_user_domain_name": self.user_domain_name,
})
def get_container_and_path(self, path: str):
container = self.container
if container is None:
parts = (path[1:] if path.startswith('/') else path).split('/')
container, path = parts[0], '/'.join(parts[1:])
return container, path
def list_files(self, path, glob_pattern=None):
container, path = self.get_container_and_path(path)
with self.get_service() as swift:
pages = swift.list(
container=container,
options={"prefix": path},
)
filenames = []
for page in pages:
if page["success"]:
# at least two files present -> pass validation
for item in page["listing"]:
if glob_pattern is None or fnmatch(item['name'], glob_pattern):
filenames.append(item['name'])
else:
raise page['error']
return filenames
def get_file(self, path, target_path):
container, path = self.get_container_and_path(path)
with self.get_service() as swift:
results = swift.download(
container,
[path],
options={
'out_file': target_path
}
)
for result in results:
if not result["success"]:
raise Exception('Failed to download %s' % path)
def get_vsi_env_and_path(self, path):
container, path = self.get_container_and_path(path)
return {
'OS_IDENTITY_API_VERSION': self.auth_version,
'OS_AUTH_URL': self.auth_url,
'OS_USERNAME': self.username,
'OS_PASSWORD': self.password,
'OS_USER_DOMAIN_NAME': self.user_domain_name,
# 'OS_PROJECT_NAME': self.tena,
# 'OS_PROJECT_DOMAIN_NAME': ,
'OS_REGION_NAME': self.region_name,
}, f'/vsiswift/{container}/{path}'
class S3Source(Source):
def __init__(self, name=None, bucket_name=None, secret_access_key=None, access_key_id=None, endpoint_url=None, strip_bucket=True, **client_kwargs):
super().__init__(name)
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html#boto3.session.Session.client
# for client_kwargs
self.bucket_name = bucket_name
self.secret_access_key=secret_access_key
self.access_key_id=access_key_id
self.endpoint_url = endpoint_url
self.strip_bucket = strip_bucket
self.client = boto3.client(
's3',
aws_secret_access_key=secret_access_key,
aws_access_key_id=access_key_id,
endpoint_url=endpoint_url,
**client_kwargs,
)
def get_bucket_and_key(self, path: str):
bucket = self.bucket_name
if bucket is None:
parts = (path[1:] if path.startswith('/') else path).split('/')
bucket, path = parts[0], '/'.join(parts[1:])
elif self.strip_bucket:
parts = (path[1:] if path.startswith('/') else path).split('/')
if parts[0] == bucket:
parts.pop(0)
path = '/'.join(parts)
return bucket, path
def list_files(self, path, glob_pattern=None):
bucket, key = self.get_bucket_and_key(path)
logger.info(f'Listing S3 files for bucket {bucket} and prefix {key}')
response = self.client.list_objects_v2(
Bucket=bucket,
Prefix=key,
)
return [
f"{bucket}/{item['Key']}"
for item in response['Contents']
if glob_pattern is None or fnmatch(item['Key'], glob_pattern)
]
def get_file(self, path, target_path):
bucket, key = self.get_bucket_and_key(path)
logger.info(f'Retrieving file from S3 {bucket}/{key} to be stored at {target_path}')
self.client.download_file(bucket, key, target_path)
def get_vsi_env_and_path(self, path: str, streaming: bool=False):
bucket, key = self.get_bucket_and_key(path)
return {
'AWS_SECRET_ACCESS_KEY': self.secret_access_key,
'AWS_ACCESS_KEY_ID': self.access_key_id,
'AWS_S3_ENDPOINT': self.endpoint_url,
}, f'/{"vsis3" if not streaming else "vsis3_streaming"}/{bucket}/{key}'
class LocalSource(Source):
def __init__(self, name, root_directory):
super().__init__(name)
self.root_directory = root_directory
def _join_path(self, path):
path = normpath(path)
if isabs(path):
path = path[1:]
return join(self.root_directory, path)
def list_files(self, path, glob_pattern=None):
if glob_pattern is not None:
return glob(join(self._join_path(path), glob_pattern))
else:
return glob(join(self._join_path(path), '*'))
def get_file(self, path, target_path):
shutil.copy(self._join_path(path), target_path)
def get_vsi_env_and_path(self, path):
return {}, self._join_path(path)
SOURCE_TYPES = {
'swift': SwiftSource,
's3': S3Source,
'local': LocalSource,
}
def get_source(config: dict, path: str) -> Source:
cfg_sources = config['sources']
for cfg_source in cfg_sources:
if cfg_source['filter']:
if re.match(cfg_source['filter'], path):
break
else:
break
else:
# no source found
raise RegistrationError(f'Could not find a suitable source for the path {path}')
return SOURCE_TYPES[cfg_source['type']](
cfg_source['name'],
*cfg_source.get('args', []),
**cfg_source.get('kwargs', {})
)