Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add backend api for hidi #868

Merged
merged 3 commits into from
Dec 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions visualdl/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,20 @@ def embedding_embedding(self, run, tag='default', reduction='pca', dimension=2):
key = os.path.join('data/plugin/embeddings/embeddings', run, str(dimension), reduction)
return self._get_with_retry(key, lib.get_embeddings, run, tag, reduction, dimension)

@result()
def embedding_list(self):
return self._get_with_retry('data/plugin/embeddings/list', lib.get_embeddings_list)

@result('text/tab-separated-values')
def embedding_metadata(self, name):
key = os.path.join('data/plugin/embeddings/metadata', name)
return self._get_with_retry(key, lib.get_embedding_labels, name)

@result('application/octet-stream')
def embedding_tensor(self, name):
key = os.path.join('data/plugin/embeddings/tensor', name)
return self._get_with_retry(key, lib.get_embedding_tensors, name)

@result()
def histogram_tags(self):
return self._get_with_retry('data/plugin/histogram/tags', lib.get_histogram_tags)
Expand Down Expand Up @@ -190,6 +204,9 @@ def create_api_call(logdir, model, cache_timeout):
'audio/list': (api.audio_list, ['run', 'tag']),
'audio/audio': (api.audio_audio, ['run', 'tag', 'index']),
'embedding/embedding': (api.embedding_embedding, ['run', 'tag', 'reduction', 'dimension']),
'embedding/list': (api.embedding_list, []),
'embedding/tensor': (api.embedding_tensor, ['name']),
'embedding/metadata': (api.embedding_metadata, ['name']),
'histogram/list': (api.histogram_list, ['run', 'tag']),
'graph/graph': (api.graph_graph, []),
'pr-curve/list': (api.pr_curves_pr_curve, ['run', 'tag']),
Expand Down
2 changes: 1 addition & 1 deletion visualdl/server/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"scalar": 1000,
"image": 10,
"histogram": 100,
"embeddings": 50000,
"embeddings": 50000000,
"audio": 10,
"pr_curve": 300,
"meta_data": 100
Expand Down
54 changes: 54 additions & 0 deletions visualdl/server/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import sys
import time
import os
import io
import csv
from functools import partial
import numpy as np
from visualdl.server.log import logger
Expand All @@ -27,6 +29,8 @@

MODIFY_PREFIX = {}
MODIFIED_RUNS = []
EMBEDDING_NAME = {}
embedding_names = []


def s2ms(timestamp):
Expand Down Expand Up @@ -196,6 +200,56 @@ def get_pr_curve_step(log_reader, run, tag=None):
return results


def get_embeddings_list(log_reader):
run2tag = get_logs(log_reader, 'embeddings')

for run, _tags in zip(run2tag['runs'], run2tag['tags']):
for tag in _tags:
name = path = os.path.join(run, tag)
if name in EMBEDDING_NAME:
return embedding_names
EMBEDDING_NAME.update({name: {'run': run, 'tag': tag}})
records = log_reader.data_manager.get_reservoir("embeddings").get_items(
run, decode_tag(tag))
row_len = len(records[0].embeddings.embeddings)
col_len = len(records[0].embeddings.embeddings[0].vectors)
shape = [row_len, col_len]
embedding_names.append({'name': name, 'shape': shape, 'path': path})
return embedding_names


def get_embedding_labels(log_reader, name):
run = EMBEDDING_NAME[name]['run']
tag = EMBEDDING_NAME[name]['tag']
log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("embeddings").get_items(
run, decode_tag(tag))
labels = []
for item in records[0].embeddings.embeddings:
labels.append([item.label])

with io.StringIO() as fp:
csv_writer = csv.writer(fp, delimiter='\t')
csv_writer.writerows(labels)
labels = fp.getvalue()

# labels = "\n".join(str(i) for i in labels)
return labels


def get_embedding_tensors(log_reader, name):
run = EMBEDDING_NAME[name]['run']
tag = EMBEDDING_NAME[name]['tag']
log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("embeddings").get_items(
run, decode_tag(tag))
vectors = []
for item in records[0].embeddings.embeddings:
vectors.append(item.vectors)
vectors = np.array(vectors).flatten().astype(np.float32).tobytes()
return vectors


def get_embeddings(log_reader, run, tag, reduction, dimension=2):
run = log_reader.name2tags[run] if run in log_reader.name2tags else run
log_reader.load_new_data()
Expand Down