ws-adapter.ts 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. import { INestApplicationContext, Logger } from '@nestjs/common';
  2. import { loadPackage } from '@nestjs/common/utils/load-package.util';
  3. import { AbstractWsAdapter } from '@nestjs/websockets';
  4. import {
  5. CLOSE_EVENT,
  6. CONNECTION_EVENT,
  7. ERROR_EVENT,
  8. } from '@nestjs/websockets/constants';
  9. import { MessageMappingProperties } from '@nestjs/websockets/gateway-metadata-explorer';
  10. import * as http from 'http';
  11. import { EMPTY, fromEvent, Observable } from 'rxjs';
  12. import { filter, first, mergeMap, share, takeUntil } from 'rxjs/operators';
  13. let wsPackage: any = {};
  14. enum READY_STATE {
  15. CONNECTING_STATE = 0,
  16. OPEN_STATE = 1,
  17. CLOSING_STATE = 2,
  18. CLOSED_STATE = 3,
  19. }
  20. type HttpServerRegistryKey = number;
  21. type HttpServerRegistryEntry = any;
  22. type WsServerRegistryKey = number;
  23. type WsServerRegistryEntry = any[];
  24. const UNDERLYING_HTTP_SERVER_PORT = 0;
  25. export class WsAdapter extends AbstractWsAdapter {
  26. protected readonly logger = new Logger(WsAdapter.name);
  27. protected readonly httpServersRegistry = new Map<
  28. HttpServerRegistryKey,
  29. HttpServerRegistryEntry
  30. >();
  31. protected readonly wsServersRegistry = new Map<
  32. WsServerRegistryKey,
  33. WsServerRegistryEntry
  34. >();
  35. constructor(appOrHttpServer?: INestApplicationContext | any) {
  36. super(appOrHttpServer);
  37. wsPackage = loadPackage('ws', 'WsAdapter', () => require('ws'));
  38. }
  39. public create(
  40. port: number,
  41. options?: Record<string, any> & { namespace?: string; server?: any },
  42. ) {
  43. const { server, ...wsOptions } = options;
  44. if (wsOptions?.namespace) {
  45. const error = new Error(
  46. '"WsAdapter" does not support namespaces. If you need namespaces in your project, consider using the "@nestjs/platform-socket.io" package instead.',
  47. );
  48. this.logger.error(error);
  49. throw error;
  50. }
  51. if (port === UNDERLYING_HTTP_SERVER_PORT && this.httpServer) {
  52. this.ensureHttpServerExists(port, this.httpServer);
  53. const wsServer = this.bindErrorHandler(
  54. new wsPackage.Server({
  55. noServer: true,
  56. ...wsOptions,
  57. }),
  58. );
  59. this.addWsServerToRegistry(wsServer, port, options.path || '/');
  60. return wsServer;
  61. }
  62. if (server) {
  63. return server;
  64. }
  65. if (options.path && port !== UNDERLYING_HTTP_SERVER_PORT) {
  66. // Multiple servers with different paths
  67. // sharing a single HTTP/S server running on different port
  68. // than a regular HTTP application
  69. const httpServer = this.ensureHttpServerExists(port);
  70. httpServer?.listen(port);
  71. const wsServer = this.bindErrorHandler(
  72. new wsPackage.Server({
  73. noServer: true,
  74. ...wsOptions,
  75. }),
  76. );
  77. this.addWsServerToRegistry(wsServer, port, options.path);
  78. return wsServer;
  79. }
  80. const wsServer = this.bindErrorHandler(
  81. new wsPackage.Server({
  82. port,
  83. ...wsOptions,
  84. }),
  85. );
  86. return wsServer;
  87. }
  88. public bindMessageHandlers(
  89. client: any,
  90. handlers: MessageMappingProperties[],
  91. transform: (data: any) => Observable<any>,
  92. ) {
  93. const close$ = fromEvent(client, CLOSE_EVENT).pipe(share(), first());
  94. const source$ = fromEvent(client, 'message').pipe(
  95. mergeMap((data) =>
  96. this.bindMessageHandler(data, handlers, transform).pipe(
  97. filter((result) => result),
  98. ),
  99. ),
  100. takeUntil(close$),
  101. );
  102. const onMessage = (response: any) => {
  103. if (client.readyState !== READY_STATE.OPEN_STATE) {
  104. return;
  105. }
  106. client.send(JSON.stringify(response));
  107. };
  108. source$.subscribe(onMessage);
  109. }
  110. public bindMessageHandler(
  111. buffer: any,
  112. handlers: MessageMappingProperties[],
  113. transform: (data: any) => Observable<any>,
  114. ): Observable<any> {
  115. try {
  116. const message = JSON.parse(buffer.data);
  117. const messageHandler = handlers.find(
  118. (handler) => handler.message === message.id,
  119. );
  120. const { callback } = messageHandler;
  121. return transform(callback(message.data));
  122. } catch {
  123. return EMPTY;
  124. }
  125. }
  126. public bindErrorHandler(server: any) {
  127. server.on(CONNECTION_EVENT, (ws: any) =>
  128. ws.on(ERROR_EVENT, (err: any) => this.logger.error(err)),
  129. );
  130. server.on(ERROR_EVENT, (err: any) => this.logger.error(err));
  131. return server;
  132. }
  133. public bindClientDisconnect(client: any, callback: Function) {
  134. client.on(CLOSE_EVENT, callback);
  135. }
  136. public async dispose() {
  137. const closeEventSignals = Array.from(this.httpServersRegistry)
  138. .filter(([port]) => port !== UNDERLYING_HTTP_SERVER_PORT)
  139. .map(([_, server]) => new Promise((resolve) => server.close(resolve)));
  140. await Promise.all(closeEventSignals);
  141. this.httpServersRegistry.clear();
  142. this.wsServersRegistry.clear();
  143. }
  144. protected ensureHttpServerExists(
  145. port: number,
  146. httpServer = http.createServer(),
  147. ) {
  148. if (this.httpServersRegistry.has(port)) {
  149. return;
  150. }
  151. this.httpServersRegistry.set(port, httpServer);
  152. httpServer.on('upgrade', (request, socket, head) => {
  153. const baseUrl = 'ws://' + request.headers.host + '/';
  154. const pathname = new URL(request.url, baseUrl).pathname;
  155. const wsServersCollection = this.wsServersRegistry.get(port);
  156. let isRequestDelegated = false;
  157. for (const wsServer of wsServersCollection) {
  158. if (pathname === wsServer.path) {
  159. wsServer.handleUpgrade(request, socket, head, (ws: unknown) => {
  160. wsServer.emit('connection', ws, request);
  161. });
  162. isRequestDelegated = true;
  163. break;
  164. }
  165. }
  166. if (!isRequestDelegated) {
  167. socket.destroy();
  168. }
  169. });
  170. return httpServer;
  171. }
  172. protected addWsServerToRegistry<T extends Record<'path', string> = any>(
  173. wsServer: T,
  174. port: number,
  175. path: string,
  176. ) {
  177. const entries = this.wsServersRegistry.get(port) ?? [];
  178. entries.push(wsServer);
  179. wsServer.path = path;
  180. this.wsServersRegistry.set(port, entries);
  181. }
  182. }