159 lines
5.9 KiB
Python
159 lines
5.9 KiB
Python
"""Tests for NER post-processing pipeline."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import pytest
|
|
|
|
from noteflow.infrastructure.ner.backends.types import RawEntity
|
|
from noteflow.infrastructure.ner.post_processing import (
|
|
dedupe_entities,
|
|
drop_low_signal_entities,
|
|
infer_duration,
|
|
merge_time_phrases,
|
|
normalize_text,
|
|
post_process,
|
|
)
|
|
|
|
|
|
class TestNormalizeText:
|
|
"""Tests for text normalization."""
|
|
|
|
@pytest.mark.parametrize(
|
|
("input_text", "expected"),
|
|
[
|
|
("John Smith", "john smith"),
|
|
(" UPPER CASE ", "upper case"),
|
|
("'quoted'", "quoted"),
|
|
('"double"', "double"),
|
|
("spaces between words", "spaces between words"),
|
|
],
|
|
)
|
|
def test_normalize_text(self, input_text: str, expected: str) -> None:
|
|
assert normalize_text(input_text) == expected
|
|
|
|
|
|
class TestDedupeEntities:
|
|
"""Tests for entity deduplication."""
|
|
|
|
def test_deduplicates_by_normalized_text_and_label(self) -> None:
|
|
entities = [
|
|
RawEntity(text="John", label="person", start=0, end=4, confidence=0.8),
|
|
RawEntity(text="john", label="person", start=10, end=14, confidence=0.9),
|
|
]
|
|
result = dedupe_entities(entities)
|
|
assert len(result) == 1, "Should deduplicate same entity"
|
|
|
|
def test_keeps_higher_confidence(self) -> None:
|
|
entities = [
|
|
RawEntity(text="John", label="person", start=0, end=4, confidence=0.7),
|
|
RawEntity(text="john", label="person", start=10, end=14, confidence=0.9),
|
|
]
|
|
result = dedupe_entities(entities)
|
|
assert result[0].confidence == 0.9, "Should keep higher confidence"
|
|
|
|
def test_different_labels_not_deduped(self) -> None:
|
|
entities = [
|
|
RawEntity(text="Apple", label="product", start=0, end=5, confidence=0.8),
|
|
RawEntity(text="Apple", label="company", start=10, end=15, confidence=0.9),
|
|
]
|
|
result = dedupe_entities(entities)
|
|
assert len(result) == 2, "Different labels should not be deduped"
|
|
|
|
|
|
class TestDropLowSignalEntities:
|
|
"""Tests for low-signal entity filtering."""
|
|
|
|
@pytest.mark.parametrize(
|
|
("text", "label"),
|
|
[
|
|
("a", "person"),
|
|
("to", "location"),
|
|
("123", "product"),
|
|
("shit", "topic"),
|
|
],
|
|
)
|
|
def test_drops_low_signal_entities(self, text: str, label: str) -> None:
|
|
entity = RawEntity(text=text, label=label, start=0, end=len(text), confidence=0.8)
|
|
result = drop_low_signal_entities([entity], "some original text")
|
|
assert len(result) == 0, f"'{text}' should be dropped"
|
|
|
|
@pytest.mark.parametrize(
|
|
("text", "label"),
|
|
[
|
|
("John", "person"),
|
|
("New York", "location"),
|
|
],
|
|
)
|
|
def test_keeps_valid_entities(self, text: str, label: str) -> None:
|
|
entity = RawEntity(text=text, label=label, start=0, end=len(text), confidence=0.8)
|
|
result = drop_low_signal_entities([entity], "some original text")
|
|
assert len(result) == 1, f"'{text}' should NOT be dropped"
|
|
|
|
|
|
class TestInferDuration:
|
|
"""Tests for duration inference."""
|
|
|
|
@pytest.mark.parametrize(
|
|
("text", "expected_label"),
|
|
[
|
|
("20 minutes", "duration"),
|
|
("two weeks", "duration"),
|
|
("3 hours", "duration"),
|
|
("next week", "time_relative"),
|
|
("tomorrow", "time_relative"),
|
|
],
|
|
)
|
|
def test_infer_duration(self, text: str, expected_label: str) -> None:
|
|
entity = RawEntity(text=text, label="time_relative", start=0, end=len(text), confidence=0.8)
|
|
result = infer_duration([entity])
|
|
assert result[0].label == expected_label, f"'{text}' should be labeled as {expected_label}"
|
|
|
|
|
|
class TestMergeTimePhrases:
|
|
"""Tests for time phrase merging."""
|
|
|
|
def test_merges_adjacent_time_entities(self) -> None:
|
|
original_text = "last night we met"
|
|
entities = [
|
|
RawEntity(text="last", label="time", start=0, end=4, confidence=0.8),
|
|
RawEntity(text="night", label="time", start=5, end=10, confidence=0.7),
|
|
]
|
|
result = merge_time_phrases(entities, original_text)
|
|
|
|
time_entities = [e for e in result if e.label == "time"]
|
|
assert len(time_entities) == 1, "Should merge adjacent time entities"
|
|
assert time_entities[0].text == "last night", "Merged text should be 'last night'"
|
|
|
|
def test_does_not_merge_non_adjacent(self) -> None:
|
|
original_text = "Monday came and then Friday arrived"
|
|
entities = [
|
|
RawEntity(text="Monday", label="time", start=0, end=6, confidence=0.8),
|
|
RawEntity(text="Friday", label="time", start=21, end=27, confidence=0.8),
|
|
]
|
|
result = merge_time_phrases(entities, original_text)
|
|
|
|
time_entities = [e for e in result if e.label == "time"]
|
|
assert len(time_entities) == 2, "Should not merge non-adjacent time entities"
|
|
|
|
|
|
class TestPostProcess:
|
|
"""Tests for full post-processing pipeline."""
|
|
|
|
def test_full_pipeline(self) -> None:
|
|
original_text = "John met shit in Paris for 20 minutes"
|
|
entities = [
|
|
RawEntity(text="John", label="person", start=0, end=4, confidence=0.9),
|
|
RawEntity(text="shit", label="topic", start=9, end=13, confidence=0.5),
|
|
RawEntity(text="Paris", label="location", start=17, end=22, confidence=0.95),
|
|
RawEntity(text="20 minutes", label="time", start=27, end=37, confidence=0.8),
|
|
]
|
|
result = post_process(entities, original_text)
|
|
|
|
texts = {e.text for e in result}
|
|
assert "shit" not in texts, "Profanity should be filtered"
|
|
assert "John" in texts, "Valid person should remain"
|
|
assert "Paris" in texts, "Valid location should remain"
|
|
|
|
duration_entities = [e for e in result if e.label == "duration"]
|
|
assert len(duration_entities) == 1, "20 minutes should be labeled as duration"
|