image_server/tests/test_auth.py
2026-04-23 22:11:51 -05:00

190 lines
7.5 KiB
Python

"""Tests for authentication."""
import argparse
import base64
from pathlib import Path
import pytest
from httpx import ASGITransport, AsyncClient
import main
def _basic_auth_header(username: str, password: str) -> str:
"""Create a Basic Auth header value."""
creds = f"{username}:{password}"
return f"Basic {base64.b64encode(creds.encode()).decode()}"
def _make_args(tmp_path: Path) -> argparse.Namespace:
"""Create an argparse.Namespace for the given path."""
return argparse.Namespace(
source=str(tmp_path),
host="127.0.0.1",
port=0,
salt="auth-salt",
password=None,
)
@pytest.fixture
def auth_setup(tmp_path: Path) -> tuple[str, str]:
"""Set up server with sample files and password protection.
Returns:
Tuple of (username, password).
"""
(tmp_path / "test.txt").write_text("hello")
main.initialize_server(_make_args(tmp_path))
main.set_auth_password("secret123")
return ("user", "secret123")
class TestNoPasswordSet:
"""Tests when no password is configured.
Note: HTTPBasic() always requires an Authorization header.
When expected_password is None, any credentials pass.
"""
async def test_health_always_open(self, client_dir: AsyncClient) -> None:
"""Health check has no auth dependency — always accessible."""
response = await client_dir.get("/api/health")
assert response.status_code == 200
async def test_protected_endpoint_requires_auth_header(
self, initialized_dir: None
) -> None:
"""Even with no password, HTTPBasic requires an auth header."""
file_hash = list(main.file_mapping.keys())[0]
# No auth header → 401 from HTTPBasic
transport = ASGITransport(app=main.app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
response = await ac.get(f"/api/{file_hash}/data")
assert response.status_code == 401
async def test_any_credentials_pass_when_no_password(
self, client_dir: AsyncClient
) -> None:
"""Any credentials pass when no password is set."""
file_hash = list(main.file_mapping.keys())[0]
transport = ASGITransport(app=main.app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
ac.headers["Authorization"] = _basic_auth_header("any", "thing")
response = await ac.get(f"/api/{file_hash}/data")
assert response.status_code == 200
async def test_root_requires_auth_header(self, initialized_dir: None) -> None:
"""Root endpoint requires auth header even with no password."""
transport = ASGITransport(app=main.app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
response = await ac.get("/", follow_redirects=False)
assert response.status_code == 401
class TestCorrectPassword:
"""Tests with correct password."""
async def test_health_with_correct_password(
self, auth_setup: tuple[str, str]
) -> None:
"""Health check works (it has no auth, always 200)."""
transport = ASGITransport(app=main.app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
response = await ac.get("/api/health")
assert response.status_code == 200
async def test_file_access_with_correct_password(
self, auth_setup: tuple[str, str]
) -> None:
"""File access works with correct password."""
username, password = auth_setup
transport = ASGITransport(app=main.app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
ac.headers["Authorization"] = _basic_auth_header(username, password)
file_hash = list(main.file_mapping.keys())[0]
response = await ac.get(f"/api/{file_hash}/data")
assert response.status_code == 200
async def test_root_with_correct_password(
self, auth_setup: tuple[str, str]
) -> None:
"""Root redirect works with correct password."""
username, password = auth_setup
transport = ASGITransport(app=main.app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
ac.headers["Authorization"] = _basic_auth_header(username, password)
response = await ac.get("/", follow_redirects=False)
assert response.status_code in (307, 302, 301)
async def test_hash_page_with_correct_password(
self, auth_setup: tuple[str, str]
) -> None:
"""Hash page works with correct password."""
username, password = auth_setup
transport = ASGITransport(app=main.app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
ac.headers["Authorization"] = _basic_auth_header(username, password)
file_hash = list(main.file_mapping.keys())[0]
response = await ac.get(f"/{file_hash}")
assert response.status_code == 200
class TestWrongPassword:
"""Tests with incorrect password."""
async def test_file_access_with_wrong_password(
self, auth_setup: tuple[str, str]
) -> None:
"""File access returns 401 with wrong password."""
transport = ASGITransport(app=main.app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
ac.headers["Authorization"] = _basic_auth_header("user", "wrong")
file_hash = list(main.file_mapping.keys())[0]
response = await ac.get(f"/api/{file_hash}/data")
assert response.status_code == 401
async def test_root_with_wrong_password(self, auth_setup: tuple[str, str]) -> None:
"""Root redirect returns 401 with wrong password."""
transport = ASGITransport(app=main.app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
ac.headers["Authorization"] = _basic_auth_header("user", "wrong")
response = await ac.get("/", follow_redirects=False)
assert response.status_code == 401
async def test_no_auth_header_returns_401(
self, auth_setup: tuple[str, str]
) -> None:
"""Missing auth header returns 401."""
transport = ASGITransport(app=main.app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
file_hash = list(main.file_mapping.keys())[0]
response = await ac.get(f"/api/{file_hash}/data")
assert response.status_code == 401
async def test_includes_www_authenticate_header(
self, auth_setup: tuple[str, str]
) -> None:
"""401 response includes WWW-Authenticate header."""
transport = ASGITransport(app=main.app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
file_hash = list(main.file_mapping.keys())[0]
response = await ac.get(f"/api/{file_hash}/data")
assert response.status_code == 401
assert "www-authenticate" in response.headers
class TestSetAuthPassword:
"""Tests for set_auth_password function."""
def test_sets_password(self) -> None:
"""Password is set correctly."""
main.set_auth_password("newpass")
assert main.expected_password == "newpass"
def test_clears_password_with_none(self) -> None:
"""Passing None clears the password."""
main.set_auth_password("something")
main.set_auth_password(None)
assert main.expected_password is None