Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve gif classification #401

Merged
merged 20 commits into from
Oct 23, 2020
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,12 @@ Example of passing a configuration:
// returns top 1 prediction of each GIF frame, and logs the status to console
const myConfig = {
topk: 1,
setGifControl: (gifControl) => console.log(gifControl),
onFrame: ({ index, totalFrames, predictions }) =>
console.log(index, totalFrames, predictions),
fps: 1,
onFrame: ({ index, totalFrames, predictions, image }) => {
console.log({ index, totalFrames, predictions })
// document.body.appendChild(image)
// require('fs').writeFileSync(`./file.jpeg`, require('jpeg-js').encode(image).data)
}
}
const framePredictions = await classifyGif(img, myConfig)
```
Expand All @@ -171,11 +174,12 @@ const framePredictions = await classifyGif(img, myConfig)
- Image element to check
- Configuration object with the following possible key/values:
- `topk` - Number of results to return per frame (default all 5)
- `setGifControl` - Function callback receives SuperGif object as an argument, allows a user to save it for later use
- `fps` - Frames per seconds, frames picks proportionally from the middle (default all frames)
- `onFrame` - Function callback on each frame - Param is an object with the following key/values:
- `index` - the current GIF frame that was classified (starting at 1)
- `index` - the current GIF frame that was classified (starting at 0)
- `totalFrames` - the complete number of frames for this GIF (for progress calculations)
- `predictions` - an array of length `topk`, returning top results from classify
- `image` - an image of specific frame

**Returns**

Expand Down Expand Up @@ -313,7 +317,7 @@ You can also use [`lovell/sharp`](/~https://github.com/lovell/sharp) for preproces

### NSFW Filter

[**NSFW Filter**](/~https://github.com/navendu-pottekkat/nsfw-filter) is a web extension that uses NSFWJS for filtering out NSFW images from your browser.
[**NSFW Filter**](/~https://github.com/navendu-pottekkat/nsfw-filter) is a web extension that uses NSFWJS for filtering out NSFW images from your browser.

It is currently available for Chrome and Firefox and is completely open-source.

Expand Down
79 changes: 79 additions & 0 deletions __tests__/classifyGif.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import { load, predictionType, NSFWJS } from '../src/index'
const fs = require('fs');

// Fix for JEST
const globalAny: any = global
globalAny.fetch = require('node-fetch')
const timeoutMS = 10000

const path = `${__dirname}/../example/manual-testing/data/animations/smile.gif`

const roundPredicitonProbability = ({ className, probability }: predictionType) => {
return {className, probability: Math.floor(probability * 10000) / 10000}
}

describe('NSFWJS classify GIF', () => {
let model: NSFWJS
let buffer: Buffer

beforeAll(async () => {
model = await load()
buffer = fs.readFileSync(path)
});

it("Probabilities match", async () => {
const expectedResults = [
[
{ className: 'Neutral', probability: 0.8766 },
{ className: 'Porn', probability: 0.091 },
{ className: 'Sexy', probability: 0.0316 }
],
[
{ className: 'Neutral', probability: 0.8995 },
{ className: 'Porn', probability: 0.0511 },
{ className: 'Sexy', probability: 0.0487 }
],
[
{ className: 'Neutral', probability: 0.8541 },
{ className: 'Sexy', probability: 0.1027 },
{ className: 'Porn', probability: 0.0424 }
]
]

const predictions = await model.classifyGif(buffer, { topk: 3, fps: 0.4 })
expect(predictions.length).toBe(3)

let index = 0
predictions[index].map((actualObj, id) => {
expect(roundPredicitonProbability(actualObj)).toEqual(expectedResults[index][id])
})

index = 1
predictions[index].map((actualObj, id) => {
expect(roundPredicitonProbability(actualObj)).toEqual(expectedResults[index][id])
})

index = 2
predictions[index].map((actualObj, id) => {
expect(roundPredicitonProbability(actualObj)).toEqual(expectedResults[index][id])
})
},
timeoutMS
)

it("0 fps - single frame from the middle", async () => {
const predictions = await model.classifyGif(buffer, { topk: 3, fps: 0 })
expect(predictions.length).toBe(1)
},
timeoutMS
)

// Takes too long
it.skip("All frames", async () => {
const predictions = await model.classifyGif(buffer, { topk: 3 })
console.log(predictions.length)
expect(predictions.length).toBe(190)
},
timeoutMS
)
})
4 changes: 2 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
},
"repository": {
"type": "git",
"url": "git+/~https://github.com/infinitered/nsfwjs.git"
"url": "git+/~https://github.com/nsfw-filter/nsfwjs.git"
},
"dependencies": {
"libgif": "0.0.3"
"@nsfw-filter/gif-frames": "1.0.2"
},
"peerDependencies": {
"@tensorflow/tfjs": "^1.7.4"
Expand Down
99 changes: 64 additions & 35 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
import * as tf from "@tensorflow/tfjs";
import { NSFW_CLASSES } from "./nsfw_classes";
import * as SuperGif from "libgif";
import gifFrames, { GifFrameCanvas, GifFrameBuffer } from "@nsfw-filter/gif-frames";

interface frameResult {
export type frameResult = {
index: number;
totalFrames: number;
predictions: Array<Object>;
predictions: Array<predictionType>;
image: HTMLCanvasElement | ImageData;
}

interface classifyConfig {
export type classifyConfig = {
topk?: number;
onFrame?: (result: frameResult) => {};
setGifControl?: (gifControl: typeof SuperGif) => {};
fps?: number;
onFrame?: (result: frameResult) => any;
}

interface nsfwjsOptions {
size?: number;
type?: string;
}

export type predictionType = {
className: string
probability: number
}

const BASE_PATH =
"https://s3.amazonaws.com/ir_public/nsfwjscdn/TFJS_nsfw_mobilenet/tfjs_quant_nsfw_mobilenet/";
const IMAGE_SIZE = 224; // default to Mobilenet v2
Expand Down Expand Up @@ -174,7 +180,7 @@ export class NSFWJS {
| HTMLCanvasElement
| HTMLVideoElement,
topk = 5
): Promise<Array<{ className: string; probability: number }>> {
): Promise<Array<predictionType>> {
const logits = this.infer(img) as tf.Tensor2D;

const classes = await getTopKClasses(logits, topk);
Expand All @@ -192,41 +198,64 @@ export class NSFWJS {
* @param config param configuration for run
*/
async classifyGif(
gif: HTMLImageElement,
gif: HTMLImageElement | Buffer,
config: classifyConfig = { topk: 5 }
): Promise<Array<Array<{ className: string; probability: number }>>> {
let gifObj = new SuperGif({ gif });
return new Promise((resolve) => {
gifObj.load(async () => {
const arrayOfClasses = [];
const gifLength = gifObj.get_length();
for (let i = 1; i <= gifLength; i++) {
gifObj.move_to(i);
const classes = await this.classify(gifObj.get_canvas(), config.topk);
// Update to onFrame
if (config.onFrame)
config.onFrame({
index: i,
totalFrames: gifLength,
predictions: classes,
});
// Store the goods
arrayOfClasses.push(classes);
): Promise<Array<Array<predictionType>>> {
let frameData: (GifFrameCanvas | GifFrameBuffer)[] = []

if (Buffer.isBuffer(gif)) {
frameData = await gifFrames({ url: gif, frames: 'all', outputType: 'jpg' });
} else {
frameData = await gifFrames({ url: gif.src, frames: 'all', outputType: 'canvas' });
}

let acceptedFrames: number[] = [];
if (typeof config.fps !== 'number') {
acceptedFrames = frameData.map((_element, index) => index);
} else {
let totalTimeInMs = 0;
for (let i = 0; i < frameData.length; i++) {
totalTimeInMs = totalTimeInMs + (frameData[i].frameInfo.delay * 10);
}

const totalFrames: number = Math.floor(totalTimeInMs / 1000 * config.fps);
if (totalFrames <= 1) {
acceptedFrames = [Math.floor(frameData.length / 2)];
} else if (totalFrames >= frameData.length) {
acceptedFrames = frameData.map((_element, index) => index);
} else {
const interval: number = Math.floor(frameData.length / (totalFrames + 1));
for (let i = 1; i <= totalFrames; i++) {
acceptedFrames.push(i * interval);
}
// save gifObj if needed
config.setGifControl && config.setGifControl(gifObj);
// try to clean up
gifObj = null;
resolve(arrayOfClasses);
});
});
}
}

const arrayOfPredictions: predictionType[][] = []
for (let i = 0; i < acceptedFrames.length; i++) {
const image = frameData[acceptedFrames[i]].getImage()
const predictions = await this.classify(image, config.topk);

if (typeof config.onFrame === 'function') {
config.onFrame({
index: acceptedFrames[i],
totalFrames: frameData.length,
predictions,
image
});
}

arrayOfPredictions.push(predictions);
}

return arrayOfPredictions;
}
}

async function getTopKClasses(
logits: tf.Tensor2D,
topK: number
): Promise<Array<{ className: string; probability: number }>> {
): Promise<Array<predictionType>> {
const values = await logits.data();

const valuesAndIndices = [];
Expand All @@ -243,7 +272,7 @@ async function getTopKClasses(
topkIndices[i] = valuesAndIndices[i].index;
}

const topClassesAndProbs = [];
const topClassesAndProbs: predictionType[] = [];
for (let i = 0; i < topkIndices.length; i++) {
topClassesAndProbs.push({
className: NSFW_CLASSES[topkIndices[i]],
Expand Down
1 change: 0 additions & 1 deletion src/libgif.d.ts

This file was deleted.

1 change: 1 addition & 0 deletions tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"noImplicitThis": true,
"alwaysStrict": true,
"noUnusedParameters": false,
"esModuleInterop": true,
"pretty": true,
"noFallthroughCasesInSwitch": true,
"allowUnreachableCode": false,
Expand Down
Loading