1- from typing import List , Optional , Union
1+ from typing import List , Optional
22
3- from torch .autograd import Variable
43import torch
54import torch .nn as nn
65
@@ -20,8 +19,6 @@ class CRF(nn.Module):
2019
2120 Attributes
2221 ----------
23- num_tags : int
24- Number of tags passed to ``__init__``.
2522 start_transitions : :class:`~torch.nn.Parameter`
2623 Start transition score tensor of size ``(num_tags,)``.
2724 end_transitions : :class:`~torch.nn.Parameter`
@@ -43,9 +40,9 @@ def __init__(self, num_tags: int) -> None:
4340 raise ValueError (f'invalid number of tags: { num_tags } ' )
4441 super ().__init__ ()
4542 self .num_tags = num_tags
46- self .start_transitions = nn .Parameter (torch .Tensor (num_tags ))
47- self .end_transitions = nn .Parameter (torch .Tensor (num_tags ))
48- self .transitions = nn .Parameter (torch .Tensor (num_tags , num_tags ))
43+ self .start_transitions = nn .Parameter (torch .empty (num_tags ))
44+ self .end_transitions = nn .Parameter (torch .empty (num_tags ))
45+ self .transitions = nn .Parameter (torch .empty (num_tags , num_tags ))
4946
5047 self .reset_parameters ()
5148
@@ -55,35 +52,35 @@ def reset_parameters(self) -> None:
5552 The parameters will be initialized randomly from a uniform distribution
5653 between -0.1 and 0.1.
5754 """
58- nn .init .uniform (self .start_transitions , - 0.1 , 0.1 )
59- nn .init .uniform (self .end_transitions , - 0.1 , 0.1 )
60- nn .init .uniform (self .transitions , - 0.1 , 0.1 )
55+ nn .init .uniform_ (self .start_transitions , - 0.1 , 0.1 )
56+ nn .init .uniform_ (self .end_transitions , - 0.1 , 0.1 )
57+ nn .init .uniform_ (self .transitions , - 0.1 , 0.1 )
6158
6259 def __repr__ (self ) -> str :
6360 return f'{ self .__class__ .__name__ } (num_tags={ self .num_tags } )'
6461
6562 def forward (self ,
66- emissions : Variable ,
67- tags : Variable ,
68- mask : Optional [Variable ] = None ,
63+ emissions : torch . Tensor ,
64+ tags : torch . LongTensor ,
65+ mask : Optional [torch . ByteTensor ] = None ,
6966 reduce : bool = True ,
70- ) -> Variable :
67+ ) -> torch . Tensor :
7168 """Compute the log likelihood of the given sequence of tags and emission score.
7269
7370 Arguments
7471 ---------
75- emissions : :class:`~torch.autograd.Variable `
72+ emissions : :class:`~torch.Tensor `
7673 Emission score tensor of size ``(seq_length, batch_size, num_tags)``.
77- tags : :class:`~torch.autograd.Variable `
78- Sequence of tags as ``LongTensor`` of size ``(seq_length, batch_size)``.
79- mask : :class:`~torch.autograd.Variable `, optional
80- Mask tensor as ``ByteTensor`` of size ``(seq_length, batch_size)``.
74+ tags : :class:`~torch.LongTensor `
75+ Sequence of tags of size ``(seq_length, batch_size)``.
76+ mask : :class:`~torch.ByteTensor `, optional
77+ Mask tensor of size ``(seq_length, batch_size)``.
8178 reduce : bool
8279 Whether to sum the log likelihood over the batch.
8380
8481 Returns
8582 -------
86- :class:`~torch.autograd.Variable `
83+ :class:`~torch.Tensor `
8784 The log likelihood. This will have size (1,) if ``reduce=True``, ``(batch_size,)``
8885 otherwise.
8986 """
@@ -107,32 +104,32 @@ def forward(self,
107104 f'size of tags and mask must match, got { tuple (tags .size ())} '
108105 f'and { tuple (mask .size ())} '
109106 )
110- if not all (mask [0 ]. data ):
107+ if not all (mask [0 ]):
111108 raise ValueError ('mask of the first timestep must all be on' )
112109
113110 if mask is None :
114- mask = Variable ( self . _new (tags . size ()). fill_ ( 1 )). byte ( )
111+ mask = torch . ones_like (tags , dtype = torch . uint8 )
115112
116113 numerator = self ._compute_joint_llh (emissions , tags , mask )
117114 denominator = self ._compute_log_partition_function (emissions , mask )
118115 llh = numerator - denominator
119116 return llh if not reduce else torch .sum (llh )
120117
121118 def decode (self ,
122- emissions : Union [ Variable , torch .FloatTensor ] ,
123- mask : Optional [Union [ Variable , torch .ByteTensor ] ] = None ) -> List [List [int ]]:
119+ emissions : torch .Tensor ,
120+ mask : Optional [torch .ByteTensor ] = None ) -> List [List [int ]]:
124121 """Find the most likely tag sequence using Viterbi algorithm.
125122
126123 Arguments
127124 ---------
128- emissions : :class:`~torch.autograd.Variable` or :class:`~torch.FloatTensor `
125+ emissions : :class:`~torch.Tensor `
129126 Emission score tensor of size ``(seq_length, batch_size, num_tags)``.
130- mask : :class:`~torch.autograd.Variable` or :class:`torch. ByteTensor`
127+ mask : :class:`~torch.ByteTensor`
131128 Mask tensor of size ``(seq_length, batch_size)``.
132129
133130 Returns
134131 -------
135- list
132+ List[List[int]]
136133 List of list containing the best tag sequence for each batch.
137134 """
138135 if emissions .dim () != 3 :
@@ -148,27 +145,23 @@ def decode(self,
148145 f'got { tuple (emissions .size ()[:2 ])} and { tuple (mask .size ())} '
149146 )
150147
151- if isinstance (emissions , Variable ):
152- emissions = emissions .data
153148 if mask is None :
154- mask = self ._new (emissions .size ()[:2 ]).fill_ (1 ).byte ()
155- elif isinstance (mask , Variable ):
156- mask = mask .data
149+ mask = emissions .new_ones (emissions .shape [:2 ], dtype = torch .uint8 )
157150
158151 return self ._viterbi_decode (emissions , mask )
159152
160153 def _compute_joint_llh (self ,
161- emissions : Variable ,
162- tags : Variable ,
163- mask : Variable ) -> Variable :
154+ emissions : torch . Tensor ,
155+ tags : torch . LongTensor ,
156+ mask : torch . ByteTensor ) -> torch . Tensor :
164157 # emissions: (seq_length, batch_size, num_tags)
165158 # tags: (seq_length, batch_size)
166159 # mask: (seq_length, batch_size)
167160 assert emissions .dim () == 3 and tags .dim () == 2
168161 assert emissions .size ()[:2 ] == tags .size ()
169162 assert emissions .size (2 ) == self .num_tags
170163 assert mask .size () == tags .size ()
171- assert all (mask [0 ]. data )
164+ assert all (mask [0 ])
172165
173166 seq_length = emissions .size (0 )
174167 mask = mask .float ()
@@ -197,14 +190,14 @@ def _compute_joint_llh(self,
197190 return llh
198191
199192 def _compute_log_partition_function (self ,
200- emissions : Variable ,
201- mask : Variable ) -> Variable :
193+ emissions : torch . Tensor ,
194+ mask : torch . ByteTensor ) -> torch . Tensor :
202195 # emissions: (seq_length, batch_size, num_tags)
203196 # mask: (seq_length, batch_size)
204197 assert emissions .dim () == 3 and mask .dim () == 2
205198 assert emissions .size ()[:2 ] == mask .size ()
206199 assert emissions .size (2 ) == self .num_tags
207- assert all (mask [0 ]. data )
200+ assert all (mask [0 ])
208201
209202 seq_length = emissions .size (0 )
210203 mask = mask .float ()
@@ -226,15 +219,15 @@ def _compute_log_partition_function(self,
226219 + broadcast_emissions # (batch_size, num_tags, num_tags)
227220 # Sum over all possible current tags, but we're in log prob space, so a sum
228221 # becomes a log-sum-exp
229- score = self . _log_sum_exp (score , 1 ) # (batch_size, num_tags)
222+ score = torch . logsumexp (score , 1 ) # (batch_size, num_tags)
230223 # Set log_prob to the score if this timestep is valid (mask == 1), otherwise
231224 # leave it alone
232225 log_prob = score * mask [i ].unsqueeze (1 ) + log_prob * (1. - mask [i ]).unsqueeze (1 )
233226
234227 # End transition score
235228 log_prob += self .end_transitions .view (1 , - 1 )
236229 # Sum (log-sum-exp) over all possible tags
237- return self . _log_sum_exp (log_prob , 1 ) # (batch_size,)
230+ return torch . logsumexp (log_prob , 1 ) # (batch_size,)
238231
239232 def _viterbi_decode (self , emissions : torch .FloatTensor , mask : torch .ByteTensor ) \
240233 -> List [List [int ]]:
@@ -251,7 +244,7 @@ def _viterbi_decode(self, emissions: torch.FloatTensor, mask: torch.ByteTensor)
251244
252245 # Start transition
253246 viterbi_score = []
254- viterbi_score .append (self .start_transitions . data + emissions [0 ])
247+ viterbi_score .append (self .start_transitions + emissions [0 ])
255248 viterbi_path = []
256249
257250 # Here, viterbi_score is a list of tensors of shapes of (num_tags,) where value at
@@ -269,7 +262,7 @@ def _viterbi_decode(self, emissions: torch.FloatTensor, mask: torch.ByteTensor)
269262 # Compute the score matrix of shape (batch_size, num_tags, num_tags) where
270263 # for each sample, each entry at row i and column j stores the score of
271264 # transitioning from tag i to tag j and emitting
272- score = broadcast_score + self .transitions . data + broadcast_emission
265+ score = broadcast_score + self .transitions + broadcast_emission
273266 # Find the maximum score over all possible current tag
274267 best_score , best_path = score .max (1 ) # (batch_size,num_tags,)
275268 # Save the score and the path
@@ -280,32 +273,17 @@ def _viterbi_decode(self, emissions: torch.FloatTensor, mask: torch.ByteTensor)
280273 for idx in range (batch_size ):
281274 # Find the tag which maximizes the score at the last timestep; this is our best tag
282275 # for the last timestep
283- seq_end = sequence_lengths [idx ]- 1
284- _ , best_last_tag = (viterbi_score [seq_end ][idx ] + self .end_transitions . data ).max (0 )
285- best_tags = [best_last_tag [ 0 ] ]
276+ seq_end = sequence_lengths [idx ] - 1
277+ _ , best_last_tag = (viterbi_score [seq_end ][idx ] + self .end_transitions ).max (0 )
278+ best_tags = [best_last_tag . item () ]
286279
287280 # We trace back where the best last tag comes from, append that to our best tag
288281 # sequence, and trace it back again, and so on
289282 for path in reversed (viterbi_path [:sequence_lengths [idx ] - 1 ]):
290283 best_last_tag = path [idx ][best_tags [- 1 ]]
291- best_tags .append (best_last_tag )
284+ best_tags .append (best_last_tag . item () )
292285
293286 # Reverse the order because we start from the last timestep
294287 best_tags .reverse ()
295288 best_tags_list .append (best_tags )
296289 return best_tags_list
297-
298- @staticmethod
299- def _log_sum_exp (tensor : Variable , dim : int ) -> Variable :
300- # Find the max value along `dim`
301- offset , _ = tensor .max (dim )
302- # Make offset broadcastable
303- broadcast_offset = offset .unsqueeze (dim )
304- # Perform log-sum-exp safely
305- safe_log_sum_exp = torch .log (torch .sum (torch .exp (tensor - broadcast_offset ), dim ))
306- # Add offset back
307- return offset + safe_log_sum_exp
308-
309- def _new (self , * args , ** kwargs ) -> Union [torch .FloatTensor , torch .cuda .FloatTensor ]:
310- param = next (self .parameters ())
311- return param .data .new (* args , ** kwargs )
0 commit comments