MLXでローカルLLM

MLX は Apple Silicon 専用のライブラリです(MLXではじめる機械学習 が参考になります)。MLX LM は MLS で LLM をするパッケージです。Vision Language Model 用の MLX-VLM というパッケージもあります。

MLX LM で Gemma 3 27B 8bit を使う例:

from mlx_lm import load, generate
from mlx_lm.sample_utils import make_sampler

model, tokenizer = load("mlx-community/gemma-3-27b-it-8bit")
sampler = make_sampler(temp=0) # set temperature etc.

prompt = '''
プロンプト
'''

response = generate(model, tokenizer, prompt=prompt.strip(),
                    max_tokens=1024, verbose=True, sampler=sampler)

上位3位までのトークンの確率分布を出力してみます:

import mlx.core as mx
from mlx_lm import load, generate

model, tokenizer = load("mlx-community/gemma-3-27b-it-8bit")

def sampler(x):
    a = mx.argsort(-x, axis=-1)[0]
    s = mx.sum(mx.exp(x[0]))
    for t in a[:3]:
        i = int(t)
        print(i, float(mx.exp(x[0][i])/s), repr(tokenizer.decode(i)))
    return mx.argmax(x, axis=-1)

response = generate(model,
                    tokenizer,
                    prompt="How many 'r's in strawberry? Answer with numbers only.",
                    max_tokens=3,
                    sampler=sampler,
                    verbose=False)

print(repr(response))

結果は次のようになり、「3」がほぼ確率1で出力されることがわかります。

107 0.73046875 '\n'
108 0.2373046875 '\n\n'
236743 0.0172119140625 ' '
236800 1.0 '3'
236812 0.0006256103515625 '4'
236778 0.0002307891845703125 '2'
107 0.984375 '\n'
108 0.01239013671875 '\n\n'
236743 0.00017642974853515625 ' '
106 0.8515625 '<end_of_turn>'
818 0.037353515625 'The'
34318 0.029052734375 'Correct'
'\n3\n'