11# coding: utf-8
22# 2021/8/1 @ tongshiwei
33
4+ import torch
45import json
56import os .path
67from typing import List , Tuple
@@ -59,12 +60,12 @@ class I2V(object):
5960 """
6061
6162 def __init__ (self , tokenizer , t2v , * args , tokenizer_kwargs : dict = None ,
62- pretrained_t2v = False , model_dir = MODEL_DIR , ** kwargs ):
63+ pretrained_t2v = False , model_dir = MODEL_DIR , device = 'cpu' , ** kwargs ):
6364 if pretrained_t2v :
6465 logger .info ("Use pretrained t2v model %s" % t2v )
65- self .t2v = get_t2v_pretrained_model (t2v , model_dir )
66+ self .t2v = get_t2v_pretrained_model (t2v , model_dir , device )
6667 else :
67- self .t2v = T2V (t2v , * args , ** kwargs )
68+ self .t2v = T2V (t2v , device = device , * args , ** kwargs )
6869 if tokenizer == 'bert' :
6970 self .tokenizer = BertTokenizer .from_pretrained (
7071 ** tokenizer_kwargs if tokenizer_kwargs is not None else {})
@@ -82,31 +83,53 @@ def __init__(self, tokenizer, t2v, *args, tokenizer_kwargs: dict = None,
8283 ** tokenizer_kwargs if tokenizer_kwargs is not None else {})
8384 self .params = {
8485 "tokenizer" : tokenizer ,
85- "tokenizer_kwargs" : tokenizer_kwargs ,
8686 "t2v" : t2v ,
8787 "args" : args ,
88+ "tokenizer_kwargs" : tokenizer_kwargs ,
89+ "pretrained_t2v" : pretrained_t2v ,
90+ "model_dir" : model_dir ,
8891 "kwargs" : kwargs ,
89- "pretrained_t2v" : pretrained_t2v
9092 }
93+ self .device = torch .device (device )
9194
9295 def __call__ (self , items , * args , ** kwargs ):
9396 """transfer item to vector"""
9497 return self .infer_vector (items , * args , ** kwargs )
9598
9699 def tokenize (self , items , * args , key = lambda x : x , ** kwargs ) -> list :
97- # """tokenize item"""
100+ """
101+ tokenize item
102+ Parameter
103+ ----------
104+ items: a list of questions
105+ Return
106+ ----------
107+ tokens: list
108+ """
98109 return self .tokenizer (items , * args , key = key , ** kwargs )
99110
100111 def infer_vector (self , items , key = lambda x : x , ** kwargs ) -> tuple :
112+ """
113+ get question embedding
114+ NotImplemented
115+ """
101116 raise NotImplementedError
102117
103118 def infer_item_vector (self , tokens , * args , ** kwargs ) -> ...:
119+ """NotImplemented"""
104120 return self .infer_vector (tokens , * args , ** kwargs )[0 ]
105121
106122 def infer_token_vector (self , tokens , * args , ** kwargs ) -> ...:
123+ """NotImplemented"""
107124 return self .infer_vector (tokens , * args , ** kwargs )[1 ]
108125
109126 def save (self , config_path ):
127+ """
128+ save model weights in config_path
129+ Parameter:
130+ ----------
131+ config_path: str
132+ """
110133 with open (config_path , "w" , encoding = "utf-8" ) as wf :
111134 json .dump (self .params , wf , ensure_ascii = False , indent = 2 )
112135
@@ -123,6 +146,7 @@ def load(cls, config_path, *args, **kwargs):
123146
124147 @classmethod
125148 def from_pretrained (cls , name , model_dir = MODEL_DIR , * args , ** kwargs ):
149+ """NotImplemented"""
126150 raise NotImplementedError
127151
128152 @property
@@ -327,13 +351,13 @@ def infer_vector(self, items: Tuple[List[str], List[dict], str, dict],
327351 return self .t2v .infer_vector (inputs , * args , ** kwargs ), self .t2v .infer_tokens (inputs , * args , ** kwargs )
328352
329353 @classmethod
330- def from_pretrained (cls , name , model_dir = MODEL_DIR , * args , ** kwargs ):
354+ def from_pretrained (cls , name , model_dir = MODEL_DIR , device = 'cpu' , * args , ** kwargs ):
331355 model_path = path_append (model_dir , get_pretrained_model_info (name )[0 ].split ('/' )[- 1 ], to_str = True )
332356 for i in [".tar.gz" , ".tar.bz2" , ".tar.bz" , ".tar.tgz" , ".tar" , ".tgz" , ".zip" , ".rar" ]:
333357 model_path = model_path .replace (i , "" )
334358 logger .info ("model_path: %s" % model_path )
335359 tokenizer_kwargs = {"tokenizer_config_dir" : model_path }
336- return cls ("elmo" , name , pretrained_t2v = True , model_dir = model_dir ,
360+ return cls ("elmo" , name , pretrained_t2v = True , model_dir = model_dir , device = device ,
337361 tokenizer_kwargs = tokenizer_kwargs )
338362
339363
@@ -386,17 +410,19 @@ def infer_vector(self, items: Tuple[List[str], List[dict], str, dict],
386410 --------
387411 vector:list
388412 """
413+ is_batch = isinstance (items , list )
414+ items = items if is_batch else [items ]
389415 inputs = self .tokenize (items , key = key , return_tensors = return_tensors )
390416 return self .t2v .infer_vector (inputs , * args , ** kwargs ), self .t2v .infer_tokens (inputs , * args , ** kwargs )
391417
392418 @classmethod
393- def from_pretrained (cls , name , model_dir = MODEL_DIR , * args , ** kwargs ):
419+ def from_pretrained (cls , name , model_dir = MODEL_DIR , device = 'cpu' , * args , ** kwargs ):
394420 model_path = path_append (model_dir , get_pretrained_model_info (name )[0 ].split ('/' )[- 1 ], to_str = True )
395421 for i in [".tar.gz" , ".tar.bz2" , ".tar.bz" , ".tar.tgz" , ".tar" , ".tgz" , ".zip" , ".rar" ]:
396422 model_path = model_path .replace (i , "" )
397423 logger .info ("model_path: %s" % model_path )
398424 tokenizer_kwargs = {"tokenizer_config_dir" : model_path }
399- return cls ("bert" , name , pretrained_t2v = True , model_dir = model_dir ,
425+ return cls ("bert" , name , pretrained_t2v = True , model_dir = model_dir , device = device ,
400426 tokenizer_kwargs = tokenizer_kwargs )
401427
402428
@@ -452,7 +478,7 @@ def infer_vector(self, items: Tuple[List[str], List[dict], str, dict],
452478 return i_vec , t_vec
453479
454480 @classmethod
455- def from_pretrained (cls , name , model_dir = MODEL_DIR , ** kwargs ):
481+ def from_pretrained (cls , name , model_dir = MODEL_DIR , device = 'cpu' , ** kwargs ):
456482 model_path = path_append (model_dir , get_pretrained_model_info (name )[0 ].split ('/' )[- 1 ], to_str = True )
457483 for i in [".tar.gz" , ".tar.bz2" , ".tar.bz" , ".tar.tgz" , ".tar" , ".tgz" , ".zip" , ".rar" ]:
458484 model_path = model_path .replace (i , "" )
@@ -461,7 +487,7 @@ def from_pretrained(cls, name, model_dir=MODEL_DIR, **kwargs):
461487 tokenizer_kwargs = {
462488 "tokenizer_config_dir" : model_path ,
463489 }
464- return cls ("disenq" , name , pretrained_t2v = True , model_dir = model_dir ,
490+ return cls ("disenq" , name , pretrained_t2v = True , model_dir = model_dir , device = device ,
465491 tokenizer_kwargs = tokenizer_kwargs , ** kwargs )
466492
467493
@@ -495,18 +521,20 @@ def infer_vector(self, items: Tuple[List[str], List[dict], str, dict],
495521 token embeddings
496522 question embedding
497523 """
524+ is_batch = isinstance (items , list )
525+ items = items if is_batch else [items ]
498526 encodes = self .tokenize (items , key = key , meta = meta , * args , ** kwargs )
499527 return self .t2v .infer_vector (encodes ), self .t2v .infer_tokens (encodes )
500528
501529 @classmethod
502- def from_pretrained (cls , name , model_dir = MODEL_DIR , * args , ** kwargs ):
530+ def from_pretrained (cls , name , model_dir = MODEL_DIR , device = 'cpu' , * args , ** kwargs ):
503531 model_path = path_append (model_dir , get_pretrained_model_info (name )[0 ].split ('/' )[- 1 ], to_str = True )
504532 for i in [".tar.gz" , ".tar.bz2" , ".tar.bz" , ".tar.tgz" , ".tar" , ".tgz" , ".zip" , ".rar" ]:
505533 model_path = model_path .replace (i , "" )
506534 logger .info ("model_path: %s" % model_path )
507535 tokenizer_kwargs = {
508536 "tokenizer_config_dir" : model_path }
509- return cls ("quesnet" , name , pretrained_t2v = True , model_dir = model_dir ,
537+ return cls ("quesnet" , name , pretrained_t2v = True , model_dir = model_dir , device = device ,
510538 tokenizer_kwargs = tokenizer_kwargs )
511539
512540
@@ -520,7 +548,7 @@ def from_pretrained(cls, name, model_dir=MODEL_DIR, *args, **kwargs):
520548}
521549
522550
523- def get_pretrained_i2v (name , model_dir = MODEL_DIR ):
551+ def get_pretrained_i2v (name , model_dir = MODEL_DIR , device = 'cpu' ):
524552 """
525553 It is a good idea if you want to switch item to vector earily.
526554
@@ -560,4 +588,4 @@ def get_pretrained_i2v(name, model_dir=MODEL_DIR):
560588 )
561589 _ , t2v = get_pretrained_model_info (name )
562590 _class , * params = MODEL_MAP [t2v ], name
563- return _class .from_pretrained (* params , model_dir = model_dir )
591+ return _class .from_pretrained (* params , model_dir = model_dir , device = device )
0 commit comments