-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathRandomForest.h
50 lines (33 loc) · 1.17 KB
/
RandomForest.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
/*
* This file is part of dsa-decision-tree
*
* Developed for the DSA UET course.
* This project was developed by Ba Luong and Gia Linh.
*/
#pragma once
#ifndef RANDOM_FOREST
#define RANDOM_FOREST
#include "Tree.h"
#include "Console.h"
class RandomForest
{
private:
vector<Tree *> forest;
Console *console;
public:
RandomForest(bool consoleActivate = true);
RandomForest(vector<Data *> *dataset, int minSize, int maxDepth, int numberOfTrees, bool consoleActivate = true);
RandomForest(DataSet *dataset, DataSet *valid, int numberOfTrees, bool consoleActivate = true);
~RandomForest();
void buildForest(vector<Data *> *dataset, int minSize, int maxDepth, int numberOfTrees);
void buildForest(vector<Data *> *dataset, vector<Data *> *valid, int numberOfTrees);
double calcAccuracy(DataSet *dataset);
// Predict the label of the data given.
bool predict(Data *data);
char getPredict(Data *data);
void predictToFile(DataSet *dataset, string filena);
void importFromFile(string filename);
void exportToFile(string filename);
};
RandomForest *buildBestModel(DataSet *dataset, DataSet *valid, int numberOfTrees);
#endif