langtest.modelhandler.transformers_modelhandler.PretrainedModelForTextClassification#
- class PretrainedModelForTextClassification(model: Pipeline)#
Bases:
ModelAPI
Transformers pretrained model for text classification tasks
- model#
Loaded Text Classification pipeline for predictions.
- Type:
transformers.pipeline.Pipeline
- __init__(model: Pipeline)#
Constructor method
- Parameters:
model (transformers.pipeline.Pipeline) – Pretrained HuggingFace NER pipeline for predictions.
Methods
__init__
(model)Constructor method
load_model
(path)Load and return text classification transformers pipeline
predict
(text[, return_all_scores, ...])Perform predictions on the input text.
predict_raw
(text[, truncation_strategy])Perform predictions on the input text.
Attributes
Return classification labels of pipeline model.
model_registry
- property labels: List[str]#
Return classification labels of pipeline model.
- classmethod load_model(path: str) Pipeline #
Load and return text classification transformers pipeline
- predict(text: str, return_all_scores: bool = False, truncation_strategy: str = 'longest_first', *args, **kwargs) SequenceClassificationOutput #
Perform predictions on the input text.
- Parameters:
text (str) – Input text to perform NER on.
return_all_scores (bool) – Option to group entities.
truncation_strategy (str) – strategy to use to truncate too long sequences
kwargs – Additional keyword arguments.
- Returns:
text classification from the input text.
- Return type:
- predict_raw(text: str, truncation_strategy: str = 'longest_first') List[str] #
Perform predictions on the input text.
- Parameters:
text (str) – Input text to perform NER on.
truncation_strategy (str) – strategy to use to truncate too long sequences
- Returns:
Predictions as a list of strings.
- Return type:
List[str]