from typing import Tuple
import numpy as np
from nlpatl.sampling import Sampling
[docs]class MarginSampling(Sampling):
"""
Sampling data points according to the margin confidence. Pick the lowest
probabilies difference between the highest class and second higest class.
:param name: Name of this sampling
:type name: str
"""
def __init__(self, name: str = 'margin_sampling'):
super().__init__(name=name)
[docs] def sample(self, data: np.ndarray,
num_sample: int) -> Tuple[np.ndarray, np.ndarray]:
num_node = min(num_sample, len(data))
# Calculate margin difference between first and second highest probabilties
margin_diffs = np.partition(-data, 1, axis=1)
margin_diffs = -margin_diffs[:, 0] + margin_diffs[:, 1]
indices = np.argpartition(margin_diffs, num_node-1)[:num_node]
return indices, margin_diffs[indices]