Introducing My First AI/ML-Based Indicator (No External Libraries)

Created at 10 Dec 2024, 19:05
How’s your experience with the cTrader Platform?
Your feedback is crucial to cTrader's development. Please take a few seconds to share your opinion and help us improve your trading experience. Thanks!
ED

EDG777

Joined 05.11.2024

Introducing My First AI/ML-Based Indicator (No External Libraries)
10 Dec 2024, 19:05


Greetings!
I am pleased to introduce my very first AI-driven indicator, which employs machine learning techniques while avoiding any external Python libraries. I hope some of you will find it beneficial.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using cAlgo.API;
using cAlgo.API.Indicators;
using cAlgo.API.Internals;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers;

namespace cAlgo.Indicators
{
    [Indicator(IsOverlay = true, TimeZone = TimeZones.UTC, AccessRights = AccessRights.None, AutoRescale = false)]
    public class ThreeDrivesPatternMLGPU : Indicator
    {
        [Parameter("Minimum Retracement (%)", DefaultValue = 61.8)]
        public double MinRetracement { get; set; }

        [Parameter("Maximum Retracement (%)", DefaultValue = 78.6)]
        public double MaxRetracement { get; set; }

        [Parameter("Time Symmetry Tolerance (%)", DefaultValue = 15)]
        public double TimeSymmetryTolerance { get; set; }

        [Parameter("Lookback Window", DefaultValue = 30, MinValue = 10)]
        public int LookbackWindow { get; set; }

        [Parameter("Swing Strength", DefaultValue = 2, MinValue = 1)]
        public int SwingStrength { get; set; }

        [Parameter("Number of CPU Cores", DefaultValue = 4, MinValue = 1)]
        public int NumCores { get; set; }

        private MLContext mlContext;
        private ITransformer trainedModel;
        private PredictionEngine<PatternData, PatternPrediction> predictionEngine;
        private List<SwingPoint> lastSwingPoints;
        private Random random;

        private class PatternData
        {
            [VectorType(5)]
            public float[] Features { get; set; }

            public bool Label { get; set; }

            public PatternData()
            {
                Features = new float[5];
            }
        }

        private class PatternPrediction
        {
            [ColumnName("PredictedLabel")]
            public bool PredictedLabel { get; set; }

            public float Probability { get; set; }
            public float Score { get; set; }
        }

        private class SwingPoint
        {
            public int Index { get; set; }
            public double Price { get; set; }
            public bool IsHigh { get; set; }
        }

        protected override void Initialize()
        {
            random = new Random();
            lastSwingPoints = new List<SwingPoint>();
            InitializeModel();
        }

        private void InitializeModel()
        {
            mlContext = new MLContext(seed: 0);
            var trainingData = GenerateTrainingDataParallel().ToList();
            var dataView = mlContext.Data.LoadFromEnumerable(trainingData);

            var pipeline = mlContext.Transforms.Concatenate("Features", nameof(PatternData.Features))
                .Append(mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(new SdcaLogisticRegressionBinaryTrainer.Options
                {
                    MaximumNumberOfIterations = 100,
                    NumberOfThreads = NumCores,
                    L2Regularization = 0.1f,
                    L1Regularization = 0.1f
                }));

            trainedModel = pipeline.Fit(dataView);
            predictionEngine = mlContext.Model.CreatePredictionEngine<PatternData, PatternPrediction>(trainedModel);
        }

        private IEnumerable<PatternData> GenerateTrainingDataParallel()
        {
            const int totalSamples = 2000; // More Better
            var batchSize = totalSamples / NumCores;
            var tasks = new List<Task<List<PatternData>>>();

            // Generate valid patterns in parallel
            for (int i = 0; i < NumCores; i++)
            {
                var task = Task.Run(() =>
                {
                    var localRandom = new Random(Guid.NewGuid().GetHashCode());
                    return Enumerable.Range(0, batchSize / 2)
                        .Select(_ => GenerateValidPattern(localRandom))
                        .ToList();
                });
                tasks.Add(task);
            }

            // Generate invalid patterns in parallel
            for (int i = 0; i < NumCores; i++)
            {
                var task = Task.Run(() =>
                {
                    var localRandom = new Random(Guid.NewGuid().GetHashCode());
                    return Enumerable.Range(0, batchSize / 2)
                        .Select(_ => GenerateInvalidPattern(localRandom))
                        .ToList();
                });
                tasks.Add(task);
            }

            Task.WaitAll(tasks.ToArray());
            return tasks.SelectMany(t => t.Result);
        }

        private PatternData GenerateValidPattern(Random localRandom)
        {
            return new PatternData
            {
                Features = new float[]
                {
                    (float)GetRandomValue(MinRetracement, MaxRetracement, localRandom),
                    (float)GetRandomValue(MinRetracement, MaxRetracement, localRandom),
                    (float)GetRandomValue(MinRetracement, MaxRetracement, localRandom),
                    (float)GetRandomValue(0, TimeSymmetryTolerance, localRandom),
                    (float)GetRandomValue(0, TimeSymmetryTolerance, localRandom)
                },
                Label = true
            };
        }

        private PatternData GenerateInvalidPattern(Random localRandom)
        {
            return new PatternData
            {
                Features = new float[]
                {
                    (float)GetRandomInvalidRetracement(localRandom),
                    (float)GetRandomInvalidRetracement(localRandom),
                    (float)GetRandomInvalidRetracement(localRandom),
                    (float)GetRandomInvalidTimeSymmetry(localRandom),
                    (float)GetRandomInvalidTimeSymmetry(localRandom)
                },
                Label = false
            };
        }

        private double GetRandomInvalidRetracement(Random localRandom)
        {
            if (localRandom.Next(2) == 0)
                return GetRandomValue(0, MinRetracement - 0.1, localRandom);
            else
                return GetRandomValue(MaxRetracement + 0.1, 100, localRandom);
        }

        private double GetRandomInvalidTimeSymmetry(Random localRandom)
        {
            return GetRandomValue(TimeSymmetryTolerance + 0.1, 100, localRandom);
        }

        public override void Calculate(int index)
        {
            if (index < SwingStrength)
                return;

            // Process only the last bar
            if (IsLastBar)
            {
                ProcessNewBar(index);
            }
        }

        private void ProcessNewBar(int index)
        {
            UpdateSwingPoints(index);
            CheckForPattern(index);
        }

        private void UpdateSwingPoints(int index)
        {
            var newPoints = new List<SwingPoint>();

            if (IsSwingHigh(index))
            {
                newPoints.Add(new SwingPoint
                {
                    Index = index,
                    Price = MarketSeries.High[index],
                    IsHigh = true
                });
            }

            if (IsSwingLow(index))
            {
                newPoints.Add(new SwingPoint
                {
                    Index = index,
                    Price = MarketSeries.Low[index],
                    IsHigh = false
                });
            }

            if (newPoints.Any())
            {
                lastSwingPoints.AddRange(newPoints);

                // Keep only the swing points within the lookback window
                lastSwingPoints = lastSwingPoints.Where(p => p.Index >= index - LookbackWindow).ToList();
            }
        }

        private void CheckForPattern(int index)
        {
            var points = lastSwingPoints.OrderBy(p => p.Index).ToList();
            if (points.Count < 7)
                return;

            var lastPoints = points.TakeLast(7).ToList();
            var patternData = ExtractPatternFeatures(lastPoints);
            EvaluatePattern(patternData, lastPoints);
        }

        private void EvaluatePattern(PatternData patternData, List<SwingPoint> points)
        {
            var prediction = predictionEngine.Predict(patternData);

            if (prediction.PredictedLabel && prediction.Probability > 0.8f)
            {
                bool isBullish = points[0].Price < points[1].Price;
                DrawPattern(points, isBullish);
                PrintSignalMessage(points, isBullish);
            }
        }

        private void DrawPattern(List<SwingPoint> points, bool isBullish)
        {
            // Remove previous objects if any
            for (int i = 0; i < points.Count; i++)
            {
                Chart.RemoveObject("DrivePoint_" + i + "_" + points[i].Index);
            }
            Chart.RemoveObject("PatternArrow_" + points[0].Index);
            Chart.RemoveObject("DriveLine1_" + points[0].Index);
            Chart.RemoveObject("DriveLine2_" + points[2].Index);

            // Color scheme based on bullish or bearish
            Color firstColor = isBullish ? Color.Green : Color.Red;
            Color secondColor = Color.Orange;
            Color thirdColor = isBullish ? Color.Lime : Color.Maroon;

            // Draw icons at the drives: points[0], points[2], points[4]
            Chart.DrawIcon("DrivePoint_0_" + points[0].Index, ChartIconType.Diamond, points[0].Index, points[0].Price, firstColor);
            Chart.DrawIcon("DrivePoint_1_" + points[2].Index, ChartIconType.Diamond, points[2].Index, points[2].Price, secondColor);
            Chart.DrawIcon("DrivePoint_2_" + points[4].Index, ChartIconType.Diamond, points[4].Index, points[4].Price, thirdColor);

            // Draw trend lines to connect the drives
            DrawConnectingLine("DriveLine1_" + points[0].Index, points[0].Index, points[0].Price, points[2].Index, points[2].Price, firstColor);
            DrawConnectingLine("DriveLine2_" + points[2].Index, points[2].Index, points[2].Price, points[4].Index, points[4].Price, secondColor);

            // Draw arrow indicating potential direction
            var arrowType = isBullish ? ChartIconType.UpArrow : ChartIconType.DownArrow;
            Color arrowColor = isBullish ? Color.Green : Color.Red;
            Chart.DrawIcon("PatternArrow_" + points[4].Index, arrowType, points[4].Index, points[4].Price, arrowColor);
        }

        private void PrintSignalMessage(List<SwingPoint> points, bool isBullish)
        {
            string direction = isBullish ? "BULLISH (BUY)" : "BEARISH (SELL)";
            Print($"[Three Drives Pattern Detected] at index {points[4].Index}: {direction}");
        }

        private void DrawConnectingLine(string name, int startIndex, double startPrice, int endIndex, double endPrice, Color color)
        {
            // Use positional arguments only
            var line = Chart.DrawTrendLine(
                name,
                startIndex,
                startPrice,
                endIndex,
                endPrice,
                color
            );

            line.Thickness = 2;
            line.LineStyle = LineStyle.Solid;
        }

        private PatternData ExtractPatternFeatures(List<SwingPoint> points)
        {
            bool isBullish = points[0].Price < points[1].Price;

            return new PatternData
            {
                Features = new float[]
                {
                    (float)CalculateRetracement(points[0], points[1], points[2], isBullish),
                    (float)CalculateRetracement(points[2], points[3], points[4], isBullish),
                    (float)CalculateRetracement(points[4], points[5], points[6], isBullish),
                    (float)CalculateTimeSymmetry(points[0].Index, points[2].Index, points[4].Index),
                    (float)CalculateTimeSymmetry(points[2].Index, points[4].Index, points[6].Index)
                }
            };
        }

        private bool IsSwingHigh(int index)
        {
            if (index < SwingStrength) return false;

            double currentHigh = MarketSeries.High[index];
            for (int i = 1; i <= SwingStrength; i++)
            {
                if (currentHigh <= MarketSeries.High[index - i])
                    return false;
            }
            return true;
        }

        private bool IsSwingLow(int index)
        {
            if (index < SwingStrength) return false;

            double currentLow = MarketSeries.Low[index];
            for (int i = 1; i <= SwingStrength; i++)
            {
                if (currentLow >= MarketSeries.Low[index - i])
                    return false;
            }
            return true;
        }

        private double GetRandomValue(double min, double max, Random localRandom)
        {
            return min + (localRandom.NextDouble() * (max - min));
        }

        private double CalculateRetracement(SwingPoint p1, SwingPoint p2, SwingPoint p3, bool isBullish)
        {
            double move = isBullish ? (p2.Price - p1.Price) : (p1.Price - p2.Price);
            double retrace = isBullish ? (p2.Price - p3.Price) : (p3.Price - p2.Price);
            return (retrace / move) * 100.0;
        }

        private double CalculateTimeSymmetry(int i1, int i2, int i3)
        {
            var t1 = Math.Abs(i2 - i1);
            var t2 = Math.Abs(i3 - i2);
            var avg = (t1 + t2) / 2.0;
            return Math.Abs(t1 - t2) / avg * 100.0;
        }
    }
}

@EDG777
Replies

firemyst
11 Dec 2024, 06:12 ( Updated at: 11 Dec 2024, 06:14 )

Are you able to post/share any screen captures of it in action on the charts?

That might get more people's interest.


@firemyst

EDG777
11 Dec 2024, 18:34


@EDG777