Skip to content

Conversation

@wenyi-guo
Copy link
Collaborator

support jax2tf in JaxLayer for tf backend by convert jax functions to tf.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @wenyi-guo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the interoperability of Keras by extending JaxLayer and FlaxLayer to fully support the TensorFlow backend. By integrating jax2tf conversion, it enables JAX-defined models and functions to be seamlessly executed within a TensorFlow environment, providing greater flexibility for developers to leverage the strengths of both frameworks. The changes involve robust conversion mechanisms, refined state management, and updated testing to ensure compatibility and correctness.

Highlights

  • TensorFlow Backend Support for JaxLayer: The JaxLayer now officially supports the TensorFlow backend, enabling JAX functions to be converted and run within a TensorFlow environment using jax2tf.
  • JAX to TensorFlow Conversion Integration: New internal methods (_get_jax2tf_input_shape, _jax2tf_convert, _partial_with_positional) have been added to JaxLayer to facilitate the conversion of JAX functions to TensorFlow operations, including handling polymorphic shapes and positional arguments.
  • Output Shape Computation Flexibility: A compute_output_shape_fn argument has been introduced to both JaxLayer and FlaxLayer constructors, allowing users to provide a custom function for determining the layer's output shape.
  • Improved Random Number Generation: The internal random number generation mechanism in JaxLayer has been updated to use jax.random.PRNGKey and a new _split_jax_rng method, replacing the previous seed_generator.
  • Enhanced State and Variable Handling: The _create_variables method has been refined to correctly handle jax.Array types and to return variables directly when the TensorFlow backend is active, ensuring proper state management across backends.
  • FlaxLayer TensorFlow Backend Compatibility: The FlaxLayer has also been updated to remove its JAX-only backend restriction and now passes the compute_output_shape_fn to its base JaxLayer for TensorFlow backend compatibility.
  • Updated Test Suite: The test suite for JaxLayer and FlaxLayer has been modified to run tests against both JAX and TensorFlow backends, use keras.src.ops and keras.src.random for data generation, and include run_eagerly=True for TensorFlow tests.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for jax2tf within JaxLayer, enabling JAX models to run efficiently with the TensorFlow backend. The changes involve adapting the layer's initialization, random number generation, and the call method to correctly convert JAX functions to TensorFlow graphs. The addition of compute_output_shape_fn provides greater flexibility in defining output shapes. Overall, the implementation appears to be a valuable enhancement for interoperability, with good attention to detailed error messages and integration with existing Keras mechanisms.

@codecov-commenter
Copy link

codecov-commenter commented Nov 7, 2025

Codecov Report

❌ Patch coverage is 94.17476% with 6 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.68%. Comparing base (bea37c5) to head (396a94e).
⚠️ Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/utils/jax_layer.py 93.93% 1 Missing and 5 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21828      +/-   ##
==========================================
+ Coverage   82.66%   82.68%   +0.02%     
==========================================
  Files         577      577              
  Lines       59477    59558      +81     
  Branches     9329     9349      +20     
==========================================
+ Hits        49167    49246      +79     
+ Misses       7907     7906       -1     
- Partials     2403     2406       +3     
Flag Coverage Δ
keras 82.50% <89.32%> (+0.01%) ⬆️
keras-jax 63.26% <42.71%> (-0.05%) ⬇️
keras-numpy 57.49% <17.47%> (-0.06%) ⬇️
keras-openvino 34.32% <17.47%> (-0.02%) ⬇️
keras-tensorflow 64.40% <84.46%> (+0.27%) ⬆️
keras-torch 63.54% <17.47%> (-0.06%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants