Skip to content

Commit

Permalink
fix: model loading in api
Browse files Browse the repository at this point in the history
  • Loading branch information
ForYourEyesOnlyyy committed Oct 31, 2024
1 parent 516df68 commit 9cc99cd
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
12 changes: 11 additions & 1 deletion config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
# Model Parameters
tokenizer_name (str): The name of the tokenizer used for preprocessing.
model_name (str): Name of the model used for sentiment analysis.
model_class (obj): Instance of the model class used for sentiment analysis.
model_weights_path (str): File path for saving and loading the model's weights.
# Device Configuration
device (str): Device to run the model on (e.g., 'cpu' or 'cuda').
"""

Expand All @@ -39,4 +43,10 @@
# MODELS
tokenizer_name = 'bert-base-uncased'
model_name = 'simple_sentiment_analysis_model'
device = 'cpu'
from models.simple_sentiment_analysis_model.simple_sentiment_analysis_model import SentimentAnalysisModel

model_class = SentimentAnalysisModel()
model_weights_path = f'models/{model_name}/model_weights.pth'

# DEVICE
device = 'cpu'
16 changes: 11 additions & 5 deletions deployment/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,17 @@
Classes:
TweetInput: Pydantic model for input validation, ensuring tweet text is within 1-280 characters.
Global Variables:
model (SentimentAnalysisModel): The sentiment analysis model loaded at startup.
tokenizer (AutoTokenizer): Tokenizer instance associated with the model.
Endpoints:
/predict-sentiment/ (POST): Accepts a tweet and returns the predicted sentiment.
Attributes:
model (SentimentAnalysisModel): The sentiment analysis model used for predictions.
tokenizer (AutoTokenizer): Tokenizer used for preprocessing tweets before sentiment prediction.
Usage:
Run this app to serve sentiment predictions via an API. The model and tokenizer are loaded
at startup for efficient processing, and request processing time is logged for each prediction.
Expand All @@ -23,6 +31,7 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
import torch

from config import config
from src import data
Expand Down Expand Up @@ -70,12 +79,9 @@ async def lifespan(app: FastAPI):
global model, tokenizer

# Load model and tokenizer at startup
import torch
from models.ssam.simple_sentiment_analysis_model import SentimentAnalysisModel
model = SentimentAnalysisModel()
model = config.model_class
model.load_state_dict(
torch.load('models/ssam/model_weights.pth',
map_location=config.device))
torch.load(config.model_weights_path, map_location=config.device))
logging.info(f"Model {config.model_name} loaded successfully at startup.")

tokenizer = data.get_tokenizer(config.tokenizer_name)
Expand Down

0 comments on commit 9cc99cd

Please sign in to comment.