import io import mimetypes import os import pytest import tarfile import zipfile from pathlib import Path from unittest.mock import Mock, call, patch from job.download import download_from_s3, unpack_archive_file from job.config import get_config from job.mr_client import validate_and_get_model_registry_client from job.models import S3StorageConfig DUMMY_FILE_DATA = { "file1.txt": b"test file 1", "dir/file2.txt": b"abc", "file3.log": b"123" } @pytest.fixture def minimal_update_artifact_env_source_dest_vars(): original_env = dict(os.environ) # Destination variables dest_vars = { "type": "oci", "oci_uri": "quay.io/example/oci", "oci_registry": "quay.io", "oci_username": "oci_username_env", "oci_password": "oci_password_env", } # Source variables - using the correct format from existing tests source_vars = { "type": "s3", "aws_bucket": "test-bucket", "aws_key": "test-key", "aws_access_key_id": "test-access-key-id", "aws_secret_access_key": "test-secret-access-key", "aws_endpoint": "http://localhost:9000", } # Set up test environment variables for key, value in dest_vars.items(): os.environ[f"MODEL_SYNC_DESTINATION_{key.upper()}"] = value for key, value in source_vars.items(): os.environ[f"MODEL_SYNC_SOURCE_{key.upper()}"] = value # Model and registry variables model_vars = { "model_upload_intent": "update_artifact", "model_artifact_id": "123", "registry_server_address": "http://localhost", "registry_port": "8080", "registry_author": "author", "storage_path": "/tmp/model-sync", } for key, value in model_vars.items(): os.environ[f"MODEL_SYNC_{key.upper()}"] = value yield model_vars # Restore original environment os.environ.clear() os.environ.update(original_env) @pytest.fixture def dummy_archive(request, tmp_path): match request.param: case "tar": return create_dummy_tar(tmp_path / "dummy.tar", "w") case "tar.gz": return create_dummy_tar(tmp_path / "dummy.tar.gz", "w:gz") case "zip": return create_dummy_zip(tmp_path / "dummy.zip", "w") case _: raise ValueError(f"Unsupported archive type: {request.param}") def create_dummy_tar(path, mode): with tarfile.open(path, mode) as f: for filename, content in DUMMY_FILE_DATA.items(): file_data = io.BytesIO(content) tarinfo = tarfile.TarInfo(name=filename) tarinfo.size = len(content) f.addfile(tarinfo, fileobj=file_data) return path def create_dummy_zip(path, mode): with zipfile.ZipFile(path, mode) as f: for filename, content in DUMMY_FILE_DATA.items(): f.writestr(filename, content) return path @pytest.mark.parametrize("dummy_archive", ["tar", "tar.gz", "zip"], indirect=True) def test_unpack_archive_file(dummy_archive, tmp_path): dest_dir = tmp_path / "unpacked_archive" mimetype = mimetypes.guess_type(dummy_archive)[0] unpack_archive_file(dummy_archive, mimetype, dest_dir) result = {} for dirpath, _, filenames in os.walk(dest_dir): for filename in filenames: filepath = Path(dirpath) / filename key = filepath.relative_to(dest_dir).as_posix() contents = filepath.read_bytes() result[key] = contents assert result == DUMMY_FILE_DATA def test_download_from_s3(minimal_update_artifact_env_source_dest_vars): """Test download_from_s3 now that it pages through prefixes.""" # load config from your fixture config = get_config([]) # sanity-check config assert isinstance(config.source, S3StorageConfig) assert config.source.bucket == "test-bucket" assert config.source.key == "test-key" # use whatever path came back in config storage_path = config.storage.path assert storage_path == "/tmp/model-sync" # mock out ModelRegistry so validate_and_get_model_registry_client returns a dummy client with patch("job.mr_client.ModelRegistry") as mock_registry_class: mock_registry_class.return_value = Mock() client = validate_and_get_model_registry_client(config.registry) # now patch _connect_to_s3 and os.makedirs with patch("job.download._connect_to_s3") as mock_connect, \ patch("os.makedirs") as mock_makedirs: # prepare our fake s3 client + transfer config mock_s3 = Mock() mock_transfer_cfg = Mock() mock_connect.return_value = (mock_s3, mock_transfer_cfg) # set up a paginator that yields a single page with two entries fake_page = { "Contents": [ {"Key": "test-key/file1.txt"}, {"Key": "test-key/dir/"}, # should be skipped {"Key": "test-key/dir/file2.bin"}, ] } mock_paginator = Mock() mock_paginator.paginate.return_value = [fake_page] mock_s3.get_paginator.return_value = mock_paginator # call under test download_from_s3(config.source, config.storage.path) # ensure _connect_to_s3 got all args including multipart settings mock_connect.assert_called_once_with( "http://localhost:9000", "test-access-key-id", "test-secret-access-key", None, # region multipart_threshold=1024 * 1024, multipart_chunksize=1024 * 1024, max_pool_connections=10, ) # ensure we asked for the right paginator and paginated correctly mock_s3.get_paginator.assert_called_once_with("list_objects_v2") mock_paginator.paginate.assert_called_once_with( Bucket="test-bucket", Prefix="test-key", ) # build expected download calls using the real storage_path expected = [ call( "test-bucket", "test-key/file1.txt", os.path.join(storage_path, "file1.txt"), ), call( "test-bucket", "test-key/dir/file2.bin", os.path.join(storage_path, "dir", "file2.bin"), ), ] mock_s3.download_file.assert_has_calls(expected, any_order=False) # directories should be created for each file mock_makedirs.assert_any_call( os.path.dirname(os.path.join(storage_path, "file1.txt")), exist_ok=True ) mock_makedirs.assert_any_call( os.path.dirname(os.path.join(storage_path, "dir", "file2.bin")), exist_ok=True ) def test_download_from_s3_with_region(minimal_update_artifact_env_source_dest_vars): """Test download_from_s3 function with region specified""" # Set region in environment os.environ["MODEL_SYNC_SOURCE_AWS_REGION"] = "us-west-2" config = get_config([]) # Create mock ModelRegistry client with patch("job.mr_client.ModelRegistry") as mock_registry_class: mock_client = Mock() mock_registry_class.return_value = mock_client client = validate_and_get_model_registry_client(config.registry) # Mock the S3 client and _connect_to_s3 function with patch("job.download._connect_to_s3") as mock_connect, \ patch("os.makedirs"): # silence dir creation mock_s3_client = Mock() mock_transfer_config = Mock() mock_connect.return_value = (mock_s3_client, mock_transfer_config) # set up a paginator that yields a single page with two entries fake_page = { "Contents": [ {"Key": "test-key/file1.txt"}, {"Key": "test-key/dir/"}, # should be skipped {"Key": "test-key/dir/file2.bin"}, ] } mock_paginator = Mock() mock_paginator.paginate.return_value = [fake_page] mock_s3_client.get_paginator.return_value = mock_paginator # Call the function under test download_from_s3(config.source, config.storage.path) # Verify _connect_to_s3 was called with correct parameters including region mock_connect.assert_called_once_with( "http://localhost:9000", "test-access-key-id", "test-secret-access-key", "us-west-2", multipart_threshold=1024 * 1024, multipart_chunksize=1024 * 1024, max_pool_connections=10, ) def test_download_from_s3_connection_error(minimal_update_artifact_env_source_dest_vars): """Test download_from_s3 function when S3 connection fails""" config = get_config([]) # Create mock ModelRegistry client with patch("job.mr_client.ModelRegistry") as mock_registry_class: mock_client = Mock() mock_registry_class.return_value = mock_client client = validate_and_get_model_registry_client(config.registry) # Mock _connect_to_s3 to raise an exception with patch("job.download._connect_to_s3") as mock_connect: mock_connect.side_effect = Exception("Connection failed") # Test that the exception is propagated with pytest.raises(Exception, match="Connection failed"): download_from_s3(config.source, config.storage.path) def test_download_from_s3_download_error(minimal_update_artifact_env_source_dest_vars): """Test download_from_s3 function when file download fails""" config = get_config([]) # Create mock ModelRegistry client with patch("job.mr_client.ModelRegistry") as mock_registry_class: mock_client = Mock() mock_registry_class.return_value = mock_client client = validate_and_get_model_registry_client(config.registry) # Mock the S3 client, paginator, and _connect_to_s3 function with patch("job.download._connect_to_s3") as mock_connect, \ patch("os.makedirs"): # silence dir creation mock_s3_client = Mock() mock_transfer_config = Mock() mock_connect.return_value = (mock_s3_client, mock_transfer_config) # Stub out pagination so we get one file to download fake_page = { "Contents": [ {"Key": "test-key/failing-file.txt"}, ] } mock_paginator = Mock() mock_paginator.paginate.return_value = [fake_page] mock_s3_client.get_paginator.return_value = mock_paginator # Have download_file raise mock_s3_client.download_file.side_effect = Exception("Download failed") # Now the loop will hit download_file and propagate our exception with pytest.raises(Exception, match="Download failed"): download_from_s3(config.source, config.storage.path)