Skip to content

Commit

Permalink
✨ (core): Device reconnection with 2s timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
jdabbech-ledger committed Jul 3, 2024
1 parent 2a9d8f7 commit 3b59289
Show file tree
Hide file tree
Showing 8 changed files with 168 additions and 49 deletions.
5 changes: 5 additions & 0 deletions .changeset/rare-tips-stare.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@ledgerhq/device-sdk-core": patch
---

Device reconnection on app change
22 changes: 19 additions & 3 deletions packages/core/src/internal/discovery/use-case/ConnectUseCase.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ import { loggerTypes } from "@internal/logger-publisher/di/loggerTypes";
import { LoggerPublisherService } from "@internal/logger-publisher/service/LoggerPublisherService";
import { usbDiTypes } from "@internal/usb/di/usbDiTypes";
import type { UsbHidTransport } from "@internal/usb/transport/UsbHidTransport";
import type { DisconnectHandler } from "@internal/usb/transport/WebUsbHidTransport";
import {
DisconnectHandler,
ReconnectHandler,
} from "@internal/usb/transport/WebUsbHidTransport";

/**
* The arguments for the ConnectUseCase.
Expand Down Expand Up @@ -45,18 +48,31 @@ export class ConnectUseCase {
this._logger = loggerFactory("ConnectUseCase");
}

private handleHardwareDisconnect: DisconnectHandler = (deviceId) => {
private handleDeviceDisconnect: DisconnectHandler = (deviceId) => {
const deviceSessionOrError =
this._sessionService.getDeviceSessionByDeviceId(deviceId);
deviceSessionOrError.map((deviceSession) => {
this._sessionService.removeDeviceSession(deviceSession.id);
});
};

private handleDeviceReconnect: ReconnectHandler = (
deviceId,
deviceConnection,
) => {
const deviceSessionOrError =
this._sessionService.getDeviceSessionByDeviceId(deviceId);
deviceSessionOrError.map((deviceSession) => {
const { connectedDevice } = deviceSession;
connectedDevice.sendApdu = deviceConnection.sendApdu;
});
};

async execute({ deviceId }: ConnectUseCaseArgs): Promise<DeviceSessionId> {
const either = await this._usbHidTransport.connect({
deviceId,
onDisconnect: this.handleHardwareDisconnect,
onDisconnect: this.handleDeviceDisconnect,
onReconnect: this.handleDeviceReconnect,
});

return either.caseOf({
Expand Down
1 change: 1 addition & 0 deletions packages/core/src/internal/usb/data/UsbHidConfig.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// [SHOULD] Move it to device-model module
export const LEDGER_VENDOR_ID = 0x2c97;
export const FRAME_SIZE = 64;
export const RECONNECT_DEVICE_TIMEOUT = 5000;
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export type ConnectedDeviceConstructorArgs = {
export class InternalConnectedDevice {
public readonly id: DeviceId;
public readonly deviceModel: InternalDeviceModel;
public readonly sendApdu: SendApduFnType;
public sendApdu: SendApduFnType;
public readonly type: ConnectionType;

constructor({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ export class UsbHidDeviceConnection implements DeviceConnection {
this._logger.debug("Sending Frame", {
data: { frame: frame.getRawData() },
});
await this._device.sendReport(0, frame.getRawData());
try {
await this._device.sendReport(0, frame.getRawData());
} catch (error) {
this._logger.error("Error sending frame", { data: { error } });
}
}

return new Promise((resolve) => {
Expand Down
6 changes: 5 additions & 1 deletion packages/core/src/internal/usb/transport/UsbHidTransport.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ import { SdkError } from "@api/Error";
import { ConnectError } from "@internal/usb/model/Errors";
import { InternalConnectedDevice } from "@internal/usb/model/InternalConnectedDevice";
import { InternalDiscoveredDevice } from "@internal/usb/model/InternalDiscoveredDevice";
import type { DisconnectHandler } from "@internal/usb/transport/WebUsbHidTransport";
import type {
DisconnectHandler,
ReconnectHandler,
} from "@internal/usb/transport/WebUsbHidTransport";

/**
* Transport interface representing a USB HID communication
Expand All @@ -27,6 +30,7 @@ export interface UsbHidTransport {
connect(params: {
deviceId: DeviceId;
onDisconnect: DisconnectHandler;
onReconnect: ReconnectHandler;
}): Promise<Either<ConnectError, InternalConnectedDevice>>;

disconnect(params: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { Left, Right } from "purify-ts";
import { DeviceModel, DeviceModelId } from "@api/device/DeviceModel";
import { StaticDeviceModelDataSource } from "@internal/device-model/data/StaticDeviceModelDataSource";
import { DefaultLoggerPublisherService } from "@internal/logger-publisher/service/DefaultLoggerPublisherService";
import { RECONNECT_DEVICE_TIMEOUT } from "@internal/usb/data/UsbHidConfig";
import {
DeviceNotRecognizedError,
NoAccessibleDeviceError,
Expand Down Expand Up @@ -239,7 +240,11 @@ describe("WebUsbHidTransport", () => {

describe("connect", () => {
it("should throw UnknownDeviceError if no internal device", async () => {
const connectParams = { deviceId: "fake", onDisconnect: jest.fn() };
const connectParams = {
deviceId: "fake",
onDisconnect: jest.fn(),
onReconnect: jest.fn(),
};

const connect = await transport.connect(connectParams);

Expand All @@ -249,7 +254,11 @@ describe("WebUsbHidTransport", () => {
});

it("should throw OpeningConnectionError if the device is already opened", async () => {
const device = { deviceId: "fake", onDisconnect: jest.fn() };
const device = {
deviceId: "fake",
onDisconnect: jest.fn(),
onReconnect: jest.fn(),
};

const connect = await transport.connect(device);

Expand All @@ -275,6 +284,7 @@ describe("WebUsbHidTransport", () => {
.connect({
deviceId: discoveredDevice.id,
onDisconnect: jest.fn(),
onReconnect: jest.fn(),
})
.then((value) => {
expect(value).toStrictEqual(
Expand Down Expand Up @@ -309,6 +319,7 @@ describe("WebUsbHidTransport", () => {
.connect({
deviceId: discoveredDevice.id,
onDisconnect: jest.fn(),
onReconnect: jest.fn(),
})
.then((connectedDevice) => {
connectedDevice
Expand Down Expand Up @@ -341,6 +352,7 @@ describe("WebUsbHidTransport", () => {
.connect({
deviceId: discoveredDevice.id,
onDisconnect: jest.fn(),
onReconnect: jest.fn(),
})
.then((connectedDevice) => {
connectedDevice
Expand All @@ -366,6 +378,9 @@ describe("WebUsbHidTransport", () => {
});

describe("disconnect", () => {
beforeAll(() => {
jest.useFakeTimers();
});
it("should throw an error if the device is not connected", async () => {
// given
const connectedDevice = connectedDeviceStubBuilder();
Expand All @@ -389,6 +404,7 @@ describe("WebUsbHidTransport", () => {
.connect({
deviceId: discoveredDevice.id,
onDisconnect: jest.fn(),
onReconnect: jest.fn(),
})
.then((connectedDevice) => {
connectedDevice
Expand Down Expand Up @@ -420,6 +436,8 @@ describe("WebUsbHidTransport", () => {
it("should call disconnect handler if a connected device is unplugged", (done) => {
// given
const onDisconnect = jest.fn();
const onReconnect = jest.fn();
const disconnectSpy = jest.spyOn(transport, "disconnect");
mockedRequestDevice.mockResolvedValueOnce([stubDevice]);

// when
Expand All @@ -429,15 +447,17 @@ describe("WebUsbHidTransport", () => {
.connect({
deviceId: discoveredDevice.id,
onDisconnect,
onReconnect,
})
.then(() => {
// @ts-expect-error trying to access private member
transport.handleDeviceDisconnectionEvent({
device: { productId: stubDevice.productId },
} as Event);

jest.advanceTimersByTime(RECONNECT_DEVICE_TIMEOUT);
// then
expect(onDisconnect).toHaveBeenCalled();
expect(disconnectSpy).toHaveBeenCalled();
done();
})
.catch((error) => {
Expand Down
Loading

0 comments on commit 3b59289

Please sign in to comment.