網頁

2019/4/25

Java 依權重隨機取得對應的值

Java 依有權重的隨機(weighted random)取得對應的值。

// 設定字母出現的權重
Map<String, Integer> map = new HashMap<>();
map.put("C", 120);
map.put("A", 10);
map.put("B", 80);
map.put("E", 200);
map.put("D", 50);

System.out.println(map);

int totalWeight = 0; // 總權重
// 用有排序的Map將原本的Map裝入,但key為權重累加值,value為字母
NavigableMap<Integer, String> sortedMap = new TreeMap<>();
for (Entry<String, Integer> entry : map.entrySet()) {
    int weight = entry.getValue();
    String letter = entry.getKey();
    totalWeight += weight;
    
    sortedMap.put(totalWeight, letter);
}

// 各權重累加對應的字母
System.out.println(sortedMap); // {10=A, 90=B, 210=C, 260=D, 460=E}

Map<String, Integer> counter = new HashMap<>();
int runtimes = 100000000; // 隨機抽取次數
Random random = new Random()
for (int i = 0; i < runtimes ; i ++) {
    
    Double randomDouble = random.nextDouble() * totalWeight; // 從 0 ~ 460間取得一個隨機數
    String letter = sortedMap.ceilingEntry(randomDouble.intValue()).getValue(); // 依權重取得對應的字母
    
    // 累計隨機取得的字母的出現次數
    if (counter.get(letter) != null) {
        counter.put(letter, counter.get(letter) + 1);
    } else {
        counter.put(letter, 1);
    }
}

// 印出每個字母出現的權重
for (Entry<String, Integer> entry : counter.entrySet()) {
    String letter = entry.getKey();
    double showRate = (double) entry.getValue()/runtimes;
    int showWeight = (int) (showRate * totalWeight);
    System.out.println(letter + ":" + showWeight);
}

這做法是將每個值所對應的權重分配到一長條上,長條的數為1 ~ 460,然後從1 ~ 460中隨機取一個數,看落在哪個區間便取出該區間對應的值。

例如E的權重為200,占的區間最大,隨機取出的數也最容易落到E的區間,因此充分反映依權重能被取得的機會。

  10       90          210   260                  460 
+-+--------+------------+-----+--------------------+
|A|    B   |      C     |  D  |          E         |
+-+--------+------------+-----+--------------------+

整理一下

public static void main(String[] args) {
    Map<String, Integer> map = new HashMap<>();
    map.put("C", 120);
    map.put("A", 10);
    map.put("B", 80);
    map.put("E", 200);
    map.put("D", 50);

    int totalWeight = 0; // 總權重
    // 用有排序的Map將原本的Map裝入,但key為權重累加值,value為字母
    NavigableMap<Integer, String> sortedMap = new TreeMap<>();
    for (Entry<String, Integer> entry : map.entrySet()) {
        int weight = entry.getValue();
        String letter = entry.getKey();
        totalWeight += weight;
    
        sortedMap.put(totalWeight, letter);
    }

    List<String> randomList = genRandomListByWeight(10, sortedMap); 
}

/**
 * 依權重表取出指定數量的List
 * @param size List大小  
 * @param weightTable 權重表
 & @return
 */
public static final <T> List<T> genRandomListByWeight(int size, NavigableMap<Integer, T> weightTable) {
    List<T> randomList = new ArrayList<>();
    for (int i = 0; i < size; i++) {
        randomList.add(getRamdomElementByWeight(weightTable));
    }
    return randomList;
}

private static Random random = new Random();

/**
 * 依權重表取出元素
 * @param weightTable 權重表
 * @return
 */
public static final <V> V getRamdomElementByWeight(NavigableMap<Integer, V> weightTable) {
    Double randomDouble = random.nextDouble() * weightTable.lastKey();
    return weightTable.ceilingEntry(randomDouble.intValue()).getValue();
}


沒有留言:

張貼留言