Skip to content

Commit

Permalink
fix label names
Browse files Browse the repository at this point in the history
  • Loading branch information
jmargutt committed Jan 29, 2024
1 parent b8fec6f commit 30ac67c
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions pipeline/src/pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,10 +876,10 @@ def classify_text(df_tweets, text_raw, text_processed, labels, config, n_example
list_df = [df_tweets[i:i+n_examples] for i in range(0, len(df_tweets), n_examples)]

topics = [
"ANOMALY", "ARMY", "CHILDREN", "CONNECTIVITY", "RC CONNECT WITH REDCROSS", "EDUCATION", "FOOD", "GOODS/SERVICES",
"ANOMALY", "ARMY", "CHILDREN", "CONNECTIVITY", "RC CONNECT WITH RED CROSS", "EDUCATION", "FOOD", "GOODS/SERVICES",
"HEALTH", "CVA INCLUSION", "LEGAL", "MONEY/BANKING", "NFI", "OTHER PROGRAMS/NGOS", "PARCEL",
"CVA PAYMENT", "PETS", "PMER/NEW PROGRAMS", "RC PROGRAM INFO", "PSS/RFL",
"CVA REGISTRATION", "SENTIMENT/FEEDBACK", "SHELTER", "TRANSLATION/LANGUAGE", "TRANSPORT/CAR",
"CVA PAYMENT", "PETS", "RC PMER/NEW PROGRAMS", "CVA PROGRAM INFO", "RC PROGRAM INFO", "PSS & RFL",
"CVA REGISTRATION", "SENTIMENT", "SHELTER", "TRANSLATION/LANGUAGE", "CAR",
"TRANSPORT/MOVEMENT", "WASH", "WORK/JOBS"
]

Expand All @@ -895,17 +895,14 @@ def classify_text(df_tweets, text_raw, text_processed, labels, config, n_example
output = response.json()
if 'predictions' in output:
for idx, prediction in zip(df_tweets_.index, output['predictions']):
prediction_label = prediction['label']
if prediction['label'] == "PROGRAMINFORMATION":
prediction_label = "RC PROGRAM INFO"
if prediction['probability'] > threshold:
df_tweets.at[idx, 'topic'] = prediction_label
df_tweets.at[idx, 'topic'] = prediction['label']

if not prediction_label:
if not prediction['label']:
tmp_prediction = [(topic, 0.) for topic in topics]
else:
tmp_prediction = [(prediction_label, prediction['probability'])]
tmp_prediction += [(topic, 0.) for topic in topics if topic != prediction_label]
tmp_prediction = [(prediction['label'], prediction['probability'])]
tmp_prediction += [(topic, 0.) for topic in topics if topic != prediction['label']]
tmp_prediction.sort()

df_tweets.at[idx, 'prediction'] = tmp_prediction
Expand Down

0 comments on commit 30ac67c

Please sign in to comment.