|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +import os |
| 3 | +import torch |
| 4 | +from modelcache.embedding.base import BaseEmbedding |
| 5 | +from modelscope.utils.constant import Tasks |
| 6 | +from modelscope.pipelines import pipeline |
| 7 | +from modelscope.preprocessors.image import load_image |
| 8 | + |
| 9 | + |
| 10 | +# def mean_pooling(model_output, attention_mask): |
| 11 | +# token_embeddings = model_output[0] # First element of model_output contains all token embeddings |
| 12 | +# input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| 13 | +# return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
| 14 | + |
| 15 | + |
| 16 | +class ClipAudio(BaseEmbedding): |
| 17 | + def __init__(self, model: str = "sentence-transformers/all-MiniLM-L6-v2"): |
| 18 | + # current_dir = os.path.dirname(os.path.abspath(__file__)) |
| 19 | + # parent_dir = os.path.dirname(current_dir) |
| 20 | + # model_dir = os.path.dirname(parent_dir) |
| 21 | + # model = os.path.join(model_dir, 'model/text2vec-base-chinese/') |
| 22 | + |
| 23 | + self.clip_pipeline = pipeline(task=Tasks.multi_modal_embedding, |
| 24 | + model='damo/multi-modal_clip-vit-base-patch16_zh', model_revision='v1.0.1') |
| 25 | + |
| 26 | + self.__dimension = 1024 |
| 27 | + |
| 28 | + def to_embeddings(self, data_dict, **_): |
| 29 | + text_list = data_dict['text'] |
| 30 | + image_data = data_dict['image'] |
| 31 | + |
| 32 | + img_data = None |
| 33 | + txt_data = None |
| 34 | + |
| 35 | + if image_data: |
| 36 | + input_img = load_image(image_data) |
| 37 | + # 2D Tensor, [图片数, 特征维度] |
| 38 | + img_embedding = self.clip_pipeline.forward({'img': input_img})['img_embedding'].tolist()[0] if input_img else [] |
| 39 | + print('img_embedding: {}'.format(img_embedding)) |
| 40 | + else: |
| 41 | + raise ValueError('image_data is None, please check!') |
| 42 | + |
| 43 | + if text_list and len(text_list) > 0: |
| 44 | + # 2D Tensor, [文本数, 特征维度] |
| 45 | + text_embedding = self.clip_pipeline.forward({'text': text_list})['text_embedding'].tolist()[0] if text_list else [] |
| 46 | + print('text_embedding: {}'.format(text_embedding)) |
| 47 | + else: |
| 48 | + raise ValueError('text_list is None, please check!') |
| 49 | + |
| 50 | + return {'image_embedding': img_embedding, 'text_embeddings': text_embedding} |
| 51 | + |
| 52 | + # return {'image_embedding': img_feats, 'text_embeddings': txt_feats} |
| 53 | + # input_texts = ["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"] |
| 54 | + # input_img = load_image( |
| 55 | + # 'https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg') |
| 56 | + |
| 57 | + # img_embedding = self.clip_pipeline.forward({'img': input_img})['img_embedding'] # 2D Tensor, [图片数, 特征维度] |
| 58 | + # print('img_embedding: {}'.format(img_embedding)) |
| 59 | + # text_embedding = self.clip_pipeline.forward({'text': input_texts})['text_embedding'] # 2D Tensor, [文本数, 特征维度] |
| 60 | + |
| 61 | + |
| 62 | + # return embedding_array |
| 63 | + |
| 64 | + def post_proc(self, token_embeddings, inputs): |
| 65 | + attention_mask = inputs["attention_mask"] |
| 66 | + input_mask_expanded = ( |
| 67 | + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| 68 | + ) |
| 69 | + sentence_embs = torch.sum( |
| 70 | + token_embeddings * input_mask_expanded, 1 |
| 71 | + ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
| 72 | + return sentence_embs |
| 73 | + |
| 74 | + @property |
| 75 | + def dimension(self): |
| 76 | + """Embedding dimension. |
| 77 | +
|
| 78 | + :return: embedding dimension |
| 79 | + """ |
| 80 | + return self.__dimension |
| 81 | + |
| 82 | + |
| 83 | +# if __name__ == '__main__': |
| 84 | +# clip_vec = ClipAudio() |
| 85 | +# text_list = ['hello', '你好'] |
| 86 | +# text = ['###'.join(text_list)] |
| 87 | +# image = 'https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg' |
| 88 | +# data_dict = {'text': text, 'image': image} |
| 89 | +# resp = clip_vec.to_embeddings(data_dict) |
| 90 | +# print('resp: {}'.format(resp)) |
0 commit comments