Skip to content

Commit 0512fdb

Browse files
authored
Fix: Support 'jpg' format in keras.utils.save_img() (#21683)
* Fix save_img(): support 'jpg' format and handle RGBA * Fix save_img: support 'jpg' format (normalize to JPEG) and add tests * Fix save_img: normalize jpg→jpeg, handle RGBA→RGB, and improve tests * Regenerate API directory * Add save_img to image_utils.py and integration tests for JPG/RGBA handling * style: fix formatting with format.sh * Simplify save_img: remove _format, normalize jpg→jpeg, add RGBA→RGB handling and tests * Simplify save_img: remove _format, normalize jpg→jpeg, add RGBA→RGB handling and tests * Simplify save_img: remove _format, normalize jpg→jpeg, add RGBA→RGB handling and tests * fix: use save_format variable to avoid modifying file_format parameter * fix: handle .jpg format without renaming file, improve tests * Move save_img tests to unit tests and convert to Keras TestCase * Move save_img tests to unit tests and convert to Keras TestCase * Move save_img tests to unit tests and convert to Keras TestCase * Move save_img tests to unit tests and convert to Keras TestCase * fix: simplify file format handling and add inferred format test * fix: simplify file format handling and add inferred format test * fix: simplify file format handling and add inferred format test
1 parent 3882d1d commit 0512fdb

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

keras/src/utils/image_utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,24 @@ def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs):
175175
**kwargs: Additional keyword arguments passed to `PIL.Image.save()`.
176176
"""
177177
data_format = backend.standardize_data_format(data_format)
178+
179+
# Infer format from path if not explicitly provided
180+
if file_format is None and isinstance(path, (str, pathlib.Path)):
181+
file_format = pathlib.Path(path).suffix[1:].lower()
182+
183+
# Normalize jpg → jpeg for Pillow compatibility
184+
if file_format and file_format.lower() == "jpg":
185+
file_format = "jpeg"
186+
178187
img = array_to_img(x, data_format=data_format, scale=scale)
179-
if img.mode == "RGBA" and (file_format == "jpg" or file_format == "jpeg"):
188+
189+
# Handle RGBA → RGB conversion for JPEG
190+
if img.mode == "RGBA" and file_format == "jpeg":
180191
warnings.warn(
181-
"The JPG format does not support RGBA images, converting to RGB."
192+
"The JPEG format does not support RGBA images, converting to RGB."
182193
)
183194
img = img.convert("RGB")
195+
184196
img.save(path, format=file_format, **kwargs)
185197

186198

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import os
2+
3+
import numpy as np
4+
from absl.testing import parameterized
5+
6+
from keras.src import testing
7+
from keras.src.utils import img_to_array
8+
from keras.src.utils import load_img
9+
from keras.src.utils import save_img
10+
11+
12+
class SaveImgTest(testing.TestCase, parameterized.TestCase):
13+
@parameterized.named_parameters(
14+
("rgb_explicit_format", (50, 50, 3), "rgb.jpg", "jpg", True),
15+
("rgba_explicit_format", (50, 50, 4), "rgba.jpg", "jpg", True),
16+
("rgb_inferred_format", (50, 50, 3), "rgb_inferred.jpg", None, False),
17+
("rgba_inferred_format", (50, 50, 4), "rgba_inferred.jpg", None, False),
18+
)
19+
def test_save_jpg(self, shape, name, file_format, use_explicit_format):
20+
tmp_dir = self.get_temp_dir()
21+
path = os.path.join(tmp_dir, name)
22+
23+
img = np.random.randint(0, 256, size=shape, dtype=np.uint8)
24+
25+
# Test the actual inferred case - don't pass file_format at all
26+
if use_explicit_format:
27+
save_img(path, img, file_format=file_format)
28+
else:
29+
save_img(path, img) # Let it infer from path
30+
31+
self.assertTrue(os.path.exists(path))
32+
33+
# Verify saved image is correctly converted to RGB if needed
34+
loaded_img = load_img(path)
35+
loaded_array = img_to_array(loaded_img)
36+
self.assertEqual(loaded_array.shape, (50, 50, 3))

0 commit comments

Comments
 (0)