How to install trax, jax, jaxlib on M1 Mac on macOS 12?
Asked Answered
B

3

6

trax

New to trax, I'm trying to run it locally (macOS 12.1, Apple Silicon ARM M1 processor, 8GB RAM, Anaconda), but I'm running into some issues.

In an environment with python 3.8.5, I installed trax running pip3 install trax==1.3.9 inside an (Anaconda) conda environment. Later, I ran into issues when trying to import trax layers in my code with from trax import layers as tl

RuntimeError: This version of jaxlib was built using AVX instructions, which your CPU and/or operating system do not support. You may be able work around this issue by building jaxlib from source.

I rushed to start a new conda environment with python 3.10. Then ran into issue just trying to install trax:

macos ERROR: Could not find a version that satisfies the requirement tensorflow-text (from trax) (from versions: none)
ERROR: No matching distribution found for tensorflow-text

I then created a new environment with python 3.9. Installation went fine, but then ran into the same error importing layers:

RuntimeError: This version of jaxlib was built using AVX instructions, which your CPU and/or operating system do not support. You may be able work around this issue by building jaxlib from source.

installing jaxlib, jax before trax

I then tried building jaxlib from source following these instructions and got this error:

Building XLA and installing it in the jaxlib source tree...
./bazel-4.2.1-darwin-x86_64 run --verbose_failures=true --config=avx_posix --config=mkl_open_source_only :build_wheel -- --output_path=/my path/jax/dist --cpu=x86_64
ERROR: bazel does not currently work properly from paths containing spaces (/my path/jax).
b''
Traceback (most recent call last):
  File "/my path/jax/build/build.py", line 524, in <module>
    main()
  File "/my path/jax/build/build.py", line 519, in main
    shell(command)
  File "/my path/jax/build/build.py", line 53, in shell
    output = subprocess.check_output(cmd)
  File "/myuserpath/opt/anaconda3/envs/mytraxenv/lib/python3.9/subprocess.py", line 424, in check_output
    return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
  File "/myuserpath/opt/anaconda3/envs/mytraxenv/lib/python3.9/subprocess.py", line 528, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['./bazel-4.2.1-darwin-x86_64', 'run', '--verbose_failures=true', '--config=avx_posix', '--config=mkl_open_source_only', ':build_wheel', '--', '--output_path=/my path/jax/dist', '--cpu=x86_64']' returned non-zero exit status 36.

Emphasis on the part that I initially missed that says: bazel does not currently work properly from paths containing spaces (/my path/jax).

I moved my /my path/ directory to a path without spaces /mypath/. Deleted and redownloaded jax directory. Still, the build (for CPU) with python build/build.py failed:

ERROR: /private/var/tmp/_bazel_a/2caf512c3c5e3f3f654bc58b48b8333a/external/llvm-project/llvm/BUILD.bazel:610:11: Generating code from table: include/llvm/IR/Intrinsics.td @llvm-project//llvm:intrinsic_XCore_gen__gen_intrinsic_enums__intrinsic_prefix_xcore_genrule failed: (Illegal instruction): bash failed: error executing command 
  (cd /private/var/tmp/_bazel_a/2caf512c3c5e3f3f654bc58b48b8333a/execroot/__main__ && \
  exec env - \
    PATH=/myuserpath/opt/anaconda3/envs/mytraxenv/bin:/myuserpath/opt/anaconda3/condabin:/opt/homebrew/bin:/opt/homebrew/sbin:/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin:/Library/TeX/texbin:/Library/Apple/usr/bin \
  /bin/bash -c 'source external/bazel_tools/tools/genrule/genrule-setup.sh; bazel-out/darwin-opt/bin/external/llvm-project/llvm/llvm-tblgen -I external/llvm-project/llvm/include -I external/llvm-project/clang/include -I $(dirname external/llvm-project/llvm/include/llvm/IR/Intrinsics.td) -gen-intrinsic-enums -intrinsic-prefix=xcore external/llvm-project/llvm/include/llvm/IR/Intrinsics.td  -o bazel-out/darwin-opt/bin/external/llvm-project/llvm/include/llvm/IR/IntrinsicsXCore.h')
Execution platform: @local_execution_config_platform//:platform
/bin/bash: line 1: 11140 Illegal instruction: 4  bazel-out/darwin-opt/bin/external/llvm-project/llvm/llvm-tblgen -I external/llvm-project/llvm/include -I external/llvm-project/clang/include -I $(dirname external/llvm-project/llvm/include/llvm/IR/Intrinsics.td) -gen-intrinsic-enums -intrinsic-prefix=xcore external/llvm-project/llvm/include/llvm/IR/Intrinsics.td -o bazel-out/darwin-opt/bin/external/llvm-project/llvm/include/llvm/IR/IntrinsicsXCore.h
Target //build:build_wheel failed to build
INFO: Elapsed time: 620.950s, Critical Path: 45.35s
INFO: 589 processes: 132 internal, 457 local.
FAILED: Build did NOT complete successfully
ERROR: Build failed. Not running target
FAILED: Build did NOT complete successfully
b''
Traceback (most recent call last):
  File "/mypath/jax/build/build.py", line 524, in <module>
    main()
  File "/mypath/jax/build/build.py", line 519, in main
    shell(command)
  File "/mypath/jax/build/build.py", line 53, in shell
    output = subprocess.check_output(cmd)
  File "/myuserpath/opt/anaconda3/envs/mytraxenv/lib/python3.9/subprocess.py", line 424, in check_output
    return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
  File "/myuserpath/opt/anaconda3/envs/mytraxenv/lib/python3.9/subprocess.py", line 528, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['./bazel-4.2.1-darwin-x86_64', 'run', '--verbose_failures=true', '--config=avx_posix', '--config=mkl_open_source_only', ':build_wheel', '--', '--output_path=/mypath/jax/dist', '--cpu=x86_64']' returned non-zero exit status 1.

Tried again a couple of times (deleted and redownloaded jax directory) and the same line 528 in .../lib/python3.9/subprocess.py seemed to cause the issue but the output, while mostly the above, sometimes was slightly different. Making me suspect an issue with memory, given that I (admittedly) had not restarted my machine in weeks and it was starting to unresponsively slow.

I updated my XCode command line tools to version 12.2 (pretty sure).

I restarted my (8GB) machine. Deleted and redownloaded jax directory. I installed bazel version 5.0.0 with homebrew in case that would help. I was a bit concerned that it kept downloading an x86 version for my ARM processor. bazel installation went fine.

Started again from these instructions. The jaxlib build, though, made it clear that it wanted an earlier (4.2.1) version of bazel and downloaded it, as before:

b'\x1b[31mERROR: The project you\'re trying to build requires Bazel 4.2.1 (specified in /mypath/jax/.bazelversion), but it wasn\'t found in /opt/homebrew/Cellar/bazel/5.0.0/libexec/bin.\x1b[0m\n\nBazel binaries for all official releases can be downloaded from here:\n  https://github.com/bazelbuild/bazel/releases\n\nYou can download the required version directly using this command:\n  (cd "/opt/homebrew/Cellar/bazel/5.0.0/libexec/bin" && curl -fLO https://releases.bazel.build/4.2.1/release/bazel-4.2.1-darwin-x86_64 && chmod +x bazel-4.2.1-darwin-x86_64)\n'

Once again, a different error in the same line 528. Showing full run now:

b'\x1b[31mERROR: The project you\'re trying to build requires Bazel 4.2.1 (specified in /mypath/jax/.bazelversion), but it wasn\'t found in /opt/homebrew/Cellar/bazel/5.0.0/libexec/bin.\x1b[0m\n\nBazel binaries for all official releases can be downloaded from here:\n  https://github.com/bazelbuild/bazel/releases\n\nYou can download the required version directly using this command:\n  (cd "/opt/homebrew/Cellar/bazel/5.0.0/libexec/bin" && curl -fLO https://releases.bazel.build/4.2.1/release/bazel-4.2.1-darwin-x86_64 && chmod +x bazel-4.2.1-darwin-x86_64)\n'
Downloading bazel from: https://github.com/bazelbuild/bazel/releases/download/4.2.1/bazel-4.2.1-darwin-x86_64
bazel-4.2.1-darwin-x86_64 [########################################] 100%
Bazel binary path: ./bazel-4.2.1-darwin-x86_64
Bazel version: 4.2.1
Python binary path: /myuserpath/opt/anaconda3/envs/mytraxenv/bin/python
Python version: 3.9
NumPy version: 1.21.2
MKL-DNN enabled: yes
Target CPU: x86_64
Target CPU features: release
CUDA enabled: no
TPU enabled: no
ROCm enabled: no

Building XLA and installing it in the jaxlib source tree...
./bazel-4.2.1-darwin-x86_64 run --verbose_failures=true --config=avx_posix --config=mkl_open_source_only :build_wheel -- --output_path=/mypath/jax/dist --cpu=x86_64
INFO: Options provided by the client:
  Inherited 'common' options: --isatty=0 --terminal_columns=80
INFO: Reading rc options for 'run' from /mypath/jax/.bazelrc:
  Inherited 'common' options: --experimental_repo_remote_exec
INFO: Reading rc options for 'run' from /mypath/jax/.bazelrc:
  Inherited 'build' options: --apple_platform_type=macos --macos_minimum_os=10.9 --announce_rc --define open_source_build=true --spawn_strategy=standalone --enable_platform_specific_config --define=no_aws_support=true --define=no_gcp_support=true --define=no_hdfs_support=true --define=no_kafka_support=true --define=no_ignite_support=true --define=grpc_no_ares=true -c opt --config=short_logs --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.
INFO: Reading rc options for 'run' from /mypath/jax/.jax_configure.bazelrc:
  Inherited 'build' options: --strategy=Genrule=standalone --repo_env PYTHON_BIN_PATH=/myuserpath/opt/anaconda3/envs/mytraxenv/bin/python --action_env=PYENV_ROOT --python_path=/myuserpath/opt/anaconda3/envs/mytraxenv/bin/python --distinct_host_configuration=false
INFO: Found applicable config definition build:short_logs in file /mypath/jax/.bazelrc: --output_filter=DONT_MATCH_ANYTHING
INFO: Found applicable config definition build:avx_posix in file /mypath/jax/.bazelrc: --copt=-mavx --host_copt=-mavx
INFO: Found applicable config definition build:mkl_open_source_only in file /mypath/jax/.bazelrc: --define=tensorflow_mkldnn_contraction_kernel=1
INFO: Found applicable config definition build:macos in file /mypath/jax/.bazelrc: --config=posix
INFO: Found applicable config definition build:posix in file /mypath/jax/.bazelrc: --copt=-fvisibility=hidden --copt=-Wno-sign-compare --cxxopt=-std=c++14 --host_cxxopt=-std=c++14
Loading: 
Loading: 0 packages loaded
Analyzing: target //build:build_wheel (1 packages loaded, 0 targets configured)
Analyzing: target //build:build_wheel (8 packages loaded, 286 targets configured)
Analyzing: target //build:build_wheel (10 packages loaded, 4271 targets configured)
Analyzing: target //build:build_wheel (10 packages loaded, 4271 targets configured)
Analyzing: target //build:build_wheel (11 packages loaded, 4494 targets configured)
INFO: Analyzed target //build:build_wheel (11 packages loaded, 7417 targets configured).

INFO: Found 1 target...
[0 / 3] [Prepa] BazelWorkspaceStatusAction stable-status.txt
[21 / 233] Compiling src/google/protobuf/generated_enum_util.cc; 2s local ... (6 actions, 3 running)
[28 / 233] Compiling src/google/protobuf/extension_set.cc; 4s local ... (8 actions, 7 running)
[114 / 523] Compiling src/google/protobuf/generated_message_util.cc; 2s local ... (8 actions running)
[140 / 594] Compiling platform/c++11/src/nsync_semaphore_mutex.cc; 0s local ... (8 actions, 7 running)
[237 / 597] Compiling src/google/protobuf/util/message_differencer.cc; 0s local ... (8 actions running)
[252 / 597] Compiling src/google/protobuf/util/message_differencer.cc; 4s local ... (8 actions running)
[272 / 597] Compiling src/google/protobuf/descriptor.pb.cc; 8s local ... (8 actions running)
[296 / 597] Compiling src/google/protobuf/descriptor.cc; 12s local ... (8 actions running)
[315 / 597] Compiling src/google/protobuf/descriptor_database.cc; 2s local ... (5 actions running)
[487 / 1,963] Compiling src/compiler/python_generator.cc; 2s local ... (8 actions running)
[579 / 3,117] Compiling tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv3d.cc; 0s local ... (8 actions, 7 running)
[619 / 3,173] Compiling tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv3d.cc; 8s local ... (8 actions running)
[687 / 3,270] Compiling tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv3d.cc; 17s local ... (8 actions running)
[774 / 3,464] Compiling tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.cc; 27s local ... (8 actions running)
[1,204 / 4,860] Compiling tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.cc; 39s local ... (8 actions running)
[1,255 / 4,916] Compiling tensorflow/compiler/xla/service/cpu/runtime_matmul.cc; 52s local ... (8 actions running)
[1,340 / 5,042] Compiling tensorflow/compiler/xla/service/cpu/runtime_matmul.cc; 68s local ... (8 actions running)
[1,456 / 5,156] Compiling tensorflow/compiler/xla/service/cpu/runtime_matmul.cc; 86s local ... (8 actions, 7 running)
[1,661 / 5,704] Compiling tensorflow/compiler/xla/service/cpu/runtime_matmul.cc; 107s local ... (8 actions, 7 running)
[1,688 / 5,704] Compiling tensorflow/compiler/xla/service/cpu/runtime_matmul.cc; 132s local ... (8 actions, 7 running)
[1,721 / 5,704] Compiling tensorflow/compiler/xla/service/cpu/runtime_matmul.cc; 160s local ... (8 actions, 7 running)
[2,106 / 6,584] Compiling tensorflow/compiler/xla/service/cpu/runtime_matmul.cc; 192s local ... (8 actions, 7 running)
[2,342 / 7,067] Compiling tensorflow/compiler/xla/service/cpu/runtime_matmul.cc; 231s local ... (8 actions, 7 running)
[2,378 / 7,067] Compiling tensorflow/compiler/xla/service/cpu/runtime_matmul.cc; 274s local ... (8 actions, 7 running)
[2,422 / 7,067] Compiling src/cpu/rnn/ref_rnn.cpp; 75s local ... (8 actions, 7 running)
[2,452 / 7,067] Compiling src/cpu/x64/gemm/f32/jit_avx512_core_f32_copy_at_kern_part1_autogen.cpp; 54s local ... (8 actions, 7 running)
[2,500 / 7,067] Compiling src/cpu/x64/gemm/f32/jit_avx512_core_f32_copy_at_kern_part1_autogen.cpp; 119s local ... (8 actions, 7 running)
[2,577 / 7,067] Compiling src/common/memory_zero_pad.cpp; 74s local ... (8 actions, 7 running)
ERROR: /private/var/tmp/_bazel_a/f5e9a3325f07a1f02c52d821857db47c/external/org_tensorflow/tensorflow/compiler/xla/BUILD:69:17: ProtoCompile external/org_tensorflow/tensorflow/compiler/xla/xla.pb.h failed: (Illegal instruction): protoc failed: error executing command 
  (cd /private/var/tmp/_bazel_a/f5e9a3325f07a1f02c52d821857db47c/execroot/__main__ && \
  exec env - \
    PATH=/myuserpath/opt/anaconda3/envs/mytraxenv/bin:/myuserpath/opt/anaconda3/condabin:/opt/homebrew/bin:/opt/homebrew/sbin:/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin:/Library/TeX/texbin:/Library/Apple/usr/bin \
  bazel-out/darwin-opt/bin/external/com_google_protobuf/protoc '--cpp_out=bazel-out/darwin-opt/bin/external/org_tensorflow' -Iexternal/org_tensorflow -Ibazel-out/darwin-opt/bin/external/org_tensorflow -Ibazel-out/darwin-opt/bin/external/com_google_protobuf/_virtual_imports/any_proto -Ibazel-out/darwin-opt/bin/external/com_google_protobuf/_virtual_imports/source_context_proto -Ibazel-out/darwin-opt/bin/external/com_google_protobuf/_virtual_imports/type_proto -Ibazel-out/darwin-opt/bin/external/com_google_protobuf/_virtual_imports/api_proto -Ibazel-out/darwin-opt/bin/external/com_google_protobuf/_virtual_imports/descriptor_proto -Ibazel-out/darwin-opt/bin/external/com_google_protobuf/_virtual_imports/compiler_plugin_proto -Ibazel-out/darwin-opt/bin/external/com_google_protobuf/_virtual_imports/duration_proto -Ibazel-out/darwin-opt/bin/external/com_google_protobuf/_virtual_imports/empty_proto -Ibazel-out/darwin-opt/bin/external/com_google_protobuf/_virtual_imports/field_mask_proto -Ibazel-out/darwin-opt/bin/external/com_google_protobuf/_virtual_imports/struct_proto -Ibazel-out/darwin-opt/bin/external/com_google_protobuf/_virtual_imports/timestamp_proto -Ibazel-out/darwin-opt/bin/external/com_google_protobuf/_virtual_imports/wrappers_proto external/org_tensorflow/tensorflow/compiler/xla/xla.proto)
Execution platform: @local_execution_config_platform//:platform
Target //build:build_wheel failed to build
INFO: Elapsed time: 631.189s, Critical Path: 283.35s
INFO: 2563 processes: 935 internal, 1628 local.
FAILED: Build did NOT complete successfully
ERROR: Build failed. Not running target
FAILED: Build did NOT complete successfully
b''
Traceback (most recent call last):
  File "/mypath/jax/build/build.py", line 524, in <module>
    main()
  File "/mypath/jax/build/build.py", line 519, in main
    shell(command)
  File "/mypath/jax/build/build.py", line 53, in shell
    output = subprocess.check_output(cmd)
  File "/myuserpath/opt/anaconda3/envs/mytraxenv/lib/python3.9/subprocess.py", line 424, in check_output
    return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
  File "/myuserpath/opt/anaconda3/envs/mytraxenv/lib/python3.9/subprocess.py", line 528, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['./bazel-4.2.1-darwin-x86_64', 'run', '--verbose_failures=true', '--config=avx_posix', '--config=mkl_open_source_only', ':build_wheel', '--', '--output_path=/mypath/jax/dist', '--cpu=x86_64']' returned non-zero exit status 1.

I switched away from trying to build jaxlib. In late October 2021, a M1-compatible jaxlib wheel was released, so I tried:

pip install -U pip
pip install -U https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl

but got

ERROR: jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl is not a supported wheel on this platform.

Tried upgrading python from 3.9 to 3.10, but got the same message.

Debugged by, in python, running

import platform
print(platform.machine())

which showed that my python is still running on x86 architechture, since "Anaconda doesn't yet provide packages for M1/ARM".

[see successful (so far) jaxlib installation in own answer below]

back to trax

After successfully installing jaxlib and jax, when trying to install trax with (miniforge’s)conda install trax I get:

Collecting package metadata (current_repodata.json): done
Solving environment: failed with initial frozen solve. Retrying with flexible solve.
Collecting package metadata (repodata.json): done
Solving environment: failed with initial frozen solve. Retrying with flexible solve.

PackagesNotFoundError: The following packages are not available from current channels:

  - trax

Current channels:

  - https://conda.anaconda.org/conda-forge/osx-arm64
  - https://conda.anaconda.org/conda-forge/noarch

To search for alternate channels that may provide the conda package you're
looking for, navigate to

    https://anaconda.org

and use the search bar at the top of the page.

I tried installing as trax itself suggests but ran into an error building hp5y:

ERROR: Failed building wheel for h5py
ERROR: Could not build wheels for h5py, which is required to install pyproject.toml-based projects

Looking now into this error and also looking into doing it with miniforge again but using Anaconda's channel to install trax. I'll update when I get to it.

Any pointers or anyone who managed to install trax on an M1 Mac it by any different means?

Batruk answered 22/1, 2022 at 18:29 Comment(1)
long question, I know. I wanted to include some of the unsuccessful outputs for wider searchability.Batruk
B
3

jaxlib + jax

After someone claimed success using miniforge, I read this and watched this to clarify using Anaconda and miniforge together.

I installed miniforge with Apple's arm64 : Apple Silicon method. For some reason when I ran miniforge's conda init it set up the initialization code in ~/.bash_profile even though I'm using the zsh shell. I tried putting the code manually instead in ~/.zprofile but it wouldn't load on interactive shells, so I just ended up putting it where Anaconda had put its initialization code, in ~/.zshrc.

This made miniforge the default manager. Following the very useful video above, I created a ~/.start_anaconda.sh script so I can use Anaconda as an alternative.

With miniforge I

  • created a new conda environment mytraxenv with conda create -n mytraxenv python=3 which has python 3.10.2 at the moment

  • activated the environment: conda activate mytraxenv

  • ran conda install numpy and conda install six to ensure numpy. six and wheel (installed by one of the previous two) were installed in my mytraxenv environment

  • tried again, with a slightly updated release (from here):

    pip install -U pip pip install -U https://storage.googleapis.com/jax-releases/mac/jaxlib-0.1.75-cp310-none-macosx_11_0_arm64.whl

This worked in installing jaxlib!

Then, I followed these instructions to install jax:

pip install --upgrade pip
pip install --upgrade "jax[cpu]"

That worked as well. Note that when running import jax in python it currently warns:

/mytraxenv/lib/python3.10/site-packages/jax/_src/lib/__init__.py:32: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "

trax

No success yet installing trax.

Batruk answered 22/1, 2022 at 18:29 Comment(4)
Based on my personal experience: try to avoid Anaconda as much as you can. Even when there is no ARM vs X86 situation going on you might get mixed results with it. If you use a terminal that supports arm you can brew install Python and you get an ARM version.Dre
Then you can create a venv and try to install trax.Dre
Drinking my own coolaid got me this much: ERROR: Command errored out with exit status 1: /Users/venv/bin/python3.10 -u -c 'import io, os, sys, setuptools, tokenize; sys.argv[0] = '"'"'/pip-install-ssq_tw2w/numpy_3385396ca56546caac7bf645a4669622/setup.py'"'"'; file__='"'"'/private/var/folders/b2/gm3cw9856h3d0jsm36_fldr40000gn/T/pip-install-ssq_tw2w/numpy_3385396ca56546caac7bf645a4669622/setup.py'"'"';f = getattr(tokenize, '"'"'open'"'"', open)(__file) if os.path.exists(file) else io.StringIO('"'"'from setuptools import setup; setup()'"'"');Dre
A quick note to say that I succeeded in what seems to be a simpler way, for jax and jaxlib only (not trax) via direct installs from conda-forge. https://mcmap.net/q/1270069/-importing-jax-fails-on-mac-with-m1-chipAcantho
A
3

[Update] As of Nov 15, 2023, a better solution to installing jax (not trax) is perhaps described here with:

python -m pip install jax-metal

Followed by verification using:

python -c 'import jax; print(jax.numpy.arange(10))'

Antipater answered 15/11, 2023 at 18:16 Comment(0)
M
0

Anaconda Source build

As of Mon, 7th of Feb 2022, installing using an Anaconda installation of python fails as Anaconda does not provide arm wheels for python.

When they do provide them, building from source should work automatically. If it does not, clone the repository, and use the following:

python build/build.py --target_cpu darwin_arm64  

If bazel does not build and you are using homebrew consider installing it, and removing the downloaded bazel files.

brew install bazel
rm -rf bazel-*

Arm64 Python wheels

I recommend downloading arm64 wheels from the official CPython website. The executable will be in /usr/local/bin/python3, and so will pip3.

This should allow you to install jax using

/usr/local/bin/pip3 install --upgrade 'jax[cpu]'

To have access to the proper python with the command line/path, add the folder to the path.

export PATH=/usr/share/bin:$PATH

Virtual Env

I recommend against the above method, and instead to use to use virtual-env from here https://virtualenv.pypa.io, to create virtual environments as with anaconda, and install jax in a virtual environment as above.

Poetry

Poetry is an even better option since it uses virtual-env, and allows you to build your projects and list dependencies just as anaconda.

There is a minor caveat as I am writing this, installing scipy when defining dependencies with poetry results in an odd error, installing numpy and scipy manually using

poetry shell # to go to virtual environment for the project
pip install --upgrade 'jax[cpu]' jaxlib scipy numpy

Trax

Since anaconda is not option, installing hdf5 is a bit of a hassle, but it can be done with brew/homebrew. Assuming you have activated the correct venv:

brew install hdf5
export HDF5_DIR=/opt/homebrew/Cellar/hdf5/1.12.1 
pip install trax

Unfortunately this results in the following error:

import trax
/Users/quinn/Library/Caches/pypoetry/virtualenvs/no-xNvLksOy-py3.9/lib/python3.9/site-packages/jax/_src/lib/__init__.py:32: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "


***************************************************************
Failed to import TensorFlow. Please note that TensorFlow is not installed by default when you install TFDS. This allow you to choose to install either `tf-nightly` or `tensorflow`. Please install the most recent version of TensorFlow, by following instructions at https://tensorflow.org/install.
***************************************************************


Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/quinn/Library/Caches/pypoetry/virtualenvs/no-xNvLksOy-py3.9/lib/python3.9/site-packages/trax/__init__.py", line 18, in <module>
    from trax import layers
  File "/Users/quinn/Library/Caches/pypoetry/virtualenvs/no-xNvLksOy-py3.9/lib/python3.9/site-packages/trax/layers/__init__.py", line 23, in <module>
    from trax.layers.activation_fns import *
  File "/Users/quinn/Library/Caches/pypoetry/virtualenvs/no-xNvLksOy-py3.9/lib/python3.9/site-packages/trax/layers/activation_fns.py", line 26, in <module>
    from trax import math
  File "/Users/quinn/Library/Caches/pypoetry/virtualenvs/no-xNvLksOy-py3.9/lib/python3.9/site-packages/trax/math/__init__.py", line 18, in <module>
    from trax.math import jax as jax_math
  File "/Users/quinn/Library/Caches/pypoetry/virtualenvs/no-xNvLksOy-py3.9/lib/python3.9/site-packages/trax/math/jax.py", line 28, in <module>
    import tensorflow_datasets as tfds
  File "/Users/quinn/Library/Caches/pypoetry/virtualenvs/no-xNvLksOy-py3.9/lib/python3.9/site-packages/tensorflow_datasets/__init__.py", line 43, in <module>
    from tensorflow_datasets.core import tf_compat
  File "/Users/quinn/Library/Caches/pypoetry/virtualenvs/no-xNvLksOy-py3.9/lib/python3.9/site-packages/tensorflow_datasets/core/__init__.py", line 22, in <module>
    tf_compat.ensure_tf_install()
  File "/Users/quinn/Library/Caches/pypoetry/virtualenvs/no-xNvLksOy-py3.9/lib/python3.9/site-packages/tensorflow_datasets/core/tf_compat.py", line 48, in ensure_tf_install
    import tensorflow as tf  # pylint: disable=import-outside-toplevel
ModuleNotFoundError: No module named 'tensorflow'
Metope answered 7/2, 2022 at 17:38 Comment(2)
thank you. jax is not an issue now, it's installing trax that's still unsolved!Batruk
If you can import tensorflow, trax should be installable, I updated the solution.Metope

© 2022 - 2024 — McMap. All rights reserved.