-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathreport.py
38 lines (26 loc) · 1.44 KB
/
report.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from nl2sql360.core import Core
from nl2sql360.arguments import CoreArguments, EvaluationArguments
from nl2sql360.filter import Filter, Scenario, Field, Operator
if __name__ == "__main__":
core_args = CoreArguments()
core = Core(core_args)
SUBQUERY_FILTER = Filter(
name="subquery",
field=Field.SUBQUERY,
operator=Operator.GT,
value=0
)
BI_SCENARIO = Scenario(
name="BI",
filters=[Filter('agg', Field.AGGREGATION, Operator.GT, 0), Filter('join', Field.JOIN, Operator.GT, 0)]
)
print(core.query_overall_leaderboard(dataset_name="spider_dev", metric="ex"))
print(core.query_filter_performance(dataset_name="spider_dev", metric="ex", filter=filter, eval_name="SuperSQL"))
print(core.query_filter_leaderboard(dataset_name="spider_dev", metric="ex", filter=filter))
print(core.query_scenario_performance(dataset_name="spider_dev", metric="ex", eval_name="SuperSQL", scenario=BI_SCENARIO))
print(core.query_scenario_leaderboard(dataset_name="spider_dev", metric="ex", scenario=BI_SCENARIO))
print(core.query_dataset_domain_distribution(dataset_name="spider_dev"))
print(core.generate_evaluation_report(dataset_name="spider_dev",
filters=[SUBQUERY_FILTER],
scenarios=[BI_SCENARIO],
metrics=["ex", "em", "ves"]))