Newer
Older
import re
from os.path import normpath, join, isabs
import shutil
from glob import glob
from fnmatch import fnmatch
import boto3.session
import botocore.session
import botocore.handlers
from swiftclient.multithreading import OutputManager
from swiftclient.service import SwiftError, SwiftService
class RegistrationError(Exception):
pass
class Source:
def __init__(self, name: str=None):
self.name = name
def get_container_and_path(self, path):
raise NotImplementedError
def list_files(self, path, glob_pattern=None):
raise NotImplementedError
def get_file(self, path, target_path):
raise NotImplementedError
def get_vsi_env_and_path(self, path):
raise NotImplementedError
class SwiftSource(Source):
def __init__(self, name=None, username=None, password=None, tenant_name=None,
tenant_id=None, region_name=None, user_domain_id=None,
user_domain_name=None, auth_url=None, auth_url_short=None,
auth_version=None, container=None):
self.username = username
self.password = password
self.tenant_name = tenant_name
self.tenant_id = tenant_id
self.region_name = region_name
self.user_domain_id = user_domain_id
self.user_domain_name = user_domain_name
self.auth_url = auth_url
self.auth_url_short = auth_url_short
self.auth_version = auth_version # TODO: assume 3
self.container = container
def get_service(self):
return SwiftService(options={
"os_username": self.username,
"os_password": self.password,
"os_tenant_name": self.tenant_name,
"os_tenant_id": self.tenant_id,
"os_region_name": self.region_name,
"os_auth_url": self.auth_url,
"auth_version": self.auth_version,
"os_user_domain_id": self.user_domain_id,
"os_user_domain_name": self.user_domain_name,
})
def get_container_and_path(self, path: str):
container = self.container

Lubomir Dolezal
committed
if container is None or container == '':
parts = (path[1:] if path.startswith('/') else path).split('/')
container, path = parts[0], '/'.join(parts[1:])
return container, path
def list_files(self, path, glob_patterns=None):
container, path = self.get_container_and_path(path)
if glob_patterns and not isinstance(glob_patterns, list):
glob_patterns = [glob_patterns]
with self.get_service() as swift:
pages = swift.list(
container=container,
options={"prefix": path},
)
filenames = []
for page in pages:
if page["success"]:
# at least two files present -> pass validation
for item in page["listing"]:
if glob_patterns is None or any(
fnmatch(item['name'], join(path, glob_pattern)) for glob_pattern in glob_patterns):
filenames.append(
item['name'] if self.container else join(container, item['name'])
)
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
else:
raise page['error']
return filenames
def get_file(self, path, target_path):
container, path = self.get_container_and_path(path)
with self.get_service() as swift:
results = swift.download(
container,
[path],
options={
'out_file': target_path
}
)
for result in results:
if not result["success"]:
raise Exception('Failed to download %s' % path)
def get_vsi_env_and_path(self, path):
container, path = self.get_container_and_path(path)
return {
'OS_IDENTITY_API_VERSION': self.auth_version,
'OS_AUTH_URL': self.auth_url,
'OS_USERNAME': self.username,
'OS_PASSWORD': self.password,
'OS_USER_DOMAIN_NAME': self.user_domain_name,
# 'OS_PROJECT_NAME': self.tena,
# 'OS_PROJECT_DOMAIN_NAME': ,
'OS_REGION_NAME': self.region_name,
}, f'/vsiswift/{container}/{path}'
class S3Source(Source):
def __init__(self, name=None, bucket_name=None, secret_access_key=None, access_key_id=None, endpoint_url=None,
strip_bucket=True, validate_bucket_name=True, **client_kwargs):
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html#boto3.session.Session.client
# for client_kwargs
self.bucket_name = bucket_name
self.secret_access_key=secret_access_key
self.access_key_id=access_key_id
self.endpoint_url = endpoint_url
botocore_session = botocore.session.Session()
if not validate_bucket_name:
botocore_session.unregister('before-parameter-build.s3', botocore.handlers.validate_bucket_name)
session = boto3.session.Session(botocore_session=botocore_session)
self.client = session.client(
's3',
aws_secret_access_key=secret_access_key,
aws_access_key_id=access_key_id,
endpoint_url=endpoint_url,
**client_kwargs,
)
def get_container_and_path(self, path: str):
bucket = self.bucket_name
if bucket is None:
parts = (path[1:] if path.startswith('/') else path).split('/')
bucket, path = parts[0], '/'.join(parts[1:])
elif self.strip_bucket:
parts = (path[1:] if path.startswith('/') else path).split('/')
if parts[0] == bucket:
parts.pop(0)
path = '/'.join(parts)
def list_files(self, path, glob_patterns=None):
if glob_patterns and not isinstance(glob_patterns, list):
glob_patterns = [glob_patterns]
bucket, key = self.get_container_and_path(path)
logger.info(f'Listing S3 files for bucket {bucket} and prefix {key}')
response = self.client.list_objects_v2(
Bucket=bucket,
Prefix=key,
)
return [
for item in response['Contents']
if glob_patterns is None or any(
fnmatch(item['Key'], join(key, glob_pattern)) for glob_pattern in glob_patterns
]
def get_file(self, path, target_path):
bucket, key = self.get_container_and_path(path)
logger.info(f'Retrieving file from S3 {bucket}/{key} to be stored at {target_path}')
self.client.download_file(bucket, key, target_path)
def get_vsi_env_and_path(self, path: str, streaming: bool=False):
bucket, key = self.get_container_and_path(path)
return {
'AWS_SECRET_ACCESS_KEY': self.secret_access_key,
'AWS_ACCESS_KEY_ID': self.access_key_id,
'AWS_S3_ENDPOINT': self.endpoint_url,
}, f'/{"vsis3" if not streaming else "vsis3_streaming"}/{bucket}/{key}'
class LocalSource(Source):
def __init__(self, name, root_directory):
super().__init__(name)
self.root_directory = root_directory
def get_container_and_path(self, path):
return (self.root_directory, path)
def _join_path(self, path):
path = normpath(path)
if isabs(path):
path = path[1:]
return join(self.root_directory, path)
def list_files(self, path, glob_patterns=None):
if glob_patterns and not isinstance(glob_patterns, list):
glob_patterns = [glob_patterns]
if glob_patterns is not None:
return glob(join(self._join_path(path), glob_patterns[0])) # TODO
else:
return glob(join(self._join_path(path), '*'))
def get_file(self, path, target_path):
shutil.copy(self._join_path(path), target_path)
def get_vsi_env_and_path(self, path):
return {}, self._join_path(path)
SOURCE_TYPES = {
'swift': SwiftSource,
's3': S3Source,
'local': LocalSource,
}
def get_source(config: dict, path: str) -> Source:
cfg_sources = config['sources']
for cfg_source in cfg_sources:
if re.match(cfg_source['filter'], path):
break
else:
break
else:
# no source found
raise RegistrationError(f'Could not find a suitable source for the path {path}')
return SOURCE_TYPES[cfg_source['type']](
*cfg_source.get('args', []),
**cfg_source.get('kwargs', {})
)