出現確率が一様な抽選は、インデックスを一様乱数で求めればいいので実装が簡単です。
public static T sample<T>(T[] values)
{
return values[UnityEngine.Random.Range(0, values.Length)];
}
public static T sample<T>(List<T> values)
{
return values[UnityEngine.Random.Range(0, values.Count)];
}
ちなみに、この抽選は引いたくじを除外せずに次回も有効となる独立試行です。このような抽選を復元抽出と呼ぶそうです。
各要素の出現確率(重み)を設定して復元抽出したいとき、上の手法ではうまくいかないので、ひと工夫する必要があります。例えば、レア度によって出現確率の異なるガチャを実装したいといったケースですね。
以下、Unity向けに実装しますが、乱数取得部分を変えればUnity以外でも使えます。
線形探索による手法
重みの合計値を算出して、その範囲内で乱数値を生成し、重みリストを順に見ていって対応する要素を引っ張ってきます。計算量はO(N)ですが、追加の記憶領域を必要としません。
public static int sampleByWeight(List<float> weights)
{
var total = calcTotalWeight(weights);
var random = UnityEngine.Random.value * total;
var index = resolveIndex(weights, random);
return index;
}
private static float calcTotalWeight(List<float> weights)
{
var length = weights.Count;
var total = 0.0f;
for(var i = 0; i < length; ++i) {
var weight = weights[i];
if(weight > 0.0f) {
total += weight;
}
}
return total;
}
public static int resolveIndex(List<float> weights, float random)
{
var length = weights.Count;
var index = -1;
var cumulative = 0.0f;
for(var i = 0; i < length; ++i) {
var weight = weights[i];
if(weight > 0.0f) {
cumulative += weight;
if(cumulative >= random) {
index = i;
break;
}
}
}
return index;
}
とりあえずList用の実装のみ示しましたが、同様に配列でも対応できます。
この手のものは境界条件に注意して実装しましょう。乱数で境界値が返るようにするとか小細工してテストします。
重みリストを降順に並べて、抽選時のループ回数が確率的に少なくなるようにする最適化もあります。準備処理として、少なくともソートにO(N log N)かかることにはなります。そして、インデックスとの対応付けを管理するデータを持たないといけないので、上の実装ではそのまま使えないですね。
二分探索による手法
重みの合計値を算出するのは必須なので、重みリストを使い捨てる前提なら、O(N)は必要経費ではあります。が、その後の抽選はどうにかならないだろうかと思いました。
重みの累積分布関数(つまり単調増加)を作成し、これを利用して二分探索する手法があります。
累積分布関数の作成にO(N)かかるのは仕方ないとして、その後の抽選は二分探索なのでO(log n)です。また、線形探索による単純な手法と異なり、累積分布関数作成のために追加の記憶領域を必要とします。メモリ使用量はO(N)です。
累積分布関数を使い回すなら、あとから重みを追加しても、追加分だけ累積分布関数に加えていけばいいので更新処理が安くて良いですね。元のサイズがNで追加するサイズがMなら、O(N)分は計算済みなので、O(M)で更新が完了する、ということになります。
但し、追加済みの重みを変更する場合、それ以降の累積分布関数をすべて直さないといけないので注意です。今回は追加済みの重みは変更できないものとしました。
重みリストと重み合計を管理する抽象クラスRandomSamplerBaseを継承して、BinarySamplerとして実装しました。
public abstract class RandomSamplerBase
{
public int count { get { return _weights.Count; } }
public float totalWeight { get { return _totalWeight; } }
public float getWeight(int index)
{
return _weights[index];
}
public virtual void addWeight(float weight, bool skipRefresh = false)
{
addWeightInner(weight);
if(!skipRefresh) {
refresh();
}
}
public virtual void addWeights(IEnumerable<float> weights, bool skipRefresh = false)
{
var collection = weights as ICollection<float>;
if(collection != null) {
// 容量更新
if(_weights.Capacity < _weights.Count + collection.Count) {
_weights.Capacity = _weights.Count + collection.Count;
}
}
foreach(var weight in weights) {
addWeightInner(weight);
}
if(!skipRefresh) {
refresh();
}
}
public virtual void clear()
{
_weights.Clear();
_totalWeight = 0.0f;
}
public virtual void refresh() { }
public abstract int sample();
protected virtual float nextRandom()
{
return Random.value;
}
private float addWeightInner(float weight)
{
if(weight > 0.0f) {
_weights.Add(weight);
_totalWeight += weight;
return weight;
}
_weights.Add(0.0f);
return 0.0f;
}
private List<float> _weights = new List<float>();
private float _totalWeight = 0.0f;
}
public class BinarySampler : RandomSamplerBase
{
public BinarySampler() { }
public BinarySampler(IEnumerable<float> weights)
{
addWeights(weights);
}
public override void clear()
{
base.clear();
_cdf.Clear();
}
// O(N)で累積分布関数(CDF)を構築
public override void refresh()
{
base.refresh();
var cumulative = 0.0f;
if(_cdf.Count > 0) {
cumulative = _cdf[_cdf.Count - 1];
}
var length = count;
for(var i = _cdf.Count; i < length; ++i) {
var weight = getWeight(i);
cumulative += weight;
_cdf.Add(cumulative);
}
}
// O(log n)で復元抽出
public override int sample()
{
var count = this.count;
if(count == 0) return -1;
var random = nextRandom() * totalWeight;
var index = -1;
// 二分探索
var begin = 0;
var end = count;
while(begin < end) {
var i = (end + begin) >> 1;
var cumulative = _cdf[i];
if(cumulative < random) {
begin = i + 1;
}
else {
index = i;
end = i;
}
}
return index;
}
private List<float> _cdf = new List<float>();
}
地味に重み付き二分探索のきれいな実装に悩みました。
Alias method
高速な復元抽出アルゴリズムがないか調べたところ、Walkerさんの考えたAlias methodという手法を知りました。
これは乱数からインデックスを決定するのに都合のいいデータを用意することで、復元抽出をO(1)という凄まじい処理速度で実現します。
但し、都合のいいデータを準備するにあたってO(N)処理が必要になるため、重みリストをそのまま使い回して何度も復元抽出するという場面で威力を発揮します。
一度しか抽選しないならば、先程の単純な線形探索による手法が良いでしょう。準備処理も込みでどちらもO(N)ですが、Alias methodではループ3回必要なのに対し、線形探索による手法であればループ2回で済みます。大した差ではないかもしれませんが、さらに線形探索による手法は追加の記憶領域を必要としないのが良いですね。うまく使い分けたいところです。
また、重みリストに変更が発生した場合は、準備処理のやり直しが必要です。元のサイズがNで追加するサイズがMなら、更新処理はO(N + M)です。更新が必要なら線形探索か二分探索による手法のほうが向いているでしょう。
ちなみに、メモリ使用量はO(N)です。
アルゴリズムの解説は上記の記事がわかりやすいと思います。ちょうどC#の実装例もありますが、無駄があったり、なんかちょっと怪しい気がします。
考え方を大まかに説明すると下記のようになります。
-
準備処理
- 重みリストを要素数Nで正規化(合計がNになるよう調整)する。
- Nで正規化された重みリストを、N個のブロック(つまり幅1)に分割し、各ブロックで1~2種類の重みが対応するよう重みを分配する。
-
抽選処理
[0, N)
範囲の乱数値を取得し、値に対応するブロックを取得する。- 乱数値の小数部分がブロック内における重みに相当するので、そのブロックに設定された重みと比較して、1~2種類設定されたインデックスから正しいものを選択する。
各インデックスに必ず1個はブロックが割り当てられて、ブロック内で見て重みが余っているようなら、重みがブロックから溢れているインデックスとスペースを共有します。1つのインデックスに対応するブロックについて、重み次第では別のインデックス(エイリアス)に差し替える、というのがこの手法の肝ですね。
言葉だとよくわからなければ、今の説明を踏まえた上で、先程の参考記事を見ていただければ図説で理解しやすいと思います。
- R言語のsample関数実装(※walker_ProbSampleReplace関数が該当)
- ランダム抽出アルゴリズムについて考える - Shogo’s Blog
さて、実装の参考としてはこのあたりが良さそうです。
R言語の実装を元に、再実装してみました。そこそこ読みやすくしたつもりですが、どうでしょう。
public class AliasSampler : RandomSamplerBase
{
public AliasSampler() { }
public AliasSampler(IEnumerable<float> weights)
{
addWeights(weights);
}
public override void clear()
{
base.clear();
_normalized = null;
_aliases = null;
}
// O(N)で内部リスト構築
public override void refresh()
{
base.refresh();
var length = count;
_normalized = new float[length];
_aliases = new int[length];
var indexes = new int[length];
// 重みの合計を算出
var _totalWeight = 0.0f;
for(var i = 0; i < length; ++i) {
var weight = getWeight(i);
if(weight > 0.0f) {
_totalWeight += weight;
}
}
var normalizeRatio = length / _totalWeight;
var left = -1;
var right = length;
for(var i = 0; i < length; ++i) {
// エイリアス初期化
_aliases[i] = i;
// 重みを要素数で正規化
var weight = getWeight(i);
if(weight > 0.0f) {
weight *= normalizeRatio;
}
else {
weight = 0.0f;
}
_normalized[i] = weight;
// 重みが1ブロックに収まるかどうかで前方・後方に振り分け
if(weight < 1.0f) {
indexes[++left] = i;
}
else {
indexes[--right] = i;
}
}
// 少なくとも1つは重みが1でなければエイリアス設定
if(left >= 0 && right < length) {
left = 0;
while(left < length && right < length) {
var leftIndex = indexes[left];
var rightIndex = indexes[right];
// エイリアス設定
_aliases[leftIndex] = rightIndex;
// rightIndexに紐づく重みからleftIndex分を減算
var leftWeight = _normalized[leftIndex];
var rightWeight = _normalized[rightIndex] + leftWeight - 1.0f;
_normalized[rightIndex] = rightWeight;
// 重みが1未満になったら後方リストの先頭を後ろにずらす
if(rightWeight < 1.0f) {
++right;
}
++left;
}
}
}
// O(1)で復元抽出
public override int sample()
{
var count = this.count;
if(count == 0) return -1;
var random = nextRandom() * count;
var index = (int)random;
var weight = random - index;
// 末尾に丸める
if(index >= count) {
index = count - 1;
weight = 1.0f;
}
// 重みを超えたらエイリアスに差し替える
if(_normalized[index] <= weight) {
index = _aliases[index];
}
return index;
}
private float[] _normalized = null;
private int[] _aliases = null;
}
余りスペースの分配方法がとてもスマートで、美しいアルゴリズムだなと思います。
非復元抽出に向けて
復元抽出は引いたくじが次回以降も有効な抽選でした。引いたくじを次から出さないようにする非復元抽出については、これだという方法は特に決まってなさそうです。
Alias methodで参考にした上記ブログ記事がかなり優秀で、今回取り扱った手法を網羅しつつ、それらを非復元抽出に対応させる方法もまとめています。考え方だけこちらにもまとめます。
復元抽出の手法を、非復元抽出に対応させるには、一度引いた要素の重みをゼロにして重み合計を更新するのが手っ取り早いです。
ただ、ほとんどの手法が、それに伴って抽選のための内部状態を更新する必要があり、悩ましいです。
- 線形探索による手法
- 他に管理する内部状態がないので問題なし
- 二分探索による手法
- 累積分布関数の再生成にO(N)かかる
- 累積分布関数ではなく二分木の構築にすれば、O(log N)で済む
- Alias method
- O(N)かけてエイリアスを更新するしかない
今のところ、私は非復元抽出を使いそうにないので、理解だけしていればいいかなぁ、と思いました。ので、今回はここまで。