nlpatl.models.classification.sklearn_classification
sci-kit learn classification wrapper
- class nlpatl.models.classification.sklearn_classification.SkLearnClassification(model_name='logistic_regression', model_config={}, name='sklearn_classification')[source]
Bases:
nlpatl.models.classification.classification.Classification
A wrapper of sci-kit learn classification class.
- Parameters
model_name (str) – sci-kit learn classification model name. Possible values are logistic_regression, svc, linear_svc and random_forest.
model_config (dict) – Model paramateters. Refer to https://scikit-learn.org/stable/modules/classes.html#module-sklearn.linear_model
name (str) – Name of this classification
>>> import nlpatl.models.classification as nmcla >>> model = nmcla.SkLearnClassification()
- predict_proba(x, predict_config={})[source]
- Parameters
x (np.ndarray) – Raw features
predict_config (dict) – Model prediction paramateters. Refer to https://scikit-learn.org/stable/modules/classes.html#module-sklearn.linear_model
- Returns
Feature and probabilities
- Return type
nlptatl.dataset.Dataset