読者です 読者をやめる 読者になる 読者になる

CUBE SUGAR CONTAINER

技術系のこと書きます。

Golang でマージソートを書いてみる

以前、勉強がてら Golang でクイックソートを書いたので、ついでにマージソートも書いてみる。

blog.amedama.jp

サンプルコードは次の通り。 例によってナイーブな実装なので実用性は低い。

package main

import (
    "fmt"
    "time"
    "math/rand"
)


func merge(left, right []int) (ret []int) {
    ret = []int{}
    for len(left) > 0 && len(right) > 0 {
        var x int
        // ソート済みのふたつのスライスからより小さいものを選んで追加していく (これがソート処理)
        if right[0] > left[0] {
            x, left = left[0], left[1:]
        } else {
            x, right = right[0], right[1:]
        }
        ret = append(ret, x)
    }
    // 片方のスライスから追加する要素がなくなったら残りは単純に連結できる (各スライスは既にソートされているため)
    ret = append(ret, left...)
    ret = append(ret, right...)
    return
}

func sort(left, right []int) (ret []int) {
    // ふたつのスライスをそれぞれ再帰的にソートする
    if len(left) > 1 {
        l, r := split(left)
        left = sort(l, r)
    }
    if len(right) > 1 {
        l, r := split(right)
        right = sort(l, r)
    }
    
    // ソート済みのふたつのスライスをひとつにマージする
    ret = merge(left, right)
    return
}

func split(values []int) (left, right []int) {
    // スライスを真ん中でふたつに分割する
    left = values[:len(values) / 2]
    right = values[len(values) / 2:]
    return
}

func Sort(values []int) (ret []int) {
    left, right := split(values)
    ret = sort(left, right)
    return
}

func main() {
    // UNIX 時間をシードにして乱数生成器を用意する
    t := time.Now().Unix()
    s := rand.NewSource(t)
    r := rand.New(s)

    // ランダムな値の入った配列を作る
    N := 10
    values := make([]int, N)
    for i := 0; i < N; i++ {
        values[i] = r.Intn(N)
    }
    
    // ソートして結果を出力する
    sortedValues := Sort(values)
    fmt.Println(sortedValues)
}

実行結果は次の通り。

$ go run mergesort.go
[0 1 2 3 5 6 6 8 8 9]

ばっちり。

おまけ

Python 版

#!/usr/bin/env python
# -*- coding: utf-8 -*-


def _merge(left, right):
    ret = []
    while len(left) > 0 and len(right) > 0:
        if left[0] > right[0]:
            ret.append(right.pop(0))
        else:
            ret.append(left.pop(0))

    ret += left + right
    return ret


def _sort(left, right):
    if len(left) > 1:
        l, r = _split(left)
        left = _sort(l, r)
    if len(right) > 1:
        l, r = _split(right)
        right = _sort(l, r)

    ret = _merge(left, right)
    return ret


def _split(values):
    left = values[:len(values) // 2]
    right = values[len(values) // 2:]
    return left, right


def sort(values):
    left, right = _split(values)
    ret = _sort(left, right)
    return ret


def main():
    import random
    N = 10
    values = [random.randint(0, N) for _ in range(N)]
    sorted_list = sort(values)
    print(sorted_list)


if __name__ == '__main__':
    main()