系列ラベリングなどで最適なパスを探索する方法はビタビアルゴリズムで効率的に求められる。 上位N個のパスを探索する方法はビタビアルゴリズムと、A*アルゴリズムで効率的に求められる。 日本語入力を支える技術 ~変わり続けるコンピュータと言葉の世界 (WEB+DB PRESS plus) の説明が分かりやすい。理解するために実装してみた。

package main
import (
"container/heap"
"fmt"
)
type PriorityQueue []*Node
func (this PriorityQueue) Len() int { return len(this) }
func (this PriorityQueue) Less(i, j int) bool {
return this[i].Score < this[j].Score
}
func (this PriorityQueue) Swap(i, j int) {
this[i], this[j] = this[j], this[i]
this[i].index = i
this[j].index = j
}
func (this *PriorityQueue) Push(x interface{}) {
n := len(*this)
node := x.(*Node)
node.index = n
*this = append(*this, node)
}
func (this *PriorityQueue) Pop() interface{} {
old := *this
n := len(*this)
node := old[n-1]
node.index = -1
*this = old[0 : n-1]
return node
}
func (this *PriorityQueue) IsEmpty() bool {
if len(*this) == 0 {
return true
}
return false
}
func (this *PriorityQueue) update(node *Node) {
heap.Fix(this, node.index)
}
type Node struct {
X int
Y int
// for priority queue
Score float64
index int
LPath []*Path
RPath []*Path
BestScore float64
Prev *Node
GoalScore float64
Next *Node
}
type Path struct {
LNode *Node
RNode *Node
Score float64
}
func (this *Path) Add(lnode, rnode *Node) {
this.LNode = lnode
this.RNode = rnode
rnode.LPath = append(rnode.LPath, this)
lnode.RPath = append(lnode.RPath, this)
}
func Viterbi(nodes [][]*Node, eos *Node) {
segment_size := len(nodes)
for i := 0; i < segment_size; i++ {
for j := 0; j < len((nodes)[i]); j++ {
n := nodes[i][j]
bestScore := 0.
var bestNode *Node
for _, p := range n.LPath {
s := p.Score + p.LNode.BestScore
// fmt.Println("path score", p.Score, "bestscore", p.RNode.BestScore)
if s > bestScore {
bestScore = s
bestNode = p.LNode
}
}
n.BestScore = bestScore
n.Prev = bestNode
// fmt.Println(fmt.Sprintf("(%d,%d)<-(%d,%d)", n.X, n.Y, n.Prev.X, n.Prev.Y), n.BestScore)
}
}
bestScore := 0.
var bestNode *Node
for _, p := range eos.LPath {
s := p.Score + p.LNode.BestScore
if s > bestScore {
bestScore = s
bestNode = p.LNode
}
}
eos.BestScore = bestScore
eos.Prev = bestNode
// fmt.Println("best score at eos", eos.BestScore)
}
func Backtrack(nodes [][]*Node, eos *Node) []int {
var result []int = make([]int, len(nodes), len(nodes))
var n *Node
var t int = len(nodes) - 1
n = eos.Prev
// fmt.Println(n.X, n.Y)
result[t] = n.Y
t -= 1
for i := len(nodes) - 1; i > 0; i-- {
n = n.Prev
// fmt.Println(n.X, n.Y)
result[t] = n.Y
t -= 1
}
fmt.Println("Backtrack")
for _, y := range result {
fmt.Println(fmt.Sprintf("%d", y))
}
return result
}
func BackwardAstar(N int, nodes [][]*Node, eos *Node) [][]int {
pqueue := make(PriorityQueue, 0, 10)
heap.Init(&pqueue)
eos.Score = 0.
heap.Push(&pqueue, eos)
var result []*Node = make([]*Node, N, N)
var n int = 0
for {
if pqueue.IsEmpty() {
break
}
var node *Node = pqueue.Pop().(*Node)
// fmt.Println("POP", node.X, node.Y)
if node.X == 0 { // is bos
result[n] = node
n += 1
} else {
for _, p := range node.LPath {
prev := *p.LNode // copy
prev.GoalScore = node.GoalScore + p.Score
prev.Score = prev.BestScore + prev.GoalScore
prev.Next = node
// fmt.Println("PUSH", prev.X, prev.Y, prev.Score)
heap.Push(&pqueue, &prev)
}
}
if n >= N {
break
}
}
var results [][]int = make([][]int, 0, N)
for i, n := range result {
var result []int = make([]int, 0, len(nodes))
for {
n = n.Next
if n.X == eos.X {
break
}
fmt.Println(fmt.Sprintf("%d-best: %d", i+1, n.Y))
result = append(result, n.Y)
}
fmt.Println("")
}
return results
}
package main
import (
// "fmt"
"testing"
)
func TestNbest(t *testing.T) {
/*
n1 --- n3
/ \ / \
bos x eos
\ / \ /
n2 --- n4
*/
nodes := make([][]*Node, 0, 10)
bos := Node{X: 0, Y: 0, BestScore: 0.}
n1 := Node{X: 1, Y: 0}
n2 := Node{X: 1, Y: 1}
p1 := Path{Score: 1.}
p1.Add(&bos, &n1)
p2 := Path{Score: 2.}
p2.Add(&bos, &n2)
n3 := Node{X: 2, Y: 0}
n4 := Node{X: 2, Y: 1}
p3 := Path{Score: 1.}
p3.Add(&n1, &n3)
p4 := Path{Score: 2.}
p4.Add(&n1, &n4)
p5 := Path{Score: 2.}
p5.Add(&n2, &n3)
p6 := Path{Score: 4.}
p6.Add(&n2, &n4)
eos := Node{X: 3, Y: 0}
p7 := Path{Score: 4.}
p7.Add(&n3, &eos)
p8 := Path{Score: 1.}
p8.Add(&n4, &eos)
nodes = append(nodes, []*Node{&n1, &n2})
nodes = append(nodes, []*Node{&n3, &n4})
// fmt.Println(eos.LPath[0].LNode)
// eos.LPath[0].LNode.BestScore = 100.
// fmt.Println("eos", eos.LPath[0].LNode)
// fmt.Println("nodes", nodes[1][1])
// fmt.Println("viterbi")
Viterbi(nodes, &eos)
Backtrack(nodes, &eos)
N := 4
BackwardAstar(N, nodes, &eos)
}

ラティスの構造は以下のとおり。パスの重みは上記コードを参照。

#0       n1 --- n3
        /   \ /   \
       bos   x     eos
        \   / \   /
#1       n2 --- n4

ビタビアルゴリズムを適用後、バックトラックで最適なパスを選んだ場合の解と、 ビタビアルゴリズムを適用後、A*アルゴリズムで後ろ向きに最適なパスを順に選んだ場合の1-best解が一致している。 2-best以降もラティスの情報から、あっていることを確認。

下記の###以降はこのエントリ用に付け足した文字列。

$go test
Backtrack ### ビタビアルゴリズムで求めたパス
1 ### n2を通って、
0 ### n3を通る
1-best: 1 ### n2を通って、
1-best: 0 ### n3を通る

2-best: 1 ### n2を通って、
2-best: 1 ### n4を通る

3-best: 0 ### n1を通って、
3-best: 0 ### n3を通る

4-best: 0 ### n1を通って、
4-best: 1 ### n4を通る

goで優先度付きキューを実装するには、heap - The Go Programming LanguageのExample (PriorityQueue) が参考になる。

少し話はそれるが、機械翻訳において、n-best解は似通ったものが選ばれてしまう問題があるので、多様性を考慮するモデルを提案している話があって、これも気になるのでメモ。

Kevin Gimpel, Dhruv Batra, Chris Dyer, and Gregory Shakhnarovich. “A Systematic Exploration of Diversity in Machine Translation”, EMNLP, 2013. pdf

参考


関連記事






最近の記事