Gemma 3の埋め込み

埋め込みの続編。Gemma 3の埋め込み空間で king - man + woman をやってみよう。

MacなのでMLXを使うことにする。モデルとトークナイザを取り出すのは簡単。例えば queen という語についてやってみよう。

model, tokenizer = load("mlx-community/gemma-3-27b-it-8bit")

prompt = "queen"
token_ids = tokenizer.encode(prompt)  # [2, 91024]
tokens = tokenizer.convert_ids_to_tokens(token_ids)  # ['', 'queen']

トークン列を埋め込みベクトルに変換するための関数を探すために、Pythonに model と打ち込んでみる:

Model(
  (language_model): Model(
    (model): Gemma3Model(
      (embed_tokens): QuantizedEmbedding(262208, 5376, group_size=64, bits=8, mode=affine)
      (layers.0): TransformerBlock(
...以下略...

これから model.language_model.model.embed_tokens だと推量する:

embeddings = model.language_model.model.embed_tokens(token_ids)
queen = embeddings[1]  # 5376次元のベクトル
mx.linalg.norm(queen)  # array(1.01562, dtype=bfloat16)
# mx.save("queen.npy", queen)  # 念のため保存したいとき

同じことを man、woman、king についても行う。

import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

X = np.vstack([man.astype(mx.float32),
               woman.astype(mx.float32),
               king.astype(mx.float32),
               queen.astype(mx.float32),
               (king - man + woman).astype(mx.float32)])
cosine_similarity(X)

結果は次の通り:

array([[ 1.000002  ,  0.27174592,  0.21292192,  0.10807453, -0.35068962],
       [ 0.27174592,  1.0000023 ,  0.10984613,  0.18116248,  0.55315906],
       [ 0.21292192,  0.10984613,  1.0000019 ,  0.32620484,  0.5962215 ],
       [ 0.10807453,  0.18116248,  0.32620484,  1.0000018 ,  0.26477924],
       [-0.35068962,  0.55315906,  0.5962215 ,  0.26477924,  1.000003  ]],
      dtype=float32)

やはり king - man + woman は queen より king に近い。