-
Notifications
You must be signed in to change notification settings - Fork 329
/
Copy pathchronos.py
564 lines (481 loc) · 19.3 KB
/
chronos.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
GenerationConfig,
PreTrainedModel,
)
import chronos
@dataclass
class ChronosConfig:
"""
This class holds all the configuration parameters to be used
by ``ChronosTokenizer`` and ``ChronosModel``.
"""
tokenizer_class: str
tokenizer_kwargs: Dict[str, Any]
context_length: int
prediction_length: int
n_tokens: int
n_special_tokens: int
pad_token_id: int
eos_token_id: int
use_eos_token: bool
model_type: Literal["causal", "seq2seq"]
num_samples: int
temperature: float
top_k: int
top_p: float
def __post_init__(self):
assert (
self.pad_token_id < self.n_special_tokens
and self.eos_token_id < self.n_special_tokens
), f"Special token id's must be smaller than {self.n_special_tokens=}"
def create_tokenizer(self) -> "ChronosTokenizer":
class_ = getattr(chronos, self.tokenizer_class)
return class_(**self.tokenizer_kwargs, config=self)
class ChronosTokenizer:
"""
A ``ChronosTokenizer`` definines how time series are mapped into token IDs
and back.
For details, see the ``input_transform`` and ``output_transform`` methods,
which concrete classes must implement.
"""
def context_input_transform(
self,
context: torch.Tensor,
) -> Tuple:
"""
Turn a batch of time series into token IDs, attention map, and tokenizer_state.
Parameters
----------
context
A tensor shaped (batch_size, time_length), containing the
timeseries to forecast. Use left-padding with ``torch.nan``
to align time series of different lengths.
Returns
-------
token_ids
A tensor of integers, shaped (batch_size, time_length + 1)
if ``config.use_eos_token`` and (batch_size, time_length)
otherwise, containing token IDs for the input series.
attention_mask
A boolean tensor, same shape as ``token_ids``, indicating
which input observations are not ``torch.nan`` (i.e. not
missing nor padding).
tokenizer_state
An object that can be passed to ``label_input_transform``
and ``output_transform``. Contains the relevant information
to decode output samples into real values,
such as location and scale parameters.
"""
raise NotImplementedError()
def label_input_transform(self, label: torch.Tensor, tokenizer_state: Any) -> Tuple:
"""
Turn a batch of label slices of time series into token IDs and attention map
using the ``tokenizer_state`` provided by ``context_input_transform``.
Parameters
----------
context
A tensor shaped (batch_size, time_length), containing the
timeseries to forecast. Use left-padding with ``torch.nan``
to align time series of different lengths.
tokenizer_state
An object returned by ``context_input_transform`` containing
relevant information to preprocess data, such as location and
scale. The nature of this depends on the specific tokenizer.
This is used for tokenizing the label, in order to use the same
scaling used to tokenize the context.
Returns
-------
token_ids
A tensor of integers, shaped (batch_size, time_length + 1)
if ``config.use_eos_token`` and (batch_size, time_length)
otherwise, containing token IDs for the input series.
attention_mask
A boolean tensor, same shape as ``token_ids``, indicating
which input observations are not ``torch.nan`` (i.e. not
missing nor padding).
"""
raise NotImplementedError()
def output_transform(
self, samples: torch.Tensor, tokenizer_state: Any
) -> torch.Tensor:
"""
Turn a batch of sample token IDs into real values.
Parameters
----------
samples
A tensor of integers, shaped (batch_size, num_samples, time_length),
containing token IDs of sample trajectories.
tokenizer_state
An object returned by ``input_transform`` containing
relevant context to decode samples, such as location and scale.
The nature of this depends on the specific tokenizer.
Returns
-------
forecasts
A real tensor, shaped (batch_size, num_samples, time_length),
containing forecasted sample paths.
"""
raise NotImplementedError()
class MeanScaleUniformBins(ChronosTokenizer):
def __init__(
self, low_limit: float, high_limit: float, config: ChronosConfig
) -> None:
self.config = config
self.centers = torch.linspace(
low_limit,
high_limit,
config.n_tokens - config.n_special_tokens - 1,
)
self.boundaries = torch.concat(
(
torch.tensor([-1e20], device=self.centers.device),
(self.centers[1:] + self.centers[:-1]) / 2,
torch.tensor([1e20], device=self.centers.device),
)
)
def _input_transform(
self, context: torch.Tensor, scale: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
attention_mask = ~torch.isnan(context)
if scale is None:
scale = torch.nansum(
torch.abs(context) * attention_mask, dim=-1
) / torch.nansum(attention_mask, dim=-1)
scale[~(scale > 0)] = 1.0
scaled_context = context / scale.unsqueeze(dim=-1)
token_ids = (
torch.bucketize(
input=scaled_context,
boundaries=self.boundaries,
# buckets are open to the right, see:
# https://pytorch.org/docs/2.1/generated/torch.bucketize.html#torch-bucketize
right=True,
)
+ self.config.n_special_tokens
)
token_ids.clamp_(0, self.config.n_tokens - 1)
token_ids[~attention_mask] = self.config.pad_token_id
return token_ids, attention_mask, scale
def _append_eos_token(
self, token_ids: torch.Tensor, attention_mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = token_ids.shape[0]
eos_tokens = torch.full((batch_size, 1), fill_value=self.config.eos_token_id)
token_ids = torch.concat((token_ids, eos_tokens), dim=1)
eos_mask = torch.full((batch_size, 1), fill_value=True)
attention_mask = torch.concat((attention_mask, eos_mask), dim=1)
return token_ids, attention_mask
def context_input_transform(
self, context: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
length = context.shape[-1]
if length > self.config.context_length:
context = context[..., -self.config.context_length :]
token_ids, attention_mask, scale = self._input_transform(context=context)
if self.config.use_eos_token and self.config.model_type == "seq2seq":
token_ids, attention_mask = self._append_eos_token(
token_ids=token_ids, attention_mask=attention_mask
)
return token_ids, attention_mask, scale
def label_input_transform(
self, label: torch.Tensor, scale: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
length = label.shape[-1]
assert length == self.config.prediction_length
token_ids, attention_mask, _ = self._input_transform(context=label, scale=scale)
if self.config.use_eos_token:
token_ids, attention_mask = self._append_eos_token(
token_ids=token_ids, attention_mask=attention_mask
)
return token_ids, attention_mask
def output_transform(
self, samples: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
scale_unsqueezed = scale.unsqueeze(-1).unsqueeze(-1)
indices = torch.clamp(
samples - self.config.n_special_tokens - 1,
min=0,
max=len(self.centers) - 1,
)
return self.centers[indices] * scale_unsqueezed
class ChronosModel(nn.Module):
"""
A ``ChronosModel`` wraps a ``PreTrainedModel`` object from ``transformers``
and uses it to predict sample paths for time series tokens.
Parameters
----------
config
The configuration to use.
model
The pretrained model to use.
"""
def __init__(self, config: ChronosConfig, model: PreTrainedModel) -> None:
super().__init__()
self.config = config
self.model = model
@property
def device(self):
return self.model.device
def encode(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
):
"""
Extract the encoder embedding for the given token sequences.
Parameters
----------
input_ids
Tensor of indices of input sequence tokens in the vocabulary
with shape (batch_size, sequence_length).
attention_mask
A mask tensor of the same shape as input_ids to avoid attending
on padding or missing tokens.
Returns
-------
embedding
A tensor of encoder embeddings with shape
(batch_size, sequence_length, d_model).
"""
assert (
self.config.model_type == "seq2seq"
), "Encoder embeddings are only supported for encoder-decoder models"
return self.model.encoder(
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
prediction_length: Optional[int] = None,
num_samples: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
) -> torch.Tensor:
"""
Predict future sample tokens for the given token sequences.
Arguments ``prediction_length``, ``num_samples``, ``temperature``,
``top_k``, ``top_p`` can be used to customize the model inference,
and default to the corresponding attributes in ``self.config`` if
not provided.
Returns
-------
samples
A tensor of integers, shaped (batch_size, num_samples, time_length),
containing forecasted sample paths.
"""
if prediction_length is None:
prediction_length = self.config.prediction_length
if num_samples is None:
num_samples = self.config.num_samples
if temperature is None:
temperature = self.config.temperature
if top_k is None:
top_k = self.config.top_k
if top_p is None:
top_p = self.config.top_p
preds = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
generation_config=GenerationConfig(
min_new_tokens=prediction_length,
max_new_tokens=prediction_length,
do_sample=True,
num_return_sequences=num_samples,
eos_token_id=self.config.eos_token_id,
pad_token_id=self.config.pad_token_id,
temperature=temperature,
top_k=top_k,
top_p=top_p,
),
)
if self.config.model_type == "seq2seq":
preds = preds[..., 1:] # remove the decoder start token
else:
assert self.config.model_type == "causal"
assert preds.size(-1) == input_ids.size(-1) + prediction_length
preds = preds[..., -prediction_length:]
return preds.reshape(input_ids.size(0), num_samples, -1)
def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor:
max_len = max(len(c) for c in tensors)
padded = []
for c in tensors:
assert isinstance(c, torch.Tensor)
assert c.ndim == 1
padding = torch.full(
size=(max_len - len(c),), fill_value=torch.nan, device=c.device
)
padded.append(torch.concat((padding, c), dim=-1))
return torch.stack(padded)
@dataclass
class ChronosPipeline:
"""
A ``ChronosPipeline`` uses the given tokenizer and model to forecast
input time series.
Use the ``from_pretrained`` class method to load serialized models.
Use the ``predict`` method to get forecasts.
Parameters
----------
tokenizer
The tokenizer object to use.
model
The model to use.
"""
tokenizer: ChronosTokenizer
model: ChronosModel
def _prepare_and_validate_context(
self, context: Union[torch.Tensor, List[torch.Tensor]]
):
if isinstance(context, list):
context = left_pad_and_stack_1D(context)
assert isinstance(context, torch.Tensor)
if context.ndim == 1:
context = context.unsqueeze(0)
assert context.ndim == 2
return context
@torch.no_grad()
def embed(
self, context: Union[torch.Tensor, List[torch.Tensor]]
) -> Tuple[torch.Tensor, Any]:
"""
Get encoder embeddings for the given time series.
Parameters
----------
context
Input series. This is either a 1D tensor, or a list
of 1D tensors, or a 2D tensor whose first dimension
is batch. In the latter case, use left-padding with
``torch.nan`` to align series of different lengths.
Returns
-------
embeddings, tokenizer_state
A tuple of two tensors: the encoder embeddings and the tokenizer_state,
e.g., the scale of the time series in the case of mean scaling.
The encoder embeddings are shaped (batch_size, context_length, d_model)
or (batch_size, context_length + 1, d_model), where context_length
is the size of the context along the time axis if a 2D tensor was provided
or the length of the longest time series, if a list of 1D tensors was
provided, and the extra 1 is for EOS.
"""
context_tensor = self._prepare_and_validate_context(context=context)
token_ids, attention_mask, tokenizer_state = (
self.tokenizer.context_input_transform(context_tensor)
)
embeddings = self.model.encode(
input_ids=token_ids.to(self.model.device),
attention_mask=attention_mask.to(self.model.device),
).cpu()
return embeddings, tokenizer_state
def predict(
self,
context: Union[torch.Tensor, List[torch.Tensor]],
prediction_length: Optional[int] = None,
num_samples: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
limit_prediction_length: bool = True,
) -> torch.Tensor:
"""
Get forecasts for the given time series.
Parameters
----------
context
Input series. This is either a 1D tensor, or a list
of 1D tensors, or a 2D tensor whose first dimension
is batch. In the latter case, use left-padding with
``torch.nan`` to align series of different lengths.
prediction_length
Time steps to predict. Defaults to what specified
in ``self.model.config``.
num_samples
Number of sample paths to predict. Defaults to what
specified in ``self.model.config``.
temperature
Temperature to use for generating sample tokens.
Defaults to what specified in ``self.model.config``.
top_k
Top-k parameter to use for generating sample tokens.
Defaults to what specified in ``self.model.config``.
top_p
Top-p parameter to use for generating sample tokens.
Defaults to what specified in ``self.model.config``.
limit_prediction_length
Force prediction length smaller or equal than the
built-in prediction length from the model. True by
default. When true, fail loudly if longer predictions
are requested, otherwise longer predictions are allowed.
Returns
-------
samples
Tensor of sample forecasts, of shape
(batch_size, num_samples, prediction_length).
"""
context_tensor = self._prepare_and_validate_context(context=context)
if prediction_length is None:
prediction_length = self.model.config.prediction_length
if prediction_length > self.model.config.prediction_length:
msg = (
f"We recommend keeping prediction length <= {self.model.config.prediction_length}. "
"The quality of longer predictions may degrade since the model is not optimized for it. "
)
if limit_prediction_length:
msg += "You can turn off this check by setting `limit_prediction_length=False`."
raise ValueError(msg)
warnings.warn(msg)
predictions = []
remaining = prediction_length
while remaining > 0:
token_ids, attention_mask, scale = self.tokenizer.context_input_transform(
context_tensor
)
samples = self.model(
token_ids.to(self.model.device),
attention_mask.to(self.model.device),
min(remaining, self.model.config.prediction_length),
num_samples,
temperature,
top_k,
top_p,
)
prediction = self.tokenizer.output_transform(
samples.to(scale.device), scale
)
predictions.append(prediction)
remaining -= prediction.shape[-1]
if remaining <= 0:
break
context_tensor = torch.cat(
[context_tensor, prediction.median(dim=1).values], dim=-1
)
return torch.cat(predictions, dim=-1)
@classmethod
def from_pretrained(cls, *args, **kwargs):
"""
Load the model, either from a local path or from the HuggingFace Hub.
Supports the same arguments as ``AutoConfig`` and ``AutoModel``
from ``transformers``.
"""
config = AutoConfig.from_pretrained(*args, **kwargs)
assert hasattr(config, "chronos_config"), "Not a Chronos config file"
chronos_config = ChronosConfig(**config.chronos_config)
if chronos_config.model_type == "seq2seq":
inner_model = AutoModelForSeq2SeqLM.from_pretrained(*args, **kwargs)
else:
assert chronos_config.model_type == "causal"
inner_model = AutoModelForCausalLM.from_pretrained(*args, **kwargs)
return cls(
tokenizer=chronos_config.create_tokenizer(),
model=ChronosModel(config=chronos_config, model=inner_model),
)