Skip to content

Commit

Permalink
test: added a new test case for thread pool
Browse files Browse the repository at this point in the history
  • Loading branch information
changemyminds committed Apr 10, 2024
1 parent 0f9a25c commit d264789
Showing 1 changed file with 54 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import ThreadPoolExecutor
from typing import List

from opentelemetry import trace
Expand Down Expand Up @@ -51,11 +51,13 @@ def run_threading_test(self, thread: threading.Thread):

# check result
self.assertEqual(len(self._mock_span_contexts), 1)
self.assert_span_context_equality(
self.assertEqual(
self._mock_span_contexts[0], expected_span_context
)

def test_trace_context_propagation_in_thread_pool(self):
def test_trace_context_propagation_in_thread_pool_with_multiple_workers(
self,
):
max_workers = 10
executor = ThreadPoolExecutor(max_workers=max_workers)

Expand All @@ -65,38 +67,65 @@ def test_trace_context_propagation_in_thread_pool(self):
with self._tracer.start_as_current_span(f"trace_{num}") as span:
expected_span_context = span.get_span_context()
expected_span_contexts.append(expected_span_context)
future = executor.submit(self.fake_func)
future = executor.submit(
self.get_current_span_context_for_test
)
futures_list.append(future)

for future in as_completed(futures_list):
future.result()
result_span_contexts = [future.result() for future in futures_list]

# check result
self.assertEqual(len(self._mock_span_contexts), max_workers)
self.assertEqual(len(result_span_contexts), max_workers)
self.assertEqual(
len(self._mock_span_contexts), len(expected_span_contexts)
len(result_span_contexts), len(expected_span_contexts)
)
for index, mock_span_context in enumerate(self._mock_span_contexts):
self.assert_span_context_equality(
mock_span_context, expected_span_contexts[index]
for index, result_span_context in enumerate(result_span_contexts):
self.assertEqual(
result_span_context, expected_span_contexts[index]
)

def fake_func(self):
span_context = trace.get_current_span().get_span_context()
def test_trace_context_propagation_in_thread_pool_with_single_worker(self):
max_workers = 1
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# test propagation of the same trace context across multiple tasks
with self._tracer.start_as_current_span(f"task") as task_span:
expected_task_context = task_span.get_span_context()
future1 = executor.submit(
self.get_current_span_context_for_test
)
future2 = executor.submit(
self.get_current_span_context_for_test
)

# check result
self.assertEqual(future1.result(), expected_task_context)
self.assertEqual(future2.result(), expected_task_context)

# test propagation of different trace contexts across tasks in sequence
with self._tracer.start_as_current_span(f"task1") as task1_span:
expected_task1_context = task1_span.get_span_context()
future1 = executor.submit(
self.get_current_span_context_for_test
)

# check result
self.assertEqual(future1.result(), expected_task1_context)

with self._tracer.start_as_current_span(f"task2") as task2_span:
expected_task2_context = task2_span.get_span_context()
future2 = executor.submit(
self.get_current_span_context_for_test
)

# check result
self.assertEqual(future2.result(), expected_task2_context)

def fake_func(self) -> trace.SpanContext:
span_context = self.get_current_span_context_for_test()
self._mock_span_contexts.append(span_context)

def assert_span_context_equality(
self,
result_span_context: trace.SpanContext,
expected_span_context: trace.SpanContext,
):
self.assertEqual(result_span_context, expected_span_context)
self.assertEqual(
result_span_context.trace_id, expected_span_context.trace_id
)
self.assertEqual(
result_span_context.span_id, expected_span_context.span_id
)
def get_current_span_context_for_test(self) -> trace.SpanContext:
return trace.get_current_span().get_span_context()

def print_square(self, num):
with self._tracer.start_as_current_span("square"):
Expand Down

0 comments on commit d264789

Please sign in to comment.