・文書dがP(d)で選ばれる
・潜在変数zがP(z|d)で選ばれる
・語wがP(w|z)で生成される
というプロセスを経て、結果として(d,w)のペアが観測されるという文書と語の生成モデル。
式で表すと
となる。P(d,w)の尤もらしい確率分布を見つけたい。対数尤度関数は(1)
となる。n(d,w)は語wが文書dに出現する回数。この式は訓練データn(d,w)(;どの語がどの文書に何回出現したか)が尤もらしい確率分布P(d,w)に従うとき最大になる。ベイズの定理を用いると(2)
となることを利用して、この尤度関数を最大化するためにEMアルゴリズムを用いて実装してみる。(過学習を回避するために文献ではTempered EM (TEM)を用いている。)尤度関数が収束するまで以下のE-stepとM-stepを繰り返す。(3)
E-step: 現在推定されているパラメータの分布に基づく潜在変数の条件付き確率の計算
M-step: 尤度の期待値を最大化するためのパラメータの計算(4)
細かい導出は下の方に。(5)
(6)
(7)
各確率はランダムな値で初期化した。
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import math | |
import random | |
class plsa(): | |
def __init__(self,n,nz=2): | |
""" n:DW行列 nz:潜在変数の数""" | |
self.n = n | |
self.nz = nz | |
# num of the docs | |
self.nd = len(n) | |
# num of the words | |
self.nw = len(n[0]) | |
# initialize the probability | |
self.pdwz = self.arr([ self.nd, self.nw, self.nz ]) # P(z|d,w) | |
self.pzw = self.arr([ self.nz, self.nw ]) # P(w|z) | |
self.pzd = self.arr([ self.nz, self.nd ]) # P(d|z) | |
self.pz = self.arr([ self.nz ]) # P(z) | |
self.pdw = self.arr([ self.nd, self.nw ]) # P(d,w) | |
def train(self,k=1000): | |
# 収束するまで繰り返す | |
tmp = 0 | |
for i in range(k): | |
self.e_step() | |
self.m_step() | |
L = self.likelihood() | |
# 収束したか | |
if abs(L - tmp) < 1.0e-10: | |
break | |
else: | |
tmp = L | |
@staticmethod | |
def normalized(list): | |
""" 合計が1になるように正規化する """ | |
total = sum(list) | |
return map(lambda a: a/total, list) | |
@staticmethod | |
def arr(list): | |
""" 多次元配列の生成する | |
list=[M,N...] としたら M x N x... で要素がランダムの配列を返す """ | |
if len(list) > 1: | |
return [ plsa.arr(list[1:]) for i in range(list[0]) ] | |
else: | |
return plsa.normalized([ random.random() for i in range(list[0]) ]) | |
def likelihood(self): | |
""" log-liklihood """ | |
# P(d,w) | |
for d in range(self.nd): | |
for w in range(self.nw): | |
self.pdw[d][w] = sum([ self.pz[z]*self.pzd[z][d]*self.pzw[z][w] for z in range(self.nz) ]) | |
# Σ n(d,w) log P(d,w) | |
return sum([ self.n[d][w]*math.log(self.pdw[d][w]) for d in range(self.nd) for w in range(self.nw) ]) | |
def e_step(self): | |
""" E-step """ | |
# P(z|d,w) | |
for d in range(self.nd): | |
for w in range(self.nw): | |
for z in range(self.nz): | |
self.pdwz[d][w][z] = self.pz[z]*self.pzd[z][d]*self.pzw[z][w] | |
self.pdwz[d][w] = self.normalized(self.pdwz[d][w]) | |
def m_step(self): | |
""" M-step """ | |
# P(w|z) | |
for z in range(self.nz): | |
for w in range(self.nw): | |
self.pzw[z][w] = sum([ self.n[d][w]*self.pdwz[d][w][z] for d in range(self.nd) ]) | |
self.pzw[z] = self.normalized(self.pzw[z]) | |
# P(d|z) | |
for z in range(self.nz): | |
for d in range(self.nd): | |
self.pzd[z][d] = sum([ self.n[d][w]*self.pdwz[d][w][z] for w in range(self.nw) ]) | |
self.pzd[z] = self.normalized(self.pzd[z]) | |
# P(z) | |
for z in range(self.nz): | |
self.pz[z] = sum([ self.n[d][w]*self.pdwz[d][w][z] for d in range(self.nd) for w in range(self.nw) ]) | |
self.pz = self.normalized(self.pz) | |
if __name__ == '__main__': | |
n = [[1,1,0,0], | |
[0,0,1,1], | |
[0,0,0,1]] | |
p = plsa(n) | |
p.train() | |
print p.pz | |
print p.pzd | |
print p.pzw | |
print p.pdwz |
桁落ちで math domain error が。要対策。
(4)~(7)の導出は以下のとおり。
(4)は単純にベイズによる変形で求められる。P(z),P(w|z),P(d|z)が入力でP(z|d,w)が出力。文書dと語wがどのクラスzに属しているか。
(5)~(7)はQ関数(対数尤度関数の期待値)を最大化するパラメータを求める。P(z|d,w)とn(d,w)が入力でP(z),P(w|z),P(d|z)が出力。
対数尤度関数(2)でlogP(d,w)は未知だが、E-stepで求めたP(z|d,w)を用いてlogP(d,w)のΣP(z|d,w)logP(d,w)と期待値を求める(ベイズの定理を用いればlogP(d,w)はzの関数)(これがEMアルゴリズムのキモ)。Q関数は、
となる。ここで(8)
という制約のもとにラグランジュ関数は(9)
(10)
(11)
となる。たとえばP(d|z)で偏微分して0になるようにパラメータを定めれば(12)
でもって、両辺にP(d|z)かけて(13)
そんで両辺dについて和を求めれば(14)
で(11)の条件から(15)
とかなってβzが求まって(14)に代入すれば(16)
で(5)になってめでたしめでたし。(17)
P(z),P(w|z)も同じように。
Thomas Hofmann, Probabilistic Latent Semantic Indexing, Proceedings of the Twenty-Second Annual International SIGIR Conference on Research and Development in Information Retrieval (SIGIR-99), 1999.
No comments:
Post a Comment