N-best解の探索
系列ラベリングなどで最適なパスを探索する方法はビタビアルゴリズムで効率的に求められる。 上位N個のパスを探索する方法はビタビアルゴリズムと、A*アルゴリズムで効率的に求められる。 日本語入力を支える技術 ~変わり続けるコンピュータと言葉の世界 (WEB+DB PRESS plus) の説明が分かりやすい。理解するために実装してみた。
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"container/heap" | |
"fmt" | |
) | |
type PriorityQueue []*Node | |
func (this PriorityQueue) Len() int { return len(this) } | |
func (this PriorityQueue) Less(i, j int) bool { | |
return this[i].Score < this[j].Score | |
} | |
func (this PriorityQueue) Swap(i, j int) { | |
this[i], this[j] = this[j], this[i] | |
this[i].index = i | |
this[j].index = j | |
} | |
func (this *PriorityQueue) Push(x interface{}) { | |
n := len(*this) | |
node := x.(*Node) | |
node.index = n | |
*this = append(*this, node) | |
} | |
func (this *PriorityQueue) Pop() interface{} { | |
old := *this | |
n := len(*this) | |
node := old[n-1] | |
node.index = -1 | |
*this = old[0 : n-1] | |
return node | |
} | |
func (this *PriorityQueue) IsEmpty() bool { | |
if len(*this) == 0 { | |
return true | |
} | |
return false | |
} | |
func (this *PriorityQueue) update(node *Node) { | |
heap.Fix(this, node.index) | |
} | |
type Node struct { | |
X int | |
Y int | |
// for priority queue | |
Score float64 | |
index int | |
LPath []*Path | |
RPath []*Path | |
BestScore float64 | |
Prev *Node | |
GoalScore float64 | |
Next *Node | |
} | |
type Path struct { | |
LNode *Node | |
RNode *Node | |
Score float64 | |
} | |
func (this *Path) Add(lnode, rnode *Node) { | |
this.LNode = lnode | |
this.RNode = rnode | |
rnode.LPath = append(rnode.LPath, this) | |
lnode.RPath = append(lnode.RPath, this) | |
} | |
func Viterbi(nodes [][]*Node, eos *Node) { | |
segment_size := len(nodes) | |
for i := 0; i < segment_size; i++ { | |
for j := 0; j < len((nodes)[i]); j++ { | |
n := nodes[i][j] | |
bestScore := 0. | |
var bestNode *Node | |
for _, p := range n.LPath { | |
s := p.Score + p.LNode.BestScore | |
// fmt.Println("path score", p.Score, "bestscore", p.RNode.BestScore) | |
if s > bestScore { | |
bestScore = s | |
bestNode = p.LNode | |
} | |
} | |
n.BestScore = bestScore | |
n.Prev = bestNode | |
// fmt.Println(fmt.Sprintf("(%d,%d)<-(%d,%d)", n.X, n.Y, n.Prev.X, n.Prev.Y), n.BestScore) | |
} | |
} | |
bestScore := 0. | |
var bestNode *Node | |
for _, p := range eos.LPath { | |
s := p.Score + p.LNode.BestScore | |
if s > bestScore { | |
bestScore = s | |
bestNode = p.LNode | |
} | |
} | |
eos.BestScore = bestScore | |
eos.Prev = bestNode | |
// fmt.Println("best score at eos", eos.BestScore) | |
} | |
func Backtrack(nodes [][]*Node, eos *Node) []int { | |
var result []int = make([]int, len(nodes), len(nodes)) | |
var n *Node | |
var t int = len(nodes) - 1 | |
n = eos.Prev | |
// fmt.Println(n.X, n.Y) | |
result[t] = n.Y | |
t -= 1 | |
for i := len(nodes) - 1; i > 0; i-- { | |
n = n.Prev | |
// fmt.Println(n.X, n.Y) | |
result[t] = n.Y | |
t -= 1 | |
} | |
fmt.Println("Backtrack") | |
for _, y := range result { | |
fmt.Println(fmt.Sprintf("%d", y)) | |
} | |
return result | |
} | |
func BackwardAstar(N int, nodes [][]*Node, eos *Node) [][]int { | |
pqueue := make(PriorityQueue, 0, 10) | |
heap.Init(&pqueue) | |
eos.Score = 0. | |
heap.Push(&pqueue, eos) | |
var result []*Node = make([]*Node, N, N) | |
var n int = 0 | |
for { | |
if pqueue.IsEmpty() { | |
break | |
} | |
var node *Node = pqueue.Pop().(*Node) | |
// fmt.Println("POP", node.X, node.Y) | |
if node.X == 0 { // is bos | |
result[n] = node | |
n += 1 | |
} else { | |
for _, p := range node.LPath { | |
prev := *p.LNode // copy | |
prev.GoalScore = node.GoalScore + p.Score | |
prev.Score = prev.BestScore + prev.GoalScore | |
prev.Next = node | |
// fmt.Println("PUSH", prev.X, prev.Y, prev.Score) | |
heap.Push(&pqueue, &prev) | |
} | |
} | |
if n >= N { | |
break | |
} | |
} | |
var results [][]int = make([][]int, 0, N) | |
for i, n := range result { | |
var result []int = make([]int, 0, len(nodes)) | |
for { | |
n = n.Next | |
if n.X == eos.X { | |
break | |
} | |
fmt.Println(fmt.Sprintf("%d-best: %d", i+1, n.Y)) | |
result = append(result, n.Y) | |
} | |
fmt.Println("") | |
} | |
return results | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
// "fmt" | |
"testing" | |
) | |
func TestNbest(t *testing.T) { | |
/* | |
n1 --- n3 | |
/ \ / \ | |
bos x eos | |
\ / \ / | |
n2 --- n4 | |
*/ | |
nodes := make([][]*Node, 0, 10) | |
bos := Node{X: 0, Y: 0, BestScore: 0.} | |
n1 := Node{X: 1, Y: 0} | |
n2 := Node{X: 1, Y: 1} | |
p1 := Path{Score: 1.} | |
p1.Add(&bos, &n1) | |
p2 := Path{Score: 2.} | |
p2.Add(&bos, &n2) | |
n3 := Node{X: 2, Y: 0} | |
n4 := Node{X: 2, Y: 1} | |
p3 := Path{Score: 1.} | |
p3.Add(&n1, &n3) | |
p4 := Path{Score: 2.} | |
p4.Add(&n1, &n4) | |
p5 := Path{Score: 2.} | |
p5.Add(&n2, &n3) | |
p6 := Path{Score: 4.} | |
p6.Add(&n2, &n4) | |
eos := Node{X: 3, Y: 0} | |
p7 := Path{Score: 4.} | |
p7.Add(&n3, &eos) | |
p8 := Path{Score: 1.} | |
p8.Add(&n4, &eos) | |
nodes = append(nodes, []*Node{&n1, &n2}) | |
nodes = append(nodes, []*Node{&n3, &n4}) | |
// fmt.Println(eos.LPath[0].LNode) | |
// eos.LPath[0].LNode.BestScore = 100. | |
// fmt.Println("eos", eos.LPath[0].LNode) | |
// fmt.Println("nodes", nodes[1][1]) | |
// fmt.Println("viterbi") | |
Viterbi(nodes, &eos) | |
Backtrack(nodes, &eos) | |
N := 4 | |
BackwardAstar(N, nodes, &eos) | |
} |
ラティスの構造は以下のとおり。パスの重みは上記コードを参照。
0 n1 --- n3
/ \ / \
bos x eos
\ / \ /
1 n2 --- n4
ビタビアルゴリズムを適用後、バックトラックで最適なパスを選んだ場合の解と、 ビタビアルゴリズムを適用後、A*アルゴリズムで後ろ向きに最適なパスを順に選んだ場合の1-best解が一致している。 2-best以降もラティスの情報から、あっていることを確認。
下記の###以降はこのエントリ用に付け足した文字列。
$go test
Backtrack ### ビタビアルゴリズムで求めたパス
1 ### n2を通って、
0 ### n3を通る
1-best: 1 ### n2を通って、
1-best: 0 ### n3を通る
2-best: 1 ### n2を通って、
2-best: 1 ### n4を通る
3-best: 0 ### n1を通って、
3-best: 0 ### n3を通る
4-best: 0 ### n1を通って、
4-best: 1 ### n4を通る
goで優先度付きキューを実装するには、heap - The Go Programming LanguageのExample (PriorityQueue) が参考になる。
少し話はそれるが、機械翻訳において、n-best解は似通ったものが選ばれてしまう問題があるので、多様性を考慮するモデルを提案している話があって、これも気になるのでメモ。
Kevin Gimpel, Dhruv Batra, Chris Dyer, and Gregory Shakhnarovich. “A Systematic Exploration of Diversity in Machine Translation”, EMNLP, 2013. pdf