190 lines
7.5 KiB
Python
190 lines
7.5 KiB
Python
"""Tests for authentication."""
|
|
|
|
import argparse
|
|
import base64
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
from httpx import ASGITransport, AsyncClient
|
|
|
|
import main
|
|
|
|
|
|
def _basic_auth_header(username: str, password: str) -> str:
|
|
"""Create a Basic Auth header value."""
|
|
creds = f"{username}:{password}"
|
|
return f"Basic {base64.b64encode(creds.encode()).decode()}"
|
|
|
|
|
|
def _make_args(tmp_path: Path) -> argparse.Namespace:
|
|
"""Create an argparse.Namespace for the given path."""
|
|
return argparse.Namespace(
|
|
source=str(tmp_path),
|
|
host="127.0.0.1",
|
|
port=0,
|
|
salt="auth-salt",
|
|
password=None,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def auth_setup(tmp_path: Path) -> tuple[str, str]:
|
|
"""Set up server with sample files and password protection.
|
|
|
|
Returns:
|
|
Tuple of (username, password).
|
|
"""
|
|
(tmp_path / "test.txt").write_text("hello")
|
|
main.initialize_server(_make_args(tmp_path))
|
|
main.set_auth_password("secret123")
|
|
return ("user", "secret123")
|
|
|
|
|
|
class TestNoPasswordSet:
|
|
"""Tests when no password is configured.
|
|
|
|
Note: HTTPBasic() always requires an Authorization header.
|
|
When expected_password is None, any credentials pass.
|
|
"""
|
|
|
|
async def test_health_always_open(self, client_dir: AsyncClient) -> None:
|
|
"""Health check has no auth dependency — always accessible."""
|
|
response = await client_dir.get("/api/health")
|
|
assert response.status_code == 200
|
|
|
|
async def test_protected_endpoint_requires_auth_header(
|
|
self, initialized_dir: None
|
|
) -> None:
|
|
"""Even with no password, HTTPBasic requires an auth header."""
|
|
file_hash = list(main.file_mapping.keys())[0]
|
|
# No auth header → 401 from HTTPBasic
|
|
transport = ASGITransport(app=main.app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
response = await ac.get(f"/api/{file_hash}/data")
|
|
assert response.status_code == 401
|
|
|
|
async def test_any_credentials_pass_when_no_password(
|
|
self, client_dir: AsyncClient
|
|
) -> None:
|
|
"""Any credentials pass when no password is set."""
|
|
file_hash = list(main.file_mapping.keys())[0]
|
|
transport = ASGITransport(app=main.app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
ac.headers["Authorization"] = _basic_auth_header("any", "thing")
|
|
response = await ac.get(f"/api/{file_hash}/data")
|
|
assert response.status_code == 200
|
|
|
|
async def test_root_requires_auth_header(self, initialized_dir: None) -> None:
|
|
"""Root endpoint requires auth header even with no password."""
|
|
transport = ASGITransport(app=main.app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
response = await ac.get("/", follow_redirects=False)
|
|
assert response.status_code == 401
|
|
|
|
|
|
class TestCorrectPassword:
|
|
"""Tests with correct password."""
|
|
|
|
async def test_health_with_correct_password(
|
|
self, auth_setup: tuple[str, str]
|
|
) -> None:
|
|
"""Health check works (it has no auth, always 200)."""
|
|
transport = ASGITransport(app=main.app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
response = await ac.get("/api/health")
|
|
assert response.status_code == 200
|
|
|
|
async def test_file_access_with_correct_password(
|
|
self, auth_setup: tuple[str, str]
|
|
) -> None:
|
|
"""File access works with correct password."""
|
|
username, password = auth_setup
|
|
transport = ASGITransport(app=main.app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
ac.headers["Authorization"] = _basic_auth_header(username, password)
|
|
file_hash = list(main.file_mapping.keys())[0]
|
|
response = await ac.get(f"/api/{file_hash}/data")
|
|
assert response.status_code == 200
|
|
|
|
async def test_root_with_correct_password(
|
|
self, auth_setup: tuple[str, str]
|
|
) -> None:
|
|
"""Root redirect works with correct password."""
|
|
username, password = auth_setup
|
|
transport = ASGITransport(app=main.app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
ac.headers["Authorization"] = _basic_auth_header(username, password)
|
|
response = await ac.get("/", follow_redirects=False)
|
|
assert response.status_code in (307, 302, 301)
|
|
|
|
async def test_hash_page_with_correct_password(
|
|
self, auth_setup: tuple[str, str]
|
|
) -> None:
|
|
"""Hash page works with correct password."""
|
|
username, password = auth_setup
|
|
transport = ASGITransport(app=main.app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
ac.headers["Authorization"] = _basic_auth_header(username, password)
|
|
file_hash = list(main.file_mapping.keys())[0]
|
|
response = await ac.get(f"/{file_hash}")
|
|
assert response.status_code == 200
|
|
|
|
|
|
class TestWrongPassword:
|
|
"""Tests with incorrect password."""
|
|
|
|
async def test_file_access_with_wrong_password(
|
|
self, auth_setup: tuple[str, str]
|
|
) -> None:
|
|
"""File access returns 401 with wrong password."""
|
|
transport = ASGITransport(app=main.app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
ac.headers["Authorization"] = _basic_auth_header("user", "wrong")
|
|
file_hash = list(main.file_mapping.keys())[0]
|
|
response = await ac.get(f"/api/{file_hash}/data")
|
|
assert response.status_code == 401
|
|
|
|
async def test_root_with_wrong_password(self, auth_setup: tuple[str, str]) -> None:
|
|
"""Root redirect returns 401 with wrong password."""
|
|
transport = ASGITransport(app=main.app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
ac.headers["Authorization"] = _basic_auth_header("user", "wrong")
|
|
response = await ac.get("/", follow_redirects=False)
|
|
assert response.status_code == 401
|
|
|
|
async def test_no_auth_header_returns_401(
|
|
self, auth_setup: tuple[str, str]
|
|
) -> None:
|
|
"""Missing auth header returns 401."""
|
|
transport = ASGITransport(app=main.app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
file_hash = list(main.file_mapping.keys())[0]
|
|
response = await ac.get(f"/api/{file_hash}/data")
|
|
assert response.status_code == 401
|
|
|
|
async def test_includes_www_authenticate_header(
|
|
self, auth_setup: tuple[str, str]
|
|
) -> None:
|
|
"""401 response includes WWW-Authenticate header."""
|
|
transport = ASGITransport(app=main.app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
file_hash = list(main.file_mapping.keys())[0]
|
|
response = await ac.get(f"/api/{file_hash}/data")
|
|
assert response.status_code == 401
|
|
assert "www-authenticate" in response.headers
|
|
|
|
|
|
class TestSetAuthPassword:
|
|
"""Tests for set_auth_password function."""
|
|
|
|
def test_sets_password(self) -> None:
|
|
"""Password is set correctly."""
|
|
main.set_auth_password("newpass")
|
|
assert main.expected_password == "newpass"
|
|
|
|
def test_clears_password_with_none(self) -> None:
|
|
"""Passing None clears the password."""
|
|
main.set_auth_password("something")
|
|
main.set_auth_password(None)
|
|
assert main.expected_password is None
|