Hello Folks,
I am back after a
Below are the raw instructions, copied from Jupyter Notebook where I ran the experiments.
!pip install -qqq jaxtyping hypothesis pytest penzai
import jax.numpy as np
import numpy as onp
from penzai import pz
arange = pz.nx.arange
where = pz.nx.nmap(np.where)
wrap = pz.nx.wrap
pz.ts.register_as_default()
# Optional automatic array visualization extras:
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()
pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer(force_continuous=True, around_zero=True, prefers_column=["j"], prefers_row=["i"]))
import os
# GPU flags
flags = (
"--xla_gpu_enable_triton_softmax_fusion=true "
"--xla_gpu_triton_gemm_any=false "
"--xla_gpu_enable_async_collectives=true "
"--xla_gpu_enable_latency_hiding_scheduler=true "
"--xla_gpu_enable_highest_priority_async_stream=true "
)
os.environ["XLA_FLAGS"] = flags
!nvidia-smi
Mon May 27 13:31:30 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.29.06 Driver Version: 545.29.06 CUDA Version: 12.3 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 4090 Off | 00000000:09:00.0 On | Off |
| 0% 47C P2 58W / 450W | 19106MiB / 24564MiB | 1% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 1195 G /usr/lib/xorg/Xorg 259MiB |
| 0 N/A N/A 2130 G /usr/lib/firefox/firefox 0MiB |
| 0 N/A N/A 477865 C /home/mycpuorg/miniconda3/bin/python3.11 18404MiB |
+---------------------------------------------------------------------------------------+
/home/mycpuorg/miniconda3/lib/python3.11/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
pid, fd = os.forkpty()
import jax
import jax.numpy as jnp
# Define a simple function - test
def test(x):
return jnp.sin(x) * x**2
# Pretty-print the function as a Jaxpr
print(jax.make_jaxpr(test)(1.0))
{ lambda ; a:f32[]. let
b:f32[] = sin a
c:f32[] = integer_pow[y=2] a
d:f32[] = mul b c
in (d,) }
# Lower the function
lowered = jax.jit(test).lower(1.0)
# Print the lowered IR as text
print(lowered.as_text())
# Get the StableHLO representation of the lowered IR
print(lowered.compiler_ir(dialect="stablehlo"))
module @jit_test attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"}) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = stablehlo.sine %arg0 : tensor<f32>
%1 = stablehlo.multiply %arg0, %arg0 : tensor<f32>
%2 = stablehlo.multiply %0, %1 : tensor<f32>
return %2 : tensor<f32>
}
}
module @jit_test attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"}) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = stablehlo.sine %arg0 : tensor<f32>
%1 = stablehlo.multiply %arg0, %arg0 : tensor<f32>
%2 = stablehlo.multiply %0, %1 : tensor<f32>
return %2 : tensor<f32>
}
}
# Compile the function
compiled = lowered.compile()
# Print the compiled IR as text
print(compiled.as_text())
# Get a summary of execution costs
print(compiled.cost_analysis())
HloModule jit_test, is_scheduled=true, entry_computation_layout={(f32[])->f32[]}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="9aa3c44fdb740a754eb15d47d89c333e"}
%fused_multiply (param_0.2: f32[]) -> f32[] {
%param_0.2 = f32[] parameter(0)
%sine.1.1 = f32[] sine(f32[] %param_0.2), metadata={op_name="jit(test)/jit(main)/sin" source_file="/tmp/ipykernel_477865/1530176408.py" source_line=6}
%multiply.2.1 = f32[] multiply(f32[] %param_0.2, f32[] %param_0.2), metadata={op_name="jit(test)/jit(main)/integer_pow[y=2]" source_file="/tmp/ipykernel_477865/1530176408.py" source_line=6}
ROOT %multiply.5.1 = f32[] multiply(f32[] %sine.1.1, f32[] %multiply.2.1), metadata={op_name="jit(test)/jit(main)/mul" source_file="/tmp/ipykernel_477865/1530176408.py" source_line=6}
}
ENTRY %main.5 (Arg_0.1.0: f32[]) -> f32[] {
%Arg_0.1.0 = f32[] parameter(0)
ROOT %loop_multiply_fusion = f32[] fusion(f32[] %Arg_0.1.0), kind=kLoop, calls=%fused_multiply, metadata={op_name="jit(test)/jit(main)/mul" source_file="/tmp/ipykernel_477865/1530176408.py" source_line=6}
}
[{'utilization0{}': 1.0, 'bytes accessed1{}': 8.0, 'bytes accessedout{}': 4.0, 'bytes accessed': 8.0, 'transcendentals': 1.0, 'bytes accessed0{}': 4.0, 'utilization1{}': 2.0, 'flops': 2.0}]
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
# A helper function to randomly initialize weights and biases
# for dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
w_key, b_key = random.split(key)
return scale * random.normal(w_key, (n,m)), scale * random.normal(b_key, (n,))
# Intialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
keys = random.split(key, len(sizes))
return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.key(0))
from jax.scipy.special import logsumexp
def relu(x):
return jnp.maximum(0, x)
def predict(params, image):
# per-example predictions
activations = image
for w, b in params[:-1]:
outputs = jnp.dot(w, activations) + b
activations = relu(outputs)
final_w, final_b = params[-1]
logits = jnp.dot(final_w, activations) + final_b
return logits - logsumexp(logits)
random_flattened_image = random.normal(random.key(1), (28*28,))
print(random_flattened_image.shape)
print(len(params))
preds = predict(params, random_flattened_image)
print(preds.shape)
(784,)
3
(10,)
batched_random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
try:
preds = predict(params, batched_random_flattened_images)
except TypeError:
print("Invalid Shapes!")
Invalid Shapes!
batched_predict = vmap(predict, in_axes=(None, 0))
batched_preds = batched_predict(params, batched_random_flattened_images)
print(batched_preds.shape)
(10, 10)
Loss Function and other utils
def one_hot(x, k, dtype=jnp.float32):
return jnp.array(x[:, None] == jnp.arange(k), dtype)
def accuracy(params, images, targets):
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
return jnp.mean(predicted_class == target_class)
def loss(params, images, targets):
preds = batched_predict(params, images)
return -(jnp.mean(preds * targets))
@jit
def update(params, x, y):
grads = grad(loss)(params, x, y)
return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]
import numpy as np
from jax.tree_util import tree_map
from torchvision.datasets import MNIST
from torch.utils import data
/home/mycpuorg/miniconda3/lib/python3.11/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/home/mycpuorg/miniconda3/lib/python3.11/site-packages/torchvision/image.so: undefined symbol: _ZN3c1017RegisterOperatorsD1Ev'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
warn(
/home/mycpuorg/miniconda3/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
def numpy_collate(batch):
return tree_map(np.asarray, data.default_collate(batch))
class NumpyLoader(data.DataLoader):
def __init__(self, dataset, batch_size=1,
shuffle=False, sampler=None,
batch_sampler=None, num_workers=0,
pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
super(self.__class__, self).__init__(dataset,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=numpy_collate,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn)
class FlattenAndCast(object):
def __call__(self, pic):
return np.ravel(np.array(pic, dtype=jnp.float32))
# Define our dataset, using torch datasets
mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)
a = np.array(np.random.randn(3,2,4))
a.reshape(-1, 24)
b = jnp.array(a)
print(b.devices())
b.sharding
x = jnp.arange(5)
w = jnp.array([2., 3., 4.])
def convolve(x ,w):
output = []
for i in range(1, len(x) - 1):
output.append(jnp.dot(x[i - 1 : i + 2], w))
return jnp.array(output)
lowered_conv = jax.jit(convolve).lower(x, w)
print(lowered_conv.compiler_ir(dialect="stablehlo"))
# print(lowered_conv.compiler_ir(dialect="mhlo"))
# print(lowered_conv.compiler_ir(dialect="LMHLO"))
compiled_conv = lowered_conv.compile()
print(compiled_conv.as_text())
ca = compiled_conv.cost_analysis()
ca
{cuda(id=0)}
module @jit_convolve attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<5xi32> {mhlo.layout_mode = "default"}, %arg1: tensor<3xf32> {mhlo.layout_mode = "default"}) -> (tensor<3xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = stablehlo.slice %arg0 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
%1 = stablehlo.convert %0 : (tensor<3xi32>) -> tensor<3xf32>
%2 = stablehlo.convert %arg1 : tensor<3xf32>
%3 = stablehlo.dot_general %1, %2, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%4 = stablehlo.slice %arg0 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
%5 = stablehlo.convert %4 : (tensor<3xi32>) -> tensor<3xf32>
%6 = stablehlo.convert %arg1 : tensor<3xf32>
%7 = stablehlo.dot_general %5, %6, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%8 = stablehlo.slice %arg0 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
%9 = stablehlo.convert %8 : (tensor<3xi32>) -> tensor<3xf32>
%10 = stablehlo.convert %arg1 : tensor<3xf32>
%11 = stablehlo.dot_general %9, %10, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%12 = stablehlo.broadcast_in_dim %3, dims = [] : (tensor<f32>) -> tensor<1xf32>
%13 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor<f32>) -> tensor<1xf32>
%14 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor<f32>) -> tensor<1xf32>
%15 = stablehlo.concatenate %12, %13, %14, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
return %15 : tensor<3xf32>
}
}
HloModule jit_convolve, is_scheduled=true, entry_computation_layout={(s32[5]{0}, f32[3]{0})->f32[3]{0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="f99fa116847aaf794a773015b66e3c7a"}
%scalar_add_computation (scalar_lhs: f32[], scalar_rhs: f32[]) -> f32[] {
%scalar_rhs = f32[] parameter(1)
%scalar_lhs = f32[] parameter(0)
ROOT %add.2 = f32[] add(f32[] %scalar_lhs, f32[] %scalar_rhs)
}
%fused_concatenate (param_0.15: f32[3], param_1.19: s32[5]) -> f32[3] {
%param_1.19 = s32[5]{0} parameter(1)
%convert.2.5 = f32[5]{0} convert(s32[5]{0} %param_1.19), metadata={op_name="jit(convolve)/jit(main)/dot_general[dimension_numbers=(((0,), (0,)), ((), ())) precision=None preferred_element_type=float32]" source_file="/tmp/ipykernel_477865/1956054057.py" source_line=14}
%slice.8.3 = f32[3]{0} slice(f32[5]{0} %convert.2.5), slice={[0:3]}, metadata={op_name="jit(convolve)/jit(main)/slice[start_indices=(0,) limit_indices=(3,) strides=None]" source_file="/tmp/ipykernel_477865/1956054057.py" source_line=14}
%param_0.15 = f32[3]{0} parameter(0)
%multiply.6.3 = f32[3]{0} multiply(f32[3]{0} %slice.8.3, f32[3]{0} %param_0.15)
%constant_7 = f32[] constant(0)
%reduce.4 = f32[] reduce(f32[3]{0} %multiply.6.3, f32[] %constant_7), dimensions={0}, to_apply=%scalar_add_computation, metadata={op_name="jit(convolve)/jit(main)/dot_general[dimension_numbers=(((0,), (0,)), ((), ())) precision=None preferred_element_type=float32]" source_file="/tmp/ipykernel_477865/1956054057.py" source_line=14}
%bitcast.63.1 = f32[1]{0} bitcast(f32[] %reduce.4)
%slice.10.3 = f32[3]{0} slice(f32[5]{0} %convert.2.5), slice={[1:4]}, metadata={op_name="jit(convolve)/jit(main)/slice[start_indices=(1,) limit_indices=(4,) strides=None]" source_file="/tmp/ipykernel_477865/1956054057.py" source_line=14}
%multiply.7.3 = f32[3]{0} multiply(f32[3]{0} %slice.10.3, f32[3]{0} %param_0.15)
%reduce.1.1 = f32[] reduce(f32[3]{0} %multiply.7.3, f32[] %constant_7), dimensions={0}, to_apply=%scalar_add_computation, metadata={op_name="jit(convolve)/jit(main)/dot_general[dimension_numbers=(((0,), (0,)), ((), ())) precision=None preferred_element_type=float32]" source_file="/tmp/ipykernel_477865/1956054057.py" source_line=14}
%bitcast.75.1 = f32[1]{0} bitcast(f32[] %reduce.1.1)
%slice.11.3 = f32[3]{0} slice(f32[5]{0} %convert.2.5), slice={[2:5]}, metadata={op_name="jit(convolve)/jit(main)/slice[start_indices=(2,) limit_indices=(5,) strides=None]" source_file="/tmp/ipykernel_477865/1956054057.py" source_line=14}
%multiply.8.3 = f32[3]{0} multiply(f32[3]{0} %slice.11.3, f32[3]{0} %param_0.15)
%reduce.2.1 = f32[] reduce(f32[3]{0} %multiply.8.3, f32[] %constant_7), dimensions={0}, to_apply=%scalar_add_computation, metadata={op_name="jit(convolve)/jit(main)/dot_general[dimension_numbers=(((0,), (0,)), ((), ())) precision=None preferred_element_type=float32]" source_file="/tmp/ipykernel_477865/1956054057.py" source_line=14}
%bitcast.87.1 = f32[1]{0} bitcast(f32[] %reduce.2.1)
ROOT %concatenate.1.1 = f32[3]{0} concatenate(f32[1]{0} %bitcast.63.1, f32[1]{0} %bitcast.75.1, f32[1]{0} %bitcast.87.1), dimensions={0}, metadata={op_name="jit(convolve)/jit(main)/concatenate[dimension=0]" source_file="/tmp/ipykernel_477865/1956054057.py" source_line=15}
}
ENTRY %main.16 (Arg_0.1.0: s32[5], Arg_1.2.0: f32[3]) -> f32[3] {
%Arg_1.2.0 = f32[3]{0} parameter(1)
%Arg_0.1.0 = s32[5]{0} parameter(0)
ROOT %input_concatenate_fusion = f32[3]{0} fusion(f32[3]{0} %Arg_1.2.0, s32[5]{0} %Arg_0.1.0), kind=kInput, calls=%fused_concatenate, metadata={op_name="jit(convolve)/jit(main)/concatenate[dimension=0]" source_file="/tmp/ipykernel_477865/1956054057.py" source_line=15}
}
(Loading...)
xs = jnp.stack([x, x, x, x, x, x, x, x, x, x, x, x, x, x, x, x])
ws = jnp.stack([w, w, w, w, w, w, w, w, w, w, w, w, w, w, w, w])
def manual_batched_conv(xs, ws):
output = []
for i in range(xs.shape[0]):
output.append(convolve(xs[i], ws[i]))
return jnp.stack(output)
manual_batched_conv(xs, ws)
(Loading...)
auto_batch_conv = jax.vmap(convolve)
auto_batch_conv(xs, ws)
(Loading...)
lowered_batch_conv = jax.jit(auto_batch_conv, device=jax.devices()[0]).lower(xs, ws)
compiled_auto_batch_conv = lowered_batch_conv.compile()
print(compiled_auto_batch_conv.cost_analysis())
print(lowered_batch_conv.compiler_ir())
[{'bytes accessed': 704.0, 'bytes accessed0{}': 192.0, 'utilization1{}': 1.0, 'utilization2{}': 1.0, 'bytes accessed1{}': 320.0, 'bytes accessedout{}': 192.0, 'utilization0{}': 1.0, 'flops': 320.0, 'bytes accessed2{}': 64.0}]
module @jit_convolve attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<16x5xi32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg1: tensor<16x3xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<16x3xf32> {jax.result_info = "", mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) {
%0 = stablehlo.slice %arg0 [0:16, 0:3] : (tensor<16x5xi32>) -> tensor<16x3xi32>
%1 = stablehlo.convert %0 : (tensor<16x3xi32>) -> tensor<16x3xf32>
%2 = stablehlo.convert %arg1 : tensor<16x3xf32>
%3 = stablehlo.dot_general %1, %2, batching_dims = [0] x [0], contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<16x3xf32>, tensor<16x3xf32>) -> tensor<16xf32>
%4 = stablehlo.slice %arg0 [0:16, 1:4] : (tensor<16x5xi32>) -> tensor<16x3xi32>
%5 = stablehlo.convert %4 : (tensor<16x3xi32>) -> tensor<16x3xf32>
%6 = stablehlo.convert %arg1 : tensor<16x3xf32>
%7 = stablehlo.dot_general %5, %6, batching_dims = [0] x [0], contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<16x3xf32>, tensor<16x3xf32>) -> tensor<16xf32>
%8 = stablehlo.slice %arg0 [0:16, 2:5] : (tensor<16x5xi32>) -> tensor<16x3xi32>
%9 = stablehlo.convert %8 : (tensor<16x3xi32>) -> tensor<16x3xf32>
%10 = stablehlo.convert %arg1 : tensor<16x3xf32>
%11 = stablehlo.dot_general %9, %10, batching_dims = [0] x [0], contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<16x3xf32>, tensor<16x3xf32>) -> tensor<16xf32>
%12 = stablehlo.broadcast_in_dim %3, dims = [0] : (tensor<16xf32>) -> tensor<16x1xf32>
%13 = stablehlo.broadcast_in_dim %7, dims = [0] : (tensor<16xf32>) -> tensor<16x1xf32>
%14 = stablehlo.broadcast_in_dim %11, dims = [0] : (tensor<16xf32>) -> tensor<16x1xf32>
%15 = stablehlo.concatenate %12, %13, %14, dim = 1 : (tensor<16x1xf32>, tensor<16x1xf32>, tensor<16x1xf32>) -> tensor<16x3xf32>
return %15 : tensor<16x3xf32>
}
}
lowered_manual_batch_conv = jax.jit(manual_batched_conv, device=jax.devices()[0]).lower(xs, ws)
print(lowered_manual_batch_conv.compiler_ir())
module @jit_manual_batched_conv attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<16x5xi32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg1: tensor<16x3xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<16x3xf32> {jax.result_info = "", mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) {
%0 = stablehlo.slice %arg0 [0:1, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
%1 = stablehlo.reshape %0 : (tensor<1x5xi32>) -> tensor<5xi32>
%2 = stablehlo.slice %arg1 [0:1, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
%3 = stablehlo.reshape %2 : (tensor<1x3xf32>) -> tensor<3xf32>
%4 = stablehlo.slice %1 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
%5 = stablehlo.convert %4 : (tensor<3xi32>) -> tensor<3xf32>
%6 = stablehlo.convert %3 : tensor<3xf32>
%7 = stablehlo.dot_general %5, %6, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%8 = stablehlo.slice %1 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
%9 = stablehlo.convert %8 : (tensor<3xi32>) -> tensor<3xf32>
%10 = stablehlo.convert %3 : tensor<3xf32>
%11 = stablehlo.dot_general %9, %10, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%12 = stablehlo.slice %1 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
%13 = stablehlo.convert %12 : (tensor<3xi32>) -> tensor<3xf32>
%14 = stablehlo.convert %3 : tensor<3xf32>
%15 = stablehlo.dot_general %13, %14, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%16 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor<f32>) -> tensor<1xf32>
%17 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor<f32>) -> tensor<1xf32>
%18 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor<f32>) -> tensor<1xf32>
%19 = stablehlo.concatenate %16, %17, %18, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
%20 = stablehlo.slice %arg0 [1:2, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
%21 = stablehlo.reshape %20 : (tensor<1x5xi32>) -> tensor<5xi32>
%22 = stablehlo.slice %arg1 [1:2, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
%23 = stablehlo.reshape %22 : (tensor<1x3xf32>) -> tensor<3xf32>
%24 = stablehlo.slice %21 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
%25 = stablehlo.convert %24 : (tensor<3xi32>) -> tensor<3xf32>
%26 = stablehlo.convert %23 : tensor<3xf32>
%27 = stablehlo.dot_general %25, %26, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%28 = stablehlo.slice %21 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
%29 = stablehlo.convert %28 : (tensor<3xi32>) -> tensor<3xf32>
%30 = stablehlo.convert %23 : tensor<3xf32>
%31 = stablehlo.dot_general %29, %30, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%32 = stablehlo.slice %21 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
%33 = stablehlo.convert %32 : (tensor<3xi32>) -> tensor<3xf32>
%34 = stablehlo.convert %23 : tensor<3xf32>
%35 = stablehlo.dot_general %33, %34, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%36 = stablehlo.broadcast_in_dim %27, dims = [] : (tensor<f32>) -> tensor<1xf32>
%37 = stablehlo.broadcast_in_dim %31, dims = [] : (tensor<f32>) -> tensor<1xf32>
%38 = stablehlo.broadcast_in_dim %35, dims = [] : (tensor<f32>) -> tensor<1xf32>
%39 = stablehlo.concatenate %36, %37, %38, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
%40 = stablehlo.slice %arg0 [2:3, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
%41 = stablehlo.reshape %40 : (tensor<1x5xi32>) -> tensor<5xi32>
%42 = stablehlo.slice %arg1 [2:3, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
%43 = stablehlo.reshape %42 : (tensor<1x3xf32>) -> tensor<3xf32>
%44 = stablehlo.slice %41 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
%45 = stablehlo.convert %44 : (tensor<3xi32>) -> tensor<3xf32>
%46 = stablehlo.convert %43 : tensor<3xf32>
%47 = stablehlo.dot_general %45, %46, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%48 = stablehlo.slice %41 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
%49 = stablehlo.convert %48 : (tensor<3xi32>) -> tensor<3xf32>
%50 = stablehlo.convert %43 : tensor<3xf32>
%51 = stablehlo.dot_general %49, %50, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%52 = stablehlo.slice %41 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
%53 = stablehlo.convert %52 : (tensor<3xi32>) -> tensor<3xf32>
%54 = stablehlo.convert %43 : tensor<3xf32>
%55 = stablehlo.dot_general %53, %54, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%56 = stablehlo.broadcast_in_dim %47, dims = [] : (tensor<f32>) -> tensor<1xf32>
%57 = stablehlo.broadcast_in_dim %51, dims = [] : (tensor<f32>) -> tensor<1xf32>
%58 = stablehlo.broadcast_in_dim %55, dims = [] : (tensor<f32>) -> tensor<1xf32>
%59 = stablehlo.concatenate %56, %57, %58, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
%60 = stablehlo.slice %arg0 [3:4, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
%61 = stablehlo.reshape %60 : (tensor<1x5xi32>) -> tensor<5xi32>
%62 = stablehlo.slice %arg1 [3:4, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
%63 = stablehlo.reshape %62 : (tensor<1x3xf32>) -> tensor<3xf32>
%64 = stablehlo.slice %61 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
%65 = stablehlo.convert %64 : (tensor<3xi32>) -> tensor<3xf32>
%66 = stablehlo.convert %63 : tensor<3xf32>
%67 = stablehlo.dot_general %65, %66, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%68 = stablehlo.slice %61 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
%69 = stablehlo.convert %68 : (tensor<3xi32>) -> tensor<3xf32>
%70 = stablehlo.convert %63 : tensor<3xf32>
%71 = stablehlo.dot_general %69, %70, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%72 = stablehlo.slice %61 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
%73 = stablehlo.convert %72 : (tensor<3xi32>) -> tensor<3xf32>
%74 = stablehlo.convert %63 : tensor<3xf32>
%75 = stablehlo.dot_general %73, %74, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%76 = stablehlo.broadcast_in_dim %67, dims = [] : (tensor<f32>) -> tensor<1xf32>
%77 = stablehlo.broadcast_in_dim %71, dims = [] : (tensor<f32>) -> tensor<1xf32>
%78 = stablehlo.broadcast_in_dim %75, dims = [] : (tensor<f32>) -> tensor<1xf32>
%79 = stablehlo.concatenate %76, %77, %78, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
%80 = stablehlo.slice %arg0 [4:5, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
%81 = stablehlo.reshape %80 : (tensor<1x5xi32>) -> tensor<5xi32>
%82 = stablehlo.slice %arg1 [4:5, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
%83 = stablehlo.reshape %82 : (tensor<1x3xf32>) -> tensor<3xf32>
%84 = stablehlo.slice %81 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
%85 = stablehlo.convert %84 : (tensor<3xi32>) -> tensor<3xf32>
%86 = stablehlo.convert %83 : tensor<3xf32>
%87 = stablehlo.dot_general %85, %86, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%88 = stablehlo.slice %81 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
%89 = stablehlo.convert %88 : (tensor<3xi32>) -> tensor<3xf32>
%90 = stablehlo.convert %83 : tensor<3xf32>
%91 = stablehlo.dot_general %89, %90, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%92 = stablehlo.slice %81 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
%93 = stablehlo.convert %92 : (tensor<3xi32>) -> tensor<3xf32>
%94 = stablehlo.convert %83 : tensor<3xf32>
%95 = stablehlo.dot_general %93, %94, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%96 = stablehlo.broadcast_in_dim %87, dims = [] : (tensor<f32>) -> tensor<1xf32>
%97 = stablehlo.broadcast_in_dim %91, dims = [] : (tensor<f32>) -> tensor<1xf32>
%98 = stablehlo.broadcast_in_dim %95, dims = [] : (tensor<f32>) -> tensor<1xf32>
%99 = stablehlo.concatenate %96, %97, %98, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
%100 = stablehlo.slice %arg0 [5:6, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
%101 = stablehlo.reshape %100 : (tensor<1x5xi32>) -> tensor<5xi32>
%102 = stablehlo.slice %arg1 [5:6, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
%103 = stablehlo.reshape %102 : (tensor<1x3xf32>) -> tensor<3xf32>
%104 = stablehlo.slice %101 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
%105 = stablehlo.convert %104 : (tensor<3xi32>) -> tensor<3xf32>
%106 = stablehlo.convert %103 : tensor<3xf32>
%107 = stablehlo.dot_general %105, %106, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%108 = stablehlo.slice %101 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
%109 = stablehlo.convert %108 : (tensor<3xi32>) -> tensor<3xf32>
%110 = stablehlo.convert %103 : tensor<3xf32>
%111 = stablehlo.dot_general %109, %110, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%112 = stablehlo.slice %101 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
%113 = stablehlo.convert %112 : (tensor<3xi32>) -> tensor<3xf32>
%114 = stablehlo.convert %103 : tensor<3xf32>
%115 = stablehlo.dot_general %113, %114, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%116 = stablehlo.broadcast_in_dim %107, dims = [] : (tensor<f32>) -> tensor<1xf32>
%117 = stablehlo.broadcast_in_dim %111, dims = [] : (tensor<f32>) -> tensor<1xf32>
%118 = stablehlo.broadcast_in_dim %115, dims = [] : (tensor<f32>) -> tensor<1xf32>
%119 = stablehlo.concatenate %116, %117, %118, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
%120 = stablehlo.slice %arg0 [6:7, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
%121 = stablehlo.reshape %120 : (tensor<1x5xi32>) -> tensor<5xi32>
%122 = stablehlo.slice %arg1 [6:7, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
%123 = stablehlo.reshape %122 : (tensor<1x3xf32>) -> tensor<3xf32>
%124 = stablehlo.slice %121 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
%125 = stablehlo.convert %124 : (tensor<3xi32>) -> tensor<3xf32>
%126 = stablehlo.convert %123 : tensor<3xf32>
%127 = stablehlo.dot_general %125, %126, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%128 = stablehlo.slice %121 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
%129 = stablehlo.convert %128 : (tensor<3xi32>) -> tensor<3xf32>
%130 = stablehlo.convert %123 : tensor<3xf32>
%131 = stablehlo.dot_general %129, %130, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%132 = stablehlo.slice %121 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
%133 = stablehlo.convert %132 : (tensor<3xi32>) -> tensor<3xf32>
%134 = stablehlo.convert %123 : tensor<3xf32>
%135 = stablehlo.dot_general %133, %134, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%136 = stablehlo.broadcast_in_dim %127, dims = [] : (tensor<f32>) -> tensor<1xf32>
%137 = stablehlo.broadcast_in_dim %131, dims = [] : (tensor<f32>) -> tensor<1xf32>
%138 = stablehlo.broadcast_in_dim %135, dims = [] : (tensor<f32>) -> tensor<1xf32>
%139 = stablehlo.concatenate %136, %137, %138, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
%140 = stablehlo.slice %arg0 [7:8, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
%141 = stablehlo.reshape %140 : (tensor<1x5xi32>) -> tensor<5xi32>
%142 = stablehlo.slice %arg1 [7:8, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
%143 = stablehlo.reshape %142 : (tensor<1x3xf32>) -> tensor<3xf32>
%144 = stablehlo.slice %141 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
%145 = stablehlo.convert %144 : (tensor<3xi32>) -> tensor<3xf32>
%146 = stablehlo.convert %143 : tensor<3xf32>
%147 = stablehlo.dot_general %145, %146, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%148 = stablehlo.slice %141 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
%149 = stablehlo.convert %148 : (tensor<3xi32>) -> tensor<3xf32>
%150 = stablehlo.convert %143 : tensor<3xf32>
%151 = stablehlo.dot_general %149, %150, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%152 = stablehlo.slice %141 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
%153 = stablehlo.convert %152 : (tensor<3xi32>) -> tensor<3xf32>
%154 = stablehlo.convert %143 : tensor<3xf32>
%155 = stablehlo.dot_general %153, %154, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%156 = stablehlo.broadcast_in_dim %147, dims = [] : (tensor<f32>) -> tensor<1xf32>
%157 = stablehlo.broadcast_in_dim %151, dims = [] : (tensor<f32>) -> tensor<1xf32>
%158 = stablehlo.broadcast_in_dim %155, dims = [] : (tensor<f32>) -> tensor<1xf32>
%159 = stablehlo.concatenate %156, %157, %158, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
%160 = stablehlo.slice %arg0 [8:9, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
%161 = stablehlo.reshape %160 : (tensor<1x5xi32>) -> tensor<5xi32>
%162 = stablehlo.slice %arg1 [8:9, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
%163 = stablehlo.reshape %162 : (tensor<1x3xf32>) -> tensor<3xf32>
%164 = stablehlo.slice %161 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
%165 = stablehlo.convert %164 : (tensor<3xi32>) -> tensor<3xf32>
%166 = stablehlo.convert %163 : tensor<3xf32>
%167 = stablehlo.dot_general %165, %166, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%168 = stablehlo.slice %161 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
%169 = stablehlo.convert %168 : (tensor<3xi32>) -> tensor<3xf32>
%170 = stablehlo.convert %163 : tensor<3xf32>
%171 = stablehlo.dot_general %169, %170, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%172 = stablehlo.slice %161 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
%173 = stablehlo.convert %172 : (tensor<3xi32>) -> tensor<3xf32>
%174 = stablehlo.convert %163 : tensor<3xf32>
%175 = stablehlo.dot_general %173, %174, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%176 = stablehlo.broadcast_in_dim %167, dims = [] : (tensor<f32>) -> tensor<1xf32>
%177 = stablehlo.broadcast_in_dim %171, dims = [] : (tensor<f32>) -> tensor<1xf32>
%178 = stablehlo.broadcast_in_dim %175, dims = [] : (tensor<f32>) -> tensor<1xf32>
%179 = stablehlo.concatenate %176, %177, %178, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
%180 = stablehlo.slice %arg0 [9:10, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
%181 = stablehlo.reshape %180 : (tensor<1x5xi32>) -> tensor<5xi32>
%182 = stablehlo.slice %arg1 [9:10, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
%183 = stablehlo.reshape %182 : (tensor<1x3xf32>) -> tensor<3xf32>
%184 = stablehlo.slice %181 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
%185 = stablehlo.convert %184 : (tensor<3xi32>) -> tensor<3xf32>
%186 = stablehlo.convert %183 : tensor<3xf32>
%187 = stablehlo.dot_general %185, %186, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%188 = stablehlo.slice %181 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
%189 = stablehlo.convert %188 : (tensor<3xi32>) -> tensor<3xf32>
%190 = stablehlo.convert %183 : tensor<3xf32>
%191 = stablehlo.dot_general %189, %190, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%192 = stablehlo.slice %181 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
%193 = stablehlo.convert %192 : (tensor<3xi32>) -> tensor<3xf32>
%194 = stablehlo.convert %183 : tensor<3xf32>
%195 = stablehlo.dot_general %193, %194, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%196 = stablehlo.broadcast_in_dim %187, dims = [] : (tensor<f32>) -> tensor<1xf32>
%197 = stablehlo.broadcast_in_dim %191, dims = [] : (tensor<f32>) -> tensor<1xf32>
%198 = stablehlo.broadcast_in_dim %195, dims = [] : (tensor<f32>) -> tensor<1xf32>
%199 = stablehlo.concatenate %196, %197, %198, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
%200 = stablehlo.slice %arg0 [10:11, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
%201 = stablehlo.reshape %200 : (tensor<1x5xi32>) -> tensor<5xi32>
%202 = stablehlo.slice %arg1 [10:11, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
%203 = stablehlo.reshape %202 : (tensor<1x3xf32>) -> tensor<3xf32>
%204 = stablehlo.slice %201 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
%205 = stablehlo.convert %204 : (tensor<3xi32>) -> tensor<3xf32>
%206 = stablehlo.convert %203 : tensor<3xf32>
%207 = stablehlo.dot_general %205, %206, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%208 = stablehlo.slice %201 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
%209 = stablehlo.convert %208 : (tensor<3xi32>) -> tensor<3xf32>
%210 = stablehlo.convert %203 : tensor<3xf32>
%211 = stablehlo.dot_general %209, %210, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%212 = stablehlo.slice %201 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
%213 = stablehlo.convert %212 : (tensor<3xi32>) -> tensor<3xf32>
%214 = stablehlo.convert %203 : tensor<3xf32>
%215 = stablehlo.dot_general %213, %214, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%216 = stablehlo.broadcast_in_dim %207, dims = [] : (tensor<f32>) -> tensor<1xf32>
%217 = stablehlo.broadcast_in_dim %211, dims = [] : (tensor<f32>) -> tensor<1xf32>
%218 = stablehlo.broadcast_in_dim %215, dims = [] : (tensor<f32>) -> tensor<1xf32>
%219 = stablehlo.concatenate %216, %217, %218, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
%220 = stablehlo.slice %arg0 [11:12, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
%221 = stablehlo.reshape %220 : (tensor<1x5xi32>) -> tensor<5xi32>
%222 = stablehlo.slice %arg1 [11:12, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
%223 = stablehlo.reshape %222 : (tensor<1x3xf32>) -> tensor<3xf32>
%224 = stablehlo.slice %221 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
%225 = stablehlo.convert %224 : (tensor<3xi32>) -> tensor<3xf32>
%226 = stablehlo.convert %223 : tensor<3xf32>
%227 = stablehlo.dot_general %225, %226, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%228 = stablehlo.slice %221 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
%229 = stablehlo.convert %228 : (tensor<3xi32>) -> tensor<3xf32>
%230 = stablehlo.convert %223 : tensor<3xf32>
%231 = stablehlo.dot_general %229, %230, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%232 = stablehlo.slice %221 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
%233 = stablehlo.convert %232 : (tensor<3xi32>) -> tensor<3xf32>
%234 = stablehlo.convert %223 : tensor<3xf32>
%235 = stablehlo.dot_general %233, %234, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%236 = stablehlo.broadcast_in_dim %227, dims = [] : (tensor<f32>) -> tensor<1xf32>
%237 = stablehlo.broadcast_in_dim %231, dims = [] : (tensor<f32>) -> tensor<1xf32>
%238 = stablehlo.broadcast_in_dim %235, dims = [] : (tensor<f32>) -> tensor<1xf32>
%239 = stablehlo.concatenate %236, %237, %238, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
%240 = stablehlo.slice %arg0 [12:13, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
%241 = stablehlo.reshape %240 : (tensor<1x5xi32>) -> tensor<5xi32>
%242 = stablehlo.slice %arg1 [12:13, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
%243 = stablehlo.reshape %242 : (tensor<1x3xf32>) -> tensor<3xf32>
%244 = stablehlo.slice %241 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
%245 = stablehlo.convert %244 : (tensor<3xi32>) -> tensor<3xf32>
%246 = stablehlo.convert %243 : tensor<3xf32>
%247 = stablehlo.dot_general %245, %246, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%248 = stablehlo.slice %241 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
%249 = stablehlo.convert %248 : (tensor<3xi32>) -> tensor<3xf32>
%250 = stablehlo.convert %243 : tensor<3xf32>
%251 = stablehlo.dot_general %249, %250, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%252 = stablehlo.slice %241 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
%253 = stablehlo.convert %252 : (tensor<3xi32>) -> tensor<3xf32>
%254 = stablehlo.convert %243 : tensor<3xf32>
%255 = stablehlo.dot_general %253, %254, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%256 = stablehlo.broadcast_in_dim %247, dims = [] : (tensor<f32>) -> tensor<1xf32>
%257 = stablehlo.broadcast_in_dim %251, dims = [] : (tensor<f32>) -> tensor<1xf32>
%258 = stablehlo.broadcast_in_dim %255, dims = [] : (tensor<f32>) -> tensor<1xf32>
%259 = stablehlo.concatenate %256, %257, %258, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
%260 = stablehlo.slice %arg0 [13:14, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
%261 = stablehlo.reshape %260 : (tensor<1x5xi32>) -> tensor<5xi32>
%262 = stablehlo.slice %arg1 [13:14, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
%263 = stablehlo.reshape %262 : (tensor<1x3xf32>) -> tensor<3xf32>
%264 = stablehlo.slice %261 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
%265 = stablehlo.convert %264 : (tensor<3xi32>) -> tensor<3xf32>
%266 = stablehlo.convert %263 : tensor<3xf32>
%267 = stablehlo.dot_general %265, %266, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%268 = stablehlo.slice %261 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
%269 = stablehlo.convert %268 : (tensor<3xi32>) -> tensor<3xf32>
%270 = stablehlo.convert %263 : tensor<3xf32>
%271 = stablehlo.dot_general %269, %270, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%272 = stablehlo.slice %261 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
%273 = stablehlo.convert %272 : (tensor<3xi32>) -> tensor<3xf32>
%274 = stablehlo.convert %263 : tensor<3xf32>
%275 = stablehlo.dot_general %273, %274, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%276 = stablehlo.broadcast_in_dim %267, dims = [] : (tensor<f32>) -> tensor<1xf32>
%277 = stablehlo.broadcast_in_dim %271, dims = [] : (tensor<f32>) -> tensor<1xf32>
%278 = stablehlo.broadcast_in_dim %275, dims = [] : (tensor<f32>) -> tensor<1xf32>
%279 = stablehlo.concatenate %276, %277, %278, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
%280 = stablehlo.slice %arg0 [14:15, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
%281 = stablehlo.reshape %280 : (tensor<1x5xi32>) -> tensor<5xi32>
%282 = stablehlo.slice %arg1 [14:15, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
%283 = stablehlo.reshape %282 : (tensor<1x3xf32>) -> tensor<3xf32>
%284 = stablehlo.slice %281 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
%285 = stablehlo.convert %284 : (tensor<3xi32>) -> tensor<3xf32>
%286 = stablehlo.convert %283 : tensor<3xf32>
%287 = stablehlo.dot_general %285, %286, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%288 = stablehlo.slice %281 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
%289 = stablehlo.convert %288 : (tensor<3xi32>) -> tensor<3xf32>
%290 = stablehlo.convert %283 : tensor<3xf32>
%291 = stablehlo.dot_general %289, %290, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%292 = stablehlo.slice %281 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
%293 = stablehlo.convert %292 : (tensor<3xi32>) -> tensor<3xf32>
%294 = stablehlo.convert %283 : tensor<3xf32>
%295 = stablehlo.dot_general %293, %294, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%296 = stablehlo.broadcast_in_dim %287, dims = [] : (tensor<f32>) -> tensor<1xf32>
%297 = stablehlo.broadcast_in_dim %291, dims = [] : (tensor<f32>) -> tensor<1xf32>
%298 = stablehlo.broadcast_in_dim %295, dims = [] : (tensor<f32>) -> tensor<1xf32>
%299 = stablehlo.concatenate %296, %297, %298, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
%300 = stablehlo.slice %arg0 [15:16, 0:5] : (tensor<16x5xi32>) -> tensor<1x5xi32>
%301 = stablehlo.reshape %300 : (tensor<1x5xi32>) -> tensor<5xi32>
%302 = stablehlo.slice %arg1 [15:16, 0:3] : (tensor<16x3xf32>) -> tensor<1x3xf32>
%303 = stablehlo.reshape %302 : (tensor<1x3xf32>) -> tensor<3xf32>
%304 = stablehlo.slice %301 [0:3] : (tensor<5xi32>) -> tensor<3xi32>
%305 = stablehlo.convert %304 : (tensor<3xi32>) -> tensor<3xf32>
%306 = stablehlo.convert %303 : tensor<3xf32>
%307 = stablehlo.dot_general %305, %306, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%308 = stablehlo.slice %301 [1:4] : (tensor<5xi32>) -> tensor<3xi32>
%309 = stablehlo.convert %308 : (tensor<3xi32>) -> tensor<3xf32>
%310 = stablehlo.convert %303 : tensor<3xf32>
%311 = stablehlo.dot_general %309, %310, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%312 = stablehlo.slice %301 [2:5] : (tensor<5xi32>) -> tensor<3xi32>
%313 = stablehlo.convert %312 : (tensor<3xi32>) -> tensor<3xf32>
%314 = stablehlo.convert %303 : tensor<3xf32>
%315 = stablehlo.dot_general %313, %314, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3xf32>, tensor<3xf32>) -> tensor<f32>
%316 = stablehlo.broadcast_in_dim %307, dims = [] : (tensor<f32>) -> tensor<1xf32>
%317 = stablehlo.broadcast_in_dim %311, dims = [] : (tensor<f32>) -> tensor<1xf32>
%318 = stablehlo.broadcast_in_dim %315, dims = [] : (tensor<f32>) -> tensor<1xf32>
%319 = stablehlo.concatenate %316, %317, %318, dim = 0 : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3xf32>
%320 = stablehlo.broadcast_in_dim %19, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
%321 = stablehlo.broadcast_in_dim %39, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
%322 = stablehlo.broadcast_in_dim %59, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
%323 = stablehlo.broadcast_in_dim %79, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
%324 = stablehlo.broadcast_in_dim %99, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
%325 = stablehlo.broadcast_in_dim %119, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
%326 = stablehlo.broadcast_in_dim %139, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
%327 = stablehlo.broadcast_in_dim %159, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
%328 = stablehlo.broadcast_in_dim %179, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
%329 = stablehlo.broadcast_in_dim %199, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
%330 = stablehlo.broadcast_in_dim %219, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
%331 = stablehlo.broadcast_in_dim %239, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
%332 = stablehlo.broadcast_in_dim %259, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
%333 = stablehlo.broadcast_in_dim %279, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
%334 = stablehlo.broadcast_in_dim %299, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
%335 = stablehlo.broadcast_in_dim %319, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
%336 = stablehlo.concatenate %320, %321, %322, %323, %324, %325, %326, %327, %328, %329, %330, %331, %332, %333, %334, %335, dim = 0 : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<16x3xf32>
return %336 : tensor<16x3xf32>
}
}
train_images = np.array(mnist_dataset.data).reshape(len(mnist_dataset.data), -1)
train_labels = one_hot(np.array(mnist_dataset.targets), n_targets)
mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = jnp.array(mnist_dataset_test.data.numpy().reshape(len(mnist_dataset_test.data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.targets), n_targets)
import time
first_time = True
print(num_epochs)
for epoch in range(num_epochs):
start_time = time.time()
for x, y in training_generator:
y = one_hot(y, n_targets)
# compiled_ = None
if first_time:
first_time = False
saved_ = jit(update, device=jax.device_get(0)).lower(params, x, y)
# # Print the lowered IR as text
print(saved_.as_text())
# Get the StableHLO representation of the lowered IR
print(saved_.compiler_ir(dialect="stablehlo"))
compiled_ = saved_.compile()
# compiled_(params, x, y)
update(params, x, y)
epoch_time = time.time() - start_time
train_acc = accuracy(params, train_images, train_labels)
test_acc = accuracy(params, test_images, test_labels)
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
10
module @jit_update attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<512x784xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<512xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<512x512xf32> {mhlo.layout_mode = "default"}, %arg3: tensor<512xf32> {mhlo.layout_mode = "default"}, %arg4: tensor<10x512xf32> {mhlo.layout_mode = "default"}, %arg5: tensor<10xf32> {mhlo.layout_mode = "default"}, %arg6: tensor<128x784xf32> {mhlo.layout_mode = "default"}, %arg7: tensor<128x10xf32> {mhlo.layout_mode = "default"}) -> (tensor<512x784xf32> {jax.result_info = "[0][0]", mhlo.layout_mode = "default"}, tensor<512xf32> {jax.result_info = "[0][1]", mhlo.layout_mode = "default"}, tensor<512x512xf32> {jax.result_info = "[1][0]", mhlo.layout_mode = "default"}, tensor<512xf32> {jax.result_info = "[1][1]", mhlo.layout_mode = "default"}, tensor<10x512xf32> {jax.result_info = "[2][0]", mhlo.layout_mode = "default"}, tensor<10xf32> {jax.result_info = "[2][1]", mhlo.layout_mode = "default"}) {
%0:6 = call @update(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (tensor<512x784xf32>, tensor<512xf32>, tensor<512x512xf32>, tensor<512xf32>, tensor<10x512xf32>, tensor<10xf32>, tensor<128x784xf32>, tensor<128x10xf32>) -> (tensor<512x784xf32>, tensor<512xf32>, tensor<512x512xf32>, tensor<512xf32>, tensor<10x512xf32>, tensor<10xf32>)
return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<512x784xf32>, tensor<512xf32>, tensor<512x512xf32>, tensor<512xf32>, tensor<10x512xf32>, tensor<10xf32>
}
func.func private @update(%arg0: tensor<512x784xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<512xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<512x512xf32> {mhlo.layout_mode = "default"}, %arg3: tensor<512xf32> {mhlo.layout_mode = "default"}, %arg4: tensor<10x512xf32> {mhlo.layout_mode = "default"}, %arg5: tensor<10xf32> {mhlo.layout_mode = "default"}, %arg6: tensor<128x784xf32> {mhlo.layout_mode = "default"}, %arg7: tensor<128x10xf32> {mhlo.layout_mode = "default"}) -> (tensor<512x784xf32> {mhlo.layout_mode = "default"}, tensor<512xf32> {mhlo.layout_mode = "default"}, tensor<512x512xf32> {mhlo.layout_mode = "default"}, tensor<512xf32> {mhlo.layout_mode = "default"}, tensor<10x512xf32> {mhlo.layout_mode = "default"}, tensor<10xf32> {mhlo.layout_mode = "default"}) {
%0 = stablehlo.dot_general %arg0, %arg6, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<512x784xf32>, tensor<128x784xf32>) -> tensor<512x128xf32>
%1 = stablehlo.transpose %0, dims = [1, 0] : (tensor<512x128xf32>) -> tensor<128x512xf32>
%2 = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<512xf32>) -> tensor<1x512xf32>
%3 = stablehlo.broadcast_in_dim %2, dims = [0, 1] : (tensor<1x512xf32>) -> tensor<128x512xf32>
%4 = stablehlo.add %1, %3 : tensor<128x512xf32>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
%6 = stablehlo.maximum %5, %4 : tensor<128x512xf32>
%7 = stablehlo.compare EQ, %4, %6, FLOAT : (tensor<128x512xf32>, tensor<128x512xf32>) -> tensor<128x512xi1>
%cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%8 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
%cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%9 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
%10 = stablehlo.select %7, %8, %9 : tensor<128x512xi1>, tensor<128x512xf32>
%cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%11 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
%12 = stablehlo.compare EQ, %11, %6, FLOAT : (tensor<128x512xf32>, tensor<128x512xf32>) -> tensor<128x512xi1>
%cst_3 = stablehlo.constant dense<2.000000e+00> : tensor<f32>
%13 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
%cst_4 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%14 = stablehlo.broadcast_in_dim %cst_4, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
%15 = stablehlo.select %12, %13, %14 : tensor<128x512xi1>, tensor<128x512xf32>
%16 = stablehlo.divide %10, %15 : tensor<128x512xf32>
%17 = stablehlo.dot_general %arg2, %6, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<512x512xf32>, tensor<128x512xf32>) -> tensor<512x128xf32>
%18 = stablehlo.transpose %17, dims = [1, 0] : (tensor<512x128xf32>) -> tensor<128x512xf32>
%19 = stablehlo.broadcast_in_dim %arg3, dims = [1] : (tensor<512xf32>) -> tensor<1x512xf32>
%20 = stablehlo.broadcast_in_dim %19, dims = [0, 1] : (tensor<1x512xf32>) -> tensor<128x512xf32>
%21 = stablehlo.add %18, %20 : tensor<128x512xf32>
%cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%22 = stablehlo.broadcast_in_dim %cst_5, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
%23 = stablehlo.maximum %22, %21 : tensor<128x512xf32>
%24 = stablehlo.compare EQ, %21, %23, FLOAT : (tensor<128x512xf32>, tensor<128x512xf32>) -> tensor<128x512xi1>
%cst_6 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%25 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
%cst_7 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%26 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
%27 = stablehlo.select %24, %25, %26 : tensor<128x512xi1>, tensor<128x512xf32>
%cst_8 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%28 = stablehlo.broadcast_in_dim %cst_8, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
%29 = stablehlo.compare EQ, %28, %23, FLOAT : (tensor<128x512xf32>, tensor<128x512xf32>) -> tensor<128x512xi1>
%cst_9 = stablehlo.constant dense<2.000000e+00> : tensor<f32>
%30 = stablehlo.broadcast_in_dim %cst_9, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
%cst_10 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%31 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
%32 = stablehlo.select %29, %30, %31 : tensor<128x512xi1>, tensor<128x512xf32>
%33 = stablehlo.divide %27, %32 : tensor<128x512xf32>
%34 = stablehlo.dot_general %arg4, %23, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<10x512xf32>, tensor<128x512xf32>) -> tensor<10x128xf32>
%35 = stablehlo.transpose %34, dims = [1, 0] : (tensor<10x128xf32>) -> tensor<128x10xf32>
%36 = stablehlo.broadcast_in_dim %arg5, dims = [1] : (tensor<10xf32>) -> tensor<1x10xf32>
%37 = stablehlo.broadcast_in_dim %36, dims = [0, 1] : (tensor<1x10xf32>) -> tensor<128x10xf32>
%38 = stablehlo.add %35, %37 : tensor<128x10xf32>
%cst_11 = stablehlo.constant dense<0xFF800000> : tensor<f32>
%39 = stablehlo.reduce(%38 init: %cst_11) applies stablehlo.maximum across dimensions = [1] : (tensor<128x10xf32>, tensor<f32>) -> tensor<128xf32>
%cst_12 = stablehlo.constant dense<0xFF800000> : tensor<f32>
%40 = stablehlo.broadcast_in_dim %cst_12, dims = [] : (tensor<f32>) -> tensor<128xf32>
%41 = stablehlo.maximum %40, %39 : tensor<128xf32>
%42 = stablehlo.is_finite %41 : (tensor<128xf32>) -> tensor<128xi1>
%cst_13 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%43 = stablehlo.broadcast_in_dim %cst_13, dims = [] : (tensor<f32>) -> tensor<128xf32>
%44 = stablehlo.select %42, %41, %43 : tensor<128xi1>, tensor<128xf32>
%45 = stablehlo.broadcast_in_dim %44, dims = [0] : (tensor<128xf32>) -> tensor<128x1xf32>
%46 = stablehlo.broadcast_in_dim %45, dims = [0, 1] : (tensor<128x1xf32>) -> tensor<128x10xf32>
%47 = stablehlo.subtract %38, %46 : tensor<128x10xf32>
%48 = stablehlo.exponential %47 : tensor<128x10xf32>
%cst_14 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%49 = stablehlo.reduce(%48 init: %cst_14) applies stablehlo.add across dimensions = [1] : (tensor<128x10xf32>, tensor<f32>) -> tensor<128xf32>
%50 = stablehlo.abs %49 : tensor<128xf32>
%cst_15 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%51 = stablehlo.broadcast_in_dim %cst_15, dims = [] : (tensor<f32>) -> tensor<128xf32>
%52 = stablehlo.compare GE, %49, %51, FLOAT : (tensor<128xf32>, tensor<128xf32>) -> tensor<128xi1>
%cst_16 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%53 = stablehlo.negate %cst_16 : tensor<f32>
%cst_17 = stablehlo.constant dense<1.280000e+03> : tensor<f32>
%54 = stablehlo.divide %53, %cst_17 : tensor<f32>
%55 = stablehlo.broadcast_in_dim %54, dims = [] : (tensor<f32>) -> tensor<128x10xf32>
%56 = stablehlo.multiply %55, %arg7 : tensor<128x10xf32>
%57 = stablehlo.negate %56 : tensor<128x10xf32>
%cst_18 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%58 = stablehlo.reduce(%57 init: %cst_18) applies stablehlo.add across dimensions = [1] : (tensor<128x10xf32>, tensor<f32>) -> tensor<128xf32>
%59 = stablehlo.reshape %58 : (tensor<128xf32>) -> tensor<128x1xf32>
%cst_19 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%60 = stablehlo.reduce(%59 init: %cst_19) applies stablehlo.add across dimensions = [1] : (tensor<128x1xf32>, tensor<f32>) -> tensor<128xf32>
%61 = stablehlo.divide %60, %50 : tensor<128xf32>
%cst_20 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%62 = stablehlo.broadcast_in_dim %cst_20, dims = [] : (tensor<f32>) -> tensor<128xf32>
%63 = stablehlo.select %52, %62, %61 : tensor<128xi1>, tensor<128xf32>
%64 = stablehlo.select %52, %61, %62 : tensor<128xi1>, tensor<128xf32>
%65 = stablehlo.negate %63 : tensor<128xf32>
%66 = stablehlo.add %64, %65 : tensor<128xf32>
%67 = stablehlo.broadcast_in_dim %66, dims = [0] : (tensor<128xf32>) -> tensor<128x10xf32>
%68 = stablehlo.multiply %67, %48 : tensor<128x10xf32>
%69 = stablehlo.add %56, %68 : tensor<128x10xf32>
%cst_21 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%70 = stablehlo.reduce(%69 init: %cst_21) applies stablehlo.add across dimensions = [0] : (tensor<128x10xf32>, tensor<f32>) -> tensor<10xf32>
%71 = stablehlo.reshape %70 : (tensor<10xf32>) -> tensor<1x10xf32>
%cst_22 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%72 = stablehlo.reduce(%71 init: %cst_22) applies stablehlo.add across dimensions = [0] : (tensor<1x10xf32>, tensor<f32>) -> tensor<10xf32>
%73 = stablehlo.transpose %69, dims = [1, 0] : (tensor<128x10xf32>) -> tensor<10x128xf32>
%74 = stablehlo.dot_general %73, %arg4, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x128xf32>, tensor<10x512xf32>) -> tensor<128x512xf32>
%75 = stablehlo.dot_general %73, %23, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x128xf32>, tensor<128x512xf32>) -> tensor<10x512xf32>
%76 = stablehlo.multiply %74, %33 : tensor<128x512xf32>
%cst_23 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%77 = stablehlo.reduce(%76 init: %cst_23) applies stablehlo.add across dimensions = [0] : (tensor<128x512xf32>, tensor<f32>) -> tensor<512xf32>
%78 = stablehlo.reshape %77 : (tensor<512xf32>) -> tensor<1x512xf32>
%cst_24 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%79 = stablehlo.reduce(%78 init: %cst_24) applies stablehlo.add across dimensions = [0] : (tensor<1x512xf32>, tensor<f32>) -> tensor<512xf32>
%80 = stablehlo.transpose %76, dims = [1, 0] : (tensor<128x512xf32>) -> tensor<512x128xf32>
%81 = stablehlo.dot_general %80, %arg2, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<512x128xf32>, tensor<512x512xf32>) -> tensor<128x512xf32>
%82 = stablehlo.dot_general %80, %6, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<512x128xf32>, tensor<128x512xf32>) -> tensor<512x512xf32>
%83 = stablehlo.multiply %81, %16 : tensor<128x512xf32>
%cst_25 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%84 = stablehlo.reduce(%83 init: %cst_25) applies stablehlo.add across dimensions = [0] : (tensor<128x512xf32>, tensor<f32>) -> tensor<512xf32>
%85 = stablehlo.reshape %84 : (tensor<512xf32>) -> tensor<1x512xf32>
%cst_26 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%86 = stablehlo.reduce(%85 init: %cst_26) applies stablehlo.add across dimensions = [0] : (tensor<1x512xf32>, tensor<f32>) -> tensor<512xf32>
%87 = stablehlo.transpose %83, dims = [1, 0] : (tensor<128x512xf32>) -> tensor<512x128xf32>
%88 = stablehlo.dot_general %87, %arg6, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<512x128xf32>, tensor<128x784xf32>) -> tensor<512x784xf32>
%cst_27 = stablehlo.constant dense<0.00999999977> : tensor<f32>
%89 = stablehlo.broadcast_in_dim %cst_27, dims = [] : (tensor<f32>) -> tensor<512x784xf32>
%90 = stablehlo.multiply %89, %88 : tensor<512x784xf32>
%91 = stablehlo.subtract %arg0, %90 : tensor<512x784xf32>
%cst_28 = stablehlo.constant dense<0.00999999977> : tensor<f32>
%92 = stablehlo.broadcast_in_dim %cst_28, dims = [] : (tensor<f32>) -> tensor<512xf32>
%93 = stablehlo.multiply %92, %86 : tensor<512xf32>
%94 = stablehlo.subtract %arg1, %93 : tensor<512xf32>
%cst_29 = stablehlo.constant dense<0.00999999977> : tensor<f32>
%95 = stablehlo.broadcast_in_dim %cst_29, dims = [] : (tensor<f32>) -> tensor<512x512xf32>
%96 = stablehlo.multiply %95, %82 : tensor<512x512xf32>
%97 = stablehlo.subtract %arg2, %96 : tensor<512x512xf32>
%cst_30 = stablehlo.constant dense<0.00999999977> : tensor<f32>
%98 = stablehlo.broadcast_in_dim %cst_30, dims = [] : (tensor<f32>) -> tensor<512xf32>
%99 = stablehlo.multiply %98, %79 : tensor<512xf32>
%100 = stablehlo.subtract %arg3, %99 : tensor<512xf32>
%cst_31 = stablehlo.constant dense<0.00999999977> : tensor<f32>
%101 = stablehlo.broadcast_in_dim %cst_31, dims = [] : (tensor<f32>) -> tensor<10x512xf32>
%102 = stablehlo.multiply %101, %75 : tensor<10x512xf32>
%103 = stablehlo.subtract %arg4, %102 : tensor<10x512xf32>
%cst_32 = stablehlo.constant dense<0.00999999977> : tensor<f32>
%104 = stablehlo.broadcast_in_dim %cst_32, dims = [] : (tensor<f32>) -> tensor<10xf32>
%105 = stablehlo.multiply %104, %72 : tensor<10xf32>
%106 = stablehlo.subtract %arg5, %105 : tensor<10xf32>
return %91, %94, %97, %100, %103, %106 : tensor<512x784xf32>, tensor<512xf32>, tensor<512x512xf32>, tensor<512xf32>, tensor<10x512xf32>, tensor<10xf32>
}
}
module @jit_update attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<512x784xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<512xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<512x512xf32> {mhlo.layout_mode = "default"}, %arg3: tensor<512xf32> {mhlo.layout_mode = "default"}, %arg4: tensor<10x512xf32> {mhlo.layout_mode = "default"}, %arg5: tensor<10xf32> {mhlo.layout_mode = "default"}, %arg6: tensor<128x784xf32> {mhlo.layout_mode = "default"}, %arg7: tensor<128x10xf32> {mhlo.layout_mode = "default"}) -> (tensor<512x784xf32> {jax.result_info = "[0][0]", mhlo.layout_mode = "default"}, tensor<512xf32> {jax.result_info = "[0][1]", mhlo.layout_mode = "default"}, tensor<512x512xf32> {jax.result_info = "[1][0]", mhlo.layout_mode = "default"}, tensor<512xf32> {jax.result_info = "[1][1]", mhlo.layout_mode = "default"}, tensor<10x512xf32> {jax.result_info = "[2][0]", mhlo.layout_mode = "default"}, tensor<10xf32> {jax.result_info = "[2][1]", mhlo.layout_mode = "default"}) {
%0:6 = call @update(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (tensor<512x784xf32>, tensor<512xf32>, tensor<512x512xf32>, tensor<512xf32>, tensor<10x512xf32>, tensor<10xf32>, tensor<128x784xf32>, tensor<128x10xf32>) -> (tensor<512x784xf32>, tensor<512xf32>, tensor<512x512xf32>, tensor<512xf32>, tensor<10x512xf32>, tensor<10xf32>)
return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<512x784xf32>, tensor<512xf32>, tensor<512x512xf32>, tensor<512xf32>, tensor<10x512xf32>, tensor<10xf32>
}
func.func private @update(%arg0: tensor<512x784xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<512xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<512x512xf32> {mhlo.layout_mode = "default"}, %arg3: tensor<512xf32> {mhlo.layout_mode = "default"}, %arg4: tensor<10x512xf32> {mhlo.layout_mode = "default"}, %arg5: tensor<10xf32> {mhlo.layout_mode = "default"}, %arg6: tensor<128x784xf32> {mhlo.layout_mode = "default"}, %arg7: tensor<128x10xf32> {mhlo.layout_mode = "default"}) -> (tensor<512x784xf32> {mhlo.layout_mode = "default"}, tensor<512xf32> {mhlo.layout_mode = "default"}, tensor<512x512xf32> {mhlo.layout_mode = "default"}, tensor<512xf32> {mhlo.layout_mode = "default"}, tensor<10x512xf32> {mhlo.layout_mode = "default"}, tensor<10xf32> {mhlo.layout_mode = "default"}) {
%0 = stablehlo.dot_general %arg0, %arg6, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<512x784xf32>, tensor<128x784xf32>) -> tensor<512x128xf32>
%1 = stablehlo.transpose %0, dims = [1, 0] : (tensor<512x128xf32>) -> tensor<128x512xf32>
%2 = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<512xf32>) -> tensor<1x512xf32>
%3 = stablehlo.broadcast_in_dim %2, dims = [0, 1] : (tensor<1x512xf32>) -> tensor<128x512xf32>
%4 = stablehlo.add %1, %3 : tensor<128x512xf32>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
%6 = stablehlo.maximum %5, %4 : tensor<128x512xf32>
%7 = stablehlo.compare EQ, %4, %6, FLOAT : (tensor<128x512xf32>, tensor<128x512xf32>) -> tensor<128x512xi1>
%cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%8 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
%cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%9 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
%10 = stablehlo.select %7, %8, %9 : tensor<128x512xi1>, tensor<128x512xf32>
%cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%11 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
%12 = stablehlo.compare EQ, %11, %6, FLOAT : (tensor<128x512xf32>, tensor<128x512xf32>) -> tensor<128x512xi1>
%cst_3 = stablehlo.constant dense<2.000000e+00> : tensor<f32>
%13 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
%cst_4 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%14 = stablehlo.broadcast_in_dim %cst_4, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
%15 = stablehlo.select %12, %13, %14 : tensor<128x512xi1>, tensor<128x512xf32>
%16 = stablehlo.divide %10, %15 : tensor<128x512xf32>
%17 = stablehlo.dot_general %arg2, %6, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<512x512xf32>, tensor<128x512xf32>) -> tensor<512x128xf32>
%18 = stablehlo.transpose %17, dims = [1, 0] : (tensor<512x128xf32>) -> tensor<128x512xf32>
%19 = stablehlo.broadcast_in_dim %arg3, dims = [1] : (tensor<512xf32>) -> tensor<1x512xf32>
%20 = stablehlo.broadcast_in_dim %19, dims = [0, 1] : (tensor<1x512xf32>) -> tensor<128x512xf32>
%21 = stablehlo.add %18, %20 : tensor<128x512xf32>
%cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%22 = stablehlo.broadcast_in_dim %cst_5, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
%23 = stablehlo.maximum %22, %21 : tensor<128x512xf32>
%24 = stablehlo.compare EQ, %21, %23, FLOAT : (tensor<128x512xf32>, tensor<128x512xf32>) -> tensor<128x512xi1>
%cst_6 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%25 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
%cst_7 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%26 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
%27 = stablehlo.select %24, %25, %26 : tensor<128x512xi1>, tensor<128x512xf32>
%cst_8 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%28 = stablehlo.broadcast_in_dim %cst_8, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
%29 = stablehlo.compare EQ, %28, %23, FLOAT : (tensor<128x512xf32>, tensor<128x512xf32>) -> tensor<128x512xi1>
%cst_9 = stablehlo.constant dense<2.000000e+00> : tensor<f32>
%30 = stablehlo.broadcast_in_dim %cst_9, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
%cst_10 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%31 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor<f32>) -> tensor<128x512xf32>
%32 = stablehlo.select %29, %30, %31 : tensor<128x512xi1>, tensor<128x512xf32>
%33 = stablehlo.divide %27, %32 : tensor<128x512xf32>
%34 = stablehlo.dot_general %arg4, %23, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<10x512xf32>, tensor<128x512xf32>) -> tensor<10x128xf32>
%35 = stablehlo.transpose %34, dims = [1, 0] : (tensor<10x128xf32>) -> tensor<128x10xf32>
%36 = stablehlo.broadcast_in_dim %arg5, dims = [1] : (tensor<10xf32>) -> tensor<1x10xf32>
%37 = stablehlo.broadcast_in_dim %36, dims = [0, 1] : (tensor<1x10xf32>) -> tensor<128x10xf32>
%38 = stablehlo.add %35, %37 : tensor<128x10xf32>
%cst_11 = stablehlo.constant dense<0xFF800000> : tensor<f32>
%39 = stablehlo.reduce(%38 init: %cst_11) applies stablehlo.maximum across dimensions = [1] : (tensor<128x10xf32>, tensor<f32>) -> tensor<128xf32>
%cst_12 = stablehlo.constant dense<0xFF800000> : tensor<f32>
%40 = stablehlo.broadcast_in_dim %cst_12, dims = [] : (tensor<f32>) -> tensor<128xf32>
%41 = stablehlo.maximum %40, %39 : tensor<128xf32>
%42 = stablehlo.is_finite %41 : (tensor<128xf32>) -> tensor<128xi1>
%cst_13 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%43 = stablehlo.broadcast_in_dim %cst_13, dims = [] : (tensor<f32>) -> tensor<128xf32>
%44 = stablehlo.select %42, %41, %43 : tensor<128xi1>, tensor<128xf32>
%45 = stablehlo.broadcast_in_dim %44, dims = [0] : (tensor<128xf32>) -> tensor<128x1xf32>
%46 = stablehlo.broadcast_in_dim %45, dims = [0, 1] : (tensor<128x1xf32>) -> tensor<128x10xf32>
%47 = stablehlo.subtract %38, %46 : tensor<128x10xf32>
%48 = stablehlo.exponential %47 : tensor<128x10xf32>
%cst_14 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%49 = stablehlo.reduce(%48 init: %cst_14) applies stablehlo.add across dimensions = [1] : (tensor<128x10xf32>, tensor<f32>) -> tensor<128xf32>
%50 = stablehlo.abs %49 : tensor<128xf32>
%cst_15 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%51 = stablehlo.broadcast_in_dim %cst_15, dims = [] : (tensor<f32>) -> tensor<128xf32>
%52 = stablehlo.compare GE, %49, %51, FLOAT : (tensor<128xf32>, tensor<128xf32>) -> tensor<128xi1>
%cst_16 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%53 = stablehlo.negate %cst_16 : tensor<f32>
%cst_17 = stablehlo.constant dense<1.280000e+03> : tensor<f32>
%54 = stablehlo.divide %53, %cst_17 : tensor<f32>
%55 = stablehlo.broadcast_in_dim %54, dims = [] : (tensor<f32>) -> tensor<128x10xf32>
%56 = stablehlo.multiply %55, %arg7 : tensor<128x10xf32>
%57 = stablehlo.negate %56 : tensor<128x10xf32>
%cst_18 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%58 = stablehlo.reduce(%57 init: %cst_18) applies stablehlo.add across dimensions = [1] : (tensor<128x10xf32>, tensor<f32>) -> tensor<128xf32>
%59 = stablehlo.reshape %58 : (tensor<128xf32>) -> tensor<128x1xf32>
%cst_19 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%60 = stablehlo.reduce(%59 init: %cst_19) applies stablehlo.add across dimensions = [1] : (tensor<128x1xf32>, tensor<f32>) -> tensor<128xf32>
%61 = stablehlo.divide %60, %50 : tensor<128xf32>
%cst_20 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%62 = stablehlo.broadcast_in_dim %cst_20, dims = [] : (tensor<f32>) -> tensor<128xf32>
%63 = stablehlo.select %52, %62, %61 : tensor<128xi1>, tensor<128xf32>
%64 = stablehlo.select %52, %61, %62 : tensor<128xi1>, tensor<128xf32>
%65 = stablehlo.negate %63 : tensor<128xf32>
%66 = stablehlo.add %64, %65 : tensor<128xf32>
%67 = stablehlo.broadcast_in_dim %66, dims = [0] : (tensor<128xf32>) -> tensor<128x10xf32>
%68 = stablehlo.multiply %67, %48 : tensor<128x10xf32>
%69 = stablehlo.add %56, %68 : tensor<128x10xf32>
%cst_21 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%70 = stablehlo.reduce(%69 init: %cst_21) applies stablehlo.add across dimensions = [0] : (tensor<128x10xf32>, tensor<f32>) -> tensor<10xf32>
%71 = stablehlo.reshape %70 : (tensor<10xf32>) -> tensor<1x10xf32>
%cst_22 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%72 = stablehlo.reduce(%71 init: %cst_22) applies stablehlo.add across dimensions = [0] : (tensor<1x10xf32>, tensor<f32>) -> tensor<10xf32>
%73 = stablehlo.transpose %69, dims = [1, 0] : (tensor<128x10xf32>) -> tensor<10x128xf32>
%74 = stablehlo.dot_general %73, %arg4, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x128xf32>, tensor<10x512xf32>) -> tensor<128x512xf32>
%75 = stablehlo.dot_general %73, %23, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x128xf32>, tensor<128x512xf32>) -> tensor<10x512xf32>
%76 = stablehlo.multiply %74, %33 : tensor<128x512xf32>
%cst_23 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%77 = stablehlo.reduce(%76 init: %cst_23) applies stablehlo.add across dimensions = [0] : (tensor<128x512xf32>, tensor<f32>) -> tensor<512xf32>
%78 = stablehlo.reshape %77 : (tensor<512xf32>) -> tensor<1x512xf32>
%cst_24 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%79 = stablehlo.reduce(%78 init: %cst_24) applies stablehlo.add across dimensions = [0] : (tensor<1x512xf32>, tensor<f32>) -> tensor<512xf32>
%80 = stablehlo.transpose %76, dims = [1, 0] : (tensor<128x512xf32>) -> tensor<512x128xf32>
%81 = stablehlo.dot_general %80, %arg2, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<512x128xf32>, tensor<512x512xf32>) -> tensor<128x512xf32>
%82 = stablehlo.dot_general %80, %6, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<512x128xf32>, tensor<128x512xf32>) -> tensor<512x512xf32>
%83 = stablehlo.multiply %81, %16 : tensor<128x512xf32>
%cst_25 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%84 = stablehlo.reduce(%83 init: %cst_25) applies stablehlo.add across dimensions = [0] : (tensor<128x512xf32>, tensor<f32>) -> tensor<512xf32>
%85 = stablehlo.reshape %84 : (tensor<512xf32>) -> tensor<1x512xf32>
%cst_26 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%86 = stablehlo.reduce(%85 init: %cst_26) applies stablehlo.add across dimensions = [0] : (tensor<1x512xf32>, tensor<f32>) -> tensor<512xf32>
%87 = stablehlo.transpose %83, dims = [1, 0] : (tensor<128x512xf32>) -> tensor<512x128xf32>
%88 = stablehlo.dot_general %87, %arg6, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<512x128xf32>, tensor<128x784xf32>) -> tensor<512x784xf32>
%cst_27 = stablehlo.constant dense<0.00999999977> : tensor<f32>
%89 = stablehlo.broadcast_in_dim %cst_27, dims = [] : (tensor<f32>) -> tensor<512x784xf32>
%90 = stablehlo.multiply %89, %88 : tensor<512x784xf32>
%91 = stablehlo.subtract %arg0, %90 : tensor<512x784xf32>
%cst_28 = stablehlo.constant dense<0.00999999977> : tensor<f32>
%92 = stablehlo.broadcast_in_dim %cst_28, dims = [] : (tensor<f32>) -> tensor<512xf32>
%93 = stablehlo.multiply %92, %86 : tensor<512xf32>
%94 = stablehlo.subtract %arg1, %93 : tensor<512xf32>
%cst_29 = stablehlo.constant dense<0.00999999977> : tensor<f32>
%95 = stablehlo.broadcast_in_dim %cst_29, dims = [] : (tensor<f32>) -> tensor<512x512xf32>
%96 = stablehlo.multiply %95, %82 : tensor<512x512xf32>
%97 = stablehlo.subtract %arg2, %96 : tensor<512x512xf32>
%cst_30 = stablehlo.constant dense<0.00999999977> : tensor<f32>
%98 = stablehlo.broadcast_in_dim %cst_30, dims = [] : (tensor<f32>) -> tensor<512xf32>
%99 = stablehlo.multiply %98, %79 : tensor<512xf32>
%100 = stablehlo.subtract %arg3, %99 : tensor<512xf32>
%cst_31 = stablehlo.constant dense<0.00999999977> : tensor<f32>
%101 = stablehlo.broadcast_in_dim %cst_31, dims = [] : (tensor<f32>) -> tensor<10x512xf32>
%102 = stablehlo.multiply %101, %75 : tensor<10x512xf32>
%103 = stablehlo.subtract %arg4, %102 : tensor<10x512xf32>
%cst_32 = stablehlo.constant dense<0.00999999977> : tensor<f32>
%104 = stablehlo.broadcast_in_dim %cst_32, dims = [] : (tensor<f32>) -> tensor<10xf32>
%105 = stablehlo.multiply %104, %72 : tensor<10xf32>
%106 = stablehlo.subtract %arg5, %105 : tensor<10xf32>
return %91, %94, %97, %100, %103, %106 : tensor<512x784xf32>, tensor<512xf32>, tensor<512x512xf32>, tensor<512xf32>, tensor<10x512xf32>, tensor<10xf32>
}
}
Epoch 0 in 11.98 sec
Training set accuracy 0.08895000070333481
Test set accuracy 0.08959999680519104
Epoch 1 in 2.93 sec
Training set accuracy 0.08895000070333481
Test set accuracy 0.08959999680519104
Epoch 2 in 2.82 sec
Training set accuracy 0.08895000070333481
Test set accuracy 0.08959999680519104
Epoch 3 in 2.72 sec
Training set accuracy 0.08895000070333481
Test set accuracy 0.08959999680519104
Epoch 4 in 2.73 sec
Training set accuracy 0.08895000070333481
Test set accuracy 0.08959999680519104
Epoch 5 in 2.77 sec
Training set accuracy 0.08895000070333481
Test set accuracy 0.08959999680519104
Epoch 6 in 2.80 sec
Training set accuracy 0.08895000070333481
Test set accuracy 0.08959999680519104
Epoch 7 in 2.73 sec
Training set accuracy 0.08895000070333481
Test set accuracy 0.08959999680519104
Epoch 8 in 2.74 sec
Training set accuracy 0.08895000070333481
Test set accuracy 0.08959999680519104
Epoch 9 in 2.79 sec
Training set accuracy 0.08895000070333481
Test set accuracy 0.08959999680519104
from jax import grad, jit, vmap
def my_predict(params, inputs):
for W, b in params:
outputs = jnp.dot(inputs, W) + b
inputs = jnp.maximum(outputs, 0)
return outputs
def loss(params, batch):
inputs, targets = batch
preds = my_predict(params, inputs)
return jnp.sum((preds - targets) ** 2)
grad_func = jit(grad(loss))
per_example_grads = jit(vmap(grad(loss), in_axes=(None, 0)))
My Podcast!
If you like topics such as this then please consider subscribing to my podcast. I talk to some of the stalwarts in tech and ask them what their favorite productivity hacks are:
Available on iTunes Podcast
Visit Void Star Podcast’s page on iTunes Podcast Portal. Please Click ‘Subscribe’, leave a comment.