使用ML.NET实现德州扑克牌型分类器
2018-05-24

本文将基于ML.NET v0.2预览版,重点介绍提取特征的思路和方法,实现德州扑克牌型分类器。

先介绍一下德州扑克的基本牌型,一手完整的牌共有五张扑克,10种牌型分别是:

1. 高牌,花色和点数同时没有相同的牌。

2. 一对,点数有且仅有两张相同的牌。

3. 两对,两张相同点数的牌,加另外两张相同点数的牌。

4. 三条,有三张同一点数的牌。

5. 顺子,五张顺连的牌。

6. 同花,五张同一花色的牌。

7. 葫芦,三张同一点数的牌,加一对其他点数的牌。

8. 四条,有四张同一点数的牌。

9. 同花顺,同一花色五张顺连的牌。

10. 皇家同花顺,最高点数是A的同花顺的牌。

这一次我们将使用逻辑回归模型,来训练数据完成我们想要的分类模型。

准备数据集


数据来源在Poker Hand Data Set,下载链接为:poker-hand-testing.datapoker-hand-training-true.data。内容类似如下:

3,92,3,3,2,2,9,3,5,1

4,4,1,11,2,9,4,13,2,7,0

1,5,1,9,2,8,2,4,4,3,0

4,12,4,7,4,5,2,10,2,2,0

4,3,2,4,4,13,3,6,4,12,0

2,5,4,5,4,1,4,9,2,7,1

2,12,3,12,3,7,3,11,2,7,2

4,13,2,6,4,6,4,10,4,9,1

...

说明一下每一行的格式:

第1张花色,第1张点数,第2张花色,第2张点数,第3张花色,第3张点数,第4张花色,第4张点数,第5张花色,第5张点数,牌型

花色是1-4代表红心,黑桃,方块,梅花。点数1表示A,2-10保持不变,11表示J,12表示Q,13表示K。

特征分析


前几篇数据集的内容,基本上分割好就是特征了,这一次不同,每一行的数值仅仅是元数据,也就是说,通过五张牌的花色和点数值是不能直接和牌型形成一一对应的联系,需要先按本文开头介绍的10种牌型的描述,找到关键可数值化的字段。因此,我选择了这样一些字段:对子数,是否是三条,是否是四条,是否是顺子,是否同花。通过这5个字段值的组合,一定能判断出牌型。

于是,我定义出我想要的数据类型:

public class PokerHandData

{

    [Column(ordinal: "0")]

    public float S1;

    [Column(ordinal: "1")]

    public float C1;

    [Column(ordinal: "2")]

    public float S2;

    [Column(ordinal: "3")]

    public float C2;

    [Column(ordinal: "4")]

    public float S3;

    [Column(ordinal: "5")]

    public float C3;

    [Column(ordinal: "6")]

    public float S4;

    [Column(ordinal: "7")]

    public float C4;

    [Column(ordinal: "8")]

    public float S5;

    [Column(ordinal: "9")]

    public float C5;

    [Column(ordinal: "10", name: "Label")]

    public float Power;


    [Column(ordinal: "11")]

    public float IsSameSuit;


    [Column(ordinal: "12")]

    public float IsStraight;


    [Column(ordinal: "13")]

    public float FourOfKind;


    [Column(ordinal: "14")]

    public float ThreeOfKind;


    [Column(ordinal: "15")]

    public float PairsCount;

}

S表示花色,C表示点数,Power表示牌型,PairsCount表示对子数,ThreeOfKind表示是否是三条,FourOfKind表示是否是四条,IsStraight表示是否顺子,IsSameSuit表示是否同花。

判断是否同花,只需要比较S1-S5的值即可。

public float GetIsSameSuit()

{

    if (S1 == S2 && S2 == S3 && S3 == S4 && S4 == S5)

        return 1;

    else

        return 0;

}

判断其它几个特征,我需要一个通用方法,先统计出每一行的C1-C5,每种点数出现的次数。

static Dictionary<int, int> GetValueCountsOfCondition(IEnumerable<int> cards)

{

    var dic = new Dictionary<int, int>();


    foreach (var item in cards)

    {

        if (dic.ContainsKey(item))

        {

            dic[item] += 1;

        }

        else

        {

            dic.Add(item, 1);

        }

    }

    return dic;

}

然后再按特征涵义计算值。

public float GetFourOfKind()

{

    return GetCountOfCondition(4);

}


public float GetThreeOfKind()

{

    return GetCountOfCondition(3);

}


public float GetPairsCount()

{

    return GetCountOfCondition(2);

}


private IEnumerable<int> GetCards()

{

    if (cards == null)

    {

        cards = new[] { Convert.ToInt32(C1), Convert.ToInt32(C2), Convert.ToInt32(C3), Convert.ToInt32(C4), Convert.ToInt32(C5) };

    }


    return cards;

}


private float GetCountOfCondition(int target)

{

    if (valueCounts == null)

    {

        valueCounts = GetValueCountsOfCondition(GetCards());

    }


    return valueCounts.Count(i => i.Value == target);

}

判断是否为顺子的方法,简单而直接,就是看间隔差是不是为1,或者最高点有A剩下的必须是10、J、Q、K,都算顺子。

public float GetIsStraight()

{

    var keys = GetCards().ToArray();

    Array.Sort(keys);

    if (keys[1] - keys[0] == keys[2] - keys[1] && keys[2] - keys[1] == keys[3] - keys[2] && keys[3] - keys[2] == keys[4] - keys[3] && keys[4] - keys[3] == 1)

    {

        return 1;

    }

    else if (keys[0] == 1 && keys[1] == 10 && keys[2] == 11 && keys[3] == 12 && keys[4] == 13)

    {

        return 1;

    }

    else

    {

        return 0;

    }

}


加载数据


这次由于使用了ML.NET v0.2,该版本的LearningPipeline新增了一种支持集合类型的数据源。因此,我将示范一种全新的载入数据集的方法,先以文件载入元数据,然后直接初始化特征的值。

static IEnumerable<PokerHandData> LoadData(string path)

{

    using (var environment = new TlcEnvironment())

    {

        var pokerHandData = new List<PokerHandData>();

        var textLoader = new Microsoft.ML.Data.TextLoader(path).CreateFrom<PokerHandData>(useHeader: false, separator: ',', trimWhitespace: false);

        var experiment = environment.CreateExperiment();

        var output = textLoader.ApplyStep(null, experiment) as ILearningPipelineDataStep;


        experiment.Compile();

        textLoader.SetInput(environment, experiment);

        experiment.Run();


        var data = experiment.GetOutput(output.Data);


        using (var cursor = data.GetRowCursor((a => true)))

        {

            var getters = new ValueGetter<float>[]{

                cursor.GetGetter<float>(0),

                cursor.GetGetter<float>(1),

                cursor.GetGetter<float>(2),

                cursor.GetGetter<float>(3),

                cursor.GetGetter<float>(4),

                cursor.GetGetter<float>(5),

                cursor.GetGetter<float>(6),

                cursor.GetGetter<float>(7),

                cursor.GetGetter<float>(8),

                cursor.GetGetter<float>(9),

                cursor.GetGetter<float>(10)

            };


            while (cursor.MoveNext())

            {

                float value0 = 0;

                float value1 = 0;

                float value2 = 0;

                float value3 = 0;

                float value4 = 0;

                float value5 = 0;

                float value6 = 0;

                float value7 = 0;

                float value8 = 0;

                float value9 = 0;

                float value10 = 0;

                getters[0](ref value0);

                getters[1](ref value1);

                getters[2](ref value2);

                getters[3](ref value3);

                getters[4](ref value4);

                getters[5](ref value5);

                getters[6](ref value6);

                getters[7](ref value7);

                getters[8](ref value8);

                getters[9](ref value9);

                getters[10](ref value10);


                var hands = new PokerHandData()

                {

                    S1 = value0,

                    C1 = value1,

                    S2 = value2,

                    C2 = value3,

                    S3 = value4,

                    C3 = value5,

                    S4 = value6,

                    C4 = value7,

                    S5 = value8,

                    C5 = value9,

                    Power = value10

                };

                hands.Init();

                pokerHandData.Add(hands);

            }

        }


        return pokerHandData;

    }

}

其中PokerHandData类增加一个初始化的方法。

public void Init()

{

    IsSameSuit = GetIsSameSuit();

    IsStraight = GetIsStraight();

    FourOfKind = GetFourOfKind();

    ThreeOfKind = GetThreeOfKind();

    PairsCount = GetPairsCount();

}

训练模型


预测的结构定义,以计分为目标,float[]类型表示是对每一种牌型有一个得分,分值越高属于那一种牌型的概率越大。

public class PokerHandPrediction
{
    [ColumnName("Score")]       public float[] PredictedPower;
}

模型的选择是LogisticRegressionClassifier,CollectionDataSource就是用来创建集合类型数据载入的对象。而特征的指定不再是全部字段,而是之前增加的那几个。

public static PredictionModel<PokerHandData, PokerHandPrediction> Train(IEnumerable<PokerHandData> data)

{

    var pipeline = new LearningPipeline();

    var collection = CollectionDataSource.Create(data);

    pipeline.Add(collection);

    pipeline.Add(new ColumnConcatenator("Features", "IsSameSuit", "IsStraight", "FourOfKind", "ThreeOfKind", "PairsCount"));

    pipeline.Add(new LogisticRegressionClassifier());

    var model = pipeline.Train<PokerHandData, PokerHandPrediction>();

    return model;

}


预测结果


首先,对预测的得分,我们需要判断一个概率倾向。

static string GetPower(float[] nums)

{

    var index = -1;

    var last = 0F;

    for (int i = 0; i < nums.Length; i++)

    {

        if (nums[i] > last)

        {

            index = i;

            last = nums[i];

        }

    }


    var suit = string.Empty;

    switch (index)

    {

        case 0:

            suit = "高牌";

            break;

        case 1:

            suit = "一对";

            break;

        case 2:

            suit = "两对";

            break;

        case 3:

            suit = "三条";

            break;

        case 4:

            suit = "顺子";

            break;

        case 5:

            suit = "同花";

            break;

        case 6:

            suit = "葫芦";

            break;

        case 7:

            suit = "四条";

            break;

        case 8:

            suit = "同花顺";

            break;

        case 9:

            suit = "皇家同花顺";

            break;


    }

    return suit;

}

最后就是进行预测的部分了。

public static void Predict(PredictionModel<PokerHandData, PokerHandPrediction> model)

{

    var test1 = new PokerHandData

    {

        S1 = 1,

        C1 = 2,

        S2 = 1,

        C2 = 3,

        S3 = 3,

        C3 = 4,

        S4 = 4,

        C4 = 5,

        S5 = 2,

        C5 = 6

    };


    var test2 = new PokerHandData

    {

        S1 = 4,

        C1 = 5,

        S2 = 1,

        C2 = 5,

        S3 = 3,

        C3 = 5,

        S4 = 2,

        C4 = 12,

        S5 = 4,

        C5 = 7

    };

    test1.Init();

    test2.Init();

    IEnumerable<PokerHandData> pokerHands = new[]

    {

        test1,

        test2

    };

    IEnumerable<PokerHandPrediction> predictions = model.Predict(pokerHands);

    Console.WriteLine();

    Console.WriteLine("PokerHand Predictions");

    Console.WriteLine("---------------------");


    var pokerHandsAndPredictions = pokerHands.Zip(predictions, (pokerHand, prediction) => (pokerHand, prediction));

    foreach (var (pokerHand, prediction) in pokerHandsAndPredictions)

    {

        Console.WriteLine($"PokerHand: {ShowHand(pokerHand)} | Prediction: { GetPower(prediction.PredictedPower)}");

    }

    Console.WriteLine();


}

创建项目的步骤请参看本文开头导读给出的文章链接,不再赘述,运行结果如下:

最后放出源代码文件:https://files.cnblogs.com/files/BeanHsiang/pokerhand.zip

希望读者们保持对ML.NET的持续关注,相信新的特性一定能实现更复杂有趣的场景。

相关文章:

原文地址: https://www.cnblogs.com/BeanHsiang/p/9080358.html