Skip to content

Commit 2c79f9a

Browse files
authored
PyTorch 0.4.1 support (#25)
* Closes #15 * Make tests pass for PyTorch 0.4 * Get rid of _new; use tensor.new_* and torch.*_like instead * Get rid of obsolete tensor creation methods * Get rid of warnings * Use PyTorch's logsumexp * Fix type annotations and docstrings * Update README
1 parent 9c3a147 commit 2c79f9a

File tree

4 files changed

+106
-166
lines changed

4 files changed

+106
-166
lines changed

.travis.yml

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,12 @@
11
language: python
2-
env:
3-
- PYTORCH_VERSION=0.3.0
4-
- PYTORCH_VERSION=0.3.1
52
before_install:
63
- wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh
74
- bash ~/miniconda.sh -b -p $HOME/miniconda
85
- export PATH="$HOME/miniconda/bin:$PATH"
96
- conda update -y conda
107
- conda create -y -n pytorch-crf python=3.6
118
- source activate pytorch-crf
12-
- conda install -y -c pytorch pytorch=$PYTORCH_VERSION
13-
# explicitly specify cudatoolkit version (see https://discuss.pytorch.org/t/libcudart-so-8-0-not-found-in-travis-ci/16071/3)
14-
- if [[ $PYTORCH_VERSION == "0.3.0" ]]; then conda install -y -c pytorch cudatoolkit=8.0; fi
9+
- conda install -y -c pytorch pytorch=0.4.1
1510
- export LD_LIBRARY_PATH="$HOME/miniconda/envs/pytorch-crf/lib:$LD_LIBRARY_PATH"
1611
install:
1712
- pip install --ignore-installed -r requirements.txt

README.rst

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Requirements
2424
============
2525

2626
- Python 3.6
27-
- PyTorch 0.3.0
27+
- PyTorch 0.4.1
2828

2929
Installation
3030
============
@@ -47,8 +47,8 @@ In the examples below, we will assume that these lines have been executed
4747
>>> import torch
4848
>>> from torchcrf import CRF
4949
>>> seq_length, batch_size, num_tags = 3, 2, 5
50-
>>> emissions = torch.autograd.Variable(torch.randn(seq_length, batch_size, num_tags), requires_grad=True)
51-
>>> tags = torch.autograd.Variable(torch.LongTensor([[0, 1], [2, 4], [3, 1]])) # (seq_length, batch_size)
50+
>>> emissions = torch.randn(seq_length, batch_size, num_tags)
51+
>>> tags = torch.tensor([[0, 1], [2, 4], [3, 1]], dtype=torch.long) # (seq_length, batch_size)
5252
>>> model = CRF(num_tags)
5353
5454
Computing log likelihood
@@ -57,20 +57,16 @@ Computing log likelihood
5757
.. code-block:: python
5858
5959
>>> model(emissions, tags)
60-
Variable containing:
61-
-10.0635
62-
[torch.FloatTensor of size 1]
60+
tensor(-12.7431, grad_fn=<SumBackward0>)
6361
6462
Computing log likelihood with mask
6563
----------------------------------
6664

6765
.. code-block:: python
6866
69-
>>> mask = torch.autograd.Variable(torch.ByteTensor([[1, 1], [1, 1], [1, 0]])) # (seq_length, batch_size)
67+
>>> mask = torch.tensor([[1, 1], [1, 1], [1, 0]], dtype=torch.uint8) # (seq_length, batch_size)
7068
>>> model(emissions, tags, mask=mask)
71-
Variable containing:
72-
-8.4981
73-
[torch.FloatTensor of size 1]
69+
tensor(-10.8390, grad_fn=<SumBackward0>)
7470
7571
Decoding
7672
--------

src/torchcrf/__init__.py

Lines changed: 40 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
from typing import List, Optional, Union
1+
from typing import List, Optional
22

3-
from torch.autograd import Variable
43
import torch
54
import torch.nn as nn
65

@@ -20,8 +19,6 @@ class CRF(nn.Module):
2019
2120
Attributes
2221
----------
23-
num_tags : int
24-
Number of tags passed to ``__init__``.
2522
start_transitions : :class:`~torch.nn.Parameter`
2623
Start transition score tensor of size ``(num_tags,)``.
2724
end_transitions : :class:`~torch.nn.Parameter`
@@ -43,9 +40,9 @@ def __init__(self, num_tags: int) -> None:
4340
raise ValueError(f'invalid number of tags: {num_tags}')
4441
super().__init__()
4542
self.num_tags = num_tags
46-
self.start_transitions = nn.Parameter(torch.Tensor(num_tags))
47-
self.end_transitions = nn.Parameter(torch.Tensor(num_tags))
48-
self.transitions = nn.Parameter(torch.Tensor(num_tags, num_tags))
43+
self.start_transitions = nn.Parameter(torch.empty(num_tags))
44+
self.end_transitions = nn.Parameter(torch.empty(num_tags))
45+
self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))
4946

5047
self.reset_parameters()
5148

@@ -55,35 +52,35 @@ def reset_parameters(self) -> None:
5552
The parameters will be initialized randomly from a uniform distribution
5653
between -0.1 and 0.1.
5754
"""
58-
nn.init.uniform(self.start_transitions, -0.1, 0.1)
59-
nn.init.uniform(self.end_transitions, -0.1, 0.1)
60-
nn.init.uniform(self.transitions, -0.1, 0.1)
55+
nn.init.uniform_(self.start_transitions, -0.1, 0.1)
56+
nn.init.uniform_(self.end_transitions, -0.1, 0.1)
57+
nn.init.uniform_(self.transitions, -0.1, 0.1)
6158

6259
def __repr__(self) -> str:
6360
return f'{self.__class__.__name__}(num_tags={self.num_tags})'
6461

6562
def forward(self,
66-
emissions: Variable,
67-
tags: Variable,
68-
mask: Optional[Variable] = None,
63+
emissions: torch.Tensor,
64+
tags: torch.LongTensor,
65+
mask: Optional[torch.ByteTensor] = None,
6966
reduce: bool = True,
70-
) -> Variable:
67+
) -> torch.Tensor:
7168
"""Compute the log likelihood of the given sequence of tags and emission score.
7269
7370
Arguments
7471
---------
75-
emissions : :class:`~torch.autograd.Variable`
72+
emissions : :class:`~torch.Tensor`
7673
Emission score tensor of size ``(seq_length, batch_size, num_tags)``.
77-
tags : :class:`~torch.autograd.Variable`
78-
Sequence of tags as ``LongTensor`` of size ``(seq_length, batch_size)``.
79-
mask : :class:`~torch.autograd.Variable`, optional
80-
Mask tensor as ``ByteTensor`` of size ``(seq_length, batch_size)``.
74+
tags : :class:`~torch.LongTensor`
75+
Sequence of tags of size ``(seq_length, batch_size)``.
76+
mask : :class:`~torch.ByteTensor`, optional
77+
Mask tensor of size ``(seq_length, batch_size)``.
8178
reduce : bool
8279
Whether to sum the log likelihood over the batch.
8380
8481
Returns
8582
-------
86-
:class:`~torch.autograd.Variable`
83+
:class:`~torch.Tensor`
8784
The log likelihood. This will have size (1,) if ``reduce=True``, ``(batch_size,)``
8885
otherwise.
8986
"""
@@ -107,32 +104,32 @@ def forward(self,
107104
f'size of tags and mask must match, got {tuple(tags.size())} '
108105
f'and {tuple(mask.size())}'
109106
)
110-
if not all(mask[0].data):
107+
if not all(mask[0]):
111108
raise ValueError('mask of the first timestep must all be on')
112109

113110
if mask is None:
114-
mask = Variable(self._new(tags.size()).fill_(1)).byte()
111+
mask = torch.ones_like(tags, dtype=torch.uint8)
115112

116113
numerator = self._compute_joint_llh(emissions, tags, mask)
117114
denominator = self._compute_log_partition_function(emissions, mask)
118115
llh = numerator - denominator
119116
return llh if not reduce else torch.sum(llh)
120117

121118
def decode(self,
122-
emissions: Union[Variable, torch.FloatTensor],
123-
mask: Optional[Union[Variable, torch.ByteTensor]] = None) -> List[List[int]]:
119+
emissions: torch.Tensor,
120+
mask: Optional[torch.ByteTensor] = None) -> List[List[int]]:
124121
"""Find the most likely tag sequence using Viterbi algorithm.
125122
126123
Arguments
127124
---------
128-
emissions : :class:`~torch.autograd.Variable` or :class:`~torch.FloatTensor`
125+
emissions : :class:`~torch.Tensor`
129126
Emission score tensor of size ``(seq_length, batch_size, num_tags)``.
130-
mask : :class:`~torch.autograd.Variable` or :class:`torch.ByteTensor`
127+
mask : :class:`~torch.ByteTensor`
131128
Mask tensor of size ``(seq_length, batch_size)``.
132129
133130
Returns
134131
-------
135-
list
132+
List[List[int]]
136133
List of list containing the best tag sequence for each batch.
137134
"""
138135
if emissions.dim() != 3:
@@ -148,27 +145,23 @@ def decode(self,
148145
f'got {tuple(emissions.size()[:2])} and {tuple(mask.size())}'
149146
)
150147

151-
if isinstance(emissions, Variable):
152-
emissions = emissions.data
153148
if mask is None:
154-
mask = self._new(emissions.size()[:2]).fill_(1).byte()
155-
elif isinstance(mask, Variable):
156-
mask = mask.data
149+
mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8)
157150

158151
return self._viterbi_decode(emissions, mask)
159152

160153
def _compute_joint_llh(self,
161-
emissions: Variable,
162-
tags: Variable,
163-
mask: Variable) -> Variable:
154+
emissions: torch.Tensor,
155+
tags: torch.LongTensor,
156+
mask: torch.ByteTensor) -> torch.Tensor:
164157
# emissions: (seq_length, batch_size, num_tags)
165158
# tags: (seq_length, batch_size)
166159
# mask: (seq_length, batch_size)
167160
assert emissions.dim() == 3 and tags.dim() == 2
168161
assert emissions.size()[:2] == tags.size()
169162
assert emissions.size(2) == self.num_tags
170163
assert mask.size() == tags.size()
171-
assert all(mask[0].data)
164+
assert all(mask[0])
172165

173166
seq_length = emissions.size(0)
174167
mask = mask.float()
@@ -197,14 +190,14 @@ def _compute_joint_llh(self,
197190
return llh
198191

199192
def _compute_log_partition_function(self,
200-
emissions: Variable,
201-
mask: Variable) -> Variable:
193+
emissions: torch.Tensor,
194+
mask: torch.ByteTensor) -> torch.Tensor:
202195
# emissions: (seq_length, batch_size, num_tags)
203196
# mask: (seq_length, batch_size)
204197
assert emissions.dim() == 3 and mask.dim() == 2
205198
assert emissions.size()[:2] == mask.size()
206199
assert emissions.size(2) == self.num_tags
207-
assert all(mask[0].data)
200+
assert all(mask[0])
208201

209202
seq_length = emissions.size(0)
210203
mask = mask.float()
@@ -226,15 +219,15 @@ def _compute_log_partition_function(self,
226219
+ broadcast_emissions # (batch_size, num_tags, num_tags)
227220
# Sum over all possible current tags, but we're in log prob space, so a sum
228221
# becomes a log-sum-exp
229-
score = self._log_sum_exp(score, 1) # (batch_size, num_tags)
222+
score = torch.logsumexp(score, 1) # (batch_size, num_tags)
230223
# Set log_prob to the score if this timestep is valid (mask == 1), otherwise
231224
# leave it alone
232225
log_prob = score * mask[i].unsqueeze(1) + log_prob * (1.-mask[i]).unsqueeze(1)
233226

234227
# End transition score
235228
log_prob += self.end_transitions.view(1, -1)
236229
# Sum (log-sum-exp) over all possible tags
237-
return self._log_sum_exp(log_prob, 1) # (batch_size,)
230+
return torch.logsumexp(log_prob, 1) # (batch_size,)
238231

239232
def _viterbi_decode(self, emissions: torch.FloatTensor, mask: torch.ByteTensor) \
240233
-> List[List[int]]:
@@ -251,7 +244,7 @@ def _viterbi_decode(self, emissions: torch.FloatTensor, mask: torch.ByteTensor)
251244

252245
# Start transition
253246
viterbi_score = []
254-
viterbi_score.append(self.start_transitions.data + emissions[0])
247+
viterbi_score.append(self.start_transitions + emissions[0])
255248
viterbi_path = []
256249

257250
# Here, viterbi_score is a list of tensors of shapes of (num_tags,) where value at
@@ -269,7 +262,7 @@ def _viterbi_decode(self, emissions: torch.FloatTensor, mask: torch.ByteTensor)
269262
# Compute the score matrix of shape (batch_size, num_tags, num_tags) where
270263
# for each sample, each entry at row i and column j stores the score of
271264
# transitioning from tag i to tag j and emitting
272-
score = broadcast_score + self.transitions.data + broadcast_emission
265+
score = broadcast_score + self.transitions + broadcast_emission
273266
# Find the maximum score over all possible current tag
274267
best_score, best_path = score.max(1) # (batch_size,num_tags,)
275268
# Save the score and the path
@@ -280,32 +273,17 @@ def _viterbi_decode(self, emissions: torch.FloatTensor, mask: torch.ByteTensor)
280273
for idx in range(batch_size):
281274
# Find the tag which maximizes the score at the last timestep; this is our best tag
282275
# for the last timestep
283-
seq_end = sequence_lengths[idx]-1
284-
_, best_last_tag = (viterbi_score[seq_end][idx] + self.end_transitions.data).max(0)
285-
best_tags = [best_last_tag[0]]
276+
seq_end = sequence_lengths[idx] - 1
277+
_, best_last_tag = (viterbi_score[seq_end][idx] + self.end_transitions).max(0)
278+
best_tags = [best_last_tag.item()]
286279

287280
# We trace back where the best last tag comes from, append that to our best tag
288281
# sequence, and trace it back again, and so on
289282
for path in reversed(viterbi_path[:sequence_lengths[idx] - 1]):
290283
best_last_tag = path[idx][best_tags[-1]]
291-
best_tags.append(best_last_tag)
284+
best_tags.append(best_last_tag.item())
292285

293286
# Reverse the order because we start from the last timestep
294287
best_tags.reverse()
295288
best_tags_list.append(best_tags)
296289
return best_tags_list
297-
298-
@staticmethod
299-
def _log_sum_exp(tensor: Variable, dim: int) -> Variable:
300-
# Find the max value along `dim`
301-
offset, _ = tensor.max(dim)
302-
# Make offset broadcastable
303-
broadcast_offset = offset.unsqueeze(dim)
304-
# Perform log-sum-exp safely
305-
safe_log_sum_exp = torch.log(torch.sum(torch.exp(tensor - broadcast_offset), dim))
306-
# Add offset back
307-
return offset + safe_log_sum_exp
308-
309-
def _new(self, *args, **kwargs) -> Union[torch.FloatTensor, torch.cuda.FloatTensor]:
310-
param = next(self.parameters())
311-
return param.data.new(*args, **kwargs)

0 commit comments

Comments
 (0)