レコメンド結果にMMRを適用して多様性を考慮したかった。
以下のようにPythonで実装されているコードはあったが、Pysparkで実装れているサンプルがなかったので実装してみた。
from pyspark.sql import DataFrame as SDF
from typing import Set, Callable, List
import pyspark.sql.functions as F
def sim_func(df: SDF, item_id: int, rec_item_id: int) -> float:
min_score = df.filter(F.col('item_id') == item_id).groupBy().min('score').collect()[0][0]
try:
score = df.filter((F.col('item_id') == item_id) & (F.col('rec_item_id') == rec_item_id)).collect()[0]['score']
except:
score = min_score
return score
def mmr(df: SDF, items: Set[int], item_id: int, lambda_: float, sim_func1: Callable[[SDF, int, int], float], sim_func2: Callable[[SDF, int, int]) -> List[int]:
def _argmax(keys, f):
return max(keys, key=f)
selected = []
while set(selected) != items:
remaining = items - set(selected)
mmr_score = lambda x: labmda_ * sim_func1(df, item_id, x) - (1 - lambda_) * max([sim_func2(df, x, y) for y in set(selected)-{x}] or [0])
next_selected = _argmax(remaining, mmr_score)
selected.append(next_selected)
return selected
mmr(
df,
set(list(df.filter(F.col('item_id') == 12345678).select('rec_item_id').toPandas()['rec_item_id'])),
12345678,
0.8,
sim_func,
sim_func
)
Crieitは誰でも投稿できるサービスです。 是非記事の投稿をお願いします。どんな軽い内容でも投稿できます。
また、「こんな記事が読みたいけど見つからない!」という方は是非記事投稿リクエストボードへ!
こじんまりと作業ログやメモ、進捗を書き残しておきたい方はボード機能をご利用ください。
ボードとは?
コメント