JAX/Flaxで学ぶディープラーニングの仕組み

Compass Data Science

JAX/Flaxで学ぶディープラーニングの仕組み

1~2日で出荷、新刊の場合、発売日以降のお届けになります

出版社
マイナビ出版
著者名
中井悦司
価格
3,179円(本体2,890円+税)
発行年月
2023年2月
判型
B5変
ISBN
9784839982324

●JAX/Flax/Optaxの特徴

JAXとFlax、およびOptaxは、米Google社のAI研究チームと米DeepMind社のエンジニアが中心となって開発しているオープンソースソフトウェアです。Googleが開発したディープラーニングのライブラリーといえばTensorFlow/Kerasが有名ですが、最近は、JAXとその周辺ライブラリーにも注目が集まっています。

JAXは、機械学習で必要となる数値計算処理をPythonのコードから高速に実行するためのライブラリーです。表面的にはNumPyとほぼ同じ使い方ができて、GPUでの実行に対応しています。

TensorFlow/KerasとJAX/Flax/Optaxを比べると、後者では裏側の仕組みが適度なレベルで見えているという点が異なります。

TensorFlow/Kerasの場合、機械学習の「定型作業」を実施する上では簡単なコードで良いものの、応用的な作業を行おうとするとTensorFlow/Kearsに固有の機能を用いた特殊なコードを書く必要があります。一方、JAX/Flax/Optaxの場合は、定型作業にもある程度のコーディングが必要な一方で、応用的な作業も通常のPythonプログラミングの感覚で行えます。応用的な作業が中心となる、研究・開発目的での利用に適したライブラリーと言えます。


● 本書の概要

本書では、ディープラーニングの代表例とも言える畳み込みニューラルネットワーク(CNN)を例として、これをJAX/Flax/Optaxで実装しながら、モデルの各パーツの役割を数式レベルで丁寧に解説していきます。

この際、モデル内部の処理の様子を確認するために、モデルの中身を分析するコードもあわせて利用します。JAX/Flax/Optaxを利用すれば、モデルの構築だけでなく、このような分析作業も簡単に実施できることが実感できるでしょう。

導入となる第1章では、JAX/Flax/Optaxの基本的な機能とその使い方を学ぶために、機械学習の基礎とも言える「最小二乗法」による回帰問題を利用します。まずは、JAXの機能だけを利用して、勾配降下法のアルゴリズムを独自に実装して、回帰モデルの学習を行います。その後、これと同等の処理をFlax/Optaxを組み合わせて、再度、実装してみます。これにより、Flax/Optaxの使い方に加えて、JAXの微分機能など、その背後で行われる実際の処理内容をより明確に理解することができるでしょう。

第2章以降では、より本格的な畳み込みニューラルネットワークを構築し、さらに、転移学習やDCGANによる画像生成モデルなども実装します。付録として、本書で使用するJAX/Flax/Optaxの主な関数の一覧も用意。JAX/Flax/Optaxの使い方をリファレンス的に知っておきたい方にもおすすめです。

お気に入りカテゴリ

よく利用するジャンルを設定できます。

≫ 設定

カテゴリ

「+」ボタンからジャンル(検索条件)を絞って検索してください。
表示の並び替えができます。

page top