pytorch-crf¶
Conditional random fields in PyTorch.
This package provides an implementation of a conditional random fields (CRF) layer in PyTorch. The implementation borrows mostly from AllenNLP CRF module with some modifications.
Minimal requirements¶
- Python 3.6
- PyTorch 1.0.0
Installation¶
Install with pip:
pip install pytorch-crf
Or, install from Github for the latest version:
pip install git+https://github.com/kmkurn/pytorch-crf#egg=pytorch_crf
Getting started¶
pytorch-crf exposes a single CRF
class which inherits from PyTorch’s
nn.Module
. This class provides an implementation of a CRF layer.
>>> import torch
>>> from torchcrf import CRF
>>> num_tags = 5 # number of tags is 5
>>> model = CRF(num_tags)
Computing log likelihood¶
Once created, you can compute the log likelihood of a sequence of tags given some emission scores.
>>> seq_length = 3 # maximum sequence length in a batch
>>> batch_size = 2 # number of samples in the batch
>>> emissions = torch.randn(seq_length, batch_size, num_tags)
>>> tags = torch.tensor([
... [0, 1], [2, 4], [3, 1]
... ], dtype=torch.long) # (seq_length, batch_size)
>>> model(emissions, tags)
tensor(-12.7431, grad_fn=<SumBackward0>)
If you have some padding in your input tensors, you can pass a mask tensor.
>>> # mask size is (seq_length, batch_size)
>>> # the last sample has length of 1
>>> mask = torch.tensor([
... [1, 1], [1, 1], [1, 0]
... ], dtype=torch.uint8)
>>> model(emissions, tags, mask=mask)
tensor(-10.8390, grad_fn=<SumBackward0>)
Note that the returned value is the log likelihood so you’ll need to make this value
negative as your loss. By default, the log likelihood is summed over batches. For other
options, consult the API documentation of CRF.forward
.
Decoding¶
To obtain the most probable sequence of tags, use the CRF.decode
method.
>>> model.decode(emissions)
[[3, 1, 3], [0, 1, 0]]
This method also accepts a mask tensor, see CRF.decode
for details.
API documentation¶
-
class
torchcrf.
CRF
(num_tags, batch_first=False)[source]¶ Conditional random field.
This module implements a conditional random field [LMP01]. The forward computation of this class computes the log likelihood of the given sequence of tags and emission score tensor. This class also has
decode
method which finds the best tag sequence given an emission score tensor using Viterbi algorithm.Parameters: -
start_transitions
¶ Start transition score tensor of size
(num_tags,)
.Type: Parameter
-
end_transitions
¶ End transition score tensor of size
(num_tags,)
.Type: Parameter
-
transitions
¶ Transition score tensor of size
(num_tags, num_tags)
.Type: Parameter
[LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001). “Conditional random fields: Probabilistic models for segmenting and labeling sequence data”. Proc. 18th International Conf. on Machine Learning. Morgan Kaufmann. pp. 282–289. -
decode
(emissions, mask=None)[source]¶ Find the most likely tag sequence using Viterbi algorithm.
Parameters: - emissions (
Tensor
) – Emission score tensor of size(seq_length, batch_size, num_tags)
ifbatch_first
isFalse
,(batch_size, seq_length, num_tags)
otherwise. - mask (
ByteTensor
) – Mask tensor of size(seq_length, batch_size)
ifbatch_first
isFalse
,(batch_size, seq_length)
otherwise.
Return type: Returns: List of list containing the best tag sequence for each batch.
- emissions (
-
forward
(emissions, tags, mask=None, reduction='sum')[source]¶ Compute the conditional log likelihood of a sequence of tags given emission scores.
Parameters: - emissions (
Tensor
) – Emission score tensor of size(seq_length, batch_size, num_tags)
ifbatch_first
isFalse
,(batch_size, seq_length, num_tags)
otherwise. - tags (
LongTensor
) – Sequence of tags tensor of size(seq_length, batch_size)
ifbatch_first
isFalse
,(batch_size, seq_length)
otherwise. - mask (
ByteTensor
) – Mask tensor of size(seq_length, batch_size)
ifbatch_first
isFalse
,(batch_size, seq_length)
otherwise. - reduction (
str
) – Specifies the reduction to apply to the output:none|sum|mean|token_mean
.none
: no reduction will be applied.sum
: the output will be summed over batches.mean
: the output will be averaged over batches.token_mean
: the output will be averaged over tokens.
Returns: The log likelihood. This will have size
(batch_size,)
if reduction isnone
,()
otherwise.Return type: - emissions (
-