Manoj Rao bio photo

Manoj Rao

Your Average Common Man

Email Twitter Github

Hello Folks,

I am back after a short break. So let's dive in. I have spent a bit building / playing with MLIR based ML Compilers. I will chronicle a bit of my experience with JAX as a framework. JAX is regarded as the framework of AGI.

Below are the raw instructions, copied from Jupyter Notebook where I ran the experiments.

!pip install -qqq jaxtyping hypothesis pytest penzai

import jax.numpy as np
import numpy as onp
from penzai import pz
arange = pz.nx.arange
where = pz.nx.nmap(np.where)
wrap = pz.nx.wrap
pz.ts.register_as_default()

# Optional automatic array visualization extras:
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()
pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer(force_continuous=True, around_zero=True,  prefers_column=["j"], prefers_row=["i"]))
import os
# GPU flags
flags = (
    "--xla_gpu_enable_triton_softmax_fusion=true "
    "--xla_gpu_triton_gemm_any=false "
    "--xla_gpu_enable_async_collectives=true "
    "--xla_gpu_enable_latency_hiding_scheduler=true "
    "--xla_gpu_enable_highest_priority_async_stream=true "
)
os.environ["XLA_FLAGS"] = flags
!nvidia-smi
Mon May 27 13:31:30 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.29.06              Driver Version: 545.29.06    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 4090        Off | 00000000:09:00.0  On |                  Off |
|  0%   47C    P2              58W / 450W |  19106MiB / 24564MiB |      1%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      1195      G   /usr/lib/xorg/Xorg                          259MiB |
|    0   N/A  N/A      2130      G   /usr/lib/firefox/firefox                      0MiB |
|    0   N/A  N/A    477865      C   /home/mycpuorg/miniconda3/bin/python3.11     18404MiB |
+---------------------------------------------------------------------------------------+


/home/mycpuorg/miniconda3/lib/python3.11/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  pid, fd = os.forkpty()
import jax
import jax.numpy as jnp

# Define a simple function - test
def test(x):
  return jnp.sin(x) * x**2


# Pretty-print the function as a Jaxpr
print(jax.make_jaxpr(test)(1.0))
{ lambda ; a:f32[]. let
    b:f32[] = sin a
    c:f32[] = integer_pow[y=2] a
    d:f32[] = mul b c
  in (d,) }


# Lower the function
lowered = jax.jit(test).lower(1.0)

# Print the lowered IR as text
print(lowered.as_text())

# Get the StableHLO representation of the lowered IR 
print(lowered.compiler_ir(dialect="stablehlo"))

module @jit_test attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"}) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.sine %arg0 : tensor<f32>
    %1 = stablehlo.multiply %arg0, %arg0 : tensor<f32>
    %2 = stablehlo.multiply %0, %1 : tensor<f32>
    return %2 : tensor<f32>
  }
}

module @jit_test attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"}) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.sine %arg0 : tensor<f32>
    %1 = stablehlo.multiply %arg0, %arg0 : tensor<f32>
    %2 = stablehlo.multiply %0, %1 : tensor<f32>
    return %2 : tensor<f32>
  }
}

# Compile the function
compiled = lowered.compile()

# Print the compiled IR as text
print(compiled.as_text())

# Get a summary of execution costs
print(compiled.cost_analysis())

HloModule jit_test, is_scheduled=true, entry_computation_layout={(f32[])->f32[]}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="9aa3c44fdb740a754eb15d47d89c333e"}

%fused_multiply (param_0.2: f32[]) -> f32[] {
  %param_0.2 = f32[] parameter(0)
  %sine.1.1 = f32[] sine(f32[] %param_0.2), metadata={op_name="jit(test)/jit(main)/sin" source_file="/tmp/ipykernel_477865/1530176408.py" source_line=6}
  %multiply.2.1 = f32[] multiply(f32[] %param_0.2, f32[] %param_0.2), metadata={op_name="jit(test)/jit(main)/integer_pow[y=2]" source_file="/tmp/ipykernel_477865/1530176408.py" source_line=6}
  ROOT %multiply.5.1 = f32[] multiply(f32[] %sine.1.1, f32[] %multiply.2.1), metadata={op_name="jit(test)/jit(main)/mul" source_file="/tmp/ipykernel_477865/1530176408.py" source_line=6}
}

ENTRY %main.5 (Arg_0.1.0: f32[]) -> f32[] {
  %Arg_0.1.0 = f32[] parameter(0)
  ROOT %loop_multiply_fusion = f32[] fusion(f32[] %Arg_0.1.0), kind=kLoop, calls=%fused_multiply, metadata={op_name="jit(test)/jit(main)/mul" source_file="/tmp/ipykernel_477865/1530176408.py" source_line=6}
}


[{'utilization0{}': 1.0, 'bytes accessed1{}': 8.0, 'bytes accessedout{}': 4.0, 'bytes accessed': 8.0, 'transcendentals': 1.0, 'bytes accessed0{}': 4.0, 'utilization1{}': 2.0, 'flops': 2.0}]
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
# A helper function to randomly initialize weights and biases
# for dense neural network layer

def random_layer_params(m, n, key, scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n,m)), scale * random.normal(b_key, (n,))

# Intialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.key(0))
from jax.scipy.special import logsumexp

def relu(x):
    return jnp.maximum(0, x)

def predict(params, image):
    # per-example predictions
    activations = image
    for w, b in params[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)

    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

random_flattened_image = random.normal(random.key(1), (28*28,))
print(random_flattened_image.shape)
print(len(params))
preds = predict(params, random_flattened_image)
print(preds.shape)
(784,)
3
(10,)
batched_random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
try:
    preds = predict(params, batched_random_flattened_images)
except TypeError:
    print("Invalid Shapes!")

Invalid Shapes!
batched_predict = vmap(predict, in_axes=(None, 0))
batched_preds = batched_predict(params, batched_random_flattened_images)
print(batched_preds.shape)
(10, 10)

Loss Function and other utils

def one_hot(x, k, dtype=jnp.float32):
    return jnp.array(x[:, None] == jnp.arange(k), dtype)

def accuracy(params, images, targets):
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
    return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
    preds = batched_predict(params, images)
    return -(jnp.mean(preds * targets))

@jit
def update(params, x, y):
    grads = grad(loss)(params, x, y)
    return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]



import numpy as np
from jax.tree_util import tree_map
from torchvision.datasets import MNIST
from torch.utils import data
/home/mycpuorg/miniconda3/lib/python3.11/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/home/mycpuorg/miniconda3/lib/python3.11/site-packages/torchvision/image.so: undefined symbol: _ZN3c1017RegisterOperatorsD1Ev'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(
/home/mycpuorg/miniconda3/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
def numpy_collate(batch):
    return tree_map(np.asarray, data.default_collate(batch))

class NumpyLoader(data.DataLoader):
    def __init__(self, dataset, batch_size=1,
                 shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0,
                 pin_memory=False, drop_last=False,
                 timeout=0, worker_init_fn=None):
        super(self.__class__, self).__init__(dataset,
             batch_size=batch_size,
             shuffle=shuffle,
             sampler=sampler,
             batch_sampler=batch_sampler,
             num_workers=num_workers,
             collate_fn=numpy_collate,
             pin_memory=pin_memory,
             drop_last=drop_last,
             timeout=timeout,
             worker_init_fn=worker_init_fn)
    
class FlattenAndCast(object):
    def __call__(self, pic):
        return np.ravel(np.array(pic, dtype=jnp.float32))
            

             
# Define our dataset, using torch datasets
mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)
a = np.array(np.random.randn(3,2,4))
a.reshape(-1, 24)

b = jnp.array(a)
print(b.devices())
b.sharding

x = jnp.arange(5)
w = jnp.array([2., 3., 4.])

def convolve(x ,w):
    output = []
    for i in range(1, len(x) - 1):
        output.append(jnp.dot(x[i - 1 : i + 2], w))
    return jnp.array(output)

lowered_conv = jax.jit(convolve).lower(x, w)
print(lowered_conv.compiler_ir(dialect="stablehlo"))
# print(lowered_conv.compiler_ir(dialect="mhlo"))

# print(lowered_conv.compiler_ir(dialect="LMHLO"))

compiled_conv = lowered_conv.compile()
print(compiled_conv.as_text())
ca = compiled_conv.cost_analysis()
ca
{cuda(id=0)}
module @jit_convolve attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<5xi32> {mhlo.layout_mode = "default"}, %arg1: tensor<3xf32> {mhlo.layout_mode = "default"}) -> (tensor<3xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.slice %arg0 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
    %1 = stablehlo.convert %0 : (tensor<3xi32>) -> tensor<3xf32>
    %2 = stablehlo.convert %arg1 : tensor<3xf32>
    %3 = stablehlo.dot_general %1, %2, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %4 = stablehlo.slice %arg0 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
    %5 = stablehlo.convert %4 : (tensor<3xi32>) -> tensor<3xf32>
    %6 = stablehlo.convert %arg1 : tensor<3xf32>
    %7 = stablehlo.dot_general %5, %6, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %8 = stablehlo.slice %arg0 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
    %9 = stablehlo.convert %8 : (tensor<3xi32>) -> tensor<3xf32>
    %10 = stablehlo.convert %arg1 : tensor<3xf32>
    %11 = stablehlo.dot_general %9, %10, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %12 = stablehlo.broadcast_in_dim %3, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %13 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %14 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %15 = stablehlo.concatenate %12, %13, %14, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
    return %15 : tensor<3xf32>
  }
}

HloModule jit_convolve, is_scheduled=true, entry_computation_layout={(s32[5]{0}, f32[3]{0})->f32[3]{0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="f99fa116847aaf794a773015b66e3c7a"}

%scalar_add_computation (scalar_lhs: f32[], scalar_rhs: f32[]) -> f32[] {
  %scalar_rhs = f32[] parameter(1)
  %scalar_lhs = f32[] parameter(0)
  ROOT %add.2 = f32[] add(f32[] %scalar_lhs, f32[] %scalar_rhs)
}

%fused_concatenate (param_0.15: f32[3], param_1.19: s32[5]) -> f32[3] {
  %param_1.19 = s32[5]{0} parameter(1)
  %convert.2.5 = f32[5]{0} convert(s32[5]{0} %param_1.19), metadata={op_name="jit(convolve)/jit(main)/dot_general[dimension_numbers=(((0,), (0,)), ((), ())) precision=None preferred_element_type=float32]" source_file="/tmp/ipykernel_477865/1956054057.py" source_line=14}
  %slice.8.3 = f32[3]{0} slice(f32[5]{0} %convert.2.5), slice={[0:3]}, metadata={op_name="jit(convolve)/jit(main)/slice[start_indices=(0,) limit_indices=(3,) strides=None]" source_file="/tmp/ipykernel_477865/1956054057.py" source_line=14}
  %param_0.15 = f32[3]{0} parameter(0)
  %multiply.6.3 = f32[3]{0} multiply(f32[3]{0} %slice.8.3, f32[3]{0} %param_0.15)
  %constant_7 = f32[] constant(0)
  %reduce.4 = f32[] reduce(f32[3]{0} %multiply.6.3, f32[] %constant_7), dimensions={0}, to_apply=%scalar_add_computation, metadata={op_name="jit(convolve)/jit(main)/dot_general[dimension_numbers=(((0,), (0,)), ((), ())) precision=None preferred_element_type=float32]" source_file="/tmp/ipykernel_477865/1956054057.py" source_line=14}
  %bitcast.63.1 = f32[1]{0} bitcast(f32[] %reduce.4)
  %slice.10.3 = f32[3]{0} slice(f32[5]{0} %convert.2.5), slice={[1:4]}, metadata={op_name="jit(convolve)/jit(main)/slice[start_indices=(1,) limit_indices=(4,) strides=None]" source_file="/tmp/ipykernel_477865/1956054057.py" source_line=14}
  %multiply.7.3 = f32[3]{0} multiply(f32[3]{0} %slice.10.3, f32[3]{0} %param_0.15)
  %reduce.1.1 = f32[] reduce(f32[3]{0} %multiply.7.3, f32[] %constant_7), dimensions={0}, to_apply=%scalar_add_computation, metadata={op_name="jit(convolve)/jit(main)/dot_general[dimension_numbers=(((0,), (0,)), ((), ())) precision=None preferred_element_type=float32]" source_file="/tmp/ipykernel_477865/1956054057.py" source_line=14}
  %bitcast.75.1 = f32[1]{0} bitcast(f32[] %reduce.1.1)
  %slice.11.3 = f32[3]{0} slice(f32[5]{0} %convert.2.5), slice={[2:5]}, metadata={op_name="jit(convolve)/jit(main)/slice[start_indices=(2,) limit_indices=(5,) strides=None]" source_file="/tmp/ipykernel_477865/1956054057.py" source_line=14}
  %multiply.8.3 = f32[3]{0} multiply(f32[3]{0} %slice.11.3, f32[3]{0} %param_0.15)
  %reduce.2.1 = f32[] reduce(f32[3]{0} %multiply.8.3, f32[] %constant_7), dimensions={0}, to_apply=%scalar_add_computation, metadata={op_name="jit(convolve)/jit(main)/dot_general[dimension_numbers=(((0,), (0,)), ((), ())) precision=None preferred_element_type=float32]" source_file="/tmp/ipykernel_477865/1956054057.py" source_line=14}
  %bitcast.87.1 = f32[1]{0} bitcast(f32[] %reduce.2.1)
  ROOT %concatenate.1.1 = f32[3]{0} concatenate(f32[1]{0} %bitcast.63.1, f32[1]{0} %bitcast.75.1, f32[1]{0} %bitcast.87.1), dimensions={0}, metadata={op_name="jit(convolve)/jit(main)/concatenate[dimension=0]" source_file="/tmp/ipykernel_477865/1956054057.py" source_line=15}
}

ENTRY %main.16 (Arg_0.1.0: s32[5], Arg_1.2.0: f32[3]) -> f32[3] {
  %Arg_1.2.0 = f32[3]{0} parameter(1)
  %Arg_0.1.0 = s32[5]{0} parameter(0)
  ROOT %input_concatenate_fusion = f32[3]{0} fusion(f32[3]{0} %Arg_1.2.0, s32[5]{0} %Arg_0.1.0), kind=kInput, calls=%fused_concatenate, metadata={op_name="jit(convolve)/jit(main)/concatenate[dimension=0]" source_file="/tmp/ipykernel_477865/1956054057.py" source_line=15}
}
(Loading...)
xs = jnp.stack([x, x, x, x, x, x, x, x, x, x, x, x, x, x, x, x])
ws = jnp.stack([w, w, w, w, w, w, w, w, w, w, w, w, w, w, w, w])
def manual_batched_conv(xs, ws):
    output = []
    for i in range(xs.shape[0]):
        output.append(convolve(xs[i], ws[i]))
    return jnp.stack(output)

manual_batched_conv(xs, ws)
(Loading...)
auto_batch_conv = jax.vmap(convolve)

auto_batch_conv(xs, ws)
(Loading...)
lowered_batch_conv = jax.jit(auto_batch_conv, device=jax.devices()[0]).lower(xs, ws)
compiled_auto_batch_conv = lowered_batch_conv.compile()
print(compiled_auto_batch_conv.cost_analysis())
print(lowered_batch_conv.compiler_ir())
[{'bytes accessed': 704.0, 'bytes accessed0{}': 192.0, 'utilization1{}': 1.0, 'utilization2{}': 1.0, 'bytes accessed1{}': 320.0, 'bytes accessedout{}': 192.0, 'utilization0{}': 1.0, 'flops': 320.0, 'bytes accessed2{}': 64.0}]
module @jit_convolve attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<16x5xi32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg1: tensor<16x3xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<16x3xf32> {jax.result_info = "", mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) {
    %0 = stablehlo.slice %arg0 [0:16, 0:3] : (tensor<16x5xi32>) -> tensor<16x3xi32>
    %1 = stablehlo.convert %0 : (tensor<16x3xi32>) -> tensor<16x3xf32>
    %2 = stablehlo.convert %arg1 : tensor<16x3xf32>
    %3 = stablehlo.dot_general %1, %2, batching_dims = [0] x [0], contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<16x3xf32>, tensor<16x3xf32>) -> tensor<16xf32>
    %4 = stablehlo.slice %arg0 [0:16, 1:4] : (tensor<16x5xi32>) -> tensor<16x3xi32>
    %5 = stablehlo.convert %4 : (tensor<16x3xi32>) -> tensor<16x3xf32>
    %6 = stablehlo.convert %arg1 : tensor<16x3xf32>
    %7 = stablehlo.dot_general %5, %6, batching_dims = [0] x [0], contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<16x3xf32>, tensor<16x3xf32>) -> tensor<16xf32>
    %8 = stablehlo.slice %arg0 [0:16, 2:5] : (tensor<16x5xi32>) -> tensor<16x3xi32>
    %9 = stablehlo.convert %8 : (tensor<16x3xi32>) -> tensor<16x3xf32>
    %10 = stablehlo.convert %arg1 : tensor<16x3xf32>
    %11 = stablehlo.dot_general %9, %10, batching_dims = [0] x [0], contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<16x3xf32>, tensor<16x3xf32>) -> tensor<16xf32>
    %12 = stablehlo.broadcast_in_dim %3, dims = [0] : (tensor<16xf32>) -> tensor<16x1xf32>
    %13 = stablehlo.broadcast_in_dim %7, dims = [0] : (tensor<16xf32>) -> tensor<16x1xf32>
    %14 = stablehlo.broadcast_in_dim %11, dims = [0] : (tensor<16xf32>) -> tensor<16x1xf32>
    %15 = stablehlo.concatenate %12, %13, %14, dim = 1 : (tensor<16x1xf32>, tensor<16x1xf32>, tensor<16x1xf32>) -> tensor<16x3xf32>
    return %15 : tensor<16x3xf32>
  }
}
lowered_manual_batch_conv = jax.jit(manual_batched_conv, device=jax.devices()[0]).lower(xs, ws)
print(lowered_manual_batch_conv.compiler_ir())
module @jit_manual_batched_conv attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<16x5xi32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg1: tensor<16x3xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<16x3xf32> {jax.result_info = "", mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) {
    %0 = stablehlo.slice %arg0 [0:1, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
    %1 = stablehlo.reshape %0 : (tensor<1x5xi32>) -> tensor<5xi32>
    %2 = stablehlo.slice %arg1 [0:1, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
    %3 = stablehlo.reshape %2 : (tensor<1x3xf32>) -> tensor<3xf32>
    %4 = stablehlo.slice %1 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
    %5 = stablehlo.convert %4 : (tensor<3xi32>) -> tensor<3xf32>
    %6 = stablehlo.convert %3 : tensor<3xf32>
    %7 = stablehlo.dot_general %5, %6, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %8 = stablehlo.slice %1 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
    %9 = stablehlo.convert %8 : (tensor<3xi32>) -> tensor<3xf32>
    %10 = stablehlo.convert %3 : tensor<3xf32>
    %11 = stablehlo.dot_general %9, %10, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %12 = stablehlo.slice %1 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
    %13 = stablehlo.convert %12 : (tensor<3xi32>) -> tensor<3xf32>
    %14 = stablehlo.convert %3 : tensor<3xf32>
    %15 = stablehlo.dot_general %13, %14, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %16 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %17 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %18 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %19 = stablehlo.concatenate %16, %17, %18, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
    %20 = stablehlo.slice %arg0 [1:2, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
    %21 = stablehlo.reshape %20 : (tensor<1x5xi32>) -> tensor<5xi32>
    %22 = stablehlo.slice %arg1 [1:2, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
    %23 = stablehlo.reshape %22 : (tensor<1x3xf32>) -> tensor<3xf32>
    %24 = stablehlo.slice %21 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
    %25 = stablehlo.convert %24 : (tensor<3xi32>) -> tensor<3xf32>
    %26 = stablehlo.convert %23 : tensor<3xf32>
    %27 = stablehlo.dot_general %25, %26, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %28 = stablehlo.slice %21 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
    %29 = stablehlo.convert %28 : (tensor<3xi32>) -> tensor<3xf32>
    %30 = stablehlo.convert %23 : tensor<3xf32>
    %31 = stablehlo.dot_general %29, %30, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %32 = stablehlo.slice %21 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
    %33 = stablehlo.convert %32 : (tensor<3xi32>) -> tensor<3xf32>
    %34 = stablehlo.convert %23 : tensor<3xf32>
    %35 = stablehlo.dot_general %33, %34, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %36 = stablehlo.broadcast_in_dim %27, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %37 = stablehlo.broadcast_in_dim %31, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %38 = stablehlo.broadcast_in_dim %35, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %39 = stablehlo.concatenate %36, %37, %38, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
    %40 = stablehlo.slice %arg0 [2:3, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
    %41 = stablehlo.reshape %40 : (tensor<1x5xi32>) -> tensor<5xi32>
    %42 = stablehlo.slice %arg1 [2:3, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
    %43 = stablehlo.reshape %42 : (tensor<1x3xf32>) -> tensor<3xf32>
    %44 = stablehlo.slice %41 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
    %45 = stablehlo.convert %44 : (tensor<3xi32>) -> tensor<3xf32>
    %46 = stablehlo.convert %43 : tensor<3xf32>
    %47 = stablehlo.dot_general %45, %46, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %48 = stablehlo.slice %41 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
    %49 = stablehlo.convert %48 : (tensor<3xi32>) -> tensor<3xf32>
    %50 = stablehlo.convert %43 : tensor<3xf32>
    %51 = stablehlo.dot_general %49, %50, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %52 = stablehlo.slice %41 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
    %53 = stablehlo.convert %52 : (tensor<3xi32>) -> tensor<3xf32>
    %54 = stablehlo.convert %43 : tensor<3xf32>
    %55 = stablehlo.dot_general %53, %54, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %56 = stablehlo.broadcast_in_dim %47, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %57 = stablehlo.broadcast_in_dim %51, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %58 = stablehlo.broadcast_in_dim %55, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %59 = stablehlo.concatenate %56, %57, %58, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
    %60 = stablehlo.slice %arg0 [3:4, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
    %61 = stablehlo.reshape %60 : (tensor<1x5xi32>) -> tensor<5xi32>
    %62 = stablehlo.slice %arg1 [3:4, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
    %63 = stablehlo.reshape %62 : (tensor<1x3xf32>) -> tensor<3xf32>
    %64 = stablehlo.slice %61 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
    %65 = stablehlo.convert %64 : (tensor<3xi32>) -> tensor<3xf32>
    %66 = stablehlo.convert %63 : tensor<3xf32>
    %67 = stablehlo.dot_general %65, %66, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %68 = stablehlo.slice %61 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
    %69 = stablehlo.convert %68 : (tensor<3xi32>) -> tensor<3xf32>
    %70 = stablehlo.convert %63 : tensor<3xf32>
    %71 = stablehlo.dot_general %69, %70, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %72 = stablehlo.slice %61 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
    %73 = stablehlo.convert %72 : (tensor<3xi32>) -> tensor<3xf32>
    %74 = stablehlo.convert %63 : tensor<3xf32>
    %75 = stablehlo.dot_general %73, %74, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %76 = stablehlo.broadcast_in_dim %67, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %77 = stablehlo.broadcast_in_dim %71, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %78 = stablehlo.broadcast_in_dim %75, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %79 = stablehlo.concatenate %76, %77, %78, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
    %80 = stablehlo.slice %arg0 [4:5, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
    %81 = stablehlo.reshape %80 : (tensor<1x5xi32>) -> tensor<5xi32>
    %82 = stablehlo.slice %arg1 [4:5, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
    %83 = stablehlo.reshape %82 : (tensor<1x3xf32>) -> tensor<3xf32>
    %84 = stablehlo.slice %81 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
    %85 = stablehlo.convert %84 : (tensor<3xi32>) -> tensor<3xf32>
    %86 = stablehlo.convert %83 : tensor<3xf32>
    %87 = stablehlo.dot_general %85, %86, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %88 = stablehlo.slice %81 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
    %89 = stablehlo.convert %88 : (tensor<3xi32>) -> tensor<3xf32>
    %90 = stablehlo.convert %83 : tensor<3xf32>
    %91 = stablehlo.dot_general %89, %90, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %92 = stablehlo.slice %81 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
    %93 = stablehlo.convert %92 : (tensor<3xi32>) -> tensor<3xf32>
    %94 = stablehlo.convert %83 : tensor<3xf32>
    %95 = stablehlo.dot_general %93, %94, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %96 = stablehlo.broadcast_in_dim %87, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %97 = stablehlo.broadcast_in_dim %91, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %98 = stablehlo.broadcast_in_dim %95, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %99 = stablehlo.concatenate %96, %97, %98, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
    %100 = stablehlo.slice %arg0 [5:6, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
    %101 = stablehlo.reshape %100 : (tensor<1x5xi32>) -> tensor<5xi32>
    %102 = stablehlo.slice %arg1 [5:6, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
    %103 = stablehlo.reshape %102 : (tensor<1x3xf32>) -> tensor<3xf32>
    %104 = stablehlo.slice %101 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
    %105 = stablehlo.convert %104 : (tensor<3xi32>) -> tensor<3xf32>
    %106 = stablehlo.convert %103 : tensor<3xf32>
    %107 = stablehlo.dot_general %105, %106, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %108 = stablehlo.slice %101 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
    %109 = stablehlo.convert %108 : (tensor<3xi32>) -> tensor<3xf32>
    %110 = stablehlo.convert %103 : tensor<3xf32>
    %111 = stablehlo.dot_general %109, %110, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %112 = stablehlo.slice %101 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
    %113 = stablehlo.convert %112 : (tensor<3xi32>) -> tensor<3xf32>
    %114 = stablehlo.convert %103 : tensor<3xf32>
    %115 = stablehlo.dot_general %113, %114, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %116 = stablehlo.broadcast_in_dim %107, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %117 = stablehlo.broadcast_in_dim %111, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %118 = stablehlo.broadcast_in_dim %115, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %119 = stablehlo.concatenate %116, %117, %118, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
    %120 = stablehlo.slice %arg0 [6:7, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
    %121 = stablehlo.reshape %120 : (tensor<1x5xi32>) -> tensor<5xi32>
    %122 = stablehlo.slice %arg1 [6:7, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
    %123 = stablehlo.reshape %122 : (tensor<1x3xf32>) -> tensor<3xf32>
    %124 = stablehlo.slice %121 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
    %125 = stablehlo.convert %124 : (tensor<3xi32>) -> tensor<3xf32>
    %126 = stablehlo.convert %123 : tensor<3xf32>
    %127 = stablehlo.dot_general %125, %126, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %128 = stablehlo.slice %121 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
    %129 = stablehlo.convert %128 : (tensor<3xi32>) -> tensor<3xf32>
    %130 = stablehlo.convert %123 : tensor<3xf32>
    %131 = stablehlo.dot_general %129, %130, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %132 = stablehlo.slice %121 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
    %133 = stablehlo.convert %132 : (tensor<3xi32>) -> tensor<3xf32>
    %134 = stablehlo.convert %123 : tensor<3xf32>
    %135 = stablehlo.dot_general %133, %134, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %136 = stablehlo.broadcast_in_dim %127, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %137 = stablehlo.broadcast_in_dim %131, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %138 = stablehlo.broadcast_in_dim %135, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %139 = stablehlo.concatenate %136, %137, %138, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
    %140 = stablehlo.slice %arg0 [7:8, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
    %141 = stablehlo.reshape %140 : (tensor<1x5xi32>) -> tensor<5xi32>
    %142 = stablehlo.slice %arg1 [7:8, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
    %143 = stablehlo.reshape %142 : (tensor<1x3xf32>) -> tensor<3xf32>
    %144 = stablehlo.slice %141 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
    %145 = stablehlo.convert %144 : (tensor<3xi32>) -> tensor<3xf32>
    %146 = stablehlo.convert %143 : tensor<3xf32>
    %147 = stablehlo.dot_general %145, %146, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %148 = stablehlo.slice %141 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
    %149 = stablehlo.convert %148 : (tensor<3xi32>) -> tensor<3xf32>
    %150 = stablehlo.convert %143 : tensor<3xf32>
    %151 = stablehlo.dot_general %149, %150, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %152 = stablehlo.slice %141 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
    %153 = stablehlo.convert %152 : (tensor<3xi32>) -> tensor<3xf32>
    %154 = stablehlo.convert %143 : tensor<3xf32>
    %155 = stablehlo.dot_general %153, %154, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %156 = stablehlo.broadcast_in_dim %147, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %157 = stablehlo.broadcast_in_dim %151, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %158 = stablehlo.broadcast_in_dim %155, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %159 = stablehlo.concatenate %156, %157, %158, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
    %160 = stablehlo.slice %arg0 [8:9, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
    %161 = stablehlo.reshape %160 : (tensor<1x5xi32>) -> tensor<5xi32>
    %162 = stablehlo.slice %arg1 [8:9, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
    %163 = stablehlo.reshape %162 : (tensor<1x3xf32>) -> tensor<3xf32>
    %164 = stablehlo.slice %161 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
    %165 = stablehlo.convert %164 : (tensor<3xi32>) -> tensor<3xf32>
    %166 = stablehlo.convert %163 : tensor<3xf32>
    %167 = stablehlo.dot_general %165, %166, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %168 = stablehlo.slice %161 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
    %169 = stablehlo.convert %168 : (tensor<3xi32>) -> tensor<3xf32>
    %170 = stablehlo.convert %163 : tensor<3xf32>
    %171 = stablehlo.dot_general %169, %170, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %172 = stablehlo.slice %161 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
    %173 = stablehlo.convert %172 : (tensor<3xi32>) -> tensor<3xf32>
    %174 = stablehlo.convert %163 : tensor<3xf32>
    %175 = stablehlo.dot_general %173, %174, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %176 = stablehlo.broadcast_in_dim %167, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %177 = stablehlo.broadcast_in_dim %171, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %178 = stablehlo.broadcast_in_dim %175, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %179 = stablehlo.concatenate %176, %177, %178, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
    %180 = stablehlo.slice %arg0 [9:10, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
    %181 = stablehlo.reshape %180 : (tensor<1x5xi32>) -> tensor<5xi32>
    %182 = stablehlo.slice %arg1 [9:10, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
    %183 = stablehlo.reshape %182 : (tensor<1x3xf32>) -> tensor<3xf32>
    %184 = stablehlo.slice %181 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
    %185 = stablehlo.convert %184 : (tensor<3xi32>) -> tensor<3xf32>
    %186 = stablehlo.convert %183 : tensor<3xf32>
    %187 = stablehlo.dot_general %185, %186, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %188 = stablehlo.slice %181 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
    %189 = stablehlo.convert %188 : (tensor<3xi32>) -> tensor<3xf32>
    %190 = stablehlo.convert %183 : tensor<3xf32>
    %191 = stablehlo.dot_general %189, %190, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %192 = stablehlo.slice %181 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
    %193 = stablehlo.convert %192 : (tensor<3xi32>) -> tensor<3xf32>
    %194 = stablehlo.convert %183 : tensor<3xf32>
    %195 = stablehlo.dot_general %193, %194, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %196 = stablehlo.broadcast_in_dim %187, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %197 = stablehlo.broadcast_in_dim %191, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %198 = stablehlo.broadcast_in_dim %195, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %199 = stablehlo.concatenate %196, %197, %198, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
    %200 = stablehlo.slice %arg0 [10:11, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
    %201 = stablehlo.reshape %200 : (tensor<1x5xi32>) -> tensor<5xi32>
    %202 = stablehlo.slice %arg1 [10:11, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
    %203 = stablehlo.reshape %202 : (tensor<1x3xf32>) -> tensor<3xf32>
    %204 = stablehlo.slice %201 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
    %205 = stablehlo.convert %204 : (tensor<3xi32>) -> tensor<3xf32>
    %206 = stablehlo.convert %203 : tensor<3xf32>
    %207 = stablehlo.dot_general %205, %206, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %208 = stablehlo.slice %201 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
    %209 = stablehlo.convert %208 : (tensor<3xi32>) -> tensor<3xf32>
    %210 = stablehlo.convert %203 : tensor<3xf32>
    %211 = stablehlo.dot_general %209, %210, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %212 = stablehlo.slice %201 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
    %213 = stablehlo.convert %212 : (tensor<3xi32>) -> tensor<3xf32>
    %214 = stablehlo.convert %203 : tensor<3xf32>
    %215 = stablehlo.dot_general %213, %214, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %216 = stablehlo.broadcast_in_dim %207, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %217 = stablehlo.broadcast_in_dim %211, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %218 = stablehlo.broadcast_in_dim %215, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %219 = stablehlo.concatenate %216, %217, %218, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
    %220 = stablehlo.slice %arg0 [11:12, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
    %221 = stablehlo.reshape %220 : (tensor<1x5xi32>) -> tensor<5xi32>
    %222 = stablehlo.slice %arg1 [11:12, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
    %223 = stablehlo.reshape %222 : (tensor<1x3xf32>) -> tensor<3xf32>
    %224 = stablehlo.slice %221 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
    %225 = stablehlo.convert %224 : (tensor<3xi32>) -> tensor<3xf32>
    %226 = stablehlo.convert %223 : tensor<3xf32>
    %227 = stablehlo.dot_general %225, %226, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %228 = stablehlo.slice %221 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
    %229 = stablehlo.convert %228 : (tensor<3xi32>) -> tensor<3xf32>
    %230 = stablehlo.convert %223 : tensor<3xf32>
    %231 = stablehlo.dot_general %229, %230, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %232 = stablehlo.slice %221 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
    %233 = stablehlo.convert %232 : (tensor<3xi32>) -> tensor<3xf32>
    %234 = stablehlo.convert %223 : tensor<3xf32>
    %235 = stablehlo.dot_general %233, %234, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %236 = stablehlo.broadcast_in_dim %227, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %237 = stablehlo.broadcast_in_dim %231, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %238 = stablehlo.broadcast_in_dim %235, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %239 = stablehlo.concatenate %236, %237, %238, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
    %240 = stablehlo.slice %arg0 [12:13, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
    %241 = stablehlo.reshape %240 : (tensor<1x5xi32>) -> tensor<5xi32>
    %242 = stablehlo.slice %arg1 [12:13, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
    %243 = stablehlo.reshape %242 : (tensor<1x3xf32>) -> tensor<3xf32>
    %244 = stablehlo.slice %241 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
    %245 = stablehlo.convert %244 : (tensor<3xi32>) -> tensor<3xf32>
    %246 = stablehlo.convert %243 : tensor<3xf32>
    %247 = stablehlo.dot_general %245, %246, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %248 = stablehlo.slice %241 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
    %249 = stablehlo.convert %248 : (tensor<3xi32>) -> tensor<3xf32>
    %250 = stablehlo.convert %243 : tensor<3xf32>
    %251 = stablehlo.dot_general %249, %250, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %252 = stablehlo.slice %241 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
    %253 = stablehlo.convert %252 : (tensor<3xi32>) -> tensor<3xf32>
    %254 = stablehlo.convert %243 : tensor<3xf32>
    %255 = stablehlo.dot_general %253, %254, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %256 = stablehlo.broadcast_in_dim %247, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %257 = stablehlo.broadcast_in_dim %251, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %258 = stablehlo.broadcast_in_dim %255, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %259 = stablehlo.concatenate %256, %257, %258, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
    %260 = stablehlo.slice %arg0 [13:14, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
    %261 = stablehlo.reshape %260 : (tensor<1x5xi32>) -> tensor<5xi32>
    %262 = stablehlo.slice %arg1 [13:14, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
    %263 = stablehlo.reshape %262 : (tensor<1x3xf32>) -> tensor<3xf32>
    %264 = stablehlo.slice %261 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
    %265 = stablehlo.convert %264 : (tensor<3xi32>) -> tensor<3xf32>
    %266 = stablehlo.convert %263 : tensor<3xf32>
    %267 = stablehlo.dot_general %265, %266, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %268 = stablehlo.slice %261 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
    %269 = stablehlo.convert %268 : (tensor<3xi32>) -> tensor<3xf32>
    %270 = stablehlo.convert %263 : tensor<3xf32>
    %271 = stablehlo.dot_general %269, %270, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %272 = stablehlo.slice %261 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
    %273 = stablehlo.convert %272 : (tensor<3xi32>) -> tensor<3xf32>
    %274 = stablehlo.convert %263 : tensor<3xf32>
    %275 = stablehlo.dot_general %273, %274, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %276 = stablehlo.broadcast_in_dim %267, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %277 = stablehlo.broadcast_in_dim %271, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %278 = stablehlo.broadcast_in_dim %275, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %279 = stablehlo.concatenate %276, %277, %278, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
    %280 = stablehlo.slice %arg0 [14:15, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
    %281 = stablehlo.reshape %280 : (tensor<1x5xi32>) -> tensor<5xi32>
    %282 = stablehlo.slice %arg1 [14:15, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
    %283 = stablehlo.reshape %282 : (tensor<1x3xf32>) -> tensor<3xf32>
    %284 = stablehlo.slice %281 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
    %285 = stablehlo.convert %284 : (tensor<3xi32>) -> tensor<3xf32>
    %286 = stablehlo.convert %283 : tensor<3xf32>
    %287 = stablehlo.dot_general %285, %286, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %288 = stablehlo.slice %281 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
    %289 = stablehlo.convert %288 : (tensor<3xi32>) -> tensor<3xf32>
    %290 = stablehlo.convert %283 : tensor<3xf32>
    %291 = stablehlo.dot_general %289, %290, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %292 = stablehlo.slice %281 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
    %293 = stablehlo.convert %292 : (tensor<3xi32>) -> tensor<3xf32>
    %294 = stablehlo.convert %283 : tensor<3xf32>
    %295 = stablehlo.dot_general %293, %294, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %296 = stablehlo.broadcast_in_dim %287, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %297 = stablehlo.broadcast_in_dim %291, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %298 = stablehlo.broadcast_in_dim %295, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %299 = stablehlo.concatenate %296, %297, %298, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
    %300 = stablehlo.slice %arg0 [15:16, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
    %301 = stablehlo.reshape %300 : (tensor<1x5xi32>) -> tensor<5xi32>
    %302 = stablehlo.slice %arg1 [15:16, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
    %303 = stablehlo.reshape %302 : (tensor<1x3xf32>) -> tensor<3xf32>
    %304 = stablehlo.slice %301 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
    %305 = stablehlo.convert %304 : (tensor<3xi32>) -> tensor<3xf32>
    %306 = stablehlo.convert %303 : tensor<3xf32>
    %307 = stablehlo.dot_general %305, %306, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %308 = stablehlo.slice %301 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
    %309 = stablehlo.convert %308 : (tensor<3xi32>) -> tensor<3xf32>
    %310 = stablehlo.convert %303 : tensor<3xf32>
    %311 = stablehlo.dot_general %309, %310, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %312 = stablehlo.slice %301 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
    %313 = stablehlo.convert %312 : (tensor<3xi32>) -> tensor<3xf32>
    %314 = stablehlo.convert %303 : tensor<3xf32>
    %315 = stablehlo.dot_general %313, %314, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
    %316 = stablehlo.broadcast_in_dim %307, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %317 = stablehlo.broadcast_in_dim %311, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %318 = stablehlo.broadcast_in_dim %315, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %319 = stablehlo.concatenate %316, %317, %318, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
    %320 = stablehlo.broadcast_in_dim %19, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
    %321 = stablehlo.broadcast_in_dim %39, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
    %322 = stablehlo.broadcast_in_dim %59, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
    %323 = stablehlo.broadcast_in_dim %79, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
    %324 = stablehlo.broadcast_in_dim %99, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
    %325 = stablehlo.broadcast_in_dim %119, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
    %326 = stablehlo.broadcast_in_dim %139, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
    %327 = stablehlo.broadcast_in_dim %159, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
    %328 = stablehlo.broadcast_in_dim %179, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
    %329 = stablehlo.broadcast_in_dim %199, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
    %330 = stablehlo.broadcast_in_dim %219, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
    %331 = stablehlo.broadcast_in_dim %239, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
    %332 = stablehlo.broadcast_in_dim %259, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
    %333 = stablehlo.broadcast_in_dim %279, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
    %334 = stablehlo.broadcast_in_dim %299, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
    %335 = stablehlo.broadcast_in_dim %319, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
    %336 = stablehlo.concatenate %320, %321, %322, %323, %324, %325, %326, %327, %328, %329, %330, %331, %332, %333, %334, %335, dim = 0 : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<16x3xf32>
    return %336 : tensor<16x3xf32>
  }
}



train_images = np.array(mnist_dataset.data).reshape(len(mnist_dataset.data), -1)
train_labels = one_hot(np.array(mnist_dataset.targets), n_targets)
mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = jnp.array(mnist_dataset_test.data.numpy().reshape(len(mnist_dataset_test.data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.targets), n_targets)
import time

first_time = True
print(num_epochs)
for epoch in range(num_epochs):
    start_time = time.time()
    for x, y in training_generator:
        y = one_hot(y, n_targets)
        # compiled_ = None
        if first_time:
            first_time = False
            saved_ = jit(update, device=jax.device_get(0)).lower(params, x, y)
            # # Print the lowered IR as text
            print(saved_.as_text())
            
            # Get the StableHLO representation of the lowered IR 
            print(saved_.compiler_ir(dialect="stablehlo"))
            compiled_ = saved_.compile()
        
        # compiled_(params, x, y)
        update(params, x, y)
    epoch_time = time.time() - start_time

    train_acc = accuracy(params, train_images, train_labels)
    test_acc = accuracy(params, test_images, test_labels)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))
10
module @jit_update attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<512x784xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<512xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<512x512xf32> {mhlo.layout_mode = "default"}, %arg3: tensor<512xf32> {mhlo.layout_mode = "default"}, %arg4: tensor<10x512xf32> {mhlo.layout_mode = "default"}, %arg5: tensor<10xf32> {mhlo.layout_mode = "default"}, %arg6: tensor<128x784xf32> {mhlo.layout_mode = "default"}, %arg7: tensor<128x10xf32> {mhlo.layout_mode = "default"}) -> (tensor<512x784xf32> {jax.result_info = "[0][0]", mhlo.layout_mode = "default"}, tensor<512xf32> {jax.result_info = "[0][1]", mhlo.layout_mode = "default"}, tensor<512x512xf32> {jax.result_info = "[1][0]", mhlo.layout_mode = "default"}, tensor<512xf32> {jax.result_info = "[1][1]", mhlo.layout_mode = "default"}, tensor<10x512xf32> {jax.result_info = "[2][0]", mhlo.layout_mode = "default"}, tensor<10xf32> {jax.result_info = "[2][1]", mhlo.layout_mode = "default"}) {
    %0:6 = call @update(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (tensor<512x784xf32>, tensor<512xf32>, tensor<512x512xf32>, tensor<512xf32>, tensor<10x512xf32>, tensor<10xf32>, tensor<128x784xf32>, tensor<128x10xf32>) -> (tensor<512x784xf32>, tensor<512xf32>, tensor<512x512xf32>, tensor<512xf32>, tensor<10x512xf32>, tensor<10xf32>)
    return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<512x784xf32>, tensor<512xf32>, tensor<512x512xf32>, tensor<512xf32>, tensor<10x512xf32>, tensor<10xf32>
  }
  func.func private @update(%arg0: tensor<512x784xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<512xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<512x512xf32> {mhlo.layout_mode = "default"}, %arg3: tensor<512xf32> {mhlo.layout_mode = "default"}, %arg4: tensor<10x512xf32> {mhlo.layout_mode = "default"}, %arg5: tensor<10xf32> {mhlo.layout_mode = "default"}, %arg6: tensor<128x784xf32> {mhlo.layout_mode = "default"}, %arg7: tensor<128x10xf32> {mhlo.layout_mode = "default"}) -> (tensor<512x784xf32> {mhlo.layout_mode = "default"}, tensor<512xf32> {mhlo.layout_mode = "default"}, tensor<512x512xf32> {mhlo.layout_mode = "default"}, tensor<512xf32> {mhlo.layout_mode = "default"}, tensor<10x512xf32> {mhlo.layout_mode = "default"}, tensor<10xf32> {mhlo.layout_mode = "default"}) {
    %0 = stablehlo.dot_general %arg0, %arg6, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<512x784xf32>, tensor<128x784xf32>) -> tensor<512x128xf32>
    %1 = stablehlo.transpose %0, dims = [1, 0] : (tensor<512x128xf32>) -> tensor<128x512xf32>
    %2 = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<512xf32>) -> tensor<1x512xf32>
    %3 = stablehlo.broadcast_in_dim %2, dims = [0, 1] : (tensor<1x512xf32>) -> tensor<128x512xf32>
    %4 = stablehlo.add %1, %3 : tensor<128x512xf32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
    %6 = stablehlo.maximum %5, %4 : tensor<128x512xf32>
    %7 = stablehlo.compare  EQ, %4, %6,  FLOAT : (tensor<128x512xf32>, tensor<128x512xf32>) -> tensor<128x512xi1>
    %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %8 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %9 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
    %10 = stablehlo.select %7, %8, %9 : tensor<128x512xi1>, tensor<128x512xf32>
    %cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %11 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
    %12 = stablehlo.compare  EQ, %11, %6,  FLOAT : (tensor<128x512xf32>, tensor<128x512xf32>) -> tensor<128x512xi1>
    %cst_3 = stablehlo.constant dense<2.000000e+00> : tensor<f32>
    %13 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
    %cst_4 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %14 = stablehlo.broadcast_in_dim %cst_4, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
    %15 = stablehlo.select %12, %13, %14 : tensor<128x512xi1>, tensor<128x512xf32>
    %16 = stablehlo.divide %10, %15 : tensor<128x512xf32>
    %17 = stablehlo.dot_general %arg2, %6, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<512x512xf32>, tensor<128x512xf32>) -> tensor<512x128xf32>
    %18 = stablehlo.transpose %17, dims = [1, 0] : (tensor<512x128xf32>) -> tensor<128x512xf32>
    %19 = stablehlo.broadcast_in_dim %arg3, dims = [1] : (tensor<512xf32>) -> tensor<1x512xf32>
    %20 = stablehlo.broadcast_in_dim %19, dims = [0, 1] : (tensor<1x512xf32>) -> tensor<128x512xf32>
    %21 = stablehlo.add %18, %20 : tensor<128x512xf32>
    %cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %22 = stablehlo.broadcast_in_dim %cst_5, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
    %23 = stablehlo.maximum %22, %21 : tensor<128x512xf32>
    %24 = stablehlo.compare  EQ, %21, %23,  FLOAT : (tensor<128x512xf32>, tensor<128x512xf32>) -> tensor<128x512xi1>
    %cst_6 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %25 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
    %cst_7 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %26 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
    %27 = stablehlo.select %24, %25, %26 : tensor<128x512xi1>, tensor<128x512xf32>
    %cst_8 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %28 = stablehlo.broadcast_in_dim %cst_8, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
    %29 = stablehlo.compare  EQ, %28, %23,  FLOAT : (tensor<128x512xf32>, tensor<128x512xf32>) -> tensor<128x512xi1>
    %cst_9 = stablehlo.constant dense<2.000000e+00> : tensor<f32>
    %30 = stablehlo.broadcast_in_dim %cst_9, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
    %cst_10 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %31 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
    %32 = stablehlo.select %29, %30, %31 : tensor<128x512xi1>, tensor<128x512xf32>
    %33 = stablehlo.divide %27, %32 : tensor<128x512xf32>
    %34 = stablehlo.dot_general %arg4, %23, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<10x512xf32>, tensor<128x512xf32>) -> tensor<10x128xf32>
    %35 = stablehlo.transpose %34, dims = [1, 0] : (tensor<10x128xf32>) -> tensor<128x10xf32>
    %36 = stablehlo.broadcast_in_dim %arg5, dims = [1] : (tensor<10xf32>) -> tensor<1x10xf32>
    %37 = stablehlo.broadcast_in_dim %36, dims = [0, 1] : (tensor<1x10xf32>) -> tensor<128x10xf32>
    %38 = stablehlo.add %35, %37 : tensor<128x10xf32>
    %cst_11 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %39 = stablehlo.reduce(%38 init: %cst_11) applies stablehlo.maximum across dimensions = [1] : (tensor<128x10xf32>, tensor<f32>) -> tensor<128xf32>
    %cst_12 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %40 = stablehlo.broadcast_in_dim %cst_12, dims = [] : (tensor<f32>) -> tensor<128xf32>
    %41 = stablehlo.maximum %40, %39 : tensor<128xf32>
    %42 = stablehlo.is_finite %41 : (tensor<128xf32>) -> tensor<128xi1>
    %cst_13 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %43 = stablehlo.broadcast_in_dim %cst_13, dims = [] : (tensor<f32>) -> tensor<128xf32>
    %44 = stablehlo.select %42, %41, %43 : tensor<128xi1>, tensor<128xf32>
    %45 = stablehlo.broadcast_in_dim %44, dims = [0] : (tensor<128xf32>) -> tensor<128x1xf32>
    %46 = stablehlo.broadcast_in_dim %45, dims = [0, 1] : (tensor<128x1xf32>) -> tensor<128x10xf32>
    %47 = stablehlo.subtract %38, %46 : tensor<128x10xf32>
    %48 = stablehlo.exponential %47 : tensor<128x10xf32>
    %cst_14 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %49 = stablehlo.reduce(%48 init: %cst_14) applies stablehlo.add across dimensions = [1] : (tensor<128x10xf32>, tensor<f32>) -> tensor<128xf32>
    %50 = stablehlo.abs %49 : tensor<128xf32>
    %cst_15 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %51 = stablehlo.broadcast_in_dim %cst_15, dims = [] : (tensor<f32>) -> tensor<128xf32>
    %52 = stablehlo.compare  GE, %49, %51,  FLOAT : (tensor<128xf32>, tensor<128xf32>) -> tensor<128xi1>
    %cst_16 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %53 = stablehlo.negate %cst_16 : tensor<f32>
    %cst_17 = stablehlo.constant dense<1.280000e+03> : tensor<f32>
    %54 = stablehlo.divide %53, %cst_17 : tensor<f32>
    %55 = stablehlo.broadcast_in_dim %54, dims = [] : (tensor<f32>) -> tensor<128x10xf32>
    %56 = stablehlo.multiply %55, %arg7 : tensor<128x10xf32>
    %57 = stablehlo.negate %56 : tensor<128x10xf32>
    %cst_18 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %58 = stablehlo.reduce(%57 init: %cst_18) applies stablehlo.add across dimensions = [1] : (tensor<128x10xf32>, tensor<f32>) -> tensor<128xf32>
    %59 = stablehlo.reshape %58 : (tensor<128xf32>) -> tensor<128x1xf32>
    %cst_19 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %60 = stablehlo.reduce(%59 init: %cst_19) applies stablehlo.add across dimensions = [1] : (tensor<128x1xf32>, tensor<f32>) -> tensor<128xf32>
    %61 = stablehlo.divide %60, %50 : tensor<128xf32>
    %cst_20 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %62 = stablehlo.broadcast_in_dim %cst_20, dims = [] : (tensor<f32>) -> tensor<128xf32>
    %63 = stablehlo.select %52, %62, %61 : tensor<128xi1>, tensor<128xf32>
    %64 = stablehlo.select %52, %61, %62 : tensor<128xi1>, tensor<128xf32>
    %65 = stablehlo.negate %63 : tensor<128xf32>
    %66 = stablehlo.add %64, %65 : tensor<128xf32>
    %67 = stablehlo.broadcast_in_dim %66, dims = [0] : (tensor<128xf32>) -> tensor<128x10xf32>
    %68 = stablehlo.multiply %67, %48 : tensor<128x10xf32>
    %69 = stablehlo.add %56, %68 : tensor<128x10xf32>
    %cst_21 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %70 = stablehlo.reduce(%69 init: %cst_21) applies stablehlo.add across dimensions = [0] : (tensor<128x10xf32>, tensor<f32>) -> tensor<10xf32>
    %71 = stablehlo.reshape %70 : (tensor<10xf32>) -> tensor<1x10xf32>
    %cst_22 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %72 = stablehlo.reduce(%71 init: %cst_22) applies stablehlo.add across dimensions = [0] : (tensor<1x10xf32>, tensor<f32>) -> tensor<10xf32>
    %73 = stablehlo.transpose %69, dims = [1, 0] : (tensor<128x10xf32>) -> tensor<10x128xf32>
    %74 = stablehlo.dot_general %73, %arg4, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x128xf32>, tensor<10x512xf32>) -> tensor<128x512xf32>
    %75 = stablehlo.dot_general %73, %23, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x128xf32>, tensor<128x512xf32>) -> tensor<10x512xf32>
    %76 = stablehlo.multiply %74, %33 : tensor<128x512xf32>
    %cst_23 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %77 = stablehlo.reduce(%76 init: %cst_23) applies stablehlo.add across dimensions = [0] : (tensor<128x512xf32>, tensor<f32>) -> tensor<512xf32>
    %78 = stablehlo.reshape %77 : (tensor<512xf32>) -> tensor<1x512xf32>
    %cst_24 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %79 = stablehlo.reduce(%78 init: %cst_24) applies stablehlo.add across dimensions = [0] : (tensor<1x512xf32>, tensor<f32>) -> tensor<512xf32>
    %80 = stablehlo.transpose %76, dims = [1, 0] : (tensor<128x512xf32>) -> tensor<512x128xf32>
    %81 = stablehlo.dot_general %80, %arg2, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<512x128xf32>, tensor<512x512xf32>) -> tensor<128x512xf32>
    %82 = stablehlo.dot_general %80, %6, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<512x128xf32>, tensor<128x512xf32>) -> tensor<512x512xf32>
    %83 = stablehlo.multiply %81, %16 : tensor<128x512xf32>
    %cst_25 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %84 = stablehlo.reduce(%83 init: %cst_25) applies stablehlo.add across dimensions = [0] : (tensor<128x512xf32>, tensor<f32>) -> tensor<512xf32>
    %85 = stablehlo.reshape %84 : (tensor<512xf32>) -> tensor<1x512xf32>
    %cst_26 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %86 = stablehlo.reduce(%85 init: %cst_26) applies stablehlo.add across dimensions = [0] : (tensor<1x512xf32>, tensor<f32>) -> tensor<512xf32>
    %87 = stablehlo.transpose %83, dims = [1, 0] : (tensor<128x512xf32>) -> tensor<512x128xf32>
    %88 = stablehlo.dot_general %87, %arg6, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<512x128xf32>, tensor<128x784xf32>) -> tensor<512x784xf32>
    %cst_27 = stablehlo.constant dense<0.00999999977> : tensor<f32>
    %89 = stablehlo.broadcast_in_dim %cst_27, dims = [] : (tensor<f32>) -> tensor<512x784xf32>
    %90 = stablehlo.multiply %89, %88 : tensor<512x784xf32>
    %91 = stablehlo.subtract %arg0, %90 : tensor<512x784xf32>
    %cst_28 = stablehlo.constant dense<0.00999999977> : tensor<f32>
    %92 = stablehlo.broadcast_in_dim %cst_28, dims = [] : (tensor<f32>) -> tensor<512xf32>
    %93 = stablehlo.multiply %92, %86 : tensor<512xf32>
    %94 = stablehlo.subtract %arg1, %93 : tensor<512xf32>
    %cst_29 = stablehlo.constant dense<0.00999999977> : tensor<f32>
    %95 = stablehlo.broadcast_in_dim %cst_29, dims = [] : (tensor<f32>) -> tensor<512x512xf32>
    %96 = stablehlo.multiply %95, %82 : tensor<512x512xf32>
    %97 = stablehlo.subtract %arg2, %96 : tensor<512x512xf32>
    %cst_30 = stablehlo.constant dense<0.00999999977> : tensor<f32>
    %98 = stablehlo.broadcast_in_dim %cst_30, dims = [] : (tensor<f32>) -> tensor<512xf32>
    %99 = stablehlo.multiply %98, %79 : tensor<512xf32>
    %100 = stablehlo.subtract %arg3, %99 : tensor<512xf32>
    %cst_31 = stablehlo.constant dense<0.00999999977> : tensor<f32>
    %101 = stablehlo.broadcast_in_dim %cst_31, dims = [] : (tensor<f32>) -> tensor<10x512xf32>
    %102 = stablehlo.multiply %101, %75 : tensor<10x512xf32>
    %103 = stablehlo.subtract %arg4, %102 : tensor<10x512xf32>
    %cst_32 = stablehlo.constant dense<0.00999999977> : tensor<f32>
    %104 = stablehlo.broadcast_in_dim %cst_32, dims = [] : (tensor<f32>) -> tensor<10xf32>
    %105 = stablehlo.multiply %104, %72 : tensor<10xf32>
    %106 = stablehlo.subtract %arg5, %105 : tensor<10xf32>
    return %91, %94, %97, %100, %103, %106 : tensor<512x784xf32>, tensor<512xf32>, tensor<512x512xf32>, tensor<512xf32>, tensor<10x512xf32>, tensor<10xf32>
  }
}

module @jit_update attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<512x784xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<512xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<512x512xf32> {mhlo.layout_mode = "default"}, %arg3: tensor<512xf32> {mhlo.layout_mode = "default"}, %arg4: tensor<10x512xf32> {mhlo.layout_mode = "default"}, %arg5: tensor<10xf32> {mhlo.layout_mode = "default"}, %arg6: tensor<128x784xf32> {mhlo.layout_mode = "default"}, %arg7: tensor<128x10xf32> {mhlo.layout_mode = "default"}) -> (tensor<512x784xf32> {jax.result_info = "[0][0]", mhlo.layout_mode = "default"}, tensor<512xf32> {jax.result_info = "[0][1]", mhlo.layout_mode = "default"}, tensor<512x512xf32> {jax.result_info = "[1][0]", mhlo.layout_mode = "default"}, tensor<512xf32> {jax.result_info = "[1][1]", mhlo.layout_mode = "default"}, tensor<10x512xf32> {jax.result_info = "[2][0]", mhlo.layout_mode = "default"}, tensor<10xf32> {jax.result_info = "[2][1]", mhlo.layout_mode = "default"}) {
    %0:6 = call @update(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (tensor<512x784xf32>, tensor<512xf32>, tensor<512x512xf32>, tensor<512xf32>, tensor<10x512xf32>, tensor<10xf32>, tensor<128x784xf32>, tensor<128x10xf32>) -> (tensor<512x784xf32>, tensor<512xf32>, tensor<512x512xf32>, tensor<512xf32>, tensor<10x512xf32>, tensor<10xf32>)
    return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<512x784xf32>, tensor<512xf32>, tensor<512x512xf32>, tensor<512xf32>, tensor<10x512xf32>, tensor<10xf32>
  }
  func.func private @update(%arg0: tensor<512x784xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<512xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<512x512xf32> {mhlo.layout_mode = "default"}, %arg3: tensor<512xf32> {mhlo.layout_mode = "default"}, %arg4: tensor<10x512xf32> {mhlo.layout_mode = "default"}, %arg5: tensor<10xf32> {mhlo.layout_mode = "default"}, %arg6: tensor<128x784xf32> {mhlo.layout_mode = "default"}, %arg7: tensor<128x10xf32> {mhlo.layout_mode = "default"}) -> (tensor<512x784xf32> {mhlo.layout_mode = "default"}, tensor<512xf32> {mhlo.layout_mode = "default"}, tensor<512x512xf32> {mhlo.layout_mode = "default"}, tensor<512xf32> {mhlo.layout_mode = "default"}, tensor<10x512xf32> {mhlo.layout_mode = "default"}, tensor<10xf32> {mhlo.layout_mode = "default"}) {
    %0 = stablehlo.dot_general %arg0, %arg6, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<512x784xf32>, tensor<128x784xf32>) -> tensor<512x128xf32>
    %1 = stablehlo.transpose %0, dims = [1, 0] : (tensor<512x128xf32>) -> tensor<128x512xf32>
    %2 = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<512xf32>) -> tensor<1x512xf32>
    %3 = stablehlo.broadcast_in_dim %2, dims = [0, 1] : (tensor<1x512xf32>) -> tensor<128x512xf32>
    %4 = stablehlo.add %1, %3 : tensor<128x512xf32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
    %6 = stablehlo.maximum %5, %4 : tensor<128x512xf32>
    %7 = stablehlo.compare  EQ, %4, %6,  FLOAT : (tensor<128x512xf32>, tensor<128x512xf32>) -> tensor<128x512xi1>
    %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %8 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %9 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
    %10 = stablehlo.select %7, %8, %9 : tensor<128x512xi1>, tensor<128x512xf32>
    %cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %11 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
    %12 = stablehlo.compare  EQ, %11, %6,  FLOAT : (tensor<128x512xf32>, tensor<128x512xf32>) -> tensor<128x512xi1>
    %cst_3 = stablehlo.constant dense<2.000000e+00> : tensor<f32>
    %13 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
    %cst_4 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %14 = stablehlo.broadcast_in_dim %cst_4, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
    %15 = stablehlo.select %12, %13, %14 : tensor<128x512xi1>, tensor<128x512xf32>
    %16 = stablehlo.divide %10, %15 : tensor<128x512xf32>
    %17 = stablehlo.dot_general %arg2, %6, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<512x512xf32>, tensor<128x512xf32>) -> tensor<512x128xf32>
    %18 = stablehlo.transpose %17, dims = [1, 0] : (tensor<512x128xf32>) -> tensor<128x512xf32>
    %19 = stablehlo.broadcast_in_dim %arg3, dims = [1] : (tensor<512xf32>) -> tensor<1x512xf32>
    %20 = stablehlo.broadcast_in_dim %19, dims = [0, 1] : (tensor<1x512xf32>) -> tensor<128x512xf32>
    %21 = stablehlo.add %18, %20 : tensor<128x512xf32>
    %cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %22 = stablehlo.broadcast_in_dim %cst_5, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
    %23 = stablehlo.maximum %22, %21 : tensor<128x512xf32>
    %24 = stablehlo.compare  EQ, %21, %23,  FLOAT : (tensor<128x512xf32>, tensor<128x512xf32>) -> tensor<128x512xi1>
    %cst_6 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %25 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
    %cst_7 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %26 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
    %27 = stablehlo.select %24, %25, %26 : tensor<128x512xi1>, tensor<128x512xf32>
    %cst_8 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %28 = stablehlo.broadcast_in_dim %cst_8, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
    %29 = stablehlo.compare  EQ, %28, %23,  FLOAT : (tensor<128x512xf32>, tensor<128x512xf32>) -> tensor<128x512xi1>
    %cst_9 = stablehlo.constant dense<2.000000e+00> : tensor<f32>
    %30 = stablehlo.broadcast_in_dim %cst_9, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
    %cst_10 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %31 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
    %32 = stablehlo.select %29, %30, %31 : tensor<128x512xi1>, tensor<128x512xf32>
    %33 = stablehlo.divide %27, %32 : tensor<128x512xf32>
    %34 = stablehlo.dot_general %arg4, %23, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<10x512xf32>, tensor<128x512xf32>) -> tensor<10x128xf32>
    %35 = stablehlo.transpose %34, dims = [1, 0] : (tensor<10x128xf32>) -> tensor<128x10xf32>
    %36 = stablehlo.broadcast_in_dim %arg5, dims = [1] : (tensor<10xf32>) -> tensor<1x10xf32>
    %37 = stablehlo.broadcast_in_dim %36, dims = [0, 1] : (tensor<1x10xf32>) -> tensor<128x10xf32>
    %38 = stablehlo.add %35, %37 : tensor<128x10xf32>
    %cst_11 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %39 = stablehlo.reduce(%38 init: %cst_11) applies stablehlo.maximum across dimensions = [1] : (tensor<128x10xf32>, tensor<f32>) -> tensor<128xf32>
    %cst_12 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %40 = stablehlo.broadcast_in_dim %cst_12, dims = [] : (tensor<f32>) -> tensor<128xf32>
    %41 = stablehlo.maximum %40, %39 : tensor<128xf32>
    %42 = stablehlo.is_finite %41 : (tensor<128xf32>) -> tensor<128xi1>
    %cst_13 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %43 = stablehlo.broadcast_in_dim %cst_13, dims = [] : (tensor<f32>) -> tensor<128xf32>
    %44 = stablehlo.select %42, %41, %43 : tensor<128xi1>, tensor<128xf32>
    %45 = stablehlo.broadcast_in_dim %44, dims = [0] : (tensor<128xf32>) -> tensor<128x1xf32>
    %46 = stablehlo.broadcast_in_dim %45, dims = [0, 1] : (tensor<128x1xf32>) -> tensor<128x10xf32>
    %47 = stablehlo.subtract %38, %46 : tensor<128x10xf32>
    %48 = stablehlo.exponential %47 : tensor<128x10xf32>
    %cst_14 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %49 = stablehlo.reduce(%48 init: %cst_14) applies stablehlo.add across dimensions = [1] : (tensor<128x10xf32>, tensor<f32>) -> tensor<128xf32>
    %50 = stablehlo.abs %49 : tensor<128xf32>
    %cst_15 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %51 = stablehlo.broadcast_in_dim %cst_15, dims = [] : (tensor<f32>) -> tensor<128xf32>
    %52 = stablehlo.compare  GE, %49, %51,  FLOAT : (tensor<128xf32>, tensor<128xf32>) -> tensor<128xi1>
    %cst_16 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %53 = stablehlo.negate %cst_16 : tensor<f32>
    %cst_17 = stablehlo.constant dense<1.280000e+03> : tensor<f32>
    %54 = stablehlo.divide %53, %cst_17 : tensor<f32>
    %55 = stablehlo.broadcast_in_dim %54, dims = [] : (tensor<f32>) -> tensor<128x10xf32>
    %56 = stablehlo.multiply %55, %arg7 : tensor<128x10xf32>
    %57 = stablehlo.negate %56 : tensor<128x10xf32>
    %cst_18 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %58 = stablehlo.reduce(%57 init: %cst_18) applies stablehlo.add across dimensions = [1] : (tensor<128x10xf32>, tensor<f32>) -> tensor<128xf32>
    %59 = stablehlo.reshape %58 : (tensor<128xf32>) -> tensor<128x1xf32>
    %cst_19 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %60 = stablehlo.reduce(%59 init: %cst_19) applies stablehlo.add across dimensions = [1] : (tensor<128x1xf32>, tensor<f32>) -> tensor<128xf32>
    %61 = stablehlo.divide %60, %50 : tensor<128xf32>
    %cst_20 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %62 = stablehlo.broadcast_in_dim %cst_20, dims = [] : (tensor<f32>) -> tensor<128xf32>
    %63 = stablehlo.select %52, %62, %61 : tensor<128xi1>, tensor<128xf32>
    %64 = stablehlo.select %52, %61, %62 : tensor<128xi1>, tensor<128xf32>
    %65 = stablehlo.negate %63 : tensor<128xf32>
    %66 = stablehlo.add %64, %65 : tensor<128xf32>
    %67 = stablehlo.broadcast_in_dim %66, dims = [0] : (tensor<128xf32>) -> tensor<128x10xf32>
    %68 = stablehlo.multiply %67, %48 : tensor<128x10xf32>
    %69 = stablehlo.add %56, %68 : tensor<128x10xf32>
    %cst_21 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %70 = stablehlo.reduce(%69 init: %cst_21) applies stablehlo.add across dimensions = [0] : (tensor<128x10xf32>, tensor<f32>) -> tensor<10xf32>
    %71 = stablehlo.reshape %70 : (tensor<10xf32>) -> tensor<1x10xf32>
    %cst_22 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %72 = stablehlo.reduce(%71 init: %cst_22) applies stablehlo.add across dimensions = [0] : (tensor<1x10xf32>, tensor<f32>) -> tensor<10xf32>
    %73 = stablehlo.transpose %69, dims = [1, 0] : (tensor<128x10xf32>) -> tensor<10x128xf32>
    %74 = stablehlo.dot_general %73, %arg4, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x128xf32>, tensor<10x512xf32>) -> tensor<128x512xf32>
    %75 = stablehlo.dot_general %73, %23, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x128xf32>, tensor<128x512xf32>) -> tensor<10x512xf32>
    %76 = stablehlo.multiply %74, %33 : tensor<128x512xf32>
    %cst_23 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %77 = stablehlo.reduce(%76 init: %cst_23) applies stablehlo.add across dimensions = [0] : (tensor<128x512xf32>, tensor<f32>) -> tensor<512xf32>
    %78 = stablehlo.reshape %77 : (tensor<512xf32>) -> tensor<1x512xf32>
    %cst_24 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %79 = stablehlo.reduce(%78 init: %cst_24) applies stablehlo.add across dimensions = [0] : (tensor<1x512xf32>, tensor<f32>) -> tensor<512xf32>
    %80 = stablehlo.transpose %76, dims = [1, 0] : (tensor<128x512xf32>) -> tensor<512x128xf32>
    %81 = stablehlo.dot_general %80, %arg2, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<512x128xf32>, tensor<512x512xf32>) -> tensor<128x512xf32>
    %82 = stablehlo.dot_general %80, %6, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<512x128xf32>, tensor<128x512xf32>) -> tensor<512x512xf32>
    %83 = stablehlo.multiply %81, %16 : tensor<128x512xf32>
    %cst_25 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %84 = stablehlo.reduce(%83 init: %cst_25) applies stablehlo.add across dimensions = [0] : (tensor<128x512xf32>, tensor<f32>) -> tensor<512xf32>
    %85 = stablehlo.reshape %84 : (tensor<512xf32>) -> tensor<1x512xf32>
    %cst_26 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %86 = stablehlo.reduce(%85 init: %cst_26) applies stablehlo.add across dimensions = [0] : (tensor<1x512xf32>, tensor<f32>) -> tensor<512xf32>
    %87 = stablehlo.transpose %83, dims = [1, 0] : (tensor<128x512xf32>) -> tensor<512x128xf32>
    %88 = stablehlo.dot_general %87, %arg6, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<512x128xf32>, tensor<128x784xf32>) -> tensor<512x784xf32>
    %cst_27 = stablehlo.constant dense<0.00999999977> : tensor<f32>
    %89 = stablehlo.broadcast_in_dim %cst_27, dims = [] : (tensor<f32>) -> tensor<512x784xf32>
    %90 = stablehlo.multiply %89, %88 : tensor<512x784xf32>
    %91 = stablehlo.subtract %arg0, %90 : tensor<512x784xf32>
    %cst_28 = stablehlo.constant dense<0.00999999977> : tensor<f32>
    %92 = stablehlo.broadcast_in_dim %cst_28, dims = [] : (tensor<f32>) -> tensor<512xf32>
    %93 = stablehlo.multiply %92, %86 : tensor<512xf32>
    %94 = stablehlo.subtract %arg1, %93 : tensor<512xf32>
    %cst_29 = stablehlo.constant dense<0.00999999977> : tensor<f32>
    %95 = stablehlo.broadcast_in_dim %cst_29, dims = [] : (tensor<f32>) -> tensor<512x512xf32>
    %96 = stablehlo.multiply %95, %82 : tensor<512x512xf32>
    %97 = stablehlo.subtract %arg2, %96 : tensor<512x512xf32>
    %cst_30 = stablehlo.constant dense<0.00999999977> : tensor<f32>
    %98 = stablehlo.broadcast_in_dim %cst_30, dims = [] : (tensor<f32>) -> tensor<512xf32>
    %99 = stablehlo.multiply %98, %79 : tensor<512xf32>
    %100 = stablehlo.subtract %arg3, %99 : tensor<512xf32>
    %cst_31 = stablehlo.constant dense<0.00999999977> : tensor<f32>
    %101 = stablehlo.broadcast_in_dim %cst_31, dims = [] : (tensor<f32>) -> tensor<10x512xf32>
    %102 = stablehlo.multiply %101, %75 : tensor<10x512xf32>
    %103 = stablehlo.subtract %arg4, %102 : tensor<10x512xf32>
    %cst_32 = stablehlo.constant dense<0.00999999977> : tensor<f32>
    %104 = stablehlo.broadcast_in_dim %cst_32, dims = [] : (tensor<f32>) -> tensor<10xf32>
    %105 = stablehlo.multiply %104, %72 : tensor<10xf32>
    %106 = stablehlo.subtract %arg5, %105 : tensor<10xf32>
    return %91, %94, %97, %100, %103, %106 : tensor<512x784xf32>, tensor<512xf32>, tensor<512x512xf32>, tensor<512xf32>, tensor<10x512xf32>, tensor<10xf32>
  }
}

Epoch 0 in 11.98 sec
Training set accuracy 0.08895000070333481
Test set accuracy 0.08959999680519104
Epoch 1 in 2.93 sec
Training set accuracy 0.08895000070333481
Test set accuracy 0.08959999680519104
Epoch 2 in 2.82 sec
Training set accuracy 0.08895000070333481
Test set accuracy 0.08959999680519104
Epoch 3 in 2.72 sec
Training set accuracy 0.08895000070333481
Test set accuracy 0.08959999680519104
Epoch 4 in 2.73 sec
Training set accuracy 0.08895000070333481
Test set accuracy 0.08959999680519104
Epoch 5 in 2.77 sec
Training set accuracy 0.08895000070333481
Test set accuracy 0.08959999680519104
Epoch 6 in 2.80 sec
Training set accuracy 0.08895000070333481
Test set accuracy 0.08959999680519104
Epoch 7 in 2.73 sec
Training set accuracy 0.08895000070333481
Test set accuracy 0.08959999680519104
Epoch 8 in 2.74 sec
Training set accuracy 0.08895000070333481
Test set accuracy 0.08959999680519104
Epoch 9 in 2.79 sec
Training set accuracy 0.08895000070333481
Test set accuracy 0.08959999680519104


from jax import grad, jit, vmap

def my_predict(params, inputs):
    for W, b in params:
        outputs = jnp.dot(inputs, W) + b
        inputs = jnp.maximum(outputs, 0)
    return outputs

def loss(params, batch):
    inputs, targets = batch
    preds = my_predict(params, inputs)
    return jnp.sum((preds - targets) ** 2)

grad_func = jit(grad(loss))
per_example_grads = jit(vmap(grad(loss), in_axes=(None, 0)))










My Podcast!

If you like topics such as this then please consider subscribing to my podcast. I talk to some of the stalwarts in tech and ask them what their favorite productivity hacks are:

Available on iTunes Podcast

Visit Void Star Podcast’s page on iTunes Podcast Portal. Please Click ‘Subscribe’, leave a comment.

Get it iTunes