diff --git a/pssecret_server/utils.py b/pssecret_server/utils.py index e57b710..ea10ef5 100644 --- a/pssecret_server/utils.py +++ b/pssecret_server/utils.py @@ -1,7 +1,10 @@ +from functools import lru_cache from uuid import uuid4 from cryptography.fernet import Fernet from redis.asyncio import Redis +from redis.exceptions import ResponseError +from redis.typing import ResponseT from pssecret_server.models import Secret @@ -30,3 +33,35 @@ async def save_secret(data: Secret, redis: Redis) -> str: await redis.setex(new_key, 60 * 60 * 24, data.data) return new_key + + +@lru_cache +async def _is_getdel_available(redis: Redis) -> bool: + """Checks the availability of GETDEL command on the Redis server instance + + GETDEL is not available in Redis prior to version 6.2 + """ + try: + await redis.getdel("test:getdel:availability") + except ResponseError: + return False + + return True + + +async def getdel(redis: Redis, key: str) -> ResponseT: + """Gets the value of key and deletes the key + + Depending on the capabilities of Redis server this function + will either call GETDEL command, either first call GETSET with empty string + and DEL right after that. + """ + result: ResponseT + + if await _is_getdel_available(redis): + result = await redis.getdel(key) + else: + result = await redis.getset(key, "") + await redis.delete(key) + + return result diff --git a/pyproject.toml b/pyproject.toml index 3e746fb..8dc6c71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,3 +53,6 @@ reportUnusedCallResult = "none" [tool.pytest.ini_options] asyncio_mode = "auto" + +[tool.isort] +profile = "black" diff --git a/tests/integration/test_utils.py b/tests/integration/test_utils.py index bef19c4..c7d54b0 100644 --- a/tests/integration/test_utils.py +++ b/tests/integration/test_utils.py @@ -1,8 +1,9 @@ -from unittest.mock import patch +from unittest.mock import AsyncMock, patch +import pytest from redis.asyncio import Redis -from pssecret_server.utils import get_new_key, save_secret +from pssecret_server.utils import get_new_key, getdel, save_secret from ..factories import SecretFactory @@ -33,3 +34,22 @@ async def test_save_secret_data(redis_server: Redis) -> None: assert redis_data is not None assert redis_data.decode() == secret.data + + +@pytest.mark.parametrize("getdel_available", [True, False]) +@patch("pssecret_server.utils._is_getdel_available", new_callable=AsyncMock) +async def test_getdel( + mock_is_getdel_available: AsyncMock, + getdel_available: bool, + redis_server: Redis, +) -> None: + mock_is_getdel_available.return_value = getdel_available + + test_value = "test_data" + test_key = "test_key" + await redis_server.set(test_key, test_value) + + result = await getdel(redis_server, test_key) + + assert result.decode() == test_value # pyright: ignore[reportAttributeAccessIssue] + assert not await redis_server.exists(test_key) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index de2984b..0fa6100 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,7 +1,10 @@ +from unittest.mock import AsyncMock + import pytest from cryptography.fernet import Fernet, InvalidToken +from redis.exceptions import ResponseError -from pssecret_server.utils import decrypt_secret, encrypt_secret +from pssecret_server.utils import _is_getdel_available, decrypt_secret, encrypt_secret from ..factories import SecretFactory @@ -29,3 +32,17 @@ def test_secret_is_not_decryptable_by_random_key(fernet: Fernet): with pytest.raises(InvalidToken): decrypt_secret(encrypted_secret.data.encode(), random_fernet) + + +@pytest.mark.parametrize( + ("getdel_effect", "expected_result"), [(None, True), (ResponseError, False)] +) +async def test_is_getdel_available( + getdel_effect: ResponseError | None, expected_result: bool +): + redis = AsyncMock() + redis.getdel.side_effect = getdel_effect # pyright: ignore[reportAny] + + result = await _is_getdel_available(redis) + + assert result is expected_result