220 lines
8.0 KiB
Python
220 lines
8.0 KiB
Python
"""Tests for GLiNER NER backend."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from noteflow.infrastructure.ner.backends.gliner_backend import (
|
|
DEFAULT_MODEL,
|
|
DEFAULT_THRESHOLD,
|
|
MEETING_LABELS,
|
|
GLiNERBackend,
|
|
)
|
|
from noteflow.infrastructure.ner.backends.types import RawEntity
|
|
|
|
|
|
def _create_backend_with_mock_model(
|
|
model_name: str = DEFAULT_MODEL,
|
|
labels: tuple[str, ...] = MEETING_LABELS,
|
|
threshold: float = DEFAULT_THRESHOLD,
|
|
mock_predictions: list[dict[str, object]] | None = None,
|
|
) -> GLiNERBackend:
|
|
backend = GLiNERBackend(model_name=model_name, labels=labels, threshold=threshold)
|
|
mock_model = MagicMock()
|
|
mock_model.predict_entities = MagicMock(return_value=mock_predictions or [])
|
|
object.__setattr__(backend, "_model", mock_model)
|
|
return backend
|
|
|
|
|
|
class TestGLiNERBackendInit:
|
|
def test_default_model_name(self) -> None:
|
|
backend = GLiNERBackend()
|
|
assert backend.model_name == DEFAULT_MODEL
|
|
|
|
def test_default_labels(self) -> None:
|
|
backend = GLiNERBackend()
|
|
assert backend.labels == MEETING_LABELS
|
|
|
|
def test_default_threshold(self) -> None:
|
|
backend = GLiNERBackend()
|
|
assert backend.threshold == DEFAULT_THRESHOLD
|
|
|
|
def test_init_with_custom_model_name(self) -> None:
|
|
backend = GLiNERBackend(model_name="custom/model")
|
|
assert backend.model_name == "custom/model"
|
|
|
|
def test_custom_labels(self) -> None:
|
|
custom_labels = ("person", "location")
|
|
backend = GLiNERBackend(labels=custom_labels)
|
|
assert backend.labels == custom_labels
|
|
|
|
def test_custom_threshold(self) -> None:
|
|
backend = GLiNERBackend(threshold=0.7)
|
|
assert backend.threshold == 0.7
|
|
|
|
def test_not_loaded_initially(self) -> None:
|
|
backend = GLiNERBackend()
|
|
assert not backend.model_loaded()
|
|
|
|
|
|
class TestGLiNERBackendModelState:
|
|
def test_model_loaded_returns_true_when_model_set(self) -> None:
|
|
backend = _create_backend_with_mock_model()
|
|
assert backend.model_loaded()
|
|
|
|
def test_model_loaded_returns_false_initially(self) -> None:
|
|
backend = GLiNERBackend()
|
|
assert not backend.model_loaded()
|
|
|
|
|
|
class TestGLiNERBackendExtraction:
|
|
def test_extract_empty_string_returns_empty(self) -> None:
|
|
backend = GLiNERBackend()
|
|
assert backend.extract("") == []
|
|
|
|
def test_extract_whitespace_only_returns_empty(self) -> None:
|
|
backend = GLiNERBackend()
|
|
assert backend.extract(" ") == []
|
|
|
|
def test_extract_returns_correct_count(self) -> None:
|
|
backend = _create_backend_with_mock_model(
|
|
mock_predictions=[
|
|
{"text": "John", "label": "PERSON", "start": 0, "end": 4, "score": 0.95},
|
|
{"text": "NYC", "label": "location", "start": 12, "end": 15, "score": 0.88},
|
|
]
|
|
)
|
|
entities = backend.extract("John lives in NYC")
|
|
assert len(entities) == 2
|
|
|
|
def test_extract_returns_raw_entity_type(self) -> None:
|
|
backend = _create_backend_with_mock_model(
|
|
mock_predictions=[
|
|
{"text": "John", "label": "PERSON", "start": 0, "end": 4, "score": 0.95},
|
|
]
|
|
)
|
|
entities = backend.extract("John")
|
|
assert isinstance(entities[0], RawEntity)
|
|
|
|
def test_extract_normalizes_labels_to_lowercase(self) -> None:
|
|
backend = _create_backend_with_mock_model(
|
|
mock_predictions=[
|
|
{"text": "John", "label": "PERSON", "start": 0, "end": 4, "score": 0.9},
|
|
]
|
|
)
|
|
entities = backend.extract("John")
|
|
assert entities[0].label == "person"
|
|
|
|
def test_extract_includes_confidence_score(self) -> None:
|
|
backend = _create_backend_with_mock_model(
|
|
mock_predictions=[
|
|
{"text": "Meeting", "label": "event", "start": 0, "end": 7, "score": 0.85},
|
|
]
|
|
)
|
|
entities = backend.extract("Meeting tomorrow")
|
|
assert entities[0].confidence == 0.85
|
|
|
|
def test_extract_passes_threshold_to_predict(self) -> None:
|
|
backend = _create_backend_with_mock_model(threshold=0.75)
|
|
backend.extract("Some text")
|
|
mock_model = getattr(backend, "_model")
|
|
predict_entities = getattr(mock_model, "predict_entities")
|
|
call_kwargs = predict_entities.call_args
|
|
assert call_kwargs[1]["threshold"] == 0.75
|
|
|
|
def test_extract_passes_labels_to_predict(self) -> None:
|
|
custom_labels = ("person", "task")
|
|
backend = _create_backend_with_mock_model(labels=custom_labels)
|
|
backend.extract("Some text")
|
|
mock_model = getattr(backend, "_model")
|
|
predict_entities = getattr(mock_model, "predict_entities")
|
|
call_kwargs = predict_entities.call_args
|
|
assert call_kwargs[1]["labels"] == ["person", "task"]
|
|
|
|
@pytest.mark.parametrize(
|
|
("text", "label", "start", "end", "score"),
|
|
[
|
|
("John Smith", "person", 0, 10, 0.92),
|
|
("New York", "location", 5, 13, 0.88),
|
|
("decision to proceed", "decision", 0, 19, 0.75),
|
|
],
|
|
)
|
|
def test_extract_entity_text_mapping(
|
|
self, text: str, label: str, start: int, end: int, score: float
|
|
) -> None:
|
|
backend = _create_backend_with_mock_model(
|
|
mock_predictions=[
|
|
{"text": text, "label": label.upper(), "start": start, "end": end, "score": score},
|
|
]
|
|
)
|
|
entities = backend.extract("input text")
|
|
assert entities[0].text == text
|
|
|
|
@pytest.mark.parametrize(
|
|
("text", "label", "start", "end", "score"),
|
|
[
|
|
("John Smith", "person", 0, 10, 0.92),
|
|
("New York", "location", 5, 13, 0.88),
|
|
("decision to proceed", "decision", 0, 19, 0.75),
|
|
],
|
|
)
|
|
def test_extract_entity_label_mapping(
|
|
self, text: str, label: str, start: int, end: int, score: float
|
|
) -> None:
|
|
backend = _create_backend_with_mock_model(
|
|
mock_predictions=[
|
|
{"text": text, "label": label.upper(), "start": start, "end": end, "score": score},
|
|
]
|
|
)
|
|
entities = backend.extract("input text")
|
|
assert entities[0].label == label
|
|
|
|
@pytest.mark.parametrize(
|
|
("text", "label", "start", "end", "score"),
|
|
[
|
|
("John Smith", "person", 0, 10, 0.92),
|
|
("New York", "location", 5, 13, 0.88),
|
|
("decision to proceed", "decision", 0, 19, 0.75),
|
|
],
|
|
)
|
|
def test_extract_entity_position_mapping(
|
|
self, text: str, label: str, start: int, end: int, score: float
|
|
) -> None:
|
|
backend = _create_backend_with_mock_model(
|
|
mock_predictions=[
|
|
{"text": text, "label": label.upper(), "start": start, "end": end, "score": score},
|
|
]
|
|
)
|
|
entities = backend.extract("input text")
|
|
assert (entities[0].start, entities[0].end) == (start, end)
|
|
|
|
@pytest.mark.parametrize(
|
|
("text", "label", "start", "end", "score"),
|
|
[
|
|
("John Smith", "person", 0, 10, 0.92),
|
|
("New York", "location", 5, 13, 0.88),
|
|
("decision to proceed", "decision", 0, 19, 0.75),
|
|
],
|
|
)
|
|
def test_extract_entity_confidence_mapping(
|
|
self, text: str, label: str, start: int, end: int, score: float
|
|
) -> None:
|
|
backend = _create_backend_with_mock_model(
|
|
mock_predictions=[
|
|
{"text": text, "label": label.upper(), "start": start, "end": end, "score": score},
|
|
]
|
|
)
|
|
entities = backend.extract("input text")
|
|
assert entities[0].confidence == score
|
|
|
|
|
|
class TestGLiNERBackendMeetingLabels:
|
|
def test_meeting_labels_include_core_categories(self) -> None:
|
|
expected_labels = {"person", "org", "product", "app", "location", "time"}
|
|
assert expected_labels.issubset(set(MEETING_LABELS))
|
|
|
|
def test_meeting_labels_include_meeting_specific_categories(self) -> None:
|
|
expected_labels = {"task", "decision", "topic", "event"}
|
|
assert expected_labels.issubset(set(MEETING_LABELS))
|