-
Notifications
You must be signed in to change notification settings - Fork 91
/
Copy pathexample1_assignment1.cu
97 lines (80 loc) · 3.85 KB
/
example1_assignment1.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#include <matx.h>
using namespace matx;
/**
* MatX training assignment 1. This training goes through basic tensor
* operations that were learned in the 01_introduction notebook. Uncomment each
* verification block as you go to ensure your solutions are correct.
*/
int main() {
/****************************************************************************************************
* Create a rank-2 tensor data object of ints with 5 rows and 4 columns called
*"t2"
*https://devtech-compute.gitlab-master-pages.nvidia.com/matx/quickstart.html#tensor-views
****************************************************************************************************/
/*** End editing ***/
/****************************************************************************************************
* Initialize the t2 view to a 4x5 matrix of increasing values starting at 1
* https://devtech-compute.gitlab-master-pages.nvidia.com/matx/quickstart.html#tensor-views
****************************************************************************************************/
// t2 = ;
/*** End editing ***/
t2.PrefetchDevice(0);
/****************************************************************************************************
* Get a slice of the second and third rows with all columns
* https://devtech-compute.gitlab-master-pages.nvidia.com/matx/quickstart.html#slicing-and-dicing
*****************************************************************************************************/
auto t2s = t2;
/*** End editing ***/
// Verify slice is correct
// for (int row = 1; row <= 2; row++) {
// for (int col = 0; col < t2.Size(1); col++) {
// if (t2(row, col) != t2s(row - 1, col)) {
// printf("Mismatch in sliced view! actual = %d, expected = %d\n",
// t2s(row - 1, col), t2(row, col)); exit(-1);
// }
// }
// }
// print(t2s);
// printf("Slice verification passed!\n");
/****************************************************************************************************
* Take the slice and clone it into a 3D tensor with new outer dimensions as
*follows: First dim: keep existing row dimension from t2s Second dim: 2 Third
*dim: keep existing col dimension from t2s
https://devtech-compute.gitlab-master-pages.nvidia.com/matx/quickstart.html#increasing-dimensionality
*****************************************************************************************************/
auto t3c = t2s;
/*** End editing ***/
// Verify clone
// for (int first = 0; first < t3c.Size(0); first++) {
// for (int sec = 0; sec < t3c.Size(1); sec++) {
// for (int third = 0; third < t3c.Size(2); third++) {
// if (t3c(first, sec, third) != t2s(first, third)) {
// printf("Mismatch in cloned view! actual = %d, expected = %d\n",
// t3c(first, sec, third), t2s(first, third)); exit(-1);
// }
// }
// }
// }
// print(t3c);
// printf("Clone verification passed!\n");
/****************************************************************************************************
* Permute the two outer dimensions of the cloned tensor
* https://devtech-compute.gitlab-master-pages.nvidia.com/matx/quickstart.html#permuting
*****************************************************************************************************/
auto t3p = t3c;
/*** End editing ***/
// Verify clone
// for (int first = 0; first < t3p.Size(0); first++) {
// for (int sec = 0; sec < t3p.Size(1); sec++) {
// for (int third = 0; third < t3p.Size(2); third++) {
// if (t3c(first, sec, third) != t2s(first, third)) {
// printf("Mismatch in permuted view! actual = %d, expected = %d\n",
// t3c(first, sec, third), t2s(sec, third)); exit(-1);
// }
// }
// }
// }
// print(t3p);
// printf("Permute verification passed!\n");
return 0;
}