|
| 1 | +import base64 |
| 2 | +import random |
| 3 | +from typing import List |
| 4 | + |
| 5 | +try: |
| 6 | + from unstructured.documents.elements import CompositeElement, ElementMetadata, Image |
| 7 | +except ImportError: |
| 8 | + raise ImportError( |
| 9 | + "Could not import unstructured package. " |
| 10 | + "Please install it with `pip install 'unstructured[pdf] @ git+https://github.com/clarifai/unstructured.git@support_clarifai_model'`." |
| 11 | + ) |
| 12 | + |
| 13 | +from clarifai.client.input import Inputs |
| 14 | +from clarifai.client.model import Model |
| 15 | + |
| 16 | +from .basetransform import BaseTransform |
| 17 | + |
| 18 | +SUMMARY_PROMPT = """You are an assistant tasked with summarizing images for retrieval. \ |
| 19 | + These summaries will be embedded and used to retrieve the raw image. \ |
| 20 | + Give a concise summary of the image that is well optimized for retrieval.""" |
| 21 | + |
| 22 | + |
| 23 | +class ImageSummarizer(BaseTransform): |
| 24 | + """ Summarizes image elements. """ |
| 25 | + |
| 26 | + def __init__(self, |
| 27 | + model_url: str = "https://clarifai.com/qwen/qwen-VL/models/qwen-VL-Chat", |
| 28 | + pat: str = None, |
| 29 | + prompt: str = SUMMARY_PROMPT): |
| 30 | + """Initializes an ImageSummarizer object. |
| 31 | +
|
| 32 | + Args: |
| 33 | + pat (str): Clarifai PAT. |
| 34 | + model_url (str): Model URL to use for summarization. |
| 35 | + prompt (str): Prompt to use for summarization. |
| 36 | + """ |
| 37 | + self.pat = pat |
| 38 | + self.model_url = model_url |
| 39 | + self.model = Model(url=model_url, pat=pat) |
| 40 | + self.summary_prompt = prompt |
| 41 | + |
| 42 | + def __call__(self, elements: List) -> List: |
| 43 | + """Applies the transformation. |
| 44 | +
|
| 45 | + Args: |
| 46 | + elements (List[str]): List of all elements. |
| 47 | +
|
| 48 | + Returns: |
| 49 | + List of transformed elements along with added summarized elements. |
| 50 | +
|
| 51 | + """ |
| 52 | + img_elements = [] |
| 53 | + for _, element in enumerate(elements): |
| 54 | + element.metadata.update(ElementMetadata.from_dict({'is_original': True})) |
| 55 | + if isinstance(element, Image): |
| 56 | + element.metadata.update( |
| 57 | + ElementMetadata.from_dict({ |
| 58 | + 'input_id': f'{random.randint(1000000, 99999999)}' |
| 59 | + })) |
| 60 | + img_elements.append(element) |
| 61 | + new_elements = self._summarize_image(img_elements) |
| 62 | + elements.extend(new_elements) |
| 63 | + return elements |
| 64 | + |
| 65 | + def _summarize_image(self, image_elements: List[Image]) -> List[CompositeElement]: |
| 66 | + """Summarizes an image element. |
| 67 | +
|
| 68 | + Args: |
| 69 | + image_elements (List[Image]): Image elements to summarize. |
| 70 | +
|
| 71 | + Returns: |
| 72 | + Summarized image elements list. |
| 73 | +
|
| 74 | + """ |
| 75 | + img_inputs = [] |
| 76 | + for element in image_elements: |
| 77 | + if not isinstance(element, Image): |
| 78 | + continue |
| 79 | + new_input_id = "summarize_" + element.metadata.input_id |
| 80 | + input_proto = Inputs.get_multimodal_input( |
| 81 | + input_id=new_input_id, |
| 82 | + image_bytes=base64.b64decode(element.metadata.image_base64), |
| 83 | + raw_text=self.summary_prompt) |
| 84 | + img_inputs.append(input_proto) |
| 85 | + resp = self.model.predict(img_inputs) |
| 86 | + del img_inputs |
| 87 | + |
| 88 | + new_elements = [] |
| 89 | + for i, output in enumerate(resp.outputs): |
| 90 | + summary = "" |
| 91 | + if image_elements[i].text: |
| 92 | + summary = image_elements[i].text |
| 93 | + summary = summary + " \n " + output.data.text.raw |
| 94 | + eid = image_elements[i].metadata.input_id |
| 95 | + meta_dict = {'source_input_id': eid, 'is_original': False} |
| 96 | + comp_element = CompositeElement( |
| 97 | + text=summary, |
| 98 | + metadata=ElementMetadata.from_dict(meta_dict), |
| 99 | + element_id="summarized_" + eid) |
| 100 | + new_elements.append(comp_element) |
| 101 | + |
| 102 | + return new_elements |
0 commit comments