Files
noteflow/tests/infrastructure/ner/test_post_processing.py

159 lines
5.9 KiB
Python

"""Tests for NER post-processing pipeline."""
from __future__ import annotations
import pytest
from noteflow.infrastructure.ner.backends.types import RawEntity
from noteflow.infrastructure.ner.post_processing import (
dedupe_entities,
drop_low_signal_entities,
infer_duration,
merge_time_phrases,
normalize_text,
post_process,
)
class TestNormalizeText:
"""Tests for text normalization."""
@pytest.mark.parametrize(
("input_text", "expected"),
[
("John Smith", "john smith"),
(" UPPER CASE ", "upper case"),
("'quoted'", "quoted"),
('"double"', "double"),
("spaces between words", "spaces between words"),
],
)
def test_normalize_text(self, input_text: str, expected: str) -> None:
assert normalize_text(input_text) == expected
class TestDedupeEntities:
"""Tests for entity deduplication."""
def test_deduplicates_by_normalized_text_and_label(self) -> None:
entities = [
RawEntity(text="John", label="person", start=0, end=4, confidence=0.8),
RawEntity(text="john", label="person", start=10, end=14, confidence=0.9),
]
result = dedupe_entities(entities)
assert len(result) == 1, "Should deduplicate same entity"
def test_keeps_higher_confidence(self) -> None:
entities = [
RawEntity(text="John", label="person", start=0, end=4, confidence=0.7),
RawEntity(text="john", label="person", start=10, end=14, confidence=0.9),
]
result = dedupe_entities(entities)
assert result[0].confidence == 0.9, "Should keep higher confidence"
def test_different_labels_not_deduped(self) -> None:
entities = [
RawEntity(text="Apple", label="product", start=0, end=5, confidence=0.8),
RawEntity(text="Apple", label="company", start=10, end=15, confidence=0.9),
]
result = dedupe_entities(entities)
assert len(result) == 2, "Different labels should not be deduped"
class TestDropLowSignalEntities:
"""Tests for low-signal entity filtering."""
@pytest.mark.parametrize(
("text", "label"),
[
("a", "person"),
("to", "location"),
("123", "product"),
("shit", "topic"),
],
)
def test_drops_low_signal_entities(self, text: str, label: str) -> None:
entity = RawEntity(text=text, label=label, start=0, end=len(text), confidence=0.8)
result = drop_low_signal_entities([entity], "some original text")
assert len(result) == 0, f"'{text}' should be dropped"
@pytest.mark.parametrize(
("text", "label"),
[
("John", "person"),
("New York", "location"),
],
)
def test_keeps_valid_entities(self, text: str, label: str) -> None:
entity = RawEntity(text=text, label=label, start=0, end=len(text), confidence=0.8)
result = drop_low_signal_entities([entity], "some original text")
assert len(result) == 1, f"'{text}' should NOT be dropped"
class TestInferDuration:
"""Tests for duration inference."""
@pytest.mark.parametrize(
("text", "expected_label"),
[
("20 minutes", "duration"),
("two weeks", "duration"),
("3 hours", "duration"),
("next week", "time_relative"),
("tomorrow", "time_relative"),
],
)
def test_infer_duration(self, text: str, expected_label: str) -> None:
entity = RawEntity(text=text, label="time_relative", start=0, end=len(text), confidence=0.8)
result = infer_duration([entity])
assert result[0].label == expected_label, f"'{text}' should be labeled as {expected_label}"
class TestMergeTimePhrases:
"""Tests for time phrase merging."""
def test_merges_adjacent_time_entities(self) -> None:
original_text = "last night we met"
entities = [
RawEntity(text="last", label="time", start=0, end=4, confidence=0.8),
RawEntity(text="night", label="time", start=5, end=10, confidence=0.7),
]
result = merge_time_phrases(entities, original_text)
time_entities = [e for e in result if e.label == "time"]
assert len(time_entities) == 1, "Should merge adjacent time entities"
assert time_entities[0].text == "last night", "Merged text should be 'last night'"
def test_does_not_merge_non_adjacent(self) -> None:
original_text = "Monday came and then Friday arrived"
entities = [
RawEntity(text="Monday", label="time", start=0, end=6, confidence=0.8),
RawEntity(text="Friday", label="time", start=21, end=27, confidence=0.8),
]
result = merge_time_phrases(entities, original_text)
time_entities = [e for e in result if e.label == "time"]
assert len(time_entities) == 2, "Should not merge non-adjacent time entities"
class TestPostProcess:
"""Tests for full post-processing pipeline."""
def test_full_pipeline(self) -> None:
original_text = "John met shit in Paris for 20 minutes"
entities = [
RawEntity(text="John", label="person", start=0, end=4, confidence=0.9),
RawEntity(text="shit", label="topic", start=9, end=13, confidence=0.5),
RawEntity(text="Paris", label="location", start=17, end=22, confidence=0.95),
RawEntity(text="20 minutes", label="time", start=27, end=37, confidence=0.8),
]
result = post_process(entities, original_text)
texts = {e.text for e in result}
assert "shit" not in texts, "Profanity should be filtered"
assert "John" in texts, "Valid person should remain"
assert "Paris" in texts, "Valid location should remain"
duration_entities = [e for e in result if e.label == "duration"]
assert len(duration_entities) == 1, "20 minutes should be labeled as duration"