importing jax fails on mac with m1 chip
Asked Answered
C

4

8

For python 3.8.8 and using the new mac air (with the m1 chip), in jupyter notebooks and in python terminal, import jax raises this error

Python 3.8.8 (default, Apr 13 2021, 12:59:45)
[Clang 10.0.0 ] :: Anaconda, Inc. on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/steve/Documents/code/jax/jax/__init__.py", line 37, in <module>
    from . import config as _config_module
  File "/Users/steve/Documents/code/jax/jax/config.py", line 18, in <module>
    from jax._src.config import config
  File "/Users/steve/Documents/code/jax/jax/_src/config.py", line 26, in <module>
    from jax import lib
  File "/Users/steve/Documents/code/jax/jax/lib/__init__.py", line 63, in <module>
    cpu_feature_guard.check_cpu_features()
RuntimeError: This version of jaxlib was built using AVX instructions, which your CPU and/or operating system do not support. You may be able work around this issue by building jaxlib from source.

I suspect it occurs because of the m1 chip.

I tried using jax with pip install jax, then I built it from source as suggested by the comment, by cloning their repository and following the instructions given here, but the same error message shows.

Cunctation answered 10/7, 2021 at 12:50 Comment(0)
C
9

I had a similar problem. Since I already had Anaconda installed and didn't want to clutter up my space with Anaconda + miniconda + homebrew versions of python and package management and whatever, I hunted around for a simple solution. What ended up working for me was first uninstalling jax and jaxlib and then installing jax and jaxlib via conda-forge directly:

pip uninstall jax jaxlib
conda install -c conda-forge jaxlib
conda install -c conda-forge jax
Caryloncaryn answered 8/2, 2022 at 17:50 Comment(4)
not the point of the question, but are you able to use with trax or were you looking to use just jax?Sybilla
I was just looking at jax. hadn't looked at trax.Caryloncaryn
@NicholasGReich, was just looking for an answer to this exact question, small world.Inhibitory
Is there a non-conda solution? I want to be able to install to a virtualenv created via virtualenv (virtualenv.pypa.io/en/latest)Blink
C
5

Thanks @jakevdp I looked at the issue you linked and found a workaround :

Thanks to Noah who mentioned in issue #5501 that you could just use a previous version of jax and jaxlib, for my purposes jaxlib==0.1.60 and jax==0.2.10 work just fine!

Cunctation answered 12/7, 2021 at 10:18 Comment(2)
jax is now available for M1 macs; runs fine for meAesop
True, @emil. Confirmation here: github.com/google/jax/issues/5501#issuecomment-955590288 - unclear when next release is, but at least it can be installed now somewhat officiallySybilla
D
2

JAX does not yet provide pre-built jaxlib wheels that are compatible with M1 chips. The best source of information I know on building jaxlib on M1 is probably this github issue: https://github.com/google/jax/issues/5501, which also tracks improving this support.

Hopefully M1 support will be improved in the near future, but it's taking a while for the scientific computing infrastructure up and down the stack to catch up with the requirements of the new chips.

Delp answered 10/7, 2021 at 14:50 Comment(0)
A
0

As of now (January 2022), jax is available for M1 Macs. Make sure to uninstall jax and jaxlib and then install the new packages via pip:

pip install --upgrade jax jaxlib

Afterwards, you can use jax without problems.

--Edit-- I am running on a machine with the following specs:

ProductName:    macOS
ProductVersion: 12.1
BuildVersion:   21C52

and with Python 3.9.6 within a conda environment.

Aesop answered 12/1, 2022 at 8:7 Comment(5)
What MacOS and python versions are you using?Urticaria
@Blade, I have updated the answer for clarity.Aesop
up to a certain point. I still get this error after uninstalling both jax and jaxlib, installing with pip install --upgrade jax jaxlib and then trying to use it as part of Trax: RuntimeError: This version of jaxlib was built using AVX instructions, which your CPU and/or operating system do not support. You may be able work around this issue by building jaxlib from source. Same OS version and build, python 3.9.7 within a conda environment.Sybilla
@Sybilla see my answer that I just posted. I fixed this on my machine by installing from conda-forge directly.Caryloncaryn
thank you! I think I did something similar: stackoverflow.com/a/70815865Sybilla

© 2022 - 2024 — McMap. All rights reserved.