MLX は Apple Silicon 専用のライブラリです(MLXではじめる機械学習 が参考になります)。MLX LM は MLS で LLM をするパッケージです。Vision Language Model 用の MLX-VLM というパッケージもあります。
MLX LM で Gemma 3 27B 8bit を使う例:
from mlx_lm import load, generate from mlx_lm.sample_utils import make_sampler model, tokenizer = load("mlx-community/gemma-3-27b-it-8bit") sampler = make_sampler(temp=0) # set temperature etc. prompt = ''' プロンプト ''' response = generate(model, tokenizer, prompt=prompt.strip(), max_tokens=1024, verbose=True, sampler=sampler)
上位3位までのトークンの確率分布を出力してみます:
import mlx.core as mx from mlx_lm import load, generate model, tokenizer = load("mlx-community/gemma-3-27b-it-8bit") def sampler(x): a = mx.argsort(-x, axis=-1)[0] s = mx.sum(mx.exp(x[0])) for t in a[:3]: i = int(t) print(i, float(mx.exp(x[0][i])/s), repr(tokenizer.decode(i))) return mx.argmax(x, axis=-1) response = generate(model, tokenizer, prompt="How many 'r's in strawberry? Answer with numbers only.", max_tokens=3, sampler=sampler, verbose=False) print(repr(response))
結果は次のようになり、「3」がほぼ確率1で出力されることがわかります。
107 0.73046875 '\n' 108 0.2373046875 '\n\n' 236743 0.0172119140625 ' ' 236800 1.0 '3' 236812 0.0006256103515625 '4' 236778 0.0002307891845703125 '2' 107 0.984375 '\n' 108 0.01239013671875 '\n\n' 236743 0.00017642974853515625 ' ' 106 0.8515625 '<end_of_turn>' 818 0.037353515625 'The' 34318 0.029052734375 'Correct' '\n3\n'