Skip to content

Commit

Permalink
fix: gliner model (#715)
Browse files Browse the repository at this point in the history
  • Loading branch information
micaelakaplan authored Feb 21, 2025
1 parent 80650cb commit 5af3d95
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 94 deletions.
13 changes: 13 additions & 0 deletions label_studio_ml/examples/gliner/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,16 @@ The following common parameters are available:
- `THREADS` - Specify the number of threads for the model server.
- `LABEL_STUDIO_URL` - Specify the URL of your Label Studio instance. Note that this might need to be `http://host.docker.internal:8080` if you are running Label Studio on another Docker container.
- `LABEL_STUDIO_API_KEY`- Specify the API key for authenticating your Label Studio instance. You can find this by logging into Label Studio and and [going to the **Account & Settings** page](https://labelstud.io/guide/user_account#Access-token).

## A Note on Model Training

If you plan to use a webhook to train this model on "Start Training", note that you do
not need to configure a separate webhook. Instead, go to the three dots next to your model
on the Model tab in your project settings and click "start training".

Additionally, note that this container has been set for a **VERY SMALL** demo set, with only 1
non-eval sample (we expect the first 10 data samples to be for evaluation.)

If you're working with a larger dataset, be sure to:
1. update num_steps and batch size to the number of training steps you want and the batch size that works for your dataset.
2. change the uploaded model after training (line 239 of `model.py`) to the highest checkpoint that you have.
1 change: 0 additions & 1 deletion label_studio_ml/examples/gliner/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ services:
# Determine the actual IP using 'ifconfig' (Linux/Mac) or 'ipconfig' (Windows).
- LABEL_STUDIO_URL=http://host.docker.internal:8080
- LABEL_STUDIO_API_KEY=

ports:
- "9090:9090"
volumes:
Expand Down
141 changes: 51 additions & 90 deletions label_studio_ml/examples/gliner/model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import logging
import os
from types import SimpleNamespace
from math import floor
from typing import List, Dict, Optional

import label_studio_sdk
import torch
from gliner import GLiNER
from gliner.data_processing.collator import DataCollator
from gliner.training import Trainer, TrainingArguments
from label_studio_sdk.label_interface.objects import PredictionValue

from label_studio_ml.model import LabelStudioMLBase
from label_studio_ml.response import ModelResponse
from label_studio_sdk.label_interface.objects import PredictionValue
from tqdm import tqdm
from transformers import get_cosine_schedule_with_warmup

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -109,7 +109,7 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -

# make predictions with currently set model
from_name, to_name, value = self.label_interface.get_first_tag_occurence('Labels', 'Text')

# get labels from the labeling configuration
labels = sorted(self.label_interface.get_tag(from_name).labels)

Expand Down Expand Up @@ -141,7 +141,7 @@ def process_training_data(self, task):
ner.append([start_token, end_token, label])
return tokens, ner

def train(self, model, config, train_data, eval_data=None):
def train(self, model, training_args, train_data, eval_data=None):
"""
retrain the GLiNER model. Code adapted from the GLiNER finetuning notebook.
:param model: the model to train
Expand All @@ -150,69 +150,23 @@ def train(self, model, config, train_data, eval_data=None):
:param eval_data: the eval data
"""
logger.info("Training Model")
model = model.to(config.device)

# Set sampling parameters from config
model.set_sampling_params(
max_types=config.max_types,
shuffle_types=config.shuffle_types,
random_drop=config.random_drop,
max_neg_type_ratio=config.max_neg_type_ratio,
max_len=config.max_len
)
if training_args.use_cpu == True:
model = model.to('cpu')
else:
model = model.to("cuda")

model.train()
train_loader = model.create_dataloader(train_data, batch_size=config.train_batch_size, shuffle=True)
optimizer = model.get_optimizer(config.lr_encoder, config.lr_others, config.freeze_token_rep)
pbar = tqdm(range(config.num_steps))
num_warmup_steps = int(config.num_steps * config.warmup_ratio) if config.warmup_ratio < 1 else int(
config.warmup_ratio)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, config.num_steps)
iter_train_loader = iter(train_loader)

for step in pbar:
try:
x = next(iter_train_loader)
except StopIteration:
iter_train_loader = iter(train_loader)
x = next(iter_train_loader)

for k, v in x.items():
if isinstance(v, torch.Tensor):
x[k] = v.to(config.device)

try:
loss = model(x) # Forward pass
except RuntimeError as e:
print(f"Error during forward pass at step {step}: {e}")
print(f"x: {x}")
continue

if torch.isnan(loss):
print("Loss is NaN, skipping...")
continue

loss.backward() # Compute gradients
optimizer.step() # Update parameters
scheduler.step() # Update learning rate schedule
optimizer.zero_grad() # Reset gradients

description = f"step: {step} | epoch: {step // len(train_loader)} | loss: {loss.item():.2f}"
pbar.set_description(description)

if (step + 1) % config.eval_every == 0:
model.eval()
if eval_data:
results, f1 = model.evaluate(eval_data["samples"], flat_ner=True, threshold=0.5, batch_size=12,
entity_types=eval_data["entity_types"])
print(f"Step={step}\n{results}")

if not os.path.exists(config.save_directory):
os.makedirs(config.save_directory)

model.save_pretrained(f"{config.save_directory}/finetuned_{step}")
model.train()
data_collator = DataCollator(model.config, data_processor=model.data_processor, prepare_labels=True)

trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_data,
eval_dataset=eval_data,
tokenizer=model.data_processor.transformer_tokenizer,
data_collator=data_collator,
)

trainer.train()

def fit(self, event, data, **kwargs):
"""
Expand Down Expand Up @@ -250,31 +204,38 @@ def fit(self, event, data, **kwargs):

# Define the hyperparameters in a config variable
# This comes from the pretraining example in the GLiNER repo
config = SimpleNamespace(
num_steps=10000, # number of training iteration
train_batch_size=2,
eval_every=1000, # evaluation/saving steps
save_directory="logs", # where to save checkpoints
warmup_ratio=0.1, # warmup steps
device='cpu',
lr_encoder=1e-5, # learning rate for the backbone
lr_others=5e-5, # learning rate for other parameters
freeze_token_rep=False, # freeze of not the backbone

# Parameters for set_sampling_params
max_types=25, # maximum number of entity types during training
shuffle_types=True, # if shuffle or not entity types
random_drop=True, # randomly drop entity types
max_neg_type_ratio=1,
# ratio of positive/negative types, 1 mean 50%/50%, 2 mean 33%/66%, 3 mean 25%/75% ...
max_len=384 # maximum sentence length
num_steps = 10
batch_size = 1
data_size = len(training_data)
num_batches = floor(data_size / batch_size)
num_epochs = max(1, floor(num_steps / num_batches))

training_args = TrainingArguments(
output_dir="models",
learning_rate=5e-6,
weight_decay=0.01,
others_lr=1e-5,
others_weight_decay=0.01,
lr_scheduler_type="linear", # cosine
warmup_ratio=0.1,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
focal_loss_alpha=0.75,
focal_loss_gamma=2,
num_train_epochs=num_epochs,
evaluation_strategy="steps",
save_steps=100,
save_total_limit=10,
dataloader_num_workers=0,
use_cpu=True,
report_to="none",
)

self.train(self.model, config, training_data, eval_data)
self.train(self.model, training_args, training_data, eval_data)

logger.info("Saving new fine-tuned model as the default model")
self.model = GLiNERModel.from_pretrained("finetuned", local_files_only=True)
model_version = self.model_version[-1] + 1
self.model = GLiNER.from_pretrained(f"models/checkpoint-10", local_files_only=True)
model_version = int(self.model_version[-1]) + 1
self.set("model_version", f'{self.__class__.__name__}-v{model_version}')
else:
logger.info("Model training not triggered")
logger.info("Model training not triggered")
5 changes: 3 additions & 2 deletions label_studio_ml/examples/gliner/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
gliner==0.1.12
torch==2.2.0
gliner==0.2.16
torch==2.2.0
accelerate>=0.26.0
2 changes: 1 addition & 1 deletion label_studio_ml/examples/gliner/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_predict(client):
}

expected_response = {"results": [{"model_version": "GLiNERModel-v0.0.1", "result": [
{"from_name": "label", "score": 0.9220, "to_name": "text", "type": "labels",
{"from_name": "label", "score": 0.922, "to_name": "text", "type": "labels",
"value": {"end": 11, "labels": ["Medication/Vaccine"], "start": 0, "text": "atomoxetine"}},
{"from_name": "label", "score": 0.7053, "to_name": "text", "type": "labels",
"value": {"end": 65, "labels": ["Medication/Vaccine"], "start": 32,
Expand Down

0 comments on commit 5af3d95

Please sign in to comment.