Skip to content

Commit 17a6230

Browse files
authored
Merge pull request #31 from Clarifai/multimodal_ingest
[DEVX-828] Added image summarization in multimodal pipeline
2 parents 7d0d53f + e3a9f13 commit 17a6230

File tree

5 files changed

+567
-3
lines changed

5 files changed

+567
-3
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ coverage.xml
5151
.hypothesis/
5252
.pytest_cache/
5353

54+
testing/*
55+
!testing/test.ipynb
56+
5457
# Translations
5558
*.mo
5659
*.pot

clarifai_datautils/multimodal/pipeline/loaders.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __getitem__(self, index: int):
2727
meta.pop('coordinates', None)
2828
meta.pop('detection_class_prob', None)
2929
image_data = meta.pop('image_base64', None)
30+
id = meta.get('input_id', None)
3031
if image_data is not None:
3132
# Ensure image_data is already bytes before encoding
3233
image_data = base64.b64decode(image_data)
@@ -39,7 +40,7 @@ def __getitem__(self, index: int):
3940
meta['type'] = 'table'
4041

4142
return MultiModalFeatures(
42-
text=text, image_bytes=image_data, labels=[self.pipeline_name], metadata=meta)
43+
text=text, image_bytes=image_data, labels=[self.pipeline_name], metadata=meta, id=id)
4344

4445
def __len__(self):
4546
return len(self.elements)
@@ -61,10 +62,13 @@ def task(self):
6162
return DATASET_UPLOAD_TASKS.TEXT_CLASSIFICATION #TODO: Better dataset name in SDK
6263

6364
def __getitem__(self, index: int):
65+
id = self.elements[index].to_dict().get('element_id', None)
66+
id = id[:48] if id is not None else None
6467
return TextFeatures(
6568
text=self.elements[index].text,
6669
labels=self.pipeline_name,
67-
metadata=self.elements[index].metadata.to_dict())
70+
metadata=self.elements[index].metadata.to_dict(),
71+
id=id)
6872

6973
def __len__(self):
7074
return len(self.elements)
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import base64
2+
import random
3+
from typing import List
4+
5+
try:
6+
from unstructured.documents.elements import CompositeElement, ElementMetadata, Image
7+
except ImportError:
8+
raise ImportError(
9+
"Could not import unstructured package. "
10+
"Please install it with `pip install 'unstructured[pdf] @ git+https://github.com/clarifai/unstructured.git@support_clarifai_model'`."
11+
)
12+
13+
from clarifai.client.input import Inputs
14+
from clarifai.client.model import Model
15+
16+
from .basetransform import BaseTransform
17+
18+
SUMMARY_PROMPT = """You are an assistant tasked with summarizing images for retrieval. \
19+
These summaries will be embedded and used to retrieve the raw image. \
20+
Give a concise summary of the image that is well optimized for retrieval."""
21+
22+
23+
class ImageSummarizer(BaseTransform):
24+
""" Summarizes image elements. """
25+
26+
def __init__(self,
27+
model_url: str = "https://clarifai.com/qwen/qwen-VL/models/qwen-VL-Chat",
28+
pat: str = None,
29+
prompt: str = SUMMARY_PROMPT):
30+
"""Initializes an ImageSummarizer object.
31+
32+
Args:
33+
pat (str): Clarifai PAT.
34+
model_url (str): Model URL to use for summarization.
35+
prompt (str): Prompt to use for summarization.
36+
"""
37+
self.pat = pat
38+
self.model_url = model_url
39+
self.model = Model(url=model_url, pat=pat)
40+
self.summary_prompt = prompt
41+
42+
def __call__(self, elements: List) -> List:
43+
"""Applies the transformation.
44+
45+
Args:
46+
elements (List[str]): List of all elements.
47+
48+
Returns:
49+
List of transformed elements along with added summarized elements.
50+
51+
"""
52+
img_elements = []
53+
for _, element in enumerate(elements):
54+
element.metadata.update(ElementMetadata.from_dict({'is_original': True}))
55+
if isinstance(element, Image):
56+
element.metadata.update(
57+
ElementMetadata.from_dict({
58+
'input_id': f'{random.randint(1000000, 99999999)}'
59+
}))
60+
img_elements.append(element)
61+
new_elements = self._summarize_image(img_elements)
62+
elements.extend(new_elements)
63+
return elements
64+
65+
def _summarize_image(self, image_elements: List[Image]) -> List[CompositeElement]:
66+
"""Summarizes an image element.
67+
68+
Args:
69+
image_elements (List[Image]): Image elements to summarize.
70+
71+
Returns:
72+
Summarized image elements list.
73+
74+
"""
75+
img_inputs = []
76+
for element in image_elements:
77+
if not isinstance(element, Image):
78+
continue
79+
new_input_id = "summarize_" + element.metadata.input_id
80+
input_proto = Inputs.get_multimodal_input(
81+
input_id=new_input_id,
82+
image_bytes=base64.b64decode(element.metadata.image_base64),
83+
raw_text=self.summary_prompt)
84+
img_inputs.append(input_proto)
85+
resp = self.model.predict(img_inputs)
86+
del img_inputs
87+
88+
new_elements = []
89+
for i, output in enumerate(resp.outputs):
90+
summary = ""
91+
if image_elements[i].text:
92+
summary = image_elements[i].text
93+
summary = summary + " \n " + output.data.text.raw
94+
eid = image_elements[i].metadata.input_id
95+
meta_dict = {'source_input_id': eid, 'is_original': False}
96+
comp_element = CompositeElement(
97+
text=summary,
98+
metadata=ElementMetadata.from_dict(meta_dict),
99+
element_id="summarized_" + eid)
100+
new_elements.append(comp_element)
101+
102+
return new_elements

testing/test.ipynb

Lines changed: 425 additions & 0 deletions
Large diffs are not rendered by default.

tests/pipelines/test_multimodal_pipelines.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os.path as osp
2-
32
import pytest
43

54
PDF_FILE_PATH = osp.abspath(
@@ -66,3 +65,34 @@ def test_pipeline_run_loader(self,):
6665
assert elements.__class__.__name__ == 'MultiModalLoader'
6766
assert len(elements) == 14
6867
assert elements.elements[0].metadata.to_dict()['filename'] == 'Multimodal_sample_file.pdf'
68+
69+
def test_pipeline_summarize(self,):
70+
"""Tests for pipeline run with summarizer"""
71+
import os
72+
73+
from clarifai_datautils.multimodal import Pipeline
74+
from clarifai_datautils.multimodal.pipeline.cleaners import Clean_extra_whitespace
75+
from clarifai_datautils.multimodal.pipeline.PDF import PDFPartitionMultimodal
76+
from clarifai_datautils.multimodal.pipeline.summarizer import ImageSummarizer
77+
78+
pipeline = Pipeline(
79+
name='pipeline-1',
80+
transformations=[
81+
PDFPartitionMultimodal(chunking_strategy="by_title", max_characters=1024),
82+
Clean_extra_whitespace(),
83+
ImageSummarizer(pat=os.environ.get("CLARIFAI_PAT"))
84+
])
85+
elements = pipeline.run(files=PDF_FILE_PATH, loader=False)
86+
87+
assert len(elements) == 17
88+
assert isinstance(elements, list)
89+
assert elements[0].metadata.to_dict()['filename'] == 'Multimodal_sample_file.pdf'
90+
assert elements[0].metadata.to_dict()['page_number'] == 1
91+
assert elements[6].__class__.__name__ == 'Table'
92+
assert elements[-3].__class__.__name__ == 'Image'
93+
assert elements[-3].metadata.is_original is True
94+
assert elements[-3].metadata.input_id is not None
95+
id = elements[-3].metadata.input_id
96+
assert elements[-1].__class__.__name__ == 'CompositeElement'
97+
assert elements[-1].metadata.is_original is False
98+
assert elements[-1].metadata.source_input_id == id

0 commit comments

Comments
 (0)