diff --git a/databusclient/api/download.py b/databusclient/api/download.py index 993dece..b42e0be 100644 --- a/databusclient/api/download.py +++ b/databusclient/api/download.py @@ -1,1106 +1,1093 @@ -import json -import os -import bz2 -import gzip -import lzma -from typing import List, Optional, Tuple -import re -from urllib.parse import urlparse - -import requests -from SPARQLWrapper import JSON, SPARQLWrapper -from tqdm import tqdm - -from databusclient.api.utils import ( - fetch_databus_jsonld, - get_databus_id_parts_from_file_url, - compute_sha256_and_length, -) - -# Compression format mappings -COMPRESSION_EXTENSIONS = { - "bz2": ".bz2", - "gz": ".gz", - "xz": ".xz", -} - -COMPRESSION_MODULES = { - "bz2": bz2, - "gz": gzip, - "xz": lzma, -} - - -def _detect_compression_format(filename: str) -> Optional[str]: - """Detect compression format from file extension. - - Args: - filename: Name of the file. - - Returns: - Compression format string ('bz2', 'gz', 'xz') or None if not compressed. - """ - filename_lower = filename.lower() - for fmt, ext in COMPRESSION_EXTENSIONS.items(): - if filename_lower.endswith(ext): - return fmt - return None - - -def _should_convert_file( - filename: str, convert_to: Optional[str], convert_from: Optional[str] -) -> Tuple[bool, Optional[str]]: - """Determine if a file should be converted and what the source format is. - - Args: - filename: Name of the file. - convert_to: Target compression format ('bz2', 'gz', 'xz'). - convert_from: Optional source compression format filter. - - Returns: - Tuple of (should_convert: bool, source_format: Optional[str]). - """ - if not convert_to: - return False, None - - source_format = _detect_compression_format(filename) - - # If file is not compressed, don't convert - if source_format is None: - return False, None - - # If source and target are the same, skip conversion - if source_format == convert_to: - return False, None - - # If convert_from is specified, only convert matching formats - if convert_from and source_format != convert_from: - return False, None - - return True, source_format - - -def _get_converted_filename(filename: str, source_format: str, target_format: str) -> str: - """Generate the new filename after compression format conversion. - - Args: - filename: Original filename. - source_format: Source compression format ('bz2', 'gz', 'xz'). - target_format: Target compression format ('bz2', 'gz', 'xz'). - - Returns: - New filename with updated extension. - """ - source_ext = COMPRESSION_EXTENSIONS[source_format] - target_ext = COMPRESSION_EXTENSIONS[target_format] - - # Handle case-insensitive extension matching - if filename.lower().endswith(source_ext): - return filename[:-len(source_ext)] + target_ext - return filename + target_ext - - -def _convert_compression_format( - source_file: str, target_file: str, source_format: str, target_format: str -) -> None: - """Convert a compressed file from one format to another. - - Args: - source_file: Path to source compressed file. - target_file: Path to target compressed file. - source_format: Source compression format ('bz2', 'gz', 'xz'). - target_format: Target compression format ('bz2', 'gz', 'xz'). - - Raises: - ValueError: If source_format or target_format is not supported. - RuntimeError: If compression conversion fails. - """ - # Validate compression formats - if source_format not in COMPRESSION_MODULES: - raise ValueError(f"Unsupported source compression format: {source_format}. Supported formats: {list(COMPRESSION_MODULES.keys())}") - if target_format not in COMPRESSION_MODULES: - raise ValueError(f"Unsupported target compression format: {target_format}. Supported formats: {list(COMPRESSION_MODULES.keys())}") - - source_module = COMPRESSION_MODULES[source_format] - target_module = COMPRESSION_MODULES[target_format] - - print(f"Converting {source_format} → {target_format}: {os.path.basename(source_file)}") - - # Decompress and recompress with progress indication - chunk_size = 8192 - - try: - with source_module.open(source_file, 'rb') as sf: - with target_module.open(target_file, 'wb') as tf: - while True: - chunk = sf.read(chunk_size) - if not chunk: - break - tf.write(chunk) - - # Remove the original file after successful conversion - os.remove(source_file) - print(f"Conversion complete: {os.path.basename(target_file)}") - except Exception as e: - # If conversion fails, ensure the partial target file is removed - if os.path.exists(target_file): - os.remove(target_file) - raise RuntimeError(f"Compression conversion failed: {e}") - -# compiled regex for SHA-256 hex strings -_SHA256_RE = re.compile(r"^[0-9a-fA-F]{64}$") - -def _extract_checksum_from_node(node) -> str | None: - """ - Try to extract a 64-char hex checksum from a JSON-LD file node. - Handles these common shapes: - - checksum or sha256sum fields as plain string - - checksum fields as dict with '@value' - - nested values under the allowed keys (lists or '@value' objects) - """ - def find_in_value(v): - if isinstance(v, str): - s = v.strip() - if _SHA256_RE.match(s): - return s - if isinstance(v, dict): - # common JSON-LD value object - if "@value" in v and isinstance(v["@value"], str): - res = find_in_value(v["@value"]) - if res: - return res - # try all nested dict values - for vv in v.values(): - res = find_in_value(vv) - if res: - return res - if isinstance(v, list): - for item in v: - res = find_in_value(item) - if res: - return res - return None - - # Only inspect the explicitly allowed keys to avoid false positives. - for key in ("checksum", "sha256sum", "sha256", "databus:checksum"): - if key in node: - res = find_in_value(node[key]) - if res: - return res - - return None - - - -# Hosts that require Vault token based authentication. Central source of truth. -VAULT_REQUIRED_HOSTS = { - "data.dbpedia.io", - "data.dev.dbpedia.link", -} - - -class DownloadAuthError(Exception): - """Raised when an authorization problem occurs during download.""" - - - -def _extract_checksums_from_jsonld(json_str: str) -> dict: - """ - Parse a JSON-LD string and return a mapping of file URI (and @id) -> checksum. - - Uses the existing _extract_checksum_from_node logic to extract checksums - from `Part` nodes. Both the node's `file` and `@id` (if present and a - string) are mapped to the checksum to preserve existing lookup behavior. - """ - try: - jd = json.loads(json_str) - except Exception: - return {} - if isinstance(jd, dict): - graph = jd.get("@graph", []) - elif isinstance(jd, list): - graph = jd - else: - return{} - - checksums: dict = {} - for node in graph: - if node.get("@type") == "Part": - expected = _extract_checksum_from_node(node) - if not expected: - continue - file_uri = node.get("file") - if isinstance(file_uri, str): - checksums[file_uri] = expected - node_id = node.get("@id") - if isinstance(node_id, str): - checksums[node_id] = expected - return checksums - - -def _resolve_checksums_for_urls(file_urls: List[str], databus_key: str | None) -> dict: - """ - Group file URLs by their Version URI, fetch each Version JSON-LD once, - and return a combined url->checksum mapping for the provided URLs. - - Best-effort: failures to fetch or parse individual versions are skipped. - """ - versions_map: dict = {} - for file_url in file_urls: - try: - host, accountId, groupId, artifactId, versionId, fileId = get_databus_id_parts_from_file_url(file_url) - except Exception: - continue - if versionId is None: - continue - if host is None or accountId is None or groupId is None or artifactId is None: - continue - version_uri = f"https://{host}/{accountId}/{groupId}/{artifactId}/{versionId}" - versions_map.setdefault(version_uri, []).append(file_url) - - checksums: dict = {} - for version_uri, urls_in_version in versions_map.items(): - try: - json_str = fetch_databus_jsonld(version_uri, databus_key=databus_key) - extracted_checksums = _extract_checksums_from_jsonld(json_str) - for url in urls_in_version: - if url in extracted_checksums: - checksums[url] = extracted_checksums[url] - except Exception: - # Best-effort: skip versions we cannot fetch or parse - continue - return checksums - -def _download_file( - url, - localDir, - vault_token_file=None, - databus_key=None, - auth_url=None, - client_id=None, - convert_to=None, - convert_from=None, - validate_checksum: bool = False, - expected_checksum: str | None = None, -) -> None: - """Download a file from the internet with a progress bar using tqdm. - - Args: - url: The URL of the file to download. - localDir: Local directory to download file to. If None, the databus folder structure is created in the current working directory. - vault_token_file: Path to Vault refresh token file. - databus_key: Databus API key for protected downloads. - auth_url: Keycloak token endpoint URL. - client_id: Client ID for token exchange. - convert_to: Target compression format for on-the-fly conversion. - convert_from: Optional source compression format filter. - validate_checksum: Whether to validate checksums after downloading. - expected_checksum: The expected checksum of the file. - """ - if localDir is None: - _host, account, group, artifact, version, file = ( - get_databus_id_parts_from_file_url(url) - ) - localDir = os.path.join( - os.getcwd(), - account, - group, - artifact, - version if version is not None else "latest", - ) - print(f"Local directory not given, using {localDir}") - - file = url.split("/")[-1] - filename = os.path.join(localDir, file) - print(f"Download file: {url}") - dirpath = os.path.dirname(filename) - if dirpath: - os.makedirs(dirpath, exist_ok=True) # Create the necessary directories - # --- 1. Get redirect URL by requesting HEAD --- - headers = {} - - # --- 1a. public databus --- - response = requests.head(url, timeout=30, allow_redirects=False) - - # Check for redirect and update URL if necessary - if response.headers.get("Location") and response.status_code in [ - 301, - 302, - 303, - 307, - 308, - ]: - url = response.headers.get("Location") - print("Redirects url: ", url) - # Re-do HEAD request on redirect URL - response = requests.head(url, timeout=30) - - # Extract hostname from final URL (after redirect) to check if vault token needed. - # This is the actual download location that may require authentication. - parsed = urlparse(url) - host = parsed.hostname - - # --- 1b. Handle 401 on HEAD request --- - if response.status_code == 401: - # Check if this is a vault-required host - if host in VAULT_REQUIRED_HOSTS: - # Vault-required host: need vault token - if not vault_token_file: - raise DownloadAuthError( - f"Vault token required for host '{host}', but no token was provided. Please use --vault-token." - ) - # Token provided; will handle in GET request below - else: - # Not a vault host; might need databus API key - if not databus_key: - raise DownloadAuthError("Databus API key not given for protected download") - headers = {"X-API-KEY": databus_key} - response = requests.head(url, headers=headers, timeout=30) - - # --- 2. Try direct GET to redirected URL --- - headers["Accept-Encoding"] = ( - "identity" # disable gzip to get correct content-length - ) - response = requests.get( - url, headers=headers, stream=True, allow_redirects=True, timeout=30 - ) - www = response.headers.get("WWW-Authenticate", "") # Check if authentication is required - - # --- 3. Handle authentication responses --- - # 3a. Server requests Bearer auth. Only attempt token exchange for hosts - # we explicitly consider Vault-protected (VAULT_REQUIRED_HOSTS). This avoids - # sending tokens to unrelated hosts and makes auth behavior predictable. - if response.status_code == 401 and "bearer" in www.lower(): - # If host is not configured for Vault, do not attempt token exchange. - if host not in VAULT_REQUIRED_HOSTS: - raise DownloadAuthError( - "Server requests Bearer authentication but this host is not configured for Vault token exchange." - " Try providing a databus API key with --databus-key or contact your administrator." - ) - - # Host requires Vault; ensure token file provided. - if not vault_token_file: - raise DownloadAuthError( - f"Vault token required for host '{host}', but no token was provided. Please use --vault-token." - ) - - # --- 3b. Fetch Vault token and retry --- - # Token exchange is potentially sensitive and should only be performed - # for known hosts. __get_vault_access__ handles reading the refresh - # token and exchanging it; errors are translated to DownloadAuthError - # for user-friendly CLI output. - vault_token = __get_vault_access__(url, vault_token_file, auth_url, client_id) - headers["Authorization"] = f"Bearer {vault_token}" - headers["Accept-Encoding"] = "identity" - - # Retry with token - response = requests.get(url, headers=headers, stream=True, timeout=30) - - # Map common auth failures to friendly messages - if response.status_code == 401: - raise DownloadAuthError("Vault token is invalid or expired. Please generate a new token.") - if response.status_code == 403: - raise DownloadAuthError("Vault token is valid but has insufficient permissions to access this file.") - - # 3c. Generic forbidden without Bearer challenge - if response.status_code == 403: - raise DownloadAuthError("Access forbidden: your token or API key does not have permission to download this file.") - - # 3d. Generic unauthorized without Bearer - if response.status_code == 401: - raise DownloadAuthError( - "Unauthorized: access denied. Check your --databus-key or --vault-token settings." - ) - - try: - response.raise_for_status() # Raise if still failing - except requests.exceptions.HTTPError as e: - if response.status_code == 404: - print(f"WARNING: Skipping file {url} because it was not found (404).") - return - else: - raise e - - # --- 4. Download with progress bar --- - total_size_in_bytes = int(response.headers.get("content-length", 0)) - block_size = 1024 # 1 KiB - - progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) - with open(filename, "wb") as f: - for data in response.iter_content(block_size): - progress_bar.update(len(data)) - f.write(data) - progress_bar.close() - - # --- 5. Verify download size --- - if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: - raise IOError("Downloaded size does not match Content-Length header") - - # --- 6. Validate checksum on original downloaded file (BEFORE conversion) --- - if validate_checksum: - # reuse compute_sha256_and_length from webdav extension - try: - actual, _ = compute_sha256_and_length(filename) - except (OSError, IOError) as e: - print(f"WARNING: error computing checksum for {filename}: {e}") - actual = None - - if expected_checksum is None: - print(f"WARNING: no expected checksum available for {filename}; skipping validation") - elif actual is None: - print(f"WARNING: could not compute checksum for {filename}; skipping validation") - else: - if actual.lower() != expected_checksum.lower(): - try: - os.remove(filename) # delete corrupted file - except OSError: - pass - raise IOError( - f"Checksum mismatch for {filename}: expected {expected_checksum}, got {actual}" - ) - - # --- 7. Convert compression format if requested (AFTER validation) --- - should_convert, source_format = _should_convert_file(file, convert_to, convert_from) - if should_convert and source_format: - target_filename = _get_converted_filename(file, source_format, convert_to) - target_filepath = os.path.join(localDir, target_filename) - _convert_compression_format(filename, target_filepath, source_format, convert_to) - - -def _download_files( - urls: List[str], - localDir: str, - vault_token_file: str = None, - databus_key: str = None, - auth_url: str = None, - client_id: str = None, - convert_to: str = None, - convert_from: str = None, - validate_checksum: bool = False, - checksums: dict | None = None, -) -> None: - """Download multiple files from the databus. - - Args: - urls: List of file download URLs. - localDir: Local directory to download files to. If None, the databus folder structure is created in the current working directory. - vault_token_file: Path to Vault refresh token file. - databus_key: Databus API key for protected downloads. - auth_url: Keycloak token endpoint URL. - client_id: Client ID for token exchange. - convert_to: Target compression format for on-the-fly conversion. - convert_from: Optional source compression format filter. - validate_checksum: Whether to validate checksums after downloading. - checksums: Dictionary mapping URLs to their expected checksums. - """ - for url in urls: - expected = None - if checksums and isinstance(checksums, dict): - expected = checksums.get(url) - _download_file( - url=url, - localDir=localDir, - vault_token_file=vault_token_file, - databus_key=databus_key, - auth_url=auth_url, - client_id=client_id, - convert_to=convert_to, - convert_from=convert_from, - validate_checksum=validate_checksum, - expected_checksum=expected, - ) - - -def _get_sparql_query_of_collection(uri: str, databus_key: str | None = None) -> str: - """Get SPARQL query of collection members from databus collection URI. - - Args: - uri: The full databus collection URI. - databus_key: Optional Databus API key for authentication on protected resources. - - Returns: - SPARQL query string to get download URLs of all files in the collection. - """ - headers = {"Accept": "text/sparql"} - if databus_key is not None: - headers["X-API-KEY"] = databus_key - - response = requests.get(uri, headers=headers, timeout=30) - response.raise_for_status() - return response.text - - -def _query_sparql_endpoint(endpoint_url, query, databus_key=None) -> dict: - """Query a SPARQL endpoint and return results in JSON format. - - Args: - endpoint_url: The URL of the SPARQL endpoint. - query: The SPARQL query string. - databus_key: Optional API key for authentication. - - Returns: - Dictionary containing the query results. - """ - sparql = SPARQLWrapper(endpoint_url) - sparql.method = "POST" - sparql.setQuery(query) - sparql.setReturnFormat(JSON) - if databus_key is not None: - sparql.setCustomHttpHeaders({"X-API-KEY": databus_key}) - results = sparql.query().convert() - return results - - -def _get_file_download_urls_from_sparql_query( - endpoint_url, query, databus_key=None -) -> List[str]: - """Execute a SPARQL query to get databus file download URLs. - - Args: - endpoint_url: The URL of the SPARQL endpoint. - query: The SPARQL query string. - databus_key: Optional API key for authentication. - - Returns: - List of file download URLs. - """ - result_dict = _query_sparql_endpoint(endpoint_url, query, databus_key=databus_key) - - bindings = result_dict.get("results", {}).get("bindings") - if not isinstance(bindings, list): - raise ValueError("Invalid SPARQL response: 'bindings' missing or not a list") - - urls: List[str] = [] - - for binding in bindings: - if not isinstance(binding, dict) or len(binding) != 1: - raise ValueError(f"Invalid SPARQL binding structure: {binding}") - - value_dict = next(iter(binding.values())) - value = value_dict.get("value") - - if not isinstance(value, str): - raise ValueError(f"Invalid SPARQL value field: {value_dict}") - - urls.append(value) - - return urls - - -def __get_vault_access__( - download_url: str, token_file: str, auth_url: str, client_id: str -) -> str: - """ - Get Vault access token for a protected databus download. - """ - # 1. Load refresh token - refresh_token = os.environ.get("REFRESH_TOKEN") - if not refresh_token: - if not os.path.exists(token_file): - raise FileNotFoundError(f"Vault token file not found: {token_file}") - with open(token_file, "r") as f: - refresh_token = f.read().strip() - if len(refresh_token) < 80: - print(f"Warning: token from {token_file} is short (<80 chars)") - - # 2. Refresh token -> access token - resp = requests.post( - auth_url, - data={ - "client_id": client_id, - "grant_type": "refresh_token", - "refresh_token": refresh_token, - }, - timeout=30, - ) - resp.raise_for_status() - access_token = resp.json()["access_token"] - - # 3. Extract host as audience - # Remove protocol prefix - if download_url.startswith("https://"): - host_part = download_url[len("https://") :] - elif download_url.startswith("http://"): - host_part = download_url[len("http://") :] - else: - host_part = download_url - audience = host_part.split("/")[0] # host is before first "/" - - # 4. Access token -> Vault token - resp = requests.post( - auth_url, - data={ - "client_id": client_id, - "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", - "subject_token": access_token, - "audience": audience, - }, - timeout=30, - ) - resp.raise_for_status() - vault_token = resp.json()["access_token"] - - print(f"Using Vault access token for {download_url}") - return vault_token - - -def _download_collection( - uri: str, - endpoint: str, - localDir: str, - vault_token: str = None, - databus_key: str = None, - auth_url: str = None, - client_id: str = None, - convert_to: str = None, - convert_from: str = None, - validate_checksum: bool = False, -) -> None: - """Download all files in a databus collection. - - Args: - uri: The full databus collection URI. - endpoint: The databus SPARQL endpoint URL. - localDir: Local directory to download files to. If None, the databus folder structure is created in the current working directory. - vault_token: Path to Vault refresh token file for protected downloads. - databus_key: Databus API key for protected downloads. - auth_url: Keycloak token endpoint URL. - client_id: Client ID for token exchange. - convert_to: Target compression format for on-the-fly conversion. - convert_from: Optional source compression format filter. - validate_checksum: Whether to validate checksums after downloading. - """ - query = _get_sparql_query_of_collection(uri, databus_key=databus_key) - file_urls = _get_file_download_urls_from_sparql_query( - endpoint, query, databus_key=databus_key - ) - - # If checksum validation requested, attempt to build url->checksum mapping - checksums: dict = {} - if validate_checksum: - checksums = _resolve_checksums_for_urls(list(file_urls), databus_key) - - _download_files( - list(file_urls), - localDir, - vault_token_file=vault_token, - databus_key=databus_key, - auth_url=auth_url, - client_id=client_id, - convert_to=convert_to, - convert_from=convert_from, - validate_checksum=validate_checksum, - checksums=checksums if checksums else None, - ) - - -def _download_version( - uri: str, - localDir: str, - vault_token_file: str = None, - databus_key: str = None, - auth_url: str = None, - client_id: str = None, - convert_to: str = None, - convert_from: str = None, - validate_checksum: bool = False, -) -> None: - """Download all files in a databus artifact version. - - Args: - uri: The full databus artifact version URI. - localDir: Local directory to download files to. If None, the databus folder structure is created in the current working directory. - vault_token_file: Path to Vault refresh token file for protected downloads. - databus_key: Databus API key for protected downloads. - auth_url: Keycloak token endpoint URL. - client_id: Client ID for token exchange. - convert_to: Target compression format for on-the-fly conversion. - convert_from: Optional source compression format filter. - validate_checksum: Whether to validate checksums after downloading. - """ - json_str = fetch_databus_jsonld(uri, databus_key=databus_key) - file_urls = _get_file_download_urls_from_artifact_jsonld(json_str) - # build url -> checksum mapping from JSON-LD when available - checksums: dict = {} - try: - checksums = _extract_checksums_from_jsonld(json_str) - except Exception: - checksums = {} - - _download_files( - file_urls, - localDir, - vault_token_file=vault_token_file, - databus_key=databus_key, - auth_url=auth_url, - client_id=client_id, - convert_to=convert_to, - convert_from=convert_from, - validate_checksum=validate_checksum, - checksums=checksums, - ) - - -def _download_artifact( - uri: str, - localDir: str, - all_versions: bool = False, - vault_token_file: str = None, - databus_key: str = None, - auth_url: str = None, - client_id: str = None, - convert_to: str = None, - convert_from: str = None, - validate_checksum: bool = False, -) -> None: - """Download files in a databus artifact. - - Args: - uri: The full databus artifact URI. - localDir: Local directory to download files to. If None, the databus folder structure is created in the current working directory. - all_versions: If True, download all versions of the artifact; otherwise, only download the latest version. - vault_token_file: Path to Vault refresh token file for protected downloads. - databus_key: Databus API key for protected downloads. - auth_url: Keycloak token endpoint URL. - client_id: Client ID for token exchange. - convert_to: Target compression format for on-the-fly conversion. - convert_from: Optional source compression format filter. - validate_checksum: Whether to validate checksums after downloading. - """ - json_str = fetch_databus_jsonld(uri, databus_key=databus_key) - versions = _get_databus_versions_of_artifact(json_str, all_versions=all_versions) - if isinstance(versions, str): - versions = [versions] - for version_uri in versions: - print(f"Downloading version: {version_uri}") - json_str = fetch_databus_jsonld(version_uri, databus_key=databus_key) - file_urls = _get_file_download_urls_from_artifact_jsonld(json_str) - # extract checksums for this version - checksums: dict = {} - try: - checksums = _extract_checksums_from_jsonld(json_str) - except Exception: - checksums = {} - - _download_files( - file_urls, - localDir, - vault_token_file=vault_token_file, - databus_key=databus_key, - auth_url=auth_url, - client_id=client_id, - convert_to=convert_to, - convert_from=convert_from, - validate_checksum=validate_checksum, - checksums=checksums, - ) - - -def _get_databus_versions_of_artifact( - json_str: str, all_versions: bool -) -> str | List[str]: - """Parse the JSON-LD of a databus artifact to extract URLs of its versions. - - Args: - json_str: JSON-LD string of the databus artifact. - all_versions: If True, return all version URLs; otherwise, return only the latest version URL. - - Returns: - If all_versions is True: List of all version URLs. - If all_versions is False: URL of the latest version. - """ - json_dict = json.loads(json_str) - versions = json_dict.get("databus:hasVersion") - - if versions is None: - raise ValueError("No 'databus:hasVersion' field in artifact JSON-LD") - - if isinstance(versions, dict): - versions = [versions] - elif not isinstance(versions, list): - raise ValueError( - f"Unexpected type for 'databus:hasVersion': {type(versions).__name__}" - ) - - version_urls = [v["@id"] for v in versions if isinstance(v, dict) and "@id" in v] - - if not version_urls: - raise ValueError("No versions found in artifact JSON-LD") - - version_urls.sort(reverse=True) # Sort versions in descending order - - if all_versions: - return version_urls - return version_urls[0] - - -def _get_file_download_urls_from_artifact_jsonld(json_str: str) -> List[str]: - """Parse the JSON-LD of a databus artifact version to extract download URLs. - - Args: - json_str: JSON-LD string of the databus artifact version. - - Returns: - List of all file download URLs in the artifact version. - """ - - databusIdUrl: List[str] = [] - - json_dict = json.loads(json_str) - graph = json_dict.get("@graph", []) - for node in graph: - if node.get("@type") == "Part": - file_uri = node.get("file") - if not isinstance(file_uri, str): - continue - databusIdUrl.append(file_uri) - return databusIdUrl - - -def _download_group( - uri: str, - localDir: str, - all_versions: bool = False, - vault_token_file: str = None, - databus_key: str = None, - auth_url: str = None, - client_id: str = None, - convert_to: str = None, - convert_from: str = None, - validate_checksum: bool = False, -) -> None: - """Download files in a databus group. - - Args: - uri: The full databus group URI. - localDir: Local directory to download files to. If None, the databus folder structure is created in the current working directory. - all_versions: If True, download all versions of each artifact in the group; otherwise, only download the latest version. - vault_token_file: Path to Vault refresh token file for protected downloads. - databus_key: Databus API key for protected downloads. - auth_url: Keycloak token endpoint URL. - client_id: Client ID for token exchange. - convert_to: Target compression format for on-the-fly conversion. - convert_from: Optional source compression format filter. - validate_checksum: Whether to validate checksums after downloading. - """ - json_str = fetch_databus_jsonld(uri, databus_key=databus_key) - artifacts = _get_databus_artifacts_of_group(json_str) - for artifact_uri in artifacts: - print(f"Download artifact: {artifact_uri}") - _download_artifact( - artifact_uri, - localDir, - all_versions=all_versions, - vault_token_file=vault_token_file, - databus_key=databus_key, - auth_url=auth_url, - client_id=client_id, - convert_to=convert_to, - convert_from=convert_from, - validate_checksum=validate_checksum, - ) - - -def _get_databus_artifacts_of_group(json_str: str) -> List[str]: - """ - Parse the JSON-LD of a databus group to extract URLs of all artifacts. - - Returns a list of artifact URLs. - """ - json_dict = json.loads(json_str) - artifacts = json_dict.get("databus:hasArtifact") - - if artifacts is None: - return [] - - if isinstance(artifacts, dict): - artifacts_iter = [artifacts] - elif isinstance(artifacts, list): - artifacts_iter = artifacts - else: - raise ValueError( - f"Unexpected type for 'databus:hasArtifact': {type(artifacts).__name__}" - ) - - result: List[str] = [] - for item in artifacts_iter: - if not isinstance(item, dict): - continue - uri = item.get("@id") - if not uri: - continue - _, _, _, _, version, _ = get_databus_id_parts_from_file_url(uri) - if version is None: - result.append(uri) - return result - - -def download( - localDir: str, - endpoint: str, - databusURIs: List[str], - token=None, - databus_key=None, - all_versions=None, - auth_url="https://auth.dbpedia.org/realms/dbpedia/protocol/openid-connect/token", - client_id="vault-token-exchange", - convert_to=None, - convert_from=None, - validate_checksum: bool = False -) -> None: - """Download datasets from databus. - - Download of files, versions, artifacts, groups or databus collections via their databus URIs or user-defined SPARQL queries that return file download URLs. - - Args: - localDir: Local directory to download datasets to. If None, the databus folder structure is created in the current working directory. - endpoint: The databus endpoint URL. If None, inferred from databusURI. Required for user-defined SPARQL queries. - databusURIs: Databus identifiers to specify datasets to download. - token: Path to Vault refresh token file for protected downloads. - databus_key: Databus API key for protected downloads. - auth_url: Keycloak token endpoint URL. Default is "https://auth.dbpedia.org/realms/dbpedia/protocol/openid-connect/token". - client_id: Client ID for token exchange. Default is "vault-token-exchange". - convert_to: Target compression format for on-the-fly conversion (supported: bz2, gz, xz). - convert_from: Optional source compression format filter. - validate_checksum: Whether to validate checksums after downloading. - """ - for databusURI in databusURIs: - host, account, group, artifact, version, file = ( - get_databus_id_parts_from_file_url(databusURI) - ) - - # Determine endpoint per-URI if not explicitly provided - uri_endpoint = endpoint - - # dataID or databus collection - if databusURI.startswith("http://") or databusURI.startswith("https://"): - # Auto-detect sparql endpoint from host if not given - if uri_endpoint is None: - uri_endpoint = f"https://{host}/sparql" - print(f"SPARQL endpoint {uri_endpoint}") - - if group == "collections" and artifact is not None: - print(f"Downloading collection: {databusURI}") - _download_collection( - databusURI, - uri_endpoint, - localDir, - token, - databus_key, - auth_url, - client_id, - convert_to, - convert_from, - validate_checksum=validate_checksum, - ) - elif file is not None: - print(f"Downloading file: {databusURI}") - # Try to fetch expected checksum from the parent Version metadata - expected = None - if validate_checksum: - try: - if version is not None: - version_uri = f"https://{host}/{account}/{group}/{artifact}/{version}" - json_str = fetch_databus_jsonld(version_uri, databus_key=databus_key) - checks = _extract_checksums_from_jsonld(json_str) - expected = checks.get(databusURI) or checks.get( - "https://" + databusURI.removeprefix("http://").removeprefix("https://") - ) - except Exception as e: - print(f"WARNING: Could not fetch checksum for single file: {e}") - - # Call the worker to download the single file (passes expected checksum) - _download_file( - databusURI, - localDir, - vault_token_file=token, - databus_key=databus_key, - auth_url=auth_url, - client_id=client_id, - convert_to=convert_to, - convert_from=convert_from, - validate_checksum=validate_checksum, - expected_checksum=expected, - ) - elif version is not None: - print(f"Downloading version: {databusURI}") - _download_version( - databusURI, - localDir, - vault_token_file=token, - databus_key=databus_key, - auth_url=auth_url, - client_id=client_id, - convert_to=convert_to, - convert_from=convert_from, - validate_checksum=validate_checksum, - ) - elif artifact is not None: - print( - f"Downloading {'all' if all_versions else 'latest'} version(s) of artifact: {databusURI}" - ) - _download_artifact( - databusURI, - localDir, - all_versions=all_versions, - vault_token_file=token, - databus_key=databus_key, - auth_url=auth_url, - client_id=client_id, - convert_to=convert_to, - convert_from=convert_from, - validate_checksum=validate_checksum, - ) - elif group is not None and group != "collections": - print( - f"Downloading group and all its artifacts and versions: {databusURI}" - ) - _download_group( - databusURI, - localDir, - all_versions=all_versions, - vault_token_file=token, - databus_key=databus_key, - auth_url=auth_url, - client_id=client_id, - convert_to=convert_to, - convert_from=convert_from, - validate_checksum=validate_checksum, - ) - elif account is not None: - print("accountId not supported yet") # TODO - else: - print( - "dataId not supported yet" - ) # TODO add support for other DatabusIds - # query in local file - elif databusURI.startswith("file://"): - print("query in file not supported yet") - # query as argument - else: - print("QUERY {}", databusURI.replace("\n", " ")) - if uri_endpoint is None: # endpoint is required for queries (--databus) - raise ValueError("No endpoint given for query") - res = _get_file_download_urls_from_sparql_query( - uri_endpoint, databusURI, databus_key=databus_key - ) - - # If checksum validation requested, try to build url->checksum mapping - checksums: dict = {} - if validate_checksum: - checksums = _resolve_checksums_for_urls(res, databus_key) - if not checksums: - print("WARNING: Checksum validation enabled but no checksums found for query results.") - - _download_files( - res, - localDir, - vault_token_file=token, - databus_key=databus_key, - auth_url=auth_url, - client_id=client_id, - convert_to=convert_to, - convert_from=convert_from, - validate_checksum=validate_checksum, - checksums=checksums if checksums else None, - ) +import hashlib +import json +import os +from typing import List, Optional, Tuple +import re +from urllib.parse import urlparse + +import requests +from SPARQLWrapper import JSON, SPARQLWrapper +from tqdm import tqdm + +from databusclient.api.utils import ( + fetch_databus_jsonld, + get_databus_id_parts_from_file_url, + compute_sha256_and_length, +) +from databusclient.extensions.file_converter import ( + FileConverter, + COMPRESSION_EXTENSIONS, + COMPRESSION_MODULES, +) + + +def _detect_compression_format(filename: str) -> Optional[str]: + """Detect compression format from file extension. + + Delegates to :meth:`FileConverter.detect_format`. Returns the format + string (``'bz2'``, ``'gz'``, ``'xz'``, ``'zstd'``) or ``'none'`` when + the file has no recognised compressed extension. + + .. note:: Prior versions returned ``None`` for uncompressed files; + callers should now compare against ``'none'``. + """ + return FileConverter.detect_format(filename) + + +def _should_convert_file( + filename: str, convert_to: Optional[str], convert_from: Optional[str] +) -> Tuple[bool, Optional[str]]: + """Determine if a file should be converted and what the source format is. + + Supports ``convert_to='none'`` (decompress to raw) and + ``source_format='none'`` with ``convert_from='none'`` (compress raw). + """ + if not convert_to: + return False, None + + source_format = _detect_compression_format(filename) + + # Decompress: convert_to='none', any compressed source is eligible + if convert_to == "none": + if source_format == "none": + return False, None # already uncompressed + if convert_from and source_format != convert_from: + return False, None + return True, source_format + + # Compress raw file: source is uncompressed + if source_format == "none": + # Only convert if caller explicitly asks for raw-file compression + if convert_from == "none": + return True, "none" + return False, None + + # Same format → skip + if source_format == convert_to: + return False, None + + # Filter by convert_from + if convert_from and source_format != convert_from: + return False, None + + return True, source_format + + +def _get_converted_filename( + filename: str, source_format: str, target_format: str +) -> str: + """Generate the new filename after compression format conversion.""" + return FileConverter.get_converted_filename(filename, source_format, target_format) + + +def _convert_compression_format( + source_file: str, target_file: str, source_format: str, target_format: str +) -> None: + """Convert a compressed file from one format to another. + + Delegates to :meth:`FileConverter.convert_file`. + """ + FileConverter.convert_file(source_file, target_file, source_format, target_format) + +# compiled regex for SHA-256 hex strings +_SHA256_RE = re.compile(r"^[0-9a-fA-F]{64}$") + +def _extract_checksum_from_node(node) -> str | None: + """ + Try to extract a 64-char hex checksum from a JSON-LD file node. + Handles these common shapes: + - checksum or sha256sum fields as plain string + - checksum fields as dict with '@value' + - nested values under the allowed keys (lists or '@value' objects) + """ + def find_in_value(v): + if isinstance(v, str): + s = v.strip() + if _SHA256_RE.match(s): + return s + if isinstance(v, dict): + # common JSON-LD value object + if "@value" in v and isinstance(v["@value"], str): + res = find_in_value(v["@value"]) + if res: + return res + # try all nested dict values + for vv in v.values(): + res = find_in_value(vv) + if res: + return res + if isinstance(v, list): + for item in v: + res = find_in_value(item) + if res: + return res + return None + + # Only inspect the explicitly allowed keys to avoid false positives. + for key in ("checksum", "sha256sum", "sha256", "databus:checksum"): + if key in node: + res = find_in_value(node[key]) + if res: + return res + + return None + + + +# Hosts that require Vault token based authentication. Central source of truth. +VAULT_REQUIRED_HOSTS = { + "data.dbpedia.io", + "data.dev.dbpedia.link", +} + + +class DownloadAuthError(Exception): + """Raised when an authorization problem occurs during download.""" + + + +def _extract_checksums_from_jsonld(json_str: str) -> dict: + """ + Parse a JSON-LD string and return a mapping of file URI (and @id) -> checksum. + + Uses the existing _extract_checksum_from_node logic to extract checksums + from `Part` nodes. Both the node's `file` and `@id` (if present and a + string) are mapped to the checksum to preserve existing lookup behavior. + """ + try: + jd = json.loads(json_str) + except Exception: + return {} + if isinstance(jd, dict): + graph = jd.get("@graph", []) + elif isinstance(jd, list): + graph = jd + else: + return{} + + checksums: dict = {} + for node in graph: + if node.get("@type") == "Part": + expected = _extract_checksum_from_node(node) + if not expected: + continue + file_uri = node.get("file") + if isinstance(file_uri, str): + checksums[file_uri] = expected + node_id = node.get("@id") + if isinstance(node_id, str): + checksums[node_id] = expected + return checksums + + +def _resolve_checksums_for_urls(file_urls: List[str], databus_key: str | None) -> dict: + """ + Group file URLs by their Version URI, fetch each Version JSON-LD once, + and return a combined url->checksum mapping for the provided URLs. + + Best-effort: failures to fetch or parse individual versions are skipped. + """ + versions_map: dict = {} + for file_url in file_urls: + try: + host, accountId, groupId, artifactId, versionId, fileId = get_databus_id_parts_from_file_url(file_url) + except Exception: + continue + if versionId is None: + continue + if host is None or accountId is None or groupId is None or artifactId is None: + continue + version_uri = f"https://{host}/{accountId}/{groupId}/{artifactId}/{versionId}" + versions_map.setdefault(version_uri, []).append(file_url) + + checksums: dict = {} + for version_uri, urls_in_version in versions_map.items(): + try: + json_str = fetch_databus_jsonld(version_uri, databus_key=databus_key) + extracted_checksums = _extract_checksums_from_jsonld(json_str) + for url in urls_in_version: + if url in extracted_checksums: + checksums[url] = extracted_checksums[url] + except Exception: + # Best-effort: skip versions we cannot fetch or parse + continue + return checksums + +def _download_file( + url, + localDir, + vault_token_file=None, + databus_key=None, + auth_url=None, + client_id=None, + convert_to=None, + convert_from=None, + validate_checksum: bool = False, + expected_checksum: str | None = None, +) -> None: + """Download a file from the internet with a progress bar using tqdm. + + Args: + url: The URL of the file to download. + localDir: Local directory to download file to. If None, the databus folder structure is created in the current working directory. + vault_token_file: Path to Vault refresh token file. + databus_key: Databus API key for protected downloads. + auth_url: Keycloak token endpoint URL. + client_id: Client ID for token exchange. + convert_to: Target compression format for on-the-fly conversion. + convert_from: Optional source compression format filter. + validate_checksum: Whether to validate checksums after downloading. + expected_checksum: The expected checksum of the file. + """ + if localDir is None: + _host, account, group, artifact, version, file = ( + get_databus_id_parts_from_file_url(url) + ) + localDir = os.path.join( + os.getcwd(), + account, + group, + artifact, + version if version is not None else "latest", + ) + print(f"Local directory not given, using {localDir}") + + file = url.split("/")[-1] + filename = os.path.join(localDir, file) + print(f"Download file: {url}") + dirpath = os.path.dirname(filename) + if dirpath: + os.makedirs(dirpath, exist_ok=True) # Create the necessary directories + # --- 1. Get redirect URL by requesting HEAD --- + headers = {} + + # --- 1a. public databus --- + response = requests.head(url, timeout=30, allow_redirects=False) + + # Check for redirect and update URL if necessary + if response.headers.get("Location") and response.status_code in [ + 301, + 302, + 303, + 307, + 308, + ]: + url = response.headers.get("Location") + print("Redirects url: ", url) + # Re-do HEAD request on redirect URL + response = requests.head(url, timeout=30) + + # Extract hostname from final URL (after redirect) to check if vault token needed. + # This is the actual download location that may require authentication. + parsed = urlparse(url) + host = parsed.hostname + + # --- 1b. Handle 401 on HEAD request --- + if response.status_code == 401: + # Check if this is a vault-required host + if host in VAULT_REQUIRED_HOSTS: + # Vault-required host: need vault token + if not vault_token_file: + raise DownloadAuthError( + f"Vault token required for host '{host}', but no token was provided. Please use --vault-token." + ) + # Token provided; will handle in GET request below + else: + # Not a vault host; might need databus API key + if not databus_key: + raise DownloadAuthError("Databus API key not given for protected download") + headers = {"X-API-KEY": databus_key} + response = requests.head(url, headers=headers, timeout=30) + + # --- 2. Try direct GET to redirected URL --- + headers["Accept-Encoding"] = ( + "identity" # disable gzip to get correct content-length + ) + response = requests.get( + url, headers=headers, stream=True, allow_redirects=True, timeout=30 + ) + www = response.headers.get("WWW-Authenticate", "") # Check if authentication is required + + # --- 3. Handle authentication responses --- + # 3a. Server requests Bearer auth. Only attempt token exchange for hosts + # we explicitly consider Vault-protected (VAULT_REQUIRED_HOSTS). This avoids + # sending tokens to unrelated hosts and makes auth behavior predictable. + if response.status_code == 401 and "bearer" in www.lower(): + # If host is not configured for Vault, do not attempt token exchange. + if host not in VAULT_REQUIRED_HOSTS: + raise DownloadAuthError( + "Server requests Bearer authentication but this host is not configured for Vault token exchange." + " Try providing a databus API key with --databus-key or contact your administrator." + ) + + # Host requires Vault; ensure token file provided. + if not vault_token_file: + raise DownloadAuthError( + f"Vault token required for host '{host}', but no token was provided. Please use --vault-token." + ) + + # --- 3b. Fetch Vault token and retry --- + # Token exchange is potentially sensitive and should only be performed + # for known hosts. __get_vault_access__ handles reading the refresh + # token and exchanging it; errors are translated to DownloadAuthError + # for user-friendly CLI output. + vault_token = __get_vault_access__(url, vault_token_file, auth_url, client_id) + headers["Authorization"] = f"Bearer {vault_token}" + headers["Accept-Encoding"] = "identity" + + # Retry with token + response = requests.get(url, headers=headers, stream=True, timeout=30) + + # Map common auth failures to friendly messages + if response.status_code == 401: + raise DownloadAuthError("Vault token is invalid or expired. Please generate a new token.") + if response.status_code == 403: + raise DownloadAuthError("Vault token is valid but has insufficient permissions to access this file.") + + # 3c. Generic forbidden without Bearer challenge + if response.status_code == 403: + raise DownloadAuthError("Access forbidden: your token or API key does not have permission to download this file.") + + # 3d. Generic unauthorized without Bearer + if response.status_code == 401: + raise DownloadAuthError( + "Unauthorized: access denied. Check your --databus-key or --vault-token settings." + ) + + try: + response.raise_for_status() # Raise if still failing + except requests.exceptions.HTTPError as e: + if response.status_code == 404: + print(f"WARNING: Skipping file {url} because it was not found (404).") + return + else: + raise e + + # --- 4. Determine if streaming conversion is possible --- + should_convert, source_format = _should_convert_file(file, convert_to, convert_from) + streaming = should_convert and source_format is not None + + if streaming: + target_filename = _get_converted_filename(file, source_format, convert_to) + target_filepath = os.path.join(localDir, target_filename) + else: + target_filepath = filename + + total_size_in_bytes = int(response.headers.get("content-length", 0)) + block_size = 1024 # 1 KiB + + if streaming: + # --- 4a. Streaming download + conversion in a single pass --- + print(f"Streaming conversion {source_format} → {convert_to}: {file}") + # Write raw bytes to a temp file first so we can validate checksum + # on the original compressed stream, then convert. + # (Streaming directly through decompression would lose ability to + # checksum the *compressed* bytes the server sent.) + progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + checksum_hasher = hashlib.sha256() if validate_checksum else None + + with open(filename, "wb") as f: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + f.write(data) + if checksum_hasher: + checksum_hasher.update(data) + progress_bar.close() + + # Verify download size + if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: + raise IOError("Downloaded size does not match Content-Length header") + + # Validate checksum of the original compressed file + if validate_checksum: + actual = checksum_hasher.hexdigest() if checksum_hasher else None + if expected_checksum is None: + print(f"WARNING: no expected checksum available for {filename}; skipping validation") + elif actual is None: + print(f"WARNING: could not compute checksum for {filename}; skipping validation") + else: + if actual.lower() != expected_checksum.lower(): + try: + os.remove(filename) + except OSError: + pass + raise IOError( + f"Checksum mismatch for {filename}: expected {expected_checksum}, got {actual}" + ) + + # Now convert the downloaded file + _convert_compression_format(filename, target_filepath, source_format, convert_to) + + else: + # --- 4b. Plain download (no conversion) --- + progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + with open(filename, "wb") as f: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + f.write(data) + progress_bar.close() + + # --- 5. Verify download size --- + if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: + raise IOError("Downloaded size does not match Content-Length header") + + # --- 6. Validate checksum on downloaded file --- + if validate_checksum: + try: + actual, _ = compute_sha256_and_length(filename) + except (OSError, IOError) as e: + print(f"WARNING: error computing checksum for {filename}: {e}") + actual = None + + if expected_checksum is None: + print(f"WARNING: no expected checksum available for {filename}; skipping validation") + elif actual is None: + print(f"WARNING: could not compute checksum for {filename}; skipping validation") + else: + if actual.lower() != expected_checksum.lower(): + try: + os.remove(filename) + except OSError: + pass + raise IOError( + f"Checksum mismatch for {filename}: expected {expected_checksum}, got {actual}" + ) + + +def _download_files( + urls: List[str], + localDir: str, + vault_token_file: str = None, + databus_key: str = None, + auth_url: str = None, + client_id: str = None, + convert_to: str = None, + convert_from: str = None, + validate_checksum: bool = False, + checksums: dict | None = None, +) -> None: + """Download multiple files from the databus. + + Args: + urls: List of file download URLs. + localDir: Local directory to download files to. If None, the databus folder structure is created in the current working directory. + vault_token_file: Path to Vault refresh token file. + databus_key: Databus API key for protected downloads. + auth_url: Keycloak token endpoint URL. + client_id: Client ID for token exchange. + convert_to: Target compression format for on-the-fly conversion. + convert_from: Optional source compression format filter. + validate_checksum: Whether to validate checksums after downloading. + checksums: Dictionary mapping URLs to their expected checksums. + """ + for url in urls: + expected = None + if checksums and isinstance(checksums, dict): + expected = checksums.get(url) + _download_file( + url=url, + localDir=localDir, + vault_token_file=vault_token_file, + databus_key=databus_key, + auth_url=auth_url, + client_id=client_id, + convert_to=convert_to, + convert_from=convert_from, + validate_checksum=validate_checksum, + expected_checksum=expected, + ) + + +def _get_sparql_query_of_collection(uri: str, databus_key: str | None = None) -> str: + """Get SPARQL query of collection members from databus collection URI. + + Args: + uri: The full databus collection URI. + databus_key: Optional Databus API key for authentication on protected resources. + + Returns: + SPARQL query string to get download URLs of all files in the collection. + """ + headers = {"Accept": "text/sparql"} + if databus_key is not None: + headers["X-API-KEY"] = databus_key + + response = requests.get(uri, headers=headers, timeout=30) + response.raise_for_status() + return response.text + + +def _query_sparql_endpoint(endpoint_url, query, databus_key=None) -> dict: + """Query a SPARQL endpoint and return results in JSON format. + + Args: + endpoint_url: The URL of the SPARQL endpoint. + query: The SPARQL query string. + databus_key: Optional API key for authentication. + + Returns: + Dictionary containing the query results. + """ + sparql = SPARQLWrapper(endpoint_url) + sparql.method = "POST" + sparql.setQuery(query) + sparql.setReturnFormat(JSON) + if databus_key is not None: + sparql.setCustomHttpHeaders({"X-API-KEY": databus_key}) + results = sparql.query().convert() + return results + + +def _get_file_download_urls_from_sparql_query( + endpoint_url, query, databus_key=None +) -> List[str]: + """Execute a SPARQL query to get databus file download URLs. + + Args: + endpoint_url: The URL of the SPARQL endpoint. + query: The SPARQL query string. + databus_key: Optional API key for authentication. + + Returns: + List of file download URLs. + """ + result_dict = _query_sparql_endpoint(endpoint_url, query, databus_key=databus_key) + + bindings = result_dict.get("results", {}).get("bindings") + if not isinstance(bindings, list): + raise ValueError("Invalid SPARQL response: 'bindings' missing or not a list") + + urls: List[str] = [] + + for binding in bindings: + if not isinstance(binding, dict) or len(binding) != 1: + raise ValueError(f"Invalid SPARQL binding structure: {binding}") + + value_dict = next(iter(binding.values())) + value = value_dict.get("value") + + if not isinstance(value, str): + raise ValueError(f"Invalid SPARQL value field: {value_dict}") + + urls.append(value) + + return urls + + +def __get_vault_access__( + download_url: str, token_file: str, auth_url: str, client_id: str +) -> str: + """ + Get Vault access token for a protected databus download. + """ + # 1. Load refresh token + refresh_token = os.environ.get("REFRESH_TOKEN") + if not refresh_token: + if not os.path.exists(token_file): + raise FileNotFoundError(f"Vault token file not found: {token_file}") + with open(token_file, "r") as f: + refresh_token = f.read().strip() + if len(refresh_token) < 80: + print(f"Warning: token from {token_file} is short (<80 chars)") + + # 2. Refresh token -> access token + resp = requests.post( + auth_url, + data={ + "client_id": client_id, + "grant_type": "refresh_token", + "refresh_token": refresh_token, + }, + timeout=30, + ) + resp.raise_for_status() + access_token = resp.json()["access_token"] + + # 3. Extract host as audience + # Remove protocol prefix + if download_url.startswith("https://"): + host_part = download_url[len("https://") :] + elif download_url.startswith("http://"): + host_part = download_url[len("http://") :] + else: + host_part = download_url + audience = host_part.split("/")[0] # host is before first "/" + + # 4. Access token -> Vault token + resp = requests.post( + auth_url, + data={ + "client_id": client_id, + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "subject_token": access_token, + "audience": audience, + }, + timeout=30, + ) + resp.raise_for_status() + vault_token = resp.json()["access_token"] + + print(f"Using Vault access token for {download_url}") + return vault_token + + +def _download_collection( + uri: str, + endpoint: str, + localDir: str, + vault_token: str = None, + databus_key: str = None, + auth_url: str = None, + client_id: str = None, + convert_to: str = None, + convert_from: str = None, + validate_checksum: bool = False, +) -> None: + """Download all files in a databus collection. + + Args: + uri: The full databus collection URI. + endpoint: The databus SPARQL endpoint URL. + localDir: Local directory to download files to. If None, the databus folder structure is created in the current working directory. + vault_token: Path to Vault refresh token file for protected downloads. + databus_key: Databus API key for protected downloads. + auth_url: Keycloak token endpoint URL. + client_id: Client ID for token exchange. + convert_to: Target compression format for on-the-fly conversion. + convert_from: Optional source compression format filter. + validate_checksum: Whether to validate checksums after downloading. + """ + query = _get_sparql_query_of_collection(uri, databus_key=databus_key) + file_urls = _get_file_download_urls_from_sparql_query( + endpoint, query, databus_key=databus_key + ) + + # If checksum validation requested, attempt to build url->checksum mapping + checksums: dict = {} + if validate_checksum: + checksums = _resolve_checksums_for_urls(list(file_urls), databus_key) + + _download_files( + list(file_urls), + localDir, + vault_token_file=vault_token, + databus_key=databus_key, + auth_url=auth_url, + client_id=client_id, + convert_to=convert_to, + convert_from=convert_from, + validate_checksum=validate_checksum, + checksums=checksums if checksums else None, + ) + + +def _download_version( + uri: str, + localDir: str, + vault_token_file: str = None, + databus_key: str = None, + auth_url: str = None, + client_id: str = None, + convert_to: str = None, + convert_from: str = None, + validate_checksum: bool = False, +) -> None: + """Download all files in a databus artifact version. + + Args: + uri: The full databus artifact version URI. + localDir: Local directory to download files to. If None, the databus folder structure is created in the current working directory. + vault_token_file: Path to Vault refresh token file for protected downloads. + databus_key: Databus API key for protected downloads. + auth_url: Keycloak token endpoint URL. + client_id: Client ID for token exchange. + convert_to: Target compression format for on-the-fly conversion. + convert_from: Optional source compression format filter. + validate_checksum: Whether to validate checksums after downloading. + """ + json_str = fetch_databus_jsonld(uri, databus_key=databus_key) + file_urls = _get_file_download_urls_from_artifact_jsonld(json_str) + # build url -> checksum mapping from JSON-LD when available + checksums: dict = {} + try: + checksums = _extract_checksums_from_jsonld(json_str) + except Exception: + checksums = {} + + _download_files( + file_urls, + localDir, + vault_token_file=vault_token_file, + databus_key=databus_key, + auth_url=auth_url, + client_id=client_id, + convert_to=convert_to, + convert_from=convert_from, + validate_checksum=validate_checksum, + checksums=checksums, + ) + + +def _download_artifact( + uri: str, + localDir: str, + all_versions: bool = False, + vault_token_file: str = None, + databus_key: str = None, + auth_url: str = None, + client_id: str = None, + convert_to: str = None, + convert_from: str = None, + validate_checksum: bool = False, +) -> None: + """Download files in a databus artifact. + + Args: + uri: The full databus artifact URI. + localDir: Local directory to download files to. If None, the databus folder structure is created in the current working directory. + all_versions: If True, download all versions of the artifact; otherwise, only download the latest version. + vault_token_file: Path to Vault refresh token file for protected downloads. + databus_key: Databus API key for protected downloads. + auth_url: Keycloak token endpoint URL. + client_id: Client ID for token exchange. + convert_to: Target compression format for on-the-fly conversion. + convert_from: Optional source compression format filter. + validate_checksum: Whether to validate checksums after downloading. + """ + json_str = fetch_databus_jsonld(uri, databus_key=databus_key) + versions = _get_databus_versions_of_artifact(json_str, all_versions=all_versions) + if isinstance(versions, str): + versions = [versions] + for version_uri in versions: + print(f"Downloading version: {version_uri}") + json_str = fetch_databus_jsonld(version_uri, databus_key=databus_key) + file_urls = _get_file_download_urls_from_artifact_jsonld(json_str) + # extract checksums for this version + checksums: dict = {} + try: + checksums = _extract_checksums_from_jsonld(json_str) + except Exception: + checksums = {} + + _download_files( + file_urls, + localDir, + vault_token_file=vault_token_file, + databus_key=databus_key, + auth_url=auth_url, + client_id=client_id, + convert_to=convert_to, + convert_from=convert_from, + validate_checksum=validate_checksum, + checksums=checksums, + ) + + +def _get_databus_versions_of_artifact( + json_str: str, all_versions: bool +) -> str | List[str]: + """Parse the JSON-LD of a databus artifact to extract URLs of its versions. + + Args: + json_str: JSON-LD string of the databus artifact. + all_versions: If True, return all version URLs; otherwise, return only the latest version URL. + + Returns: + If all_versions is True: List of all version URLs. + If all_versions is False: URL of the latest version. + """ + json_dict = json.loads(json_str) + versions = json_dict.get("databus:hasVersion") + + if versions is None: + raise ValueError("No 'databus:hasVersion' field in artifact JSON-LD") + + if isinstance(versions, dict): + versions = [versions] + elif not isinstance(versions, list): + raise ValueError( + f"Unexpected type for 'databus:hasVersion': {type(versions).__name__}" + ) + + version_urls = [v["@id"] for v in versions if isinstance(v, dict) and "@id" in v] + + if not version_urls: + raise ValueError("No versions found in artifact JSON-LD") + + version_urls.sort(reverse=True) # Sort versions in descending order + + if all_versions: + return version_urls + return version_urls[0] + + +def _get_file_download_urls_from_artifact_jsonld(json_str: str) -> List[str]: + """Parse the JSON-LD of a databus artifact version to extract download URLs. + + Args: + json_str: JSON-LD string of the databus artifact version. + + Returns: + List of all file download URLs in the artifact version. + """ + + databusIdUrl: List[str] = [] + + json_dict = json.loads(json_str) + graph = json_dict.get("@graph", []) + for node in graph: + if node.get("@type") == "Part": + file_uri = node.get("file") + if not isinstance(file_uri, str): + continue + databusIdUrl.append(file_uri) + return databusIdUrl + + +def _download_group( + uri: str, + localDir: str, + all_versions: bool = False, + vault_token_file: str = None, + databus_key: str = None, + auth_url: str = None, + client_id: str = None, + convert_to: str = None, + convert_from: str = None, + validate_checksum: bool = False, +) -> None: + """Download files in a databus group. + + Args: + uri: The full databus group URI. + localDir: Local directory to download files to. If None, the databus folder structure is created in the current working directory. + all_versions: If True, download all versions of each artifact in the group; otherwise, only download the latest version. + vault_token_file: Path to Vault refresh token file for protected downloads. + databus_key: Databus API key for protected downloads. + auth_url: Keycloak token endpoint URL. + client_id: Client ID for token exchange. + convert_to: Target compression format for on-the-fly conversion. + convert_from: Optional source compression format filter. + validate_checksum: Whether to validate checksums after downloading. + """ + json_str = fetch_databus_jsonld(uri, databus_key=databus_key) + artifacts = _get_databus_artifacts_of_group(json_str) + for artifact_uri in artifacts: + print(f"Download artifact: {artifact_uri}") + _download_artifact( + artifact_uri, + localDir, + all_versions=all_versions, + vault_token_file=vault_token_file, + databus_key=databus_key, + auth_url=auth_url, + client_id=client_id, + convert_to=convert_to, + convert_from=convert_from, + validate_checksum=validate_checksum, + ) + + +def _get_databus_artifacts_of_group(json_str: str) -> List[str]: + """ + Parse the JSON-LD of a databus group to extract URLs of all artifacts. + + Returns a list of artifact URLs. + """ + json_dict = json.loads(json_str) + artifacts = json_dict.get("databus:hasArtifact") + + if artifacts is None: + return [] + + if isinstance(artifacts, dict): + artifacts_iter = [artifacts] + elif isinstance(artifacts, list): + artifacts_iter = artifacts + else: + raise ValueError( + f"Unexpected type for 'databus:hasArtifact': {type(artifacts).__name__}" + ) + + result: List[str] = [] + for item in artifacts_iter: + if not isinstance(item, dict): + continue + uri = item.get("@id") + if not uri: + continue + _, _, _, _, version, _ = get_databus_id_parts_from_file_url(uri) + if version is None: + result.append(uri) + return result + + +def download( + localDir: str, + endpoint: str, + databusURIs: List[str], + token=None, + databus_key=None, + all_versions=None, + auth_url="https://auth.dbpedia.org/realms/dbpedia/protocol/openid-connect/token", + client_id="vault-token-exchange", + convert_to=None, + convert_from=None, + validate_checksum: bool = False +) -> None: + """Download datasets from databus. + + Download of files, versions, artifacts, groups or databus collections via their databus URIs or user-defined SPARQL queries that return file download URLs. + + Args: + localDir: Local directory to download datasets to. If None, the databus folder structure is created in the current working directory. + endpoint: The databus endpoint URL. If None, inferred from databusURI. Required for user-defined SPARQL queries. + databusURIs: Databus identifiers to specify datasets to download. + token: Path to Vault refresh token file for protected downloads. + databus_key: Databus API key for protected downloads. + auth_url: Keycloak token endpoint URL. Default is "https://auth.dbpedia.org/realms/dbpedia/protocol/openid-connect/token". + client_id: Client ID for token exchange. Default is "vault-token-exchange". + convert_to: Target compression format for on-the-fly conversion (supported: bz2, gz, xz). + convert_from: Optional source compression format filter. + validate_checksum: Whether to validate checksums after downloading. + """ + for databusURI in databusURIs: + host, account, group, artifact, version, file = ( + get_databus_id_parts_from_file_url(databusURI) + ) + + # Determine endpoint per-URI if not explicitly provided + uri_endpoint = endpoint + + # dataID or databus collection + if databusURI.startswith("http://") or databusURI.startswith("https://"): + # Auto-detect sparql endpoint from host if not given + if uri_endpoint is None: + uri_endpoint = f"https://{host}/sparql" + print(f"SPARQL endpoint {uri_endpoint}") + + if group == "collections" and artifact is not None: + print(f"Downloading collection: {databusURI}") + _download_collection( + databusURI, + uri_endpoint, + localDir, + token, + databus_key, + auth_url, + client_id, + convert_to, + convert_from, + validate_checksum=validate_checksum, + ) + elif file is not None: + print(f"Downloading file: {databusURI}") + # Try to fetch expected checksum from the parent Version metadata + expected = None + if validate_checksum: + try: + if version is not None: + version_uri = f"https://{host}/{account}/{group}/{artifact}/{version}" + json_str = fetch_databus_jsonld(version_uri, databus_key=databus_key) + checks = _extract_checksums_from_jsonld(json_str) + expected = checks.get(databusURI) or checks.get( + "https://" + databusURI.removeprefix("http://").removeprefix("https://") + ) + except Exception as e: + print(f"WARNING: Could not fetch checksum for single file: {e}") + + # Call the worker to download the single file (passes expected checksum) + _download_file( + databusURI, + localDir, + vault_token_file=token, + databus_key=databus_key, + auth_url=auth_url, + client_id=client_id, + convert_to=convert_to, + convert_from=convert_from, + validate_checksum=validate_checksum, + expected_checksum=expected, + ) + elif version is not None: + print(f"Downloading version: {databusURI}") + _download_version( + databusURI, + localDir, + vault_token_file=token, + databus_key=databus_key, + auth_url=auth_url, + client_id=client_id, + convert_to=convert_to, + convert_from=convert_from, + validate_checksum=validate_checksum, + ) + elif artifact is not None: + print( + f"Downloading {'all' if all_versions else 'latest'} version(s) of artifact: {databusURI}" + ) + _download_artifact( + databusURI, + localDir, + all_versions=all_versions, + vault_token_file=token, + databus_key=databus_key, + auth_url=auth_url, + client_id=client_id, + convert_to=convert_to, + convert_from=convert_from, + validate_checksum=validate_checksum, + ) + elif group is not None and group != "collections": + print( + f"Downloading group and all its artifacts and versions: {databusURI}" + ) + _download_group( + databusURI, + localDir, + all_versions=all_versions, + vault_token_file=token, + databus_key=databus_key, + auth_url=auth_url, + client_id=client_id, + convert_to=convert_to, + convert_from=convert_from, + validate_checksum=validate_checksum, + ) + elif account is not None: + print("accountId not supported yet") # TODO + else: + print( + "dataId not supported yet" + ) # TODO add support for other DatabusIds + # query in local file + elif databusURI.startswith("file://"): + print("query in file not supported yet") + # query as argument + else: + print("QUERY {}", databusURI.replace("\n", " ")) + if uri_endpoint is None: # endpoint is required for queries (--databus) + raise ValueError("No endpoint given for query") + res = _get_file_download_urls_from_sparql_query( + uri_endpoint, databusURI, databus_key=databus_key + ) + + # If checksum validation requested, try to build url->checksum mapping + checksums: dict = {} + if validate_checksum: + checksums = _resolve_checksums_for_urls(res, databus_key) + if not checksums: + print("WARNING: Checksum validation enabled but no checksums found for query results.") + + _download_files( + res, + localDir, + vault_token_file=token, + databus_key=databus_key, + auth_url=auth_url, + client_id=client_id, + convert_to=convert_to, + convert_from=convert_from, + validate_checksum=validate_checksum, + checksums=checksums if checksums else None, + ) diff --git a/databusclient/cli.py b/databusclient/cli.py index 1daa4bb..082b43c 100644 --- a/databusclient/cli.py +++ b/databusclient/cli.py @@ -1,248 +1,266 @@ -#!/usr/bin/env python3 -import json -import os -from typing import List - -import click - -import databusclient.api.deploy as api_deploy -from databusclient.api.delete import delete as api_delete -from databusclient.api.download import download as api_download, DownloadAuthError -from databusclient.extensions import webdav - - -@click.group() -def app(): - """Databus Client CLI. - - Provides `deploy`, `download`, and `delete` commands for interacting - with the DBpedia Databus. - """ - pass - - -@app.command() -@click.option( - "--version-id", - "version_id", - required=True, - help="Target databus version/dataset identifier of the form " - "", -) -@click.option("--title", required=True, help="Artifact & Version Title: used for BOTH artifact and version. Keep stable across releases; identifies the data series.") -@click.option("--abstract", required=True, help="Artifact & Version Abstract: used for BOTH artifact and version (max 200 chars). Updating it changes both artifact and version metadata.") -@click.option("--description", required=True, help="Artifact & Version Description: used for BOTH artifact and version. Supports Markdown. Updating it changes both artifact and version metadata.") -@click.option( - "--license", "license_url", required=True, help="License (see dalicc.net)" -) -@click.option("--apikey", required=True, help="API key") -@click.option( - "--metadata", - "metadata_file", - type=click.Path(exists=True), - help="Path to metadata JSON file (for metadata mode)", -) -@click.option( - "--webdav-url", - "webdav_url", - help="WebDAV URL (e.g., https://cloud.example.com/remote.php/webdav)", -) -@click.option("--remote", help="rclone remote name (e.g., 'nextcloud')") -@click.option("--path", help="Remote path on Nextcloud (e.g., 'datasets/mydataset')") -@click.argument("distributions", nargs=-1) -def deploy( - version_id, - title, - abstract, - description, - license_url, - apikey, - metadata_file, - webdav_url, - remote, - path, - distributions: List[str], -): - """ - Flexible deploy to Databus command supporting three modes:\n - - Classic deploy (distributions as arguments)\n - - Metadata-based deploy (--metadata )\n - - Upload & deploy via Nextcloud (--webdav-url, --remote, --path) - """ - - # Sanity checks for conflicting options - if metadata_file and any([distributions, webdav_url, remote, path]): - raise click.UsageError( - "Invalid combination: when using --metadata, do not provide --webdav-url, --remote, --path, or distributions." - ) - if any([webdav_url, remote, path]) and not all([webdav_url, remote, path]): - raise click.UsageError( - "Invalid combination: when using WebDAV/Nextcloud mode, please provide --webdav-url, --remote, and --path together." - ) - - # === Mode 1: Classic Deploy === - if distributions and not (metadata_file or webdav_url or remote or path): - click.echo("[MODE] Classic deploy with distributions") - click.echo(f"Deploying dataset version: {version_id}") - - dataid = api_deploy.create_dataset( - version_id=version_id, - artifact_version_title=title, - artifact_version_abstract=abstract, - artifact_version_description=description, - license_url=license_url, - distributions=distributions - ) - api_deploy.deploy(dataid=dataid, api_key=apikey) - return - - # === Mode 2: Metadata File === - if metadata_file: - click.echo(f"[MODE] Deploy from metadata file: {metadata_file}") - with open(metadata_file, "r") as f: - metadata = json.load(f) - api_deploy.deploy_from_metadata( - metadata, version_id, title, abstract, description, license_url, apikey - ) - return - - # === Mode 3: Upload & Deploy (Nextcloud) === - if webdav_url and remote and path: - if not distributions: - raise click.UsageError( - "Please provide files to upload when using WebDAV/Nextcloud mode." - ) - - # Check that all given paths exist and are files or directories. - invalid = [f for f in distributions if not os.path.exists(f)] - if invalid: - raise click.UsageError( - f"The following input files or folders do not exist: {', '.join(invalid)}" - ) - - click.echo("[MODE] Upload & Deploy to DBpedia Databus via Nextcloud") - click.echo(f"→ Uploading to: {remote}:{path}") - metadata = webdav.upload_to_webdav(distributions, remote, path, webdav_url) - api_deploy.deploy_from_metadata( - metadata, version_id, title, abstract, description, license_url, apikey - ) - return - - raise click.UsageError( - "No valid input provided. Please use one of the following modes:\n" - " - Classic deploy: pass distributions as arguments\n" - " - Metadata deploy: use --metadata \n" - " - Upload & deploy: use --webdav-url, --remote, --path, and file arguments" - ) - - -@app.command() -@click.argument("databusuris", nargs=-1, required=True) -@click.option( - "--localdir", - help="Local databus folder (if not given, databus folder structure is created in current working directory)", -) -@click.option( - "--databus", - help="Databus URL (if not given, inferred from databusuri, e.g. https://databus.dbpedia.org/sparql)", -) -@click.option("--vault-token", help="Path to Vault refresh token file") -@click.option( - "--databus-key", help="Databus API key to download from protected databus" -) -@click.option( - "--all-versions", - is_flag=True, - help="When downloading artifacts, download all versions instead of only the latest", -) -@click.option( - "--authurl", - default="https://auth.dbpedia.org/realms/dbpedia/protocol/openid-connect/token", - show_default=True, - help="Keycloak token endpoint URL", -) -@click.option( - "--clientid", - default="vault-token-exchange", - show_default=True, - help="Client ID for token exchange", -) -@click.option( - "--convert-to", - type=click.Choice(["bz2", "gz", "xz"], case_sensitive=False), - help="Target compression format for on-the-fly conversion during download (supported: bz2, gz, xz)", -) -@click.option( - "--convert-from", - type=click.Choice(["bz2", "gz", "xz"], case_sensitive=False), - help="Source compression format to convert from (optional filter). Only files with this compression will be converted.", -) -@click.option( - "--validate-checksum", - is_flag=True, - help="Validate checksums of downloaded files" -) -def download( - databusuris: List[str], - localdir, - databus, - vault_token, - databus_key, - all_versions, - authurl, - clientid, - convert_to, - convert_from, - validate_checksum, -): - """ - Download datasets from databus, optionally using vault access if vault options are provided. - Supports on-the-fly compression format conversion using --convert-to and --convert-from options. - """ - try: - api_download( - localDir=localdir, - endpoint=databus, - databusURIs=databusuris, - token=vault_token, - databus_key=databus_key, - all_versions=all_versions, - auth_url=authurl, - client_id=clientid, - convert_to=convert_to, - convert_from=convert_from, - validate_checksum=validate_checksum, - ) - except DownloadAuthError as e: - raise click.ClickException(str(e)) - - -@app.command() -@click.argument("databusuris", nargs=-1, required=True) -@click.option( - "--databus-key", help="Databus API key to access protected databus", required=True -) -@click.option( - "--dry-run", is_flag=True, help="Perform a dry run without actual deletion" -) -@click.option( - "--force", is_flag=True, help="Force deletion without confirmation prompt" -) -def delete(databusuris: List[str], databus_key: str, dry_run: bool, force: bool): - """ - Delete a dataset from the databus. - - Delete a group, artifact, or version identified by the given databus URI. - Will recursively delete all data associated with the dataset. - """ - - api_delete( - databusURIs=databusuris, - databus_key=databus_key, - dry_run=dry_run, - force=force, - ) - - -if __name__ == "__main__": - app() +#!/usr/bin/env python3 +import json +import os +from typing import List + +import click + +import databusclient.api.deploy as api_deploy +from databusclient.api.delete import delete as api_delete +from databusclient.api.download import download as api_download, DownloadAuthError +from databusclient.extensions import webdav + + +@click.group() +def app(): + """Databus Client CLI. + + Provides `deploy`, `download`, and `delete` commands for interacting + with the DBpedia Databus. + """ + pass + + +@app.command() +@click.option( + "--version-id", + "version_id", + required=True, + help="Target databus version/dataset identifier of the form " + "", +) +@click.option("--title", required=True, help="Artifact & Version Title: used for BOTH artifact and version. Keep stable across releases; identifies the data series.") +@click.option("--abstract", required=True, help="Artifact & Version Abstract: used for BOTH artifact and version (max 200 chars). Updating it changes both artifact and version metadata.") +@click.option("--description", required=True, help="Artifact & Version Description: used for BOTH artifact and version. Supports Markdown. Updating it changes both artifact and version metadata.") +@click.option( + "--license", "license_url", required=True, help="License (see dalicc.net)" +) +@click.option("--apikey", required=True, help="API key") +@click.option( + "--metadata", + "metadata_file", + type=click.Path(exists=True), + help="Path to metadata JSON file (for metadata mode)", +) +@click.option( + "--webdav-url", + "webdav_url", + help="WebDAV URL (e.g., https://cloud.example.com/remote.php/webdav)", +) +@click.option("--remote", help="rclone remote name (e.g., 'nextcloud')") +@click.option("--path", help="Remote path on Nextcloud (e.g., 'datasets/mydataset')") +@click.argument("distributions", nargs=-1) +def deploy( + version_id, + title, + abstract, + description, + license_url, + apikey, + metadata_file, + webdav_url, + remote, + path, + distributions: List[str], +): + """ + Flexible deploy to Databus command supporting three modes:\n + - Classic deploy (distributions as arguments)\n + - Metadata-based deploy (--metadata )\n + - Upload & deploy via Nextcloud (--webdav-url, --remote, --path) + """ + + # Sanity checks for conflicting options + if metadata_file and any([distributions, webdav_url, remote, path]): + raise click.UsageError( + "Invalid combination: when using --metadata, do not provide --webdav-url, --remote, --path, or distributions." + ) + if any([webdav_url, remote, path]) and not all([webdav_url, remote, path]): + raise click.UsageError( + "Invalid combination: when using WebDAV/Nextcloud mode, please provide --webdav-url, --remote, and --path together." + ) + + # === Mode 1: Classic Deploy === + if distributions and not (metadata_file or webdav_url or remote or path): + click.echo("[MODE] Classic deploy with distributions") + click.echo(f"Deploying dataset version: {version_id}") + + dataid = api_deploy.create_dataset( + version_id=version_id, + artifact_version_title=title, + artifact_version_abstract=abstract, + artifact_version_description=description, + license_url=license_url, + distributions=distributions + ) + api_deploy.deploy(dataid=dataid, api_key=apikey) + return + + # === Mode 2: Metadata File === + if metadata_file: + click.echo(f"[MODE] Deploy from metadata file: {metadata_file}") + with open(metadata_file, "r") as f: + metadata = json.load(f) + api_deploy.deploy_from_metadata( + metadata, version_id, title, abstract, description, license_url, apikey + ) + return + + # === Mode 3: Upload & Deploy (Nextcloud) === + if webdav_url and remote and path: + if not distributions: + raise click.UsageError( + "Please provide files to upload when using WebDAV/Nextcloud mode." + ) + + # Check that all given paths exist and are files or directories. + invalid = [f for f in distributions if not os.path.exists(f)] + if invalid: + raise click.UsageError( + f"The following input files or folders do not exist: {', '.join(invalid)}" + ) + + click.echo("[MODE] Upload & Deploy to DBpedia Databus via Nextcloud") + click.echo(f"→ Uploading to: {remote}:{path}") + metadata = webdav.upload_to_webdav(distributions, remote, path, webdav_url) + api_deploy.deploy_from_metadata( + metadata, version_id, title, abstract, description, license_url, apikey + ) + return + + raise click.UsageError( + "No valid input provided. Please use one of the following modes:\n" + " - Classic deploy: pass distributions as arguments\n" + " - Metadata deploy: use --metadata \n" + " - Upload & deploy: use --webdav-url, --remote, --path, and file arguments" + ) + + +@app.command() +@click.argument("databusuris", nargs=-1, required=True) +@click.option( + "--localdir", + help="Local databus folder (if not given, databus folder structure is created in current working directory)", +) +@click.option( + "--databus", + help="Databus URL (if not given, inferred from databusuri, e.g. https://databus.dbpedia.org/sparql)", +) +@click.option("--vault-token", help="Path to Vault refresh token file") +@click.option( + "--databus-key", help="Databus API key to download from protected databus" +) +@click.option( + "--all-versions", + is_flag=True, + help="When downloading artifacts, download all versions instead of only the latest", +) +@click.option( + "--authurl", + default="https://auth.dbpedia.org/realms/dbpedia/protocol/openid-connect/token", + show_default=True, + help="Keycloak token endpoint URL", +) +@click.option( + "--clientid", + default="vault-token-exchange", + show_default=True, + help="Client ID for token exchange", +) +@click.option( + "--convert-to", + type=click.Choice(["bz2", "gz", "xz", "none"], case_sensitive=False), + help="Target compression format for on-the-fly conversion during download. " + "Use 'none' to decompress files to raw format.", +) +@click.option( + "--convert-from", + type=click.Choice(["bz2", "gz", "xz", "none"], case_sensitive=False), + help="Source compression format to convert from (optional filter). " + "Use 'none' when compressing uncompressed files.", +) +@click.option( + "--decompress", + is_flag=True, + help="Decompress downloaded files to raw format. Shorthand for --convert-to none.", +) +@click.option( + "--validate-checksum", + is_flag=True, + help="Validate checksums of downloaded files" +) +def download( + databusuris: List[str], + localdir, + databus, + vault_token, + databus_key, + all_versions, + authurl, + clientid, + convert_to, + convert_from, + decompress, + validate_checksum, +): + """ + Download datasets from databus, optionally using vault access if vault options are provided. + Supports on-the-fly compression format conversion using --convert-to and --convert-from options. + Use --decompress (or --convert-to none) to download and decompress files to raw format. + """ + # --decompress is shorthand for --convert-to none + if decompress: + if convert_to is not None: + raise click.UsageError( + "Cannot use --decompress together with --convert-to. " + "Use one or the other." + ) + convert_to = "none" + + try: + api_download( + localDir=localdir, + endpoint=databus, + databusURIs=databusuris, + token=vault_token, + databus_key=databus_key, + all_versions=all_versions, + auth_url=authurl, + client_id=clientid, + convert_to=convert_to, + convert_from=convert_from, + validate_checksum=validate_checksum, + ) + except DownloadAuthError as e: + raise click.ClickException(str(e)) + + +@app.command() +@click.argument("databusuris", nargs=-1, required=True) +@click.option( + "--databus-key", help="Databus API key to access protected databus", required=True +) +@click.option( + "--dry-run", is_flag=True, help="Perform a dry run without actual deletion" +) +@click.option( + "--force", is_flag=True, help="Force deletion without confirmation prompt" +) +def delete(databusuris: List[str], databus_key: str, dry_run: bool, force: bool): + """ + Delete a dataset from the databus. + + Delete a group, artifact, or version identified by the given databus URI. + Will recursively delete all data associated with the dataset. + """ + + api_delete( + databusURIs=databusuris, + databus_key=databus_key, + dry_run=dry_run, + force=force, + ) + + +if __name__ == "__main__": + app() diff --git a/databusclient/extensions/__init__.py b/databusclient/extensions/__init__.py index 8b13789..6eec02e 100644 --- a/databusclient/extensions/__init__.py +++ b/databusclient/extensions/__init__.py @@ -1 +1,13 @@ - +from .file_converter import ( + FileConverter, + COMPRESSION_EXTENSIONS, + COMPRESSION_MODULES, + MAGIC_NUMBERS, +) + +__all__ = [ + "FileConverter", + "COMPRESSION_EXTENSIONS", + "COMPRESSION_MODULES", + "MAGIC_NUMBERS", +] diff --git a/databusclient/extensions/file_converter.py b/databusclient/extensions/file_converter.py new file mode 100644 index 0000000..09d55e5 --- /dev/null +++ b/databusclient/extensions/file_converter.py @@ -0,0 +1,472 @@ +"""File format conversion extension for databus-python-client. + +Provides streaming pipeline for file decompression, re-compression, +and checksum validation during download operations. + +Supports gzip, bz2, xz formats natively. Optional zstd support is +available when the ``zstandard`` package is installed. + +The special format name ``'none'`` represents an uncompressed / raw file. +""" + +import bz2 +import gzip +import hashlib +import lzma +import os +from typing import BinaryIO, Dict, Optional + +# --- Optional zstd support --------------------------------------------------- +try: + import zstandard as _zstd + + _HAS_ZSTD = True +except ImportError: # pragma: no cover + _zstd = None + _HAS_ZSTD = False + + +# --------------------------------------------------------------------------- +# Module-level constants +# --------------------------------------------------------------------------- + +COMPRESSION_EXTENSIONS: Dict[str, str] = { + "bz2": ".bz2", + "gz": ".gz", + "xz": ".xz", +} + +COMPRESSION_MODULES: Dict[str, object] = { + "bz2": bz2, + "gz": gzip, + "xz": lzma, +} + +if _HAS_ZSTD: + COMPRESSION_EXTENSIONS["zstd"] = ".zst" + # zstandard doesn't expose a module-level open(); handled specially. + COMPRESSION_MODULES["zstd"] = _zstd + +# Magic-number signatures (first N bytes -> format). +MAGIC_NUMBERS: Dict[bytes, str] = { + b"\x1f\x8b": "gz", # gzip + b"BZ": "bz2", # bzip2 (BZh...) + b"\xfd7zXZ\x00": "xz", # xz / LZMA +} +if _HAS_ZSTD: + MAGIC_NUMBERS[b"\x28\xb5\x2f\xfd"] = "zstd" + + +class FileConverter: + """Handles file format conversion with streaming support. + + All public methods are ``@staticmethod``; instantiation is not required. + """ + + CHUNK_SIZE = 8192 # 8 KiB chunks for streaming + + # ------------------------------------------------------------------ + # Format detection + # ------------------------------------------------------------------ + + @staticmethod + def detect_format(filename: str, header_bytes: Optional[bytes] = None) -> str: + """Detect the compression format of a file. + + Checks the file *extension* first. When *header_bytes* are provided + the magic-number signature is also inspected and takes precedence if + the extension is ambiguous. + + Args: + filename: Name (or path) of the file. + header_bytes: Optional first bytes of the file content for + magic-number detection. + + Returns: + ``'gz'``, ``'bz2'``, ``'xz'``, ``'zstd'`` or ``'none'``. + """ + # 1) Extension-based detection + filename_lower = filename.lower() + for fmt, ext in COMPRESSION_EXTENSIONS.items(): + if filename_lower.endswith(ext): + return fmt + + # 2) Magic-number fallback + if header_bytes: + detected = FileConverter.detect_format_by_magic(header_bytes) + if detected != "none": + return detected + + return "none" + + @staticmethod + def detect_format_by_magic(header_bytes: bytes) -> str: + """Detect compression format from raw magic bytes. + + Args: + header_bytes: The first bytes of file content (≥6 bytes + recommended). + + Returns: + ``'gz'``, ``'bz2'``, ``'xz'``, ``'zstd'`` or ``'none'``. + """ + for magic, fmt in MAGIC_NUMBERS.items(): + if header_bytes[: len(magic)] == magic: + return fmt + return "none" + + # ------------------------------------------------------------------ + # Individual stream helpers (gzip) + # ------------------------------------------------------------------ + + @staticmethod + def decompress_gzip_stream( + input_stream: BinaryIO, + output_stream: BinaryIO, + validate_checksum: bool = False, + ) -> Optional[str]: + """Decompress gzip stream with optional checksum computation. + + Args: + input_stream: Input gzip compressed stream. + output_stream: Output decompressed stream. + validate_checksum: Whether to compute a SHA-256 checksum of + the decompressed output. + + Returns: + Hex-encoded SHA-256 checksum when *validate_checksum* is + ``True``, otherwise ``None``. + """ + hasher = hashlib.sha256() if validate_checksum else None + + with gzip.open(input_stream, "rb") as gz: + while True: + chunk = gz.read(FileConverter.CHUNK_SIZE) + if not chunk: + break + output_stream.write(chunk) + if hasher: + hasher.update(chunk) + + return hasher.hexdigest() if hasher else None + + @staticmethod + def compress_gzip_stream( + input_stream: BinaryIO, output_stream: BinaryIO + ) -> None: + """Compress stream to gzip format.""" + with gzip.open(output_stream, "wb") as gz: + while True: + chunk = input_stream.read(FileConverter.CHUNK_SIZE) + if not chunk: + break + gz.write(chunk) + + # ------------------------------------------------------------------ + # Individual stream helpers (bz2) + # ------------------------------------------------------------------ + + @staticmethod + def decompress_bz2_stream( + input_stream: BinaryIO, + output_stream: BinaryIO, + validate_checksum: bool = False, + ) -> Optional[str]: + """Decompress bz2 stream with optional checksum computation.""" + hasher = hashlib.sha256() if validate_checksum else None + + with bz2.open(input_stream, "rb") as bf: + while True: + chunk = bf.read(FileConverter.CHUNK_SIZE) + if not chunk: + break + output_stream.write(chunk) + if hasher: + hasher.update(chunk) + + return hasher.hexdigest() if hasher else None + + @staticmethod + def compress_bz2_stream( + input_stream: BinaryIO, output_stream: BinaryIO + ) -> None: + """Compress stream to bz2 format.""" + with bz2.open(output_stream, "wb") as bf: + while True: + chunk = input_stream.read(FileConverter.CHUNK_SIZE) + if not chunk: + break + bf.write(chunk) + + # ------------------------------------------------------------------ + # Individual stream helpers (xz / LZMA) + # ------------------------------------------------------------------ + + @staticmethod + def decompress_xz_stream( + input_stream: BinaryIO, + output_stream: BinaryIO, + validate_checksum: bool = False, + ) -> Optional[str]: + """Decompress xz stream with optional checksum computation.""" + hasher = hashlib.sha256() if validate_checksum else None + + with lzma.open(input_stream, "rb") as xf: + while True: + chunk = xf.read(FileConverter.CHUNK_SIZE) + if not chunk: + break + output_stream.write(chunk) + if hasher: + hasher.update(chunk) + + return hasher.hexdigest() if hasher else None + + @staticmethod + def compress_xz_stream( + input_stream: BinaryIO, output_stream: BinaryIO + ) -> None: + """Compress stream to xz format.""" + with lzma.open(output_stream, "wb") as xf: + while True: + chunk = input_stream.read(FileConverter.CHUNK_SIZE) + if not chunk: + break + xf.write(chunk) + + # ------------------------------------------------------------------ + # Checksum validation + # ------------------------------------------------------------------ + + @staticmethod + def validate_checksum_stream( + input_stream: BinaryIO, expected_checksum: str + ) -> bool: + """Validate SHA-256 checksum of a stream. + + The stream is rewound to position 0 both before and after reading. + + Args: + input_stream: Seekable input stream. + expected_checksum: Expected SHA-256 hex digest. + + Returns: + ``True`` if checksum matches. + + Raises: + IOError: If the checksum does not match. + """ + hasher = hashlib.sha256() + input_stream.seek(0) + + while True: + chunk = input_stream.read(FileConverter.CHUNK_SIZE) + if not chunk: + break + hasher.update(chunk) + + computed = hasher.hexdigest() + input_stream.seek(0) + if computed.lower() != expected_checksum.lower(): + raise IOError( + f"Checksum mismatch: expected {expected_checksum}, got {computed}" + ) + return True + + # ------------------------------------------------------------------ + # High-level: convert on-disk files + # ------------------------------------------------------------------ + + @staticmethod + def convert_file( + source_path: str, + target_path: str, + source_format: str, + target_format: str, + ) -> None: + """Convert a file between compression formats. + + ``source_format`` / ``target_format`` may be ``'none'`` to mean + "raw / uncompressed". So ``convert_file(f, t, 'gz', 'none')`` + decompresses and ``convert_file(f, t, 'none', 'bz2')`` compresses. + + The *source_path* is removed on success. + + Raises: + ValueError: If a format is not recognised. + RuntimeError: If the conversion fails. + """ + _validate_format(source_format, "source") + _validate_format(target_format, "target") + + print( + f"Converting {source_format} → {target_format}: " + f"{os.path.basename(source_path)}" + ) + + try: + with _open_reader(source_path, source_format) as reader: + with _open_writer(target_path, target_format) as writer: + while True: + chunk = reader.read(FileConverter.CHUNK_SIZE) + if not chunk: + break + writer.write(chunk) + + os.remove(source_path) + print(f"Conversion complete: {os.path.basename(target_path)}") + except Exception as e: + if os.path.exists(target_path): + os.remove(target_path) + raise RuntimeError(f"Compression conversion failed: {e}") + + # ------------------------------------------------------------------ + # High-level: streaming conversion on file-like objects + # ------------------------------------------------------------------ + + @staticmethod + def convert_stream( + input_stream: BinaryIO, + output_stream: BinaryIO, + source_format: str, + target_format: str, + validate_checksum: bool = False, + ) -> Optional[str]: + """Stream conversion between two file-like objects. + + Data is read from *input_stream*, decompressed (if + ``source_format != 'none'``), recompressed (if + ``target_format != 'none'``), and written to *output_stream*. + + When *validate_checksum* is ``True`` the SHA-256 digest of the + **decompressed** (intermediate) bytes is returned. + + Args: + input_stream: Source file-like object (binary read). + output_stream: Target file-like object (binary write). + source_format: Compression format of *input_stream*. + target_format: Compression format for *output_stream*. + validate_checksum: Compute SHA-256 of decompressed data. + + Returns: + Hex SHA-256 digest when *validate_checksum* is ``True``, + otherwise ``None``. + """ + _validate_format(source_format, "source") + _validate_format(target_format, "target") + + hasher = hashlib.sha256() if validate_checksum else None + + # Build a reader wrapper that yields decompressed chunks + reader = _wrap_reader(input_stream, source_format) + writer_ctx = _wrap_writer(output_stream, target_format) + + with writer_ctx as writer: + while True: + chunk = reader.read(FileConverter.CHUNK_SIZE) + if not chunk: + break + if hasher: + hasher.update(chunk) + writer.write(chunk) + + # Close the reader wrapper if it supports it + if hasattr(reader, "close") and reader is not input_stream: + reader.close() + + return hasher.hexdigest() if hasher else None + + # ------------------------------------------------------------------ + # Filename helpers + # ------------------------------------------------------------------ + + @staticmethod + def get_converted_filename( + filename: str, source_format: str, target_format: str + ) -> str: + """Generate the new filename after format conversion. + + Handles ``'none'`` by stripping / adding extensions as needed. + """ + # Strip existing compression extension (if any) + if source_format != "none": + source_ext = COMPRESSION_EXTENSIONS[source_format] + if filename.lower().endswith(source_ext): + filename = filename[: -len(source_ext)] + + # Append new compression extension (if any) + if target_format != "none": + target_ext = COMPRESSION_EXTENSIONS[target_format] + filename = filename + target_ext + + return filename + + +# --------------------------------------------------------------------------- +# Private helpers +# --------------------------------------------------------------------------- + +_VALID_FORMATS = set(COMPRESSION_MODULES.keys()) | {"none"} + + +def _validate_format(fmt: str, label: str) -> None: + if fmt not in _VALID_FORMATS: + raise ValueError( + f"Unsupported {label} compression format: {fmt}. " + f"Supported formats: {sorted(_VALID_FORMATS)}" + ) + + +def _open_reader(path: str, fmt: str): + """Return a context-manager file object for *reading*.""" + if fmt == "none": + return open(path, "rb") + if fmt == "zstd" and _HAS_ZSTD: + fh = open(path, "rb") + dctx = _zstd.ZstdDecompressor() + return dctx.stream_reader(fh) + return COMPRESSION_MODULES[fmt].open(path, "rb") + + +def _open_writer(path: str, fmt: str): + """Return a context-manager file object for *writing*.""" + if fmt == "none": + return open(path, "wb") + if fmt == "zstd" and _HAS_ZSTD: + fh = open(path, "wb") + cctx = _zstd.ZstdCompressor() + return cctx.stream_writer(fh) + return COMPRESSION_MODULES[fmt].open(path, "wb") + + +class _NullCtx: + """Tiny wrapper to turn a plain file-like into a no-op context manager.""" + + def __init__(self, obj): + self._obj = obj + + def __enter__(self): + return self._obj + + def __exit__(self, *exc): + return False + + +def _wrap_reader(stream: BinaryIO, fmt: str): + """Wrap *stream* so that ``read()`` yields decompressed bytes.""" + if fmt == "none": + return stream + if fmt == "zstd" and _HAS_ZSTD: + dctx = _zstd.ZstdDecompressor() + return dctx.stream_reader(stream) + return COMPRESSION_MODULES[fmt].open(stream, "rb") + + +def _wrap_writer(stream: BinaryIO, fmt: str): + """Return a context-manager wrapping *stream* for compressed writing.""" + if fmt == "none": + return _NullCtx(stream) + if fmt == "zstd" and _HAS_ZSTD: + cctx = _zstd.ZstdCompressor() + return cctx.stream_writer(stream) + return COMPRESSION_MODULES[fmt].open(stream, "wb") diff --git a/tests/test_compression_conversion.py b/tests/test_compression_conversion.py index a8c7618..5a58431 100644 --- a/tests/test_compression_conversion.py +++ b/tests/test_compression_conversion.py @@ -1,198 +1,198 @@ -"""Tests for on-the-fly compression conversion feature""" - -import os -import gzip -import bz2 -import lzma -import tempfile -import pytest -from databusclient.api.download import ( - _detect_compression_format, - _should_convert_file, - _get_converted_filename, - _convert_compression_format, -) - - -def test_detect_compression_format(): - """Test compression format detection from filenames""" - assert _detect_compression_format("file.txt.bz2") == "bz2" - assert _detect_compression_format("file.txt.gz") == "gz" - assert _detect_compression_format("file.txt.xz") == "xz" - assert _detect_compression_format("file.txt") is None - assert _detect_compression_format("FILE.TXT.GZ") == "gz" # case insensitive - - -def test_should_convert_file(): - """Test file conversion decision logic""" - # No conversion target specified - should_convert, source = _should_convert_file("file.txt.bz2", None, None) - assert should_convert is False - assert source is None - - # Uncompressed file - should_convert, source = _should_convert_file("file.txt", "gz", None) - assert should_convert is False - assert source is None - - # Same source and target - should_convert, source = _should_convert_file("file.txt.gz", "gz", None) - assert should_convert is False - assert source is None - - # Valid conversion - should_convert, source = _should_convert_file("file.txt.bz2", "gz", None) - assert should_convert is True - assert source == "bz2" - - # With convert_from filter matching - should_convert, source = _should_convert_file("file.txt.bz2", "gz", "bz2") - assert should_convert is True - assert source == "bz2" - - # With convert_from filter not matching - should_convert, source = _should_convert_file("file.txt.bz2", "gz", "xz") - assert should_convert is False - assert source is None - - -def test_get_converted_filename(): - """Test filename conversion""" - assert _get_converted_filename("data.txt.bz2", "bz2", "gz") == "data.txt.gz" - assert _get_converted_filename("data.txt.gz", "gz", "xz") == "data.txt.xz" - assert _get_converted_filename("data.txt.xz", "xz", "bz2") == "data.txt.bz2" - - -def test_convert_compression_format(): - """Test actual compression format conversion""" - with tempfile.TemporaryDirectory() as tmpdir: - # Create test data - test_data = b"This is test data for compression conversion " * 100 - - # Create a bz2 file - bz2_file = os.path.join(tmpdir, "test.txt.bz2") - with bz2.open(bz2_file, 'wb') as f: - f.write(test_data) - - # Convert bz2 to gz - gz_file = os.path.join(tmpdir, "test.txt.gz") - _convert_compression_format(bz2_file, gz_file, "bz2", "gz") - - # Verify the original file was removed - assert not os.path.exists(bz2_file) - - # Verify the new file exists and contains the same data - assert os.path.exists(gz_file) - with gzip.open(gz_file, 'rb') as f: - decompressed = f.read() - assert decompressed == test_data - - -def test_convert_gz_to_xz(): - """Test conversion from gzip to xz""" - with tempfile.TemporaryDirectory() as tmpdir: - # Create test data - test_data = b"Conversion test: gz to xz format" * 50 - - # Create a gz file - gz_file = os.path.join(tmpdir, "test.txt.gz") - with gzip.open(gz_file, 'wb') as f: - f.write(test_data) - - # Convert gz to xz - xz_file = os.path.join(tmpdir, "test.txt.xz") - _convert_compression_format(gz_file, xz_file, "gz", "xz") - - # Verify conversion - assert not os.path.exists(gz_file) - assert os.path.exists(xz_file) - with lzma.open(xz_file, 'rb') as f: - decompressed = f.read() - assert decompressed == test_data - - -def test_convert_xz_to_bz2(): - """Test conversion from xz to bz2""" - with tempfile.TemporaryDirectory() as tmpdir: - # Create test data - test_data = b"XZ to BZ2 compression conversion test" * 75 - - # Create an xz file - xz_file = os.path.join(tmpdir, "test.txt.xz") - with lzma.open(xz_file, 'wb') as f: - f.write(test_data) - - # Convert xz to bz2 - bz2_file = os.path.join(tmpdir, "test.txt.bz2") - _convert_compression_format(xz_file, bz2_file, "xz", "bz2") - - # Verify conversion - assert not os.path.exists(xz_file) - assert os.path.exists(bz2_file) - with bz2.open(bz2_file, 'rb') as f: - decompressed = f.read() - assert decompressed == test_data - - -def test_case_insensitive_filename_conversion(): - """Test that uppercase extensions are handled correctly (addresses PR feedback)""" - # Test uppercase extension matching - assert _get_converted_filename("FILE.BZ2", "bz2", "gz") == "FILE.gz" - assert _get_converted_filename("data.GZ", "gz", "xz") == "data.xz" - assert _get_converted_filename("archive.XZ", "xz", "bz2") == "archive.bz2" - - # Test mixed case - assert _get_converted_filename("File.Bz2", "bz2", "gz") == "File.gz" - - -def test_invalid_source_format_validation(): - """Test that invalid source format raises ValueError (addresses PR feedback)""" - with tempfile.TemporaryDirectory() as tmpdir: - source_file = os.path.join(tmpdir, "test.zip") - target_file = os.path.join(tmpdir, "test.gz") - - # Create a dummy file - with open(source_file, 'wb') as f: - f.write(b"test data") - - # Should raise ValueError for unsupported format - with pytest.raises(ValueError, match="Unsupported source compression format"): - _convert_compression_format(source_file, target_file, "zip", "gz") - - -def test_invalid_target_format_validation(): - """Test that invalid target format raises ValueError (addresses PR feedback)""" - with tempfile.TemporaryDirectory() as tmpdir: - source_file = os.path.join(tmpdir, "test.gz") - target_file = os.path.join(tmpdir, "test.rar") - - # Create a valid gz file - test_data = b"test data" - with gzip.open(source_file, 'wb') as f: - f.write(test_data) - - # Should raise ValueError for unsupported format - with pytest.raises(ValueError, match="Unsupported target compression format"): - _convert_compression_format(source_file, target_file, "gz", "rar") - - -def test_corrupted_file_handling(): - """Test that corrupted files are handled gracefully and target file is cleaned up""" - with tempfile.TemporaryDirectory() as tmpdir: - source_file = os.path.join(tmpdir, "corrupted.bz2") - target_file = os.path.join(tmpdir, "target.gz") - - # Create a file with .bz2 extension but invalid content - with open(source_file, 'wb') as f: - f.write(b"This is not valid bz2 compressed data") - - # Should raise RuntimeError - with pytest.raises(RuntimeError, match="Compression conversion failed"): - _convert_compression_format(source_file, target_file, "bz2", "gz") - - # Verify target file was cleaned up - assert not os.path.exists(target_file) - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) +"""Tests for on-the-fly compression conversion feature""" + +import os +import gzip +import bz2 +import lzma +import tempfile +import pytest +from databusclient.api.download import ( + _detect_compression_format, + _should_convert_file, + _get_converted_filename, + _convert_compression_format, +) + + +def test_detect_compression_format(): + """Test compression format detection from filenames""" + assert _detect_compression_format("file.txt.bz2") == "bz2" + assert _detect_compression_format("file.txt.gz") == "gz" + assert _detect_compression_format("file.txt.xz") == "xz" + assert _detect_compression_format("file.txt") == "none" + assert _detect_compression_format("FILE.TXT.GZ") == "gz" # case insensitive + + +def test_should_convert_file(): + """Test file conversion decision logic""" + # No conversion target specified + should_convert, source = _should_convert_file("file.txt.bz2", None, None) + assert should_convert is False + assert source is None + + # Uncompressed file, no convert_from='none' -> don't convert + should_convert, source = _should_convert_file("file.txt", "gz", None) + assert should_convert is False + assert source is None + + # Same source and target + should_convert, source = _should_convert_file("file.txt.gz", "gz", None) + assert should_convert is False + assert source is None + + # Valid conversion + should_convert, source = _should_convert_file("file.txt.bz2", "gz", None) + assert should_convert is True + assert source == "bz2" + + # With convert_from filter matching + should_convert, source = _should_convert_file("file.txt.bz2", "gz", "bz2") + assert should_convert is True + assert source == "bz2" + + # With convert_from filter not matching + should_convert, source = _should_convert_file("file.txt.bz2", "gz", "xz") + assert should_convert is False + assert source is None + + +def test_get_converted_filename(): + """Test filename conversion""" + assert _get_converted_filename("data.txt.bz2", "bz2", "gz") == "data.txt.gz" + assert _get_converted_filename("data.txt.gz", "gz", "xz") == "data.txt.xz" + assert _get_converted_filename("data.txt.xz", "xz", "bz2") == "data.txt.bz2" + + +def test_convert_compression_format(): + """Test actual compression format conversion""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create test data + test_data = b"This is test data for compression conversion " * 100 + + # Create a bz2 file + bz2_file = os.path.join(tmpdir, "test.txt.bz2") + with bz2.open(bz2_file, 'wb') as f: + f.write(test_data) + + # Convert bz2 to gz + gz_file = os.path.join(tmpdir, "test.txt.gz") + _convert_compression_format(bz2_file, gz_file, "bz2", "gz") + + # Verify the original file was removed + assert not os.path.exists(bz2_file) + + # Verify the new file exists and contains the same data + assert os.path.exists(gz_file) + with gzip.open(gz_file, 'rb') as f: + decompressed = f.read() + assert decompressed == test_data + + +def test_convert_gz_to_xz(): + """Test conversion from gzip to xz""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create test data + test_data = b"Conversion test: gz to xz format" * 50 + + # Create a gz file + gz_file = os.path.join(tmpdir, "test.txt.gz") + with gzip.open(gz_file, 'wb') as f: + f.write(test_data) + + # Convert gz to xz + xz_file = os.path.join(tmpdir, "test.txt.xz") + _convert_compression_format(gz_file, xz_file, "gz", "xz") + + # Verify conversion + assert not os.path.exists(gz_file) + assert os.path.exists(xz_file) + with lzma.open(xz_file, 'rb') as f: + decompressed = f.read() + assert decompressed == test_data + + +def test_convert_xz_to_bz2(): + """Test conversion from xz to bz2""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create test data + test_data = b"XZ to BZ2 compression conversion test" * 75 + + # Create an xz file + xz_file = os.path.join(tmpdir, "test.txt.xz") + with lzma.open(xz_file, 'wb') as f: + f.write(test_data) + + # Convert xz to bz2 + bz2_file = os.path.join(tmpdir, "test.txt.bz2") + _convert_compression_format(xz_file, bz2_file, "xz", "bz2") + + # Verify conversion + assert not os.path.exists(xz_file) + assert os.path.exists(bz2_file) + with bz2.open(bz2_file, 'rb') as f: + decompressed = f.read() + assert decompressed == test_data + + +def test_case_insensitive_filename_conversion(): + """Test that uppercase extensions are handled correctly (addresses PR feedback)""" + # Test uppercase extension matching + assert _get_converted_filename("FILE.BZ2", "bz2", "gz") == "FILE.gz" + assert _get_converted_filename("data.GZ", "gz", "xz") == "data.xz" + assert _get_converted_filename("archive.XZ", "xz", "bz2") == "archive.bz2" + + # Test mixed case + assert _get_converted_filename("File.Bz2", "bz2", "gz") == "File.gz" + + +def test_invalid_source_format_validation(): + """Test that invalid source format raises ValueError (addresses PR feedback)""" + with tempfile.TemporaryDirectory() as tmpdir: + source_file = os.path.join(tmpdir, "test.zip") + target_file = os.path.join(tmpdir, "test.gz") + + # Create a dummy file + with open(source_file, 'wb') as f: + f.write(b"test data") + + # Should raise ValueError for unsupported format + with pytest.raises(ValueError, match="Unsupported source compression format"): + _convert_compression_format(source_file, target_file, "zip", "gz") + + +def test_invalid_target_format_validation(): + """Test that invalid target format raises ValueError (addresses PR feedback)""" + with tempfile.TemporaryDirectory() as tmpdir: + source_file = os.path.join(tmpdir, "test.gz") + target_file = os.path.join(tmpdir, "test.rar") + + # Create a valid gz file + test_data = b"test data" + with gzip.open(source_file, 'wb') as f: + f.write(test_data) + + # Should raise ValueError for unsupported format + with pytest.raises(ValueError, match="Unsupported target compression format"): + _convert_compression_format(source_file, target_file, "gz", "rar") + + +def test_corrupted_file_handling(): + """Test that corrupted files are handled gracefully and target file is cleaned up""" + with tempfile.TemporaryDirectory() as tmpdir: + source_file = os.path.join(tmpdir, "corrupted.bz2") + target_file = os.path.join(tmpdir, "target.gz") + + # Create a file with .bz2 extension but invalid content + with open(source_file, 'wb') as f: + f.write(b"This is not valid bz2 compressed data") + + # Should raise RuntimeError + with pytest.raises(RuntimeError, match="Compression conversion failed"): + _convert_compression_format(source_file, target_file, "bz2", "gz") + + # Verify target file was cleaned up + assert not os.path.exists(target_file) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_file_converter.py b/tests/test_file_converter.py new file mode 100644 index 0000000..052216c --- /dev/null +++ b/tests/test_file_converter.py @@ -0,0 +1,391 @@ +"""Tests for the FileConverter module and new conversion capabilities. + +Covers: +- Decompression to raw (convert_to='none') +- Compression of raw files (source_format='none') +- Streaming conversion via convert_stream +- Magic-number format detection +- Filename generation with 'none' format +- Checksum validation during streaming +- --decompress CLI flag +""" + +import bz2 +import gzip +import hashlib +import io +import lzma +import os +import tempfile + +import pytest +from click.testing import CliRunner + +from databusclient.extensions.file_converter import ( + FileConverter, + COMPRESSION_EXTENSIONS, + COMPRESSION_MODULES, +) +from databusclient.api.download import ( + _detect_compression_format, + _should_convert_file, + _get_converted_filename, + _convert_compression_format, +) +from databusclient.cli import download + +# Re-usable test payload +_TEST_DATA = b"Hello, streaming converter! " * 200 + + +# --------------------------------------------------------------------------- +# Format detection: detect_format and detect_format_by_magic +# --------------------------------------------------------------------------- + +class TestFormatDetection: + + def test_detect_format_by_extension(self): + assert FileConverter.detect_format("data.csv.gz") == "gz" + assert FileConverter.detect_format("dump.nt.bz2") == "bz2" + assert FileConverter.detect_format("file.xz") == "xz" + assert FileConverter.detect_format("readme.txt") == "none" + + def test_detect_format_case_insensitive(self): + assert FileConverter.detect_format("FILE.GZ") == "gz" + assert FileConverter.detect_format("dump.BZ2") == "bz2" + + def test_detect_format_by_magic_gzip(self): + header = b"\x1f\x8b\x08\x00" + assert FileConverter.detect_format_by_magic(header) == "gz" + + def test_detect_format_by_magic_bz2(self): + header = b"BZh91AY&SY" + assert FileConverter.detect_format_by_magic(header) == "bz2" + + def test_detect_format_by_magic_xz(self): + header = b"\xfd7zXZ\x00\x00" + assert FileConverter.detect_format_by_magic(header) == "xz" + + def test_detect_format_by_magic_unknown(self): + header = b"\x00\x00\x00\x00" + assert FileConverter.detect_format_by_magic(header) == "none" + + def test_detect_format_extension_wins_but_magic_fallback(self): + # Extension wins when present + assert FileConverter.detect_format("file.gz", b"\x1f\x8b") == "gz" + # Magic fallback when extension is absent + assert FileConverter.detect_format("file.dat", b"\x1f\x8b") == "gz" + + +# --------------------------------------------------------------------------- +# Decompression tests (convert_to='none') +# --------------------------------------------------------------------------- + +class TestDecompressionToNone: + + def test_decompress_gzip_to_none(self): + with tempfile.TemporaryDirectory() as tmpdir: + gz_path = os.path.join(tmpdir, "data.txt.gz") + raw_path = os.path.join(tmpdir, "data.txt") + with gzip.open(gz_path, "wb") as f: + f.write(_TEST_DATA) + FileConverter.convert_file(gz_path, raw_path, "gz", "none") + assert not os.path.exists(gz_path) + with open(raw_path, "rb") as f: + assert f.read() == _TEST_DATA + + def test_decompress_bz2_to_none(self): + with tempfile.TemporaryDirectory() as tmpdir: + bz2_path = os.path.join(tmpdir, "data.txt.bz2") + raw_path = os.path.join(tmpdir, "data.txt") + with bz2.open(bz2_path, "wb") as f: + f.write(_TEST_DATA) + FileConverter.convert_file(bz2_path, raw_path, "bz2", "none") + assert not os.path.exists(bz2_path) + with open(raw_path, "rb") as f: + assert f.read() == _TEST_DATA + + def test_decompress_xz_to_none(self): + with tempfile.TemporaryDirectory() as tmpdir: + xz_path = os.path.join(tmpdir, "data.txt.xz") + raw_path = os.path.join(tmpdir, "data.txt") + with lzma.open(xz_path, "wb") as f: + f.write(_TEST_DATA) + FileConverter.convert_file(xz_path, raw_path, "xz", "none") + assert not os.path.exists(xz_path) + with open(raw_path, "rb") as f: + assert f.read() == _TEST_DATA + + +# --------------------------------------------------------------------------- +# Compression of raw files (source_format='none') +# --------------------------------------------------------------------------- + +class TestCompressionFromNone: + + def test_compress_none_to_gzip(self): + with tempfile.TemporaryDirectory() as tmpdir: + raw_path = os.path.join(tmpdir, "data.txt") + gz_path = os.path.join(tmpdir, "data.txt.gz") + with open(raw_path, "wb") as f: + f.write(_TEST_DATA) + FileConverter.convert_file(raw_path, gz_path, "none", "gz") + assert not os.path.exists(raw_path) + with gzip.open(gz_path, "rb") as f: + assert f.read() == _TEST_DATA + + def test_compress_none_to_bz2(self): + with tempfile.TemporaryDirectory() as tmpdir: + raw_path = os.path.join(tmpdir, "data.txt") + bz2_path = os.path.join(tmpdir, "data.txt.bz2") + with open(raw_path, "wb") as f: + f.write(_TEST_DATA) + FileConverter.convert_file(raw_path, bz2_path, "none", "bz2") + assert not os.path.exists(raw_path) + with bz2.open(bz2_path, "rb") as f: + assert f.read() == _TEST_DATA + + def test_compress_none_to_xz(self): + with tempfile.TemporaryDirectory() as tmpdir: + raw_path = os.path.join(tmpdir, "data.txt") + xz_path = os.path.join(tmpdir, "data.txt.xz") + with open(raw_path, "wb") as f: + f.write(_TEST_DATA) + FileConverter.convert_file(raw_path, xz_path, "none", "xz") + assert not os.path.exists(raw_path) + with lzma.open(xz_path, "rb") as f: + assert f.read() == _TEST_DATA + + +# --------------------------------------------------------------------------- +# Streaming conversion (convert_stream) +# --------------------------------------------------------------------------- + +class TestStreamConversion: + + def test_convert_stream_gz_to_bz2(self): + # Compress test data into gzip in-memory + gz_buf = io.BytesIO() + with gzip.open(gz_buf, "wb") as gz: + gz.write(_TEST_DATA) + gz_buf.seek(0) + + bz2_buf = io.BytesIO() + FileConverter.convert_stream(gz_buf, bz2_buf, "gz", "bz2") + + bz2_buf.seek(0) + assert bz2.decompress(bz2_buf.read()) == _TEST_DATA + + def test_convert_stream_decompress(self): + gz_buf = io.BytesIO() + with gzip.open(gz_buf, "wb") as gz: + gz.write(_TEST_DATA) + gz_buf.seek(0) + + raw_buf = io.BytesIO() + FileConverter.convert_stream(gz_buf, raw_buf, "gz", "none") + + assert raw_buf.getvalue() == _TEST_DATA + + def test_convert_stream_compress(self): + raw_buf = io.BytesIO(_TEST_DATA) + gz_buf = io.BytesIO() + FileConverter.convert_stream(raw_buf, gz_buf, "none", "gz") + + gz_buf.seek(0) + with gzip.open(gz_buf, "rb") as gz: + assert gz.read() == _TEST_DATA + + def test_streaming_checksum_validation(self): + expected_hash = hashlib.sha256(_TEST_DATA).hexdigest() + + gz_buf = io.BytesIO() + with gzip.open(gz_buf, "wb") as gz: + gz.write(_TEST_DATA) + gz_buf.seek(0) + + raw_buf = io.BytesIO() + result_hash = FileConverter.convert_stream( + gz_buf, raw_buf, "gz", "none", validate_checksum=True + ) + assert result_hash == expected_hash + + +# --------------------------------------------------------------------------- +# Filename generation with 'none' +# --------------------------------------------------------------------------- + +class TestConvertedFilename: + + def test_strip_extension_for_none_target(self): + assert FileConverter.get_converted_filename( + "data.csv.gz", "gz", "none" + ) == "data.csv" + + def test_add_extension_for_none_source(self): + assert FileConverter.get_converted_filename( + "data.csv", "none", "bz2" + ) == "data.csv.bz2" + + def test_convert_between_formats(self): + assert FileConverter.get_converted_filename( + "dump.nt.bz2", "bz2", "xz" + ) == "dump.nt.xz" + + +# --------------------------------------------------------------------------- +# _should_convert_file with 'none' +# --------------------------------------------------------------------------- + +class TestShouldConvertFileNone: + + def test_decompress_compressed_file(self): + ok, src = _should_convert_file("file.txt.gz", "none", None) + assert ok is True + assert src == "gz" + + def test_decompress_already_uncompressed(self): + ok, src = _should_convert_file("file.txt", "none", None) + assert ok is False + + def test_compress_raw_file(self): + ok, src = _should_convert_file("file.txt", "gz", "none") + assert ok is True + assert src == "none" + + def test_decompress_with_filter_match(self): + ok, src = _should_convert_file("file.txt.bz2", "none", "bz2") + assert ok is True + assert src == "bz2" + + def test_decompress_with_filter_no_match(self): + ok, src = _should_convert_file("file.txt.bz2", "none", "xz") + assert ok is False + + +# --------------------------------------------------------------------------- +# Checksum validation +# --------------------------------------------------------------------------- + +class TestChecksumValidation: + + def test_valid_checksum(self): + data = b"checksum test data" + expected = hashlib.sha256(data).hexdigest() + stream = io.BytesIO(data) + assert FileConverter.validate_checksum_stream(stream, expected) is True + + def test_invalid_checksum_raises(self): + stream = io.BytesIO(b"some data") + with pytest.raises(IOError, match="Checksum mismatch"): + FileConverter.validate_checksum_stream(stream, "0" * 64) + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + +class TestErrorHandling: + + def test_unsupported_source_format(self): + with pytest.raises(ValueError, match="Unsupported source"): + FileConverter.convert_file("a", "b", "zip", "gz") + + def test_unsupported_target_format(self): + with pytest.raises(ValueError, match="Unsupported target"): + FileConverter.convert_file("a", "b", "gz", "rar") + + def test_corrupted_file_cleanup(self): + with tempfile.TemporaryDirectory() as tmpdir: + src = os.path.join(tmpdir, "bad.bz2") + tgt = os.path.join(tmpdir, "out.gz") + with open(src, "wb") as f: + f.write(b"not real bz2 data") + with pytest.raises(RuntimeError, match="Compression conversion failed"): + FileConverter.convert_file(src, tgt, "bz2", "gz") + assert not os.path.exists(tgt) + + +# --------------------------------------------------------------------------- +# CLI --decompress flag +# --------------------------------------------------------------------------- + +class TestDecompressCLI: + """Test the --decompress flag via Click's CliRunner. + + We don't actually download anything; we just verify the flag + handling logic by patching the download function. + """ + + def test_decompress_flag_sets_convert_to_none(self, monkeypatch): + """--decompress should result in convert_to='none' reaching api_download.""" + captured = {} + + def fake_download(**kwargs): + captured.update(kwargs) + + monkeypatch.setattr( + "databusclient.cli.api_download", fake_download + ) + + runner = CliRunner() + result = runner.invoke( + download, + ["--decompress", "https://example.org/test"], + ) + assert result.exit_code == 0, result.output + assert captured.get("convert_to") == "none" + + def test_decompress_and_convert_to_conflict(self): + runner = CliRunner() + result = runner.invoke( + download, + ["--decompress", "--convert-to", "gz", "https://example.org/test"], + ) + assert result.exit_code != 0 + assert "Cannot use --decompress together with --convert-to" in result.output + + +# --------------------------------------------------------------------------- +# zstd (optional) +# --------------------------------------------------------------------------- + +try: + import zstandard # noqa: F401 + _HAS_ZSTD = True +except ImportError: + _HAS_ZSTD = False + + +@pytest.mark.skipif(not _HAS_ZSTD, reason="zstandard package not installed") +class TestZstd: + + def test_compress_none_to_zstd(self): + with tempfile.TemporaryDirectory() as tmpdir: + raw = os.path.join(tmpdir, "data.txt") + zst = os.path.join(tmpdir, "data.txt.zst") + with open(raw, "wb") as f: + f.write(_TEST_DATA) + FileConverter.convert_file(raw, zst, "none", "zstd") + assert not os.path.exists(raw) + import zstandard as zd + dctx = zd.ZstdDecompressor() + with open(zst, "rb") as f: + assert dctx.decompress(f.read()) == _TEST_DATA + + def test_decompress_zstd_to_none(self): + with tempfile.TemporaryDirectory() as tmpdir: + raw = os.path.join(tmpdir, "data.txt") + zst = os.path.join(tmpdir, "data.txt.zst") + import zstandard as zd + cctx = zd.ZstdCompressor() + with open(zst, "wb") as f: + f.write(cctx.compress(_TEST_DATA)) + FileConverter.convert_file(zst, raw, "zstd", "none") + assert not os.path.exists(zst) + with open(raw, "rb") as f: + assert f.read() == _TEST_DATA + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])