Files
noteflow/tests/infrastructure/ner/test_gliner_backend.py

220 lines
8.0 KiB
Python

"""Tests for GLiNER NER backend."""
from __future__ import annotations
from unittest.mock import MagicMock
import pytest
from noteflow.infrastructure.ner.backends.gliner_backend import (
DEFAULT_MODEL,
DEFAULT_THRESHOLD,
MEETING_LABELS,
GLiNERBackend,
)
from noteflow.infrastructure.ner.backends.types import RawEntity
def _create_backend_with_mock_model(
model_name: str = DEFAULT_MODEL,
labels: tuple[str, ...] = MEETING_LABELS,
threshold: float = DEFAULT_THRESHOLD,
mock_predictions: list[dict[str, object]] | None = None,
) -> GLiNERBackend:
backend = GLiNERBackend(model_name=model_name, labels=labels, threshold=threshold)
mock_model = MagicMock()
mock_model.predict_entities = MagicMock(return_value=mock_predictions or [])
object.__setattr__(backend, "_model", mock_model)
return backend
class TestGLiNERBackendInit:
def test_default_model_name(self) -> None:
backend = GLiNERBackend()
assert backend.model_name == DEFAULT_MODEL
def test_default_labels(self) -> None:
backend = GLiNERBackend()
assert backend.labels == MEETING_LABELS
def test_default_threshold(self) -> None:
backend = GLiNERBackend()
assert backend.threshold == DEFAULT_THRESHOLD
def test_init_with_custom_model_name(self) -> None:
backend = GLiNERBackend(model_name="custom/model")
assert backend.model_name == "custom/model"
def test_custom_labels(self) -> None:
custom_labels = ("person", "location")
backend = GLiNERBackend(labels=custom_labels)
assert backend.labels == custom_labels
def test_custom_threshold(self) -> None:
backend = GLiNERBackend(threshold=0.7)
assert backend.threshold == 0.7
def test_not_loaded_initially(self) -> None:
backend = GLiNERBackend()
assert not backend.model_loaded()
class TestGLiNERBackendModelState:
def test_model_loaded_returns_true_when_model_set(self) -> None:
backend = _create_backend_with_mock_model()
assert backend.model_loaded()
def test_model_loaded_returns_false_initially(self) -> None:
backend = GLiNERBackend()
assert not backend.model_loaded()
class TestGLiNERBackendExtraction:
def test_extract_empty_string_returns_empty(self) -> None:
backend = GLiNERBackend()
assert backend.extract("") == []
def test_extract_whitespace_only_returns_empty(self) -> None:
backend = GLiNERBackend()
assert backend.extract(" ") == []
def test_extract_returns_correct_count(self) -> None:
backend = _create_backend_with_mock_model(
mock_predictions=[
{"text": "John", "label": "PERSON", "start": 0, "end": 4, "score": 0.95},
{"text": "NYC", "label": "location", "start": 12, "end": 15, "score": 0.88},
]
)
entities = backend.extract("John lives in NYC")
assert len(entities) == 2
def test_extract_returns_raw_entity_type(self) -> None:
backend = _create_backend_with_mock_model(
mock_predictions=[
{"text": "John", "label": "PERSON", "start": 0, "end": 4, "score": 0.95},
]
)
entities = backend.extract("John")
assert isinstance(entities[0], RawEntity)
def test_extract_normalizes_labels_to_lowercase(self) -> None:
backend = _create_backend_with_mock_model(
mock_predictions=[
{"text": "John", "label": "PERSON", "start": 0, "end": 4, "score": 0.9},
]
)
entities = backend.extract("John")
assert entities[0].label == "person"
def test_extract_includes_confidence_score(self) -> None:
backend = _create_backend_with_mock_model(
mock_predictions=[
{"text": "Meeting", "label": "event", "start": 0, "end": 7, "score": 0.85},
]
)
entities = backend.extract("Meeting tomorrow")
assert entities[0].confidence == 0.85
def test_extract_passes_threshold_to_predict(self) -> None:
backend = _create_backend_with_mock_model(threshold=0.75)
backend.extract("Some text")
mock_model = getattr(backend, "_model")
predict_entities = getattr(mock_model, "predict_entities")
call_kwargs = predict_entities.call_args
assert call_kwargs[1]["threshold"] == 0.75
def test_extract_passes_labels_to_predict(self) -> None:
custom_labels = ("person", "task")
backend = _create_backend_with_mock_model(labels=custom_labels)
backend.extract("Some text")
mock_model = getattr(backend, "_model")
predict_entities = getattr(mock_model, "predict_entities")
call_kwargs = predict_entities.call_args
assert call_kwargs[1]["labels"] == ["person", "task"]
@pytest.mark.parametrize(
("text", "label", "start", "end", "score"),
[
("John Smith", "person", 0, 10, 0.92),
("New York", "location", 5, 13, 0.88),
("decision to proceed", "decision", 0, 19, 0.75),
],
)
def test_extract_entity_text_mapping(
self, text: str, label: str, start: int, end: int, score: float
) -> None:
backend = _create_backend_with_mock_model(
mock_predictions=[
{"text": text, "label": label.upper(), "start": start, "end": end, "score": score},
]
)
entities = backend.extract("input text")
assert entities[0].text == text
@pytest.mark.parametrize(
("text", "label", "start", "end", "score"),
[
("John Smith", "person", 0, 10, 0.92),
("New York", "location", 5, 13, 0.88),
("decision to proceed", "decision", 0, 19, 0.75),
],
)
def test_extract_entity_label_mapping(
self, text: str, label: str, start: int, end: int, score: float
) -> None:
backend = _create_backend_with_mock_model(
mock_predictions=[
{"text": text, "label": label.upper(), "start": start, "end": end, "score": score},
]
)
entities = backend.extract("input text")
assert entities[0].label == label
@pytest.mark.parametrize(
("text", "label", "start", "end", "score"),
[
("John Smith", "person", 0, 10, 0.92),
("New York", "location", 5, 13, 0.88),
("decision to proceed", "decision", 0, 19, 0.75),
],
)
def test_extract_entity_position_mapping(
self, text: str, label: str, start: int, end: int, score: float
) -> None:
backend = _create_backend_with_mock_model(
mock_predictions=[
{"text": text, "label": label.upper(), "start": start, "end": end, "score": score},
]
)
entities = backend.extract("input text")
assert (entities[0].start, entities[0].end) == (start, end)
@pytest.mark.parametrize(
("text", "label", "start", "end", "score"),
[
("John Smith", "person", 0, 10, 0.92),
("New York", "location", 5, 13, 0.88),
("decision to proceed", "decision", 0, 19, 0.75),
],
)
def test_extract_entity_confidence_mapping(
self, text: str, label: str, start: int, end: int, score: float
) -> None:
backend = _create_backend_with_mock_model(
mock_predictions=[
{"text": text, "label": label.upper(), "start": start, "end": end, "score": score},
]
)
entities = backend.extract("input text")
assert entities[0].confidence == score
class TestGLiNERBackendMeetingLabels:
def test_meeting_labels_include_core_categories(self) -> None:
expected_labels = {"person", "org", "product", "app", "location", "time"}
assert expected_labels.issubset(set(MEETING_LABELS))
def test_meeting_labels_include_meeting_specific_categories(self) -> None:
expected_labels = {"task", "decision", "topic", "event"}
assert expected_labels.issubset(set(MEETING_LABELS))