Module:Sandbox/Jackmcbarn/ustring.lua

local ustring = {}

-- Copy these, just in case
local S = {
	byte = string.byte,
	char = string.char,
	len = string.len,
	sub = string.sub,
	find = string.find,
	match = string.match,
	gmatch = string.gmatch,
	gsub = string.gsub,
	format = string.format,
}

---- Configuration ----
-- To limit the length of strings or patterns processed, set these
ustring.maxStringLength = math.huge
ustring.maxPatternLength = math.huge

---- Utility functions ----

local function checkType( name, argidx, arg, expecttype, nilok )
	if arg == nil and nilok then
		return
	end
	if type( arg ) ~= expecttype then
		local msg = S.format( "bad argument #%d to '%s' (%s expected, got %s)",
			argidx, name, expecttype, type( arg )
		)
		error( msg, 3 )
	end
end

local function checkString( name, s )
	if type( s ) == 'number' then
		s = tostring( s )
	end
	if type( s ) ~= 'string' then
		local msg = S.format( "bad argument #1 to '%s' (string expected, got %s)",
			name, type( s )
		)
		error( msg, 3 )
	end
	if S.len( s ) > ustring.maxStringLength then
		local msg = S.format( "bad argument #1 to '%s' (string is longer than %d bytes)",
			name, ustring.maxStringLength
		)
		error( msg, 3 )
	end
end

local function checkPattern( name, pattern )
	if type( pattern ) == 'number' then
		pattern = tostring( pattern )
	end
	if type( pattern ) ~= 'string' then
		local msg = S.format( "bad argument #2 to '%s' (string expected, got %s)",
			name, type( pattern )
		)
		error( msg, 3 )
	end
	if S.len( pattern ) > ustring.maxPatternLength then
		local msg = S.format( "bad argument #2 to '%s' (pattern is longer than %d bytes)",
			name, ustring.maxPatternLength
		)
		error( msg, 3 )
	end
end

-- A private helper that splits a string into codepoints, and also collects the
-- starting position of each character and the total length in codepoints.
--
-- @param s string  utf8-encoded string to decode
-- @return table
local function utf8_explode( s )
	local ret = {
		len = 0,
		codepoints = {},
		bytepos = {},
	}

	local i = 1
	local l = S.len( s )
	local cp, b, b2, trail
	local min
	while i <= l do
		b = S.byte( s, i )
		if b < 0x80 then
			-- 1-byte code point, 00-7F
			cp = b
			trail = 0
			min = 0
		elseif b < 0xc2 then
			-- Either a non-initial code point (invalid here) or
			-- an overlong encoding for a 1-byte code point
			return nil
		elseif b < 0xe0 then
			-- 2-byte code point, C2-DF
			trail = 1
			cp = b - 0xc0
			min = 0x80
		elseif b < 0xf0 then
			-- 3-byte code point, E0-EF
			trail = 2
			cp = b - 0xe0
			min = 0x800
		elseif b < 0xf4 then
			-- 4-byte code point, F0-F3
			trail = 3
			cp = b - 0xf0
			min = 0x10000
		elseif b == 0xf4 then
			-- 4-byte code point, F4
			-- Make sure it doesn't decode to over U+10FFFF
			if S.byte( s, i + 1 ) > 0x8f then
				return nil
			end
			trail = 3
			cp = 4
			min = 0x100000
		else
			-- Code point over U+10FFFF, or invalid byte
			return nil
		end

		-- Check subsequent bytes for multibyte code points
		for j = i + 1, i + trail do
			b = S.byte( s, j )
			if not b or b < 0x80 or b > 0xbf then
				return nil
			end
			cp = cp * 0x40 + b - 0x80
		end
		if cp < min then
			-- Overlong encoding
			return nil
		end

		ret.codepoints[#ret.codepoints + 1] = cp
		ret.bytepos[#ret.bytepos + 1] = i
		ret.len = ret.len + 1
		i = i + 1 + trail
	end

	-- Two past the end (for sub with empty string)
	ret.bytepos[#ret.bytepos + 1] = l + 1
	ret.bytepos[#ret.bytepos + 1] = l + 1

	return ret
end

-- A private helper that finds the character offset for a byte offset.
--
-- @param cps table  from utf8_explode
-- @param i int  byte offset
-- @return int
local function cpoffset( cps, i )
	local min, max, p = 0, cps.len + 1
	if i == 0 then
		return 0
	end
	while min + 1 < max do
		p = math.floor( ( min + max ) / 2 ) + 1
		if cps.bytepos[p] <= i then
			min = p - 1
		end
		if cps.bytepos[p] >= i then
			max = p - 1
		end
	end
	return min + 1
end

---- Trivial functions ----
-- These functions are the same as the standard string versions

ustring.byte = string.byte
ustring.format = string.format
ustring.rep = string.rep

---- Non-trivial functions ----
-- These functions actually have to be UTF-8 aware


-- Determine if a string is valid UTF-8
--
-- @param s string
-- @return boolean
function ustring.isutf8( s )
	checkString( 'isutf8', s )
	return utf8_explode( s ) ~= nil
end

-- Return the byte offset of a character in a string
--
-- @param s string
-- @param l int  codepoint number [default 1]
-- @param i int  starting byte offset [default 1]
-- @return int|nil
function ustring.byteoffset( s, l, i )
	checkString( 'byteoffset', s )
	checkType( 'byteoffset', 2, l, 'number', true )
	checkType( 'byteoffset', 3, i, 'number', true )
	local cps = utf8_explode( s )
	if cps == nil then
		error( "bad argument #1 for 'byteoffset' (string is not UTF-8)", 2 )
	end

	i = i or 1
	if i < 0 then
		i = S.len( s ) + i + 1
	end
	if i < 1 or i > S.len( s ) then
		return nil
	end
	local p = cpoffset( cps, i )
	if l > 0 and cps.bytepos[p] == i then
		l = l - 1
	end
	if p + l > cps.len then
		return nil
	end
	return cps.bytepos[p + l]
end

-- Return codepoints from a string
--
-- @see string.byte
-- @param s string
-- @param i int  Starting character [default 1]
-- @param j int  Ending character [default i]
-- @return int*  Zero or more codepoints
function ustring.codepoint( s, i, j )
	checkString( 'codepoint', s )
	checkType( 'codepoint', 2, i, 'number', true )
	checkType( 'codepoint', 3, j, 'number', true )
	local cps = utf8_explode( s )
	if cps == nil then
		error( "bad argument #1 for 'codepoint' (string is not UTF-8)", 2 )
	end
	i = i or 1
	if i < 0 then
		i = cps.len + i + 1
	end
	j = j or i
	if j < 0 then
		j = cps.len + j + 1
	end
	if j < i then
		return -- empty result set
	end
	i = math.max( 1, math.min( i, cps.len + 1 ) )
	j = math.max( 1, math.min( j, cps.len + 1 ) )
	return unpack( cps.codepoints, i, j )
end

-- Return an iterator over the codepoint (as integers)
--   for cp in ustring.gcodepoint( s ) do ... end
--
-- @param s string
-- @param i int  Starting character [default 1]
-- @param j int  Ending character [default -1]
-- @return function
-- @return nil
-- @return nil
function ustring.gcodepoint( s, i, j )
	checkString( 'gcodepoint', s )
	checkType( 'gcodepoint', 2, i, 'number', true )
	checkType( 'gcodepoint', 3, j, 'number', true )
	local cp = { ustring.codepoint( s, i or 1, j or -1 ) }
	return function ()
		return table.remove( cp, 1 )
	end
end

-- Convert codepoints to a string
--
-- @see string.char
-- @param ... int  List of codepoints
-- @return string
local function internalChar( t, s, e )
	local ret = {}
	for i = s, e do
		local v = t[i]
		if type( v ) ~= 'number' then
			checkType( 'char', i, v, 'number' )
		end
		v = math.floor( v )
		if v < 0 or v > 0x10ffff then
			error( S.format( "bad argument #%d to 'char' (value out of range)", i ), 2 )
		elseif v < 0x80 then
			ret[#ret + 1] = v
		elseif v < 0x800 then
			ret[#ret + 1] = 0xc0 + math.floor( v / 0x40 ) % 0x20
			ret[#ret + 1] = 0x80 + v % 0x40
		elseif v < 0x10000 then
			ret[#ret + 1] = 0xe0 + math.floor( v / 0x1000 ) % 0x10
			ret[#ret + 1] = 0x80 + math.floor( v / 0x40 ) % 0x40
			ret[#ret + 1] = 0x80 + v % 0x40
		else
			ret[#ret + 1] = 0xf0 + math.floor( v / 0x40000 ) % 0x08
			ret[#ret + 1] = 0x80 + math.floor( v / 0x1000 ) % 0x40
			ret[#ret + 1] = 0x80 + math.floor( v / 0x40 ) % 0x40
			ret[#ret + 1] = 0x80 + v % 0x40
		end
	end
	return S.char( unpack( ret ) )
end
function ustring.char( ... )
	return internalChar( { ... }, 1, select( '#', ... ) )
end

-- Return the length of a string in codepoints, or
-- nil if the string is not valid UTF-8.
--
-- @see string.len
-- @param string
-- @return int|nil
function ustring.len( s )
	checkString( 'len', s )
	local cps = utf8_explode( s )
	if cps == nil then
		return nil
	else
		return cps.len
	end
end

-- Private function to return a substring of a string
--
-- @param s string
-- @param cps table  Exploded string
-- @param i int  Starting character [default 1]
-- @param j int  Ending character [default -1]
-- @return string
local function sub( s, cps, i, j )
	return S.sub( s, cps.bytepos[i], cps.bytepos[j+1] - 1 )
end

-- Return a substring of a string
--
-- @see string.sub
-- @param s string
-- @param i int  Starting character [default 1]
-- @param j int  Ending character [default -1]
-- @return string
function ustring.sub( s, i, j )
	checkString( 'sub', s )
	checkType( 'sub', 2, i, 'number', true )
	checkType( 'sub', 3, j, 'number', true )
	local cps = utf8_explode( s )
	if cps == nil then
		error( "bad argument #1 for 'sub' (string is not UTF-8)", 2 )
	end
	i = i or 1
	if i < 0 then
		i = cps.len + i + 1
	end
	j = j or -1
	if j < 0 then
		j = cps.len + j + 1
	end
	if j < i then
		return ''
	end
	i = math.max( 1, math.min( i, cps.len + 1 ) )
	j = math.max( 1, math.min( j, cps.len + 1 ) )
	return sub( s, cps, i, j )
end

---- Table-driven functions ----
-- These functions load a conversion table when called

-- Convert a string to uppercase
--
-- @see string.upper
-- @param s string
-- @return string
function ustring.upper( s )
	checkString( 'upper', s )
	local map = require 'ustring/upper';
	local ret = S.gsub( s, '([^\128-\191][\128-\191]*)', map )
	return ret
end

-- Convert a string to lowercase
--
-- @see string.lower
-- @param s string
-- @return string
function ustring.lower( s )
	checkString( 'lower', s )
	local map = require 'ustring/lower';
	local ret = S.gsub( s, '([^\128-\191][\128-\191]*)', map )
	return ret
end

---- Pattern functions ----
-- Ugh. Just ugh.

-- Cache for character sets (e.g. [a-z])
local charset_cache = {}
setmetatable( charset_cache, { __weak = 'kv' } )

-- Private function to find a pattern in a string
-- Yes, this basically reimplements the whole of Lua's pattern matching, in
-- Lua.
--
-- @see ustring.find
-- @param s string
-- @param cps table  Exploded string
-- @param rawpat string  Pattern
-- @param pattern table  Exploded pattern
-- @param init int  Starting index
-- @param noAnchor boolean  True to ignore '^'
-- @return int starting index of the match
-- @return int ending index of the match
-- @return string|int* captures
local function find( s, cps, rawpat, pattern, init, noAnchor )
	local charsets = require 'ustring/charsets'
	local anchor = false
	local ncapt, captures
	local captparen = {}

	-- Extract the value of a capture from the
	-- upvalues ncapt and capture.
	local function getcapt( n, err, errl )
		if n > ncapt then
			error( err, errl + 1 )
		elseif type( captures[n] ) == 'table' then
			if captures[n][2] == '' then
				error( err, errl + 1 )
			end
			return sub( s, cps, captures[n][1], captures[n][2] ), captures[n][2] - captures[n][1] + 1
		else
			return captures[n], math.floor( math.log10( captures[n] ) ) + 1
		end
	end

	local match, match_charset, parse_charset

	-- Main matching function. Uses tail recursion where possible.
	-- Returns the position of the character after the match, and updates the
	-- upvalues ncapt and captures.
	match = function ( sp, pp )
		local c = pattern.codepoints[pp]
		if c == 0x28 then -- '(': starts capture group
			ncapt = ncapt + 1
			captparen[ncapt] = pp
			local ret
			if pattern.codepoints[pp + 1] == 0x29 then -- ')': Pattern is '()', capture position
				captures[ncapt] = sp
				ret = match( sp, pp + 2 )
			else
				-- Start capture group
				captures[ncapt] = { sp, '' }
				ret = match( sp, pp + 1 )
			end
			if ret then
				return ret
			else
				-- Failed, rollback
				ncapt = ncapt - 1
				return nil
			end
		elseif c == 0x29 then -- ')': ends capture group, pop current capture index from stack
			for n = ncapt, 1, -1 do
				if type( captures[n] ) == 'table' and captures[n][2] == '' then
					captures[n][2] = sp - 1
					local ret = match( sp, pp + 1 )
					if ret then
						return ret
					else
						-- Failed, rollback
						captures[n][2] = ''
						return nil
					end
				end
			end
			error( 'Unmatched close-paren at pattern character ' .. pp, 3 )
		elseif c == 0x5b then -- '[': starts character set
			return match_charset( sp, parse_charset( pp ) )
		elseif c == 0x5d then -- ']'
			error( 'Unmatched close-bracket at pattern character ' .. pp, 3 )
		elseif c == 0x25 then -- '%'
			c = pattern.codepoints[pp + 1]
			if charsets[c] then -- A character set like '%a'
				return match_charset( sp, pp + 2, charsets[c] )
			elseif c == 0x62 then -- '%b': balanced delimiter match
				local d1 = pattern.codepoints[pp + 2]
				local d2 = pattern.codepoints[pp + 3]
				if not d1 or not d2 then
					error( 'malformed pattern (missing arguments to \'%b\')', 3 )
				end
				if cps.codepoints[sp] ~= d1 then
					return nil
				end
				sp = sp + 1
				local ct = 1
				while true do
					c = cps.codepoints[sp]
					sp = sp + 1
					if not c then
						return nil
					elseif c == d2 then
						if ct == 1 then
							return match( sp, pp + 4 )
						end
						ct = ct - 1
					elseif c == d1 then
						ct = ct + 1
					end
				end
			elseif c == 0x66 then -- '%f': frontier pattern match
				if pattern.codepoints[pp + 2] ~= 0x5b then
					error( 'missing \'[\' after %f in pattern at pattern character ' .. pp, 3 )
				end
				local pp, charset = parse_charset( pp + 2 )
				local c1 = cps.codepoints[sp - 1] or 0
				local c2 = cps.codepoints[sp] or 0
				if not charset[c1] and charset[c2] then
					return match( sp, pp )
				else
					return nil
				end
			elseif c >= 0x30 and c <= 0x39 then -- '%0' to '%9': backreference
				local m, l = getcapt( c - 0x30, 'invalid capture index %' .. c .. ' at pattern character ' .. pp, 3 )
				local ep = math.min( cps.len + 1, sp + l )
				if sub( s, cps, sp, ep - 1 ) == m then
					return match( ep, pp + 2 )
				else
					return nil
				end
			elseif not c then -- percent at the end of the pattern
				error( 'malformed pattern (ends with \'%\')', 3 )
			else -- something else, treat as a literal
				return match_charset( sp, pp + 2, { [c] = 1 } )
			end
		elseif c == 0x2e then -- '.': match anything
			if not charset_cache['.'] then
				local t = {}
				setmetatable( t, { __index = function ( t, k ) return k end } )
				charset_cache['.'] = { 1, t }
			end
			return match_charset( sp, pp + 1, charset_cache['.'][2] )
		elseif c == nil then -- end of pattern
			return sp
		elseif c == 0x24 and pattern.len == pp then -- '$': assert end of string
			return ( sp == cps.len + 1 ) and sp or nil
		else
			-- Any other character matches itself
			return match_charset( sp, pp + 1, { [c] = 1 } )
		end
	end

	-- Parse a bracketed character set (e.g. [a-z])
	-- Returns the position after the set and a table holding the matching characters
	parse_charset = function ( pp )
		local _, ep
		local epp = pattern.bytepos[pp]
		repeat
			_, ep = S.find( rawpat, ']', epp, true )
			if not ep then
				error( 'Missing close-bracket for character set beginning at pattern character ' .. pp, 3 )
			end
			epp = ep + 1
		until S.byte( rawpat, ep - 1 ) ~= 0x25 or S.byte( rawpat, ep - 2 ) == 0x25
		local key = S.sub( rawpat, pattern.bytepos[pp], ep )
		if charset_cache[key] then
			local pl, cs = unpack( charset_cache[key] )
			return pp + pl, cs
		end

		local p0 = pp
		local cs = {}
		local csrefs = { cs }
		local invert = false
		pp = pp + 1
		if pattern.codepoints[pp] == 0x5e then -- '^'
			invert = true
			pp = pp + 1
		end
		while true do
			local c = pattern.codepoints[pp]
			if c == 0x25 then -- '%'
				c = pattern.codepoints[pp + 1]
				if charsets[c] then
					csrefs[#csrefs + 1] = charsets[c]
				else
					cs[c] = 1
				end
				pp = pp + 2
			elseif pattern.codepoints[pp + 1] == 0x2d and pattern.codepoints[pp + 2] and pattern.codepoints[pp + 2] ~= 0x5d then -- '-' followed by another char (not ']'), it's a range
				for i = c, pattern.codepoints[pp + 2] do
					cs[i] = 1
				end
				pp = pp + 3
			elseif c == 0x5d then -- closing ']'
				pp = pp + 1
				break
			elseif not c then -- Should never get here, but Just In Case...
				error( 'Missing close-bracket', 3 )
			else
				cs[c] = 1
				pp = pp + 1
			end
		end

		local ret
		if not csrefs[2] then
			if not invert then
				-- If there's only the one charset table, we can use it directly
				ret = cs
			else
				-- Simple invert
				ret = {}
				setmetatable( ret, { __index = function ( t, k ) return k and not cs[k] end } )
			end
		else
			-- Ok, we have to iterate over multiple charset tables
			ret = {}
			setmetatable( ret, { __index = function ( t, k )
				if not k then
					return nil
				end
				for i = 1, #csrefs do
					if csrefs[i][k] then
						return not invert
					end
				end
				return invert
			end } )
		end

		charset_cache[key] = { pp - p0, ret }
		return pp, ret
	end

	-- Match a character set table with optional quantifier, followed by
	-- the rest of the pattern.
	-- Returns same as 'match' above.
	match_charset = function ( sp, pp, charset )
		local q = pattern.codepoints[pp]
		if q == 0x2a then -- '*', 0 or more matches
			pp = pp + 1
			local i = 0
			while charset[cps.codepoints[sp + i]] do
				i = i + 1
			end
			while i >= 0 do
				local ret = match( sp + i, pp )
				if ret then
					return ret
				end
				i = i - 1
			end
			return nil
		elseif q == 0x2b then -- '+', 1 or more matches
			pp = pp + 1
			local i = 0
			while charset[cps.codepoints[sp + i]] do
				i = i + 1
			end
			while i > 0 do
				local ret = match( sp + i, pp )
				if ret then
					return ret
				end
				i = i - 1
			end
			return nil
		elseif q == 0x2d then -- '-', 0 or more matches non-greedy
			pp = pp + 1
			while true do
				local ret = match( sp, pp )
				if ret then
					return ret
				end
				if not charset[cps.codepoints[sp]] then
					return nil
				end
				sp = sp + 1
			end
		elseif q == 0x3f then -- '?', 0 or 1 match
			pp = pp + 1
			if charset[cps.codepoints[sp]] then
				local ret = match( sp + 1, pp )
				if ret then
					return ret
				end
			end
			return match( sp, pp )
		else -- no suffix, must match 1
			if charset[cps.codepoints[sp]] then
				return match( sp + 1, pp )
			else
				return nil
			end
		end
	end

	init = init or 1
	if init < 0 then
		init = cps.len + init + 1
	end
	init = math.max( 1, math.min( init, cps.len + 1 ) )

	-- Here is the actual match loop. It just calls 'match' on successive
	-- starting positions (or not, if the pattern is anchored) until it finds a
	-- match.
	local sp = init
	local pp = 1
	if not noAnchor and pattern.codepoints[1] == 0x5e then -- '^': Pattern is anchored
		anchor = true
		pp = 2
	end

	repeat
		ncapt, captures = 0, {}
		local ep = match( sp, pp )
		if ep then
			for i = 1, ncapt do
				captures[i] = getcapt( i, 'Unclosed capture beginning at pattern character ' .. captparen[pp], 2 )
			end
			return sp, ep - 1, unpack( captures )
		end
		sp = sp + 1
	until anchor or sp > cps.len + 1
	return nil
end

-- Private function to decide if a pattern looks simple enough to use
-- Lua's built-in string library. The following make a pattern not simple:
--  * If it contains any bytes over 0x7f. We could skip these if they're not
--    inside brackets and aren't followed by quantifiers and aren't part of a
--    '%b', but that's too complicated to check.
--  * If it contains a negated character set.
--  * If it contains "%a" or any of the other %-prefixed character sets except
--    %z or %Z.
--  * If it contains a '.' not followed by '*', '+', or '-'. A bare '.' or '.?'
--    would try to match a partial UTF-8 character, but the others will happily
--    enough match a whole character thinking it's 2 or 4.
--  * If it contains position-captures.
--
-- @param string pattern
-- @return boolean
local function patternIsSimple( pattern )
	return not (
		S.find( pattern, '[\128-\255]' ) or
		S.find( pattern, '%[%^' ) or
		S.find( pattern, '%%[acdlpsuwxACDLPSUWX]' ) or
		S.find( pattern, '%.[^*+-]' ) or
		S.find( pattern, '()', 1, true )
	)
end

-- Find a pattern in a string
--
-- This works just like string.find, with the following changes:
--  * Everything works on UTF-8 characters rather than bytes
--  * Character classes are redefined in terms of Unicode properties:
--    * %a - Letter
--    * %c - Control
--    * %d - Decimal Number
--    * %l - Lower case letter
--    * %p - Punctuation
--    * %s - Separator, plus HT, LF, FF, CR, and VT
--    * %u - Upper case letter
--    * %w - Letter or Decimal Number
--    * %x - [0-9A-Fa-f0-9A-Fa-f]
--
-- @see string.find
-- @param s string
-- @param pattern string  Pattern
-- @param init int  Starting index
-- @param plain boolean  Literal match, no pattern matching
-- @return int starting index of the match
-- @return int ending index of the match
-- @return string|int* captures
function ustring.find( s, pattern, init, plain )
	checkString( 'find', s )
	checkPattern( 'find', pattern )
	checkType( 'find', 3, init, 'number', true )
	checkType( 'find', 4, plain, 'boolean', true )
	local cps = utf8_explode( s )
	if cps == nil then
		error( "bad argument #1 for 'find' (string is not UTF-8)", 2 )
	end
	local pat = utf8_explode( pattern )
	if pat == nil then
		error( "bad argument #2 for 'find' (string is not UTF-8)", 2 )
	end

	if plain or patternIsSimple( pattern ) then
		if init and init > cps.len + 1 then
			init = cps.len + 1
		end
		local m = { S.find( s, pattern, cps.bytepos[init], plain ) }
		if m[1] then
			m[1] = cpoffset( cps, m[1] )
			m[2] = cpoffset( cps, m[2] )
		end
		return unpack( m )
	end

	return find( s, cps, pattern, pat, init )
end

-- Match a string against a pattern
--
-- @see ustring.find
-- @see string.match
-- @param s string
-- @param pattern string
-- @param init int Starting offset for match
-- @return string|int* captures, or the whole match if there are none
function ustring.match( s, pattern, init )
	checkString( 'match', s )
	checkPattern( 'match', pattern )
	checkType( 'match', 3, init, 'number', true )
	local cps = utf8_explode( s )
	if cps == nil then
		error( "bad argument #1 for 'match' (string is not UTF-8)", 2 )
	end
	local pat = utf8_explode( pattern )
	if pat == nil then
		error( "bad argument #2 for 'match' (string is not UTF-8)", 2 )
	end

	if patternIsSimple( pattern ) then
		return S.match( s, pattern, cps.bytepos[init] )
	end

	local m = { find( s, cps, pattern, pat, init ) }
	if not m[1] then
		return nil
	end
	if m[3] then
		return unpack( m, 3 )
	end
	return sub( s, cps, m[1], m[2] )
end

-- Return an iterator function over the matches for a pattern
--
-- @see ustring.find
-- @see string.gmatch
-- @param s string
-- @param pattern string
-- @return function
-- @return nil
-- @return nil
function ustring.gmatch( s, pattern )
	checkString( 'gmatch', s )
	checkPattern( 'gmatch', pattern )
	if patternIsSimple( pattern ) then
		return S.gmatch( s, pattern )
	end

	local cps = utf8_explode( s )
	if cps == nil then
		error( "bad argument #1 for 'gmatch' (string is not UTF-8)", 2 )
	end
	local pat = utf8_explode( pattern )
	if pat == nil then
		error( "bad argument #2 for 'gmatch' (string is not UTF-8)", 2 )
	end
	local init = 1

	return function ()
		local m = { find( s, cps, pattern, pat, init, true ) }
		if not m[1] then
			return nil
		end
		init = m[2] + 1
		if m[3] then
			return unpack( m, 3 )
		end
		return sub( s, cps, m[1], m[2] )
	end
end

-- Replace pattern matches in a string
--
-- @see ustring.find
-- @see string.gsub
-- @param s string
-- @param pattern string
-- @param repl string|function|table
-- @param int n
-- @return string
-- @return int
function ustring.gsub( s, pattern, repl, n )
	checkString( 'gsub', s )
	checkPattern( 'gsub', pattern )
	checkType( 'gsub', 4, n, 'number', true )
	if patternIsSimple( pattern ) then
		return S.gsub( s, pattern, repl, n )
	end

	local cps = utf8_explode( s )
	if cps == nil then
		error( "bad argument #1 for 'gsub' (string is not UTF-8)", 2 )
	end
	local pat = utf8_explode( pattern )
	if pat == nil then
		error( "bad argument #2 for 'gsub' (string is not UTF-8)", 2 )
	end
	if n == nil then
		n = 1e100
	end

	if pat.codepoints[1] == 0x5e then -- '^': Pattern is anchored
		-- There can be only the one match, so make that explicit
		n = 1
	end

	local tp
	if type( repl ) == 'function' then
		tp = 1
	elseif type( repl ) == 'table' then
		tp = 2
	elseif type( repl ) == 'string' then
		tp = 3
	elseif type( repl ) == 'number' then
		repl = tostring( repl )
		tp = 3
	else
		checkType( 'gsub', 3, repl, 'function or table or string' )
	end

	local init = 1
	local ct = 0
	local ret = {}
	while init < cps.len and ct < n do
		local m = { find( s, cps, pattern, pat, init ) }
		if not m[1] then
			break
		end
		if init < m[1] then
			ret[#ret + 1] = sub( s, cps, init, m[1] - 1 )
		end
		local mm = sub( s, cps, m[1], m[2] )
		local val
		if tp == 1 then
			if m[3] then
				val = repl( unpack( m, 3 ) )
			else
				val = repl( mm )
			end
		elseif tp == 2 then
			val = repl[m[3] or mm]
		elseif tp == 3 then
			if ct == 0 and #m < 11 then
				local ss = S.gsub( repl, '%%[%%0-' .. ( #m - 2 ) .. ']', 'x' )
				ss = S.match( ss, '%%[0-9]' )
				if ss then
					error( 'invalid capture index ' .. ss .. ' in replacement string', 2 )
				end
			end
			local t = {
				["%0"] = mm,
				["%1"] = m[3],
				["%2"] = m[4],
				["%3"] = m[5],
				["%4"] = m[6],
				["%5"] = m[7],
				["%6"] = m[8],
				["%7"] = m[9],
				["%8"] = m[10],
				["%9"] = m[11],
				["%%"] = "%"
			}
			val = S.gsub( repl, '%%[%%0-9]', t )
		end
		ret[#ret + 1] = val or mm
		init = m[2] + 1
		ct = ct + 1
	end
	if init <= cps.len then
		ret[#ret + 1] = sub( s, cps, init, cps.len )
	end
	return table.concat( ret ), ct
end

---- Unicode Normalization ----
-- These functions load a conversion table when called

local function internalToNFD( cps )
	local cp = {}
	local normal = require 'ustring/normalization-data'

	-- Decompose into cp, using the lookup table and logic for hangul
	for i = 1, cps.len do
		local c = cps.codepoints[i]
		local m = normal.decomp[c]
		if m then
			for j = 0, #m do
				cp[#cp + 1] = m[j]
			end
		else
			cp[#cp + 1] = c
		end
	end

	-- Now sort combiners by class
	local i, l = 1, #cp
	while i < l do
		local cc1 = normal.combclass[cp[i]]
		local cc2 = normal.combclass[cp[i+1]]
		if cc1 and cc2 and cc1 > cc2 then
			cp[i], cp[i+1] = cp[i+1], cp[i]
			if i > 1 then
				i = i - 1
			else
				i = i + 1
			end
		else
			i = i + 1
		end
	end

	return cp, 1, l
end

-- Normalize a string to NFC
--
-- Based on MediaWiki's UtfNormal class. Returns nil if the string is not valid
-- UTF-8.
--
-- @param s string
-- @return string|nil
function ustring.toNFC( s )
	checkString( 'toNFC', s )

	-- ASCII is always NFC
	if not S.find( s, '[\128-\255]' ) then
		return s
	end

	local cps = utf8_explode( s )
	if cps == nil then
		return nil
	end
	local normal = require 'ustring/normalization-data'

	-- First, scan through to see if the string is definitely already NFC
	local ok = true
	for i = 1, cps.len do
		local c = cps.codepoints[i]
		if normal.check[c] then
			ok = false
			break
		end
	end
	if ok then
		return s
	end

	-- Next, expand to NFD
	local cp, _, l = internalToNFD( cps )

	-- Then combine to NFC. Since NFD->NFC can never expand a character
	-- sequence, we can do this in-place.
	local comp = normal.comp[cp[1]]
	local sc = 1
	local j = 1
	local lastclass = 0
	for i = 2, l do
		local c = cp[i]
		local ccc = normal.combclass[c]
		if ccc then
			-- Trying a combiner with the starter
			if comp and lastclass < ccc and comp[c] then
				-- Yes!
				c = comp[c]
				cp[sc] = c
				comp = normal.comp[c]
			else
				-- No, copy it to the right place for output
				j = j + 1
				cp[j] = c
				lastclass = ccc
			end
		elseif comp and lastclass == 0 and comp[c] then
			-- Combining two adjacent starters
			c = comp[c]
			cp[sc] = c
			comp = normal.comp[c]
		else
			-- New starter, doesn't combine
			j = j + 1
			cp[j] = c
			comp = normal.comp[c]
			sc = j
			lastclass = 0
		end
	end

	return internalChar( cp, 1, j )
end

-- Normalize a string to NFD
--
-- Based on MediaWiki's UtfNormal class. Returns nil if the string is not valid
-- UTF-8.
--
-- @param s string
-- @return string|nil
function ustring.toNFD( s )
	checkString( 'toNFD', s )

	-- ASCII is always NFC
	if not S.find( s, '[\128-\255]' ) then
		return s
	end

	local cps = utf8_explode( s )
	if cps == nil then
		return nil
	end

	return internalChar( internalToNFD( cps ) )
end

return ustring

Content Disclaimer

Informasi ini disarikan dari Wikipedia dan disajikan kembali untuk tujuan edukasi. Konten tersedia di bawah lisensi CC BY-SA 3.0. Kami tidak bertanggung jawab atas ketidakakuratan data yang bersumber dari kontribusi publik tersebut.

  1. The information displayed on this website is sourced in part or in whole from Wikipedia and has been adapted for the purpose of restating it. We strive to provide accurate and relevant information, however:
  2. There is no guarantee of absolute accuracy. Wikipedia is an open, collaborative project that can be edited by anyone, so information is subject to change.
  3. It is not intended to constitute professional advice. The content displayed is for informational and educational purposes only. For important decisions (e.g., medical, legal, or financial), please consult a professional.
  4. Content copyright. Wikipedia is licensed under the Creative Commons Attribution-ShareAlike License (CC BY-SA). This means that content may be reused with appropriate attribution and shared under a similar license.
  5. Responsible use. Any risk arising from the use of information from this website is entirely the responsibility of the user.