1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414"""
15- T5 Model
15+ T5 Model.
1616"""
1717
1818from typing import Optional , Union
@@ -39,9 +39,8 @@ class T5EncoderDecoder(EncoderDecoderBase, PretrainedT5Mixin):
3939
4040 This module basically stacks
4141 :class:`~texar.torch.modules.WordEmbedder`,
42- :class:`~texar.torch.modules.TransformerEncoder`,
43- :class:`~texar.torch.modules.TransformerDecoder` and a dense
44- pooler.
42+ :class:`~texar.torch.modules.T5Encoder`, and
43+ :class:`~texar.torch.modules.T5Decoder`.
4544
4645 Args:
4746 pretrained_model_name (optional): a `str`, the name
@@ -100,10 +99,10 @@ def reset_parameters(self):
10099 def default_hparams ():
101100 r"""Returns a dictionary of hyperparameters with default values.
102101
103- * The encoder arch is determined by the constructor argument
102+ * The model arch is determined by the constructor argument
104103 :attr:`pretrained_model_name` if it's specified. In this case,
105104 `hparams` are ignored.
106- * Otherwise, the encoder arch is determined by
105+ * Otherwise, the model arch is determined by
107106 `hparams['pretrained_model_name']` if it's specified. All other
108107 configurations in `hparams` are ignored.
109108 * If the above two are `None`, the encoder arch is defined by the
@@ -112,7 +111,7 @@ def default_hparams():
112111 .. code-block:: python
113112
114113 {
115- "pretrained_model_name": "bert-base-uncased ",
114+ "pretrained_model_name": "T5-Small ",
116115 "embed": {
117116 "dim": 768,
118117 "name": "word_embeddings"
@@ -128,18 +127,20 @@ def default_hparams():
128127 "num_heads": 12,
129128 "num_units": 768,
130129 "output_dim": 768,
131- "use_bias": True
130+ "use_bias": False,
131+ "is_decoder": False,
132+ "relative_attention_num_buckets": 32,
132133 },
133- "relative_attention_num_buckets ": 32 ,
134- "name": "t5encoder ",
134+ "eps ": 1e-6 ,
135+ "name": "encoder ",
135136 "num_blocks": 12,
136137 "poswise_feedforward": {
137138 "layers": [
138139 {
139140 "kwargs": {
140141 "in_features": 768,
141142 "out_features": 3072,
142- "bias": True
143+ "bias": False
143144 },
144145 "type": "Linear"
145146 },
@@ -148,7 +149,7 @@ def default_hparams():
148149 "kwargs": {
149150 "in_features": 3072,
150151 "out_features": 768,
151- "bias": True
152+ "bias": False
152153 },
153154 "type": "Linear"
154155 }
@@ -158,6 +159,7 @@ def default_hparams():
158159 },
159160
160161 "decoder": {
162+ "eps": 1e-6,
161163 "dim": 768,
162164 "embedding_dropout": 0.1,
163165 "multihead_attention": {
@@ -166,19 +168,19 @@ def default_hparams():
166168 "num_heads": 12,
167169 "num_units": 768,
168170 "output_dim": 768,
169- "use_bias": True,
171+ "use_bias": False,
172+ "is_decoder": True,
170173 "relative_attention_num_buckets": 32,
171174 },
172-
173- "name": "t5coder",
175+ "name": "decoder",
174176 "num_blocks": 12,
175177 "poswise_feedforward": {
176178 "layers": [
177179 {
178180 "kwargs": {
179181 "in_features": 768,
180182 "out_features": 3072,
181- "bias": True
183+ "bias": False
182184 },
183185 "type": "Linear"
184186 },
@@ -187,7 +189,7 @@ def default_hparams():
187189 "kwargs": {
188190 "in_features": 3072,
189191 "out_features": 768,
190- "bias": True
192+ "bias": False
191193 },
192194 "type": "Linear"
193195 }
@@ -202,34 +204,30 @@ def default_hparams():
202204
203205 Here:
204206
205- The default parameters are values for uncased BERT-Base model.
207+ The default parameters are values for T5-Small model.
206208
207209 `"pretrained_model_name"`: str or None
208- The name of the pre-trained BERT model. If None, the model
210+ The name of the pre-trained T5 model. If None, the model
209211 will be randomly initialized.
210212
211213 `"embed"`: dict
212214 Hyperparameters for word embedding layer.
213215
214216 `"vocab_size"`: int
215- The vocabulary size of `inputs` in BERT model.
216-
217- `"type_vocab_size"`: int
218- The vocabulary size of the `segment_ids` passed into `BertModel`.
219-
220- `"position_embed"`: dict
221- Hyperparameters for position embedding layer.
222-
223- `"position_size"`: int
224- The maximum sequence length that this model might ever be used with.
217+ The vocabulary size of `inputs` in T5 model.
225218
226219 `"encoder"`: dict
227- Hyperparameters for the T5Encoder.
220+ Hyperparameters for the ` T5Encoder` .
228221 See :func:`~texar.torch.modules.T5Encoder.default_hparams`
229222 for details.
230223
224+ `"decoder"`: dict
225+ Hyperparameters for the `T5Decoder`.
226+ See :func:`~texar.torch.modules.T5Decoder.default_hparams`
227+ for details.
228+
231229 `"hidden_size"`: int
232- Size of the pooler dense layer.
230+ Size of the hidden layer.
233231
234232 `"initializer"`: dict, optional
235233 Hyperparameters of the default initializer that initializes
@@ -301,7 +299,7 @@ def default_hparams():
301299 'is_decoder' : True ,
302300 'relative_attention_num_buckets' : 32
303301 },
304- 'name' : 'encoder ' ,
302+ 'name' : 'decoder ' ,
305303 'num_blocks' : 12 ,
306304 'poswise_feedforward' : {
307305 'layers' : [
@@ -335,10 +333,10 @@ def default_hparams():
335333 def forward (self , # type: ignore
336334 inputs : Union [torch .Tensor , torch .LongTensor ],
337335 sequence_length : Optional [torch .LongTensor ] = None ):
338- r"""
336+ r"""Performs encoding and decoding.
339337
340338 Args:
341- inputs: Either a **2D Tensor** of shape `[batch_size, max_time]`,
339+ inputs: Either a **2D Tensor** of shape `` [batch_size, max_time]` `,
342340 containing the ids of tokens in input sequences, or
343341 a **3D Tensor** of shape `[batch_size, max_time, vocab_size]`,
344342 containing soft token ids (i.e., weights or probabilities)
@@ -348,6 +346,14 @@ def forward(self, # type: ignore
348346 lengths are masked out automatically.
349347
350348 Returns:
349+ A pair :attr:`(encoder_output, decoder_output)`
350+
351+ - :attr:`encoder_output`: A Tensor of shape
352+ `[batch_size, max_time, dim]` containing the encoded vectors.
353+
354+ - :attr:`decoder_output`: An instance of
355+ :class:`~texar.torch.modules.TransformerDecoderOutput` which
356+ contains `sample_id` and `logits`.
351357 """
352358 if inputs .dim () == 2 :
353359 word_embeds = self .word_embedder (ids = inputs )
@@ -373,7 +379,6 @@ def forward(self, # type: ignore
373379
374380 @property
375381 def output_size (self ):
376- r"""The feature size of :meth:`forward` output
377- :attr:`pooled_output`.
382+ r"""The feature size of :meth:`forward` output of the encoder.
378383 """
379384 return self ._hparams .hidden_size
0 commit comments