keras and jax: same neural net, same weights but different outputs?

I am trying to replicate the most simple autoencoder in keras and jax and have trouble understanding why I get different results after training. Outside of training, static evaluation of both networks gives the same output given the same weights with linear or relu activations. For some reason, with leaky relu I don't even manage to get the same output given the same weights.

import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt
from keras.layers import Input, Dense
from keras.models import Model
from jax.experimental import optimizers
from jax import jit, vmap, value_and_grad
import jax.numpy as jnp
from copy import deepcopy

seed_value = 0
os.environ['PYTHONHASHSEED']=str(seed_value)
np.random.seed(seed_value)
tf.random.set_seed(seed_value)


_X = np.random.random((1000,2))
X = np.zeros((1000,10))
X[:,8:]=_X

#########################

### Keras autoencoder ###

#########################

bneck_size = 2
layer_sizes = [64,32,16,8]

#encoder
inputs = Input(shape=(X.shape[1],), name='encoder_input')
x = inputs
for size in layer_sizes:
    x = Dense(size, activation='relu',kernel_initializer='he_uniform')(x)
latent = Dense(bneck_size, activation=None, kernel_initializer='he_uniform', name='latent_vector')(x)
kencoder = Model(inputs, latent, name='encoder')

#decoder
latent_inputs = Input(shape=(bneck_size,), name='decoder_input')
x = latent_inputs
for size in layer_sizes[::-1]:
    x = Dense(size, activation='relu',kernel_initializer='he_uniform')(x)
outputs = Dense(X.shape[1],activation=None,kernel_initializer='he_uniform',name='decoder_output')(x)
kdecoder = Model(latent_inputs, outputs, name='decoder')

#autoencoder
kautoencoder = Model(inputs, kdecoder(kencoder(inputs)), name='autoencoder')
kinit_weights = deepcopy(kautoencoder.get_weights())
kautoencoder.compile(loss='mse', optimizer=tf.optimizers.Adam(learning_rate=0.001,beta_1=0.9,beta_2=0.999,epsilon=1e-07))
kautoencoder.fit(x=X,y=X,batch_size=5,epochs=2,verbose=1,shuffle=False)




#########################

### Jax autoencoder ###

#########################
def lrelu_layer(X,w,b):  return jax.nn.leaky_relu( jnp.dot(X,w) + b  ) 
def relu_layer(X,w,b):   return jax.nn.relu( jnp.dot(X,w) + b  ) 
def linear_layer(X,w,b): return jnp.dot(X,w) + b

def encoder(encoder_params,activations):
    for w, b in encoder_params[:-1]: activations = relu_layer(activations,w,b)
    # bottleneck
    w, b = encoder_params[-1]
    activations = linear_layer(activations,w,b) 
    return activations

def decoder(decoder_params,activations):
    for w, b in decoder_params[:-1]: activations = relu_layer(activations,w,b)
    # linear output layer
    w, b = decoder_params[-1]
    outputs = linear_layer(activations,w,b) 
    return outputs

def ae(params, inputs):    
    bneck_idx = np.argmin([len(p[1]) for p in params])
    activations = encoder(params[:bneck_idx+1],inputs)    
    outputs = decoder(params[bneck_idx+1:],activations)    
    return outputs 

def mse_loss(params, batch):
    preds = ae(params, batch)
    return jnp.mean(jnp.square(preds - batch))

@jit
def update(params, X, opt_state):
    """ Compute the gradient for a batch and update the parameters """
    value, grads = value_and_grad(mse_loss)(params, X)
    opt_state = opt_update(0, grads, opt_state)
    return get_params(opt_state), opt_state, value

v_encoder = jit(vmap(encoder, in_axes=(None, 0)))
v_ae = jit(vmap(ae, in_axes=(None, 0)))

init_params = [(w,b) for w,b in zip(kinit_weights[0::2],kinit_weights[1::2])]
params = deepcopy(init_params)

opt_init, opt_update, get_params = optimizers.adam(step_size=0.001,b1=0.9, b2=0.999, eps=1e-07)
opt_state = opt_init(init_params)

# train
batch_size = 5
n_batches = len(X)//batch_size
n_epochs = 2

for e in range(n_epochs):
    for i in range(n_batches):
        batch_s, batch_e = i*batch_size, (i+1)*batch_size
        params, opt_state, loss = update(params, X[batch_s:batch_e], opt_state)
        
        


kend_weights = kautoencoder.get_weights()
kend_params = [(w,b) for w,b in zip(kend_weights[0::2],kend_weights[1::2])]

Y=kencoder.predict(X)
Y2=v_encoder(params[:5],X)
Y3=v_encoder(kend_params[:5],X)

f,axs = plt.subplots(1,3,figsize=(15,5),sharey=True,sharex=True)
axs[0].scatter(Y[:, 0], Y[:, 1],s=1)  ;axs[0].set_title('Keras bottleneck')
axs[1].scatter(Y2[:, 0], Y2[:, 1],s=1);axs[1].set_title('Jax bottleneck')
axs[2].scatter(Y3[:, 0], Y3[:, 1],s=1);axs[2].set_title('Jax bottleneck, given keras weights')
How many English words
do you know?
Test your English vocabulary size, and measure
how many words do you know
Online Test
Powered by Examplum