import { parser } from '../../../src/parser/shrimp' import * as Terms from '../../../src/parser/shrimp.terms' import { SyntaxNode } from '@lezer/common' import { TextDocument } from 'vscode-languageserver-textdocument' import { SemanticTokensBuilder, SemanticTokenTypes, SemanticTokenModifiers, } from 'vscode-languageserver/node' export const TOKEN_TYPES = [ SemanticTokenTypes.function, SemanticTokenTypes.variable, SemanticTokenTypes.string, SemanticTokenTypes.number, SemanticTokenTypes.operator, SemanticTokenTypes.keyword, SemanticTokenTypes.parameter, SemanticTokenTypes.property, SemanticTokenTypes.regexp, SemanticTokenTypes.comment, ] export const TOKEN_MODIFIERS = [ SemanticTokenModifiers.declaration, SemanticTokenModifiers.modification, SemanticTokenModifiers.readonly, ] export function buildSemanticTokens(document: TextDocument): number[] { const text = document.getText() const tree = parser.parse(text) const builder = new SemanticTokensBuilder() walkTree(tree.topNode, document, builder) return builder.build().data } // Walk the tree and collect tokens function walkTree(node: SyntaxNode, document: TextDocument, builder: SemanticTokensBuilder) { const tokenInfo = getTokenType(node.type.id, node.parent?.type.id) if (tokenInfo !== undefined) { const start = document.positionAt(node.from) const length = node.to - node.from builder.push(start.line, start.character, length, tokenInfo.type, tokenInfo.modifiers) } let child = node.firstChild while (child) { walkTree(child, document, builder) child = child.nextSibling } } // Map Lezer node IDs to semantic token type indices and modifiers function getTokenType( nodeTypeId: number, parentTypeId?: number ): { type: number; modifiers: number } | undefined { switch (nodeTypeId) { case Terms.Identifier: // Check parent to determine context if (parentTypeId === Terms.FunctionCall) { return { type: TOKEN_TYPES.indexOf(SemanticTokenTypes.function), modifiers: 0, } } if (parentTypeId === Terms.FunctionDef) { return { type: TOKEN_TYPES.indexOf(SemanticTokenTypes.function), modifiers: getModifierBits(SemanticTokenModifiers.declaration), } } if (parentTypeId === Terms.FunctionCallOrIdentifier) { return { type: TOKEN_TYPES.indexOf(SemanticTokenTypes.function), modifiers: 0, } } if (parentTypeId === Terms.Params) { return { type: TOKEN_TYPES.indexOf(SemanticTokenTypes.parameter), modifiers: 0, } } if (parentTypeId === Terms.DotGet) { return { type: TOKEN_TYPES.indexOf(SemanticTokenTypes.property), modifiers: 0, } } // Otherwise it's a regular variable return { type: TOKEN_TYPES.indexOf(SemanticTokenTypes.variable), modifiers: 0, } case Terms.IdentifierBeforeDot: return { type: TOKEN_TYPES.indexOf(SemanticTokenTypes.variable), modifiers: 0, } case Terms.NamedArgPrefix: return { type: TOKEN_TYPES.indexOf(SemanticTokenTypes.property), modifiers: 0, } case Terms.AssignableIdentifier: return { type: TOKEN_TYPES.indexOf(SemanticTokenTypes.variable), modifiers: getModifierBits(SemanticTokenModifiers.modification), } case Terms.String: case Terms.StringFragment: case Terms.Word: return { type: TOKEN_TYPES.indexOf(SemanticTokenTypes.string), modifiers: 0, } case Terms.Number: return { type: TOKEN_TYPES.indexOf(SemanticTokenTypes.number), modifiers: 0, } case Terms.Plus: case Terms.Minus: case Terms.Star: case Terms.Slash: case Terms.Eq: case Terms.EqEq: case Terms.Neq: case Terms.Lt: case Terms.Lte: case Terms.Gt: case Terms.Gte: case Terms.Modulo: case Terms.And: case Terms.Or: return { type: TOKEN_TYPES.indexOf(SemanticTokenTypes.operator), modifiers: 0, } case Terms.keyword: case Terms.Do: case Terms.colon: return { type: TOKEN_TYPES.indexOf(SemanticTokenTypes.keyword), modifiers: 0, } case Terms.Regex: return { type: TOKEN_TYPES.indexOf(SemanticTokenTypes.regexp), modifiers: 0, } case Terms.Comment: return { type: TOKEN_TYPES.indexOf(SemanticTokenTypes.comment), modifiers: 0, } default: return undefined } } const getModifierBits = (...modifiers: SemanticTokenModifiers[]): number => { let bits = 0 for (const modifier of modifiers) { const index = TOKEN_MODIFIERS.indexOf(modifier) if (index !== -1) bits |= 1 << index } return bits }