관련 이것저것/IT Book 리뷰

JAX/Flax로 딥러닝 레벨업 - 제이펍 [도서리뷰]

agingcurve 2024. 9. 15. 11:51
반응형

최신 LLM 모델들을 공부하시다 보면

JAX, Flax 라이브러리를 많이 들어보셨을 거라 생각합니다.

JAX와 Flax는 아직 딥러닝 모델 시장에서 활성화 되지는 않았지만, 

시장을 잠식해가고 있는 메타플랫폼스의 파이토치에

대항마로 구글에서 최근에 밀어주고 있는 라이브러리 인데요,

JAX와 Flax는 각각 2018년, 2020년에 구글리서치에서 개발하여 사용하고 있습니다.

 

 

지은이들은 모두의 연구소 'JAX/ Flax LAB'로 구성된 멤버분들이라고 합니다.

모두의 연구소는 여러 LAB을 운영하면서,

AI에 관한 여러 LAB을 운영하는 걸로 유명한데요,

현업자들과 연구자분들이 뭉쳐서 책을 지으셨다고 합니다. 

모두의 연구소의 풀잎스쿨을 여러번 참여한 경험이 있었는데요,

다들 열정이 넘치고 적극적으로 활동해주셨던 분들이 많았다는

기억이 있어서 JAX/ Flax LAB분들이 반가웠습니다.

 

그렇다면 어떤 장점이 있길래, JAX/Flax를 앞세워 사용하고 있을까요?

 

XLA (Accelerated Linear Algebra):

고급 컴파일러로, 텐서 연산을 최적화하여 다양한 하드웨어에서 효율적으로 실행할 수 있게 합니다. XLA는 Google이 개발한 고급 컴파일러로, 텐서 연산을 위한 최적화된 코드를 생성합니다.

이 최적화 된 커널은 파이토치의 동적 그래프 방식보다 훨씬 빠른 속도로 추론과 학습이 가능합니다.

이 컴파일러를 통해 XLA는 CPU, GPU, TPU 등 다양한 하드웨어에서 실행 가능한 코드를 생성을 최적화해 하드웨어 연산을 크게 올립니다.

 

JIT 컴파일 :

데코레이터를 적용한 함수는 호출 시 XLA를 통해 최적화된 기계 코드로 컴파일됩니다. 이 과정에서 불필요한 연산이 제거되고, 연산 그래프가 단순화되어 실행 속도가 향상됩니다. 기본 파이썬의 경우, JIT을 지원하지 않기 때문에 복잡한 문제를 해결 시, 고성능의 작업을 수행하기 어렵지만, JAX에는 이것을 지원하여 빠른속도로 수행할 수 있도록 해주고 있습니다.

 

Wishper 모델을 보면, 성능표를 보면 추론에 있어서

월등한 성능을 보여줄 수 있음을 확인할 수 있습니다.

 

HuggingFace에서도 Flax community week를 만들어서 변환하고 있으며,

Google Research에서 나온 대부분의 논문 구현은 Flax로 구현되어 있습니다.

지속적인 발전을 이루고 있는, Jax 및 Flax는 충분히 매력있고,

계속 배울 필요성을 느끼게 하는 라이브러리라고 생각이 듭니다.

 

책의 내용은 간단하면서도 탄탄하게 모델에 대한 설명을 작성해주셨다고 생각이 듭니다.

 

 

 

코드 예제 또한 쉽고 가독성 있게 읽을 수 있도록

작성을 해주셨다고 느낄 수 있었습니다.

 

 

딥러닝 기초적인 모델 뿐 아니라, 최신 CLIP모델 까지

JAX/Flax로 다루는 방법에 대해 설명하고 있습니다.

 

https://github.com/JAX-KR/jax-flax-book

 

GitHub - JAX-KR/jax-flax-book

Contribute to JAX-KR/jax-flax-book development by creating an account on GitHub.

github.com

 

깃허브를 통해서 코드 예제를 살펴 볼 수도 있습니다.

 

기본적으로 실습환경은 코랩인데요,

https://github.com/JAX-KR/jax-flax-book/blob/main/requirements.txt 

 

jax-flax-book/requirements.txt at main · JAX-KR/jax-flax-book

Contribute to JAX-KR/jax-flax-book development by creating an account on GitHub.

github.com

 

실습에 필요한 라이브러리를 필수적으로 설치해 주셔야 합니다.

jax==0.4.26
flax==0.8.4
optax== 0.2.2
datasets
transformers
tokenizers

 

라이브러리가 업데이트 되면서 필수라이브러리가

바뀔수 있으니 꼭 깃허브 requirements를 참조하시기 바랍니다.

 

 

고성능 모델 개발을 원하는 딥러닝 개발자,

리서처라면

JAX/Flax로 딥러닝 레벨업으로

준비해보는 것을 권해드리고 싶습니다.