需求:

1.提供一个供用户使用的网页,该网页支持用户输入程序,然后用户可以运行程序,然后得到返回的结果。

2.该编译网站支持多个用户提交请求。

思路

我们需要来考虑一下java编译网站做了什么事情。

用户在网页的栏目中编程,那么用户实质上提交的是 字符串,我们设计的系统需要对字符串进行处理,然后返回对应的结果。

首先明确一点,我们设计的是java在线编译系统,因此,至少要完成字符串到可运行程序的这个过程。

以我们在本地运行java程序进行类比,那么就相当于是我们编辑了一个.java的文件,而最终要将该文件以程序的方式运行起来。

回忆java从文件到执行文件的过程,第一步我们需要将.java这个文件变成.class文件。

实现.java—–.class的转变,一个直接的思路是调用javac命令。可以在java程序中调用javac命令吗?答案是肯定的,但是我们应该考虑到,利用javac固然可以将.java文件转换为.class文件,但是这样一次编译过程,就会产生一个class文件,我们还需要去做class文件的删除工作。而有时候可能会出现多个用户同时执行请求,这个时候就可能会出现问题;另外,通过生成文件是需要io操作的,该操作的消耗较大。因此,我们应当采用java动态编译技术,直接在内存中将源代码编译为字节码的字节数组。

当我们获得了字节码的字节数组之后,我们考虑如何将字节码变成可执行的程序。

考虑真实的情况,当我们启动java虚拟机执行程序的时候,首先会进入到java.c的main方法,做一些创建执行环境,初始化java虚拟机的事情。java虚拟机创建好之后,就开始要执行我们的用户程序了。我们直到,程序的入口在main函数,所以我们首先得找到main函数。而我们知道,想要找到main函数,前提是main函数所属的类已经存在了。因此,在获得字节码之后,我们应当通过类加载器将字节码加载为class对象。一旦将类加载进虚拟机之后,我们就可以利用反射机制来运行该类的main方法了。

因为我们主要通过输入输出来显示程序运行的结果,因此程序必然要调用system类的输入输出方法。但是,将system类的控制权交给用户实在是太危险了,因此我们希望自己可以定义一个新的system类,替换掉系统的system类,限制用户可以调用的方法。

同时,还要考虑多个用户都要执行程序的情况,因此还要考虑多线程问题,即我们需要将HackSystem变成一个线程安全的类。

因此,我们可以给出我们设计的总体框架如下:

image-20220907210229045

动态编译

参考资料

java6之后提供了一套compiler的api。用户通过

1
JavaCompiler compiler = ToolProvider.getSystemJavaCompiler(); 

获得compiler。利用该compiler的getTask方法,会返回一个CompliationTask接口,该接口继承了Callable接口,而Callable接口调用call即可执行编译的任务。

1
2
Boolean result = compiler.getTask(null, javaFileManager, compileCollector,
null, null, Arrays.asList(sourceJavaFileObject)).call();

接下来,让我们看看call方法到底是如何执行编译任务的:

1.首先,会调用JavaFileObject的getCharContent方法,得到需要编译的对象CharSequence。

2.利用自己编写的JavaFileManager 的getJavaFileForOutput方法,将编译生成的字节码放到我们在该方法中new出来的自定义TmpJavaFileObject对象中。为了存放字节码数组,我们在自定义TmpJavaFileObject中加入一个ByteArrayOutputStream 属性用于存储字节码,编译器会通过openOutputStream来创建输出流对象,并把这个用来存储字节的容器返回编译器,让其把编译生成的字节码放进去

3.最后,我们需要返回一个byte字节数组,因此还需要在TmpJavaFIleObject加入一个getCompileBytes方法将ByteArrayOutputStream 中的内容变成 byte[] 返回。

所以,我们实现的javafileobject类如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
public static class TmpJavaFileObject extends SimpleJavaFileObject {
private String source;
private ByteArrayOutputStream outputStream;

/**
* 构造用来存储源代码的JavaFileObject
* 需要传入源码source,然后调用父类的构造方法创建kind = Kind.SOURCE的JavaFileObject对象
*/
public TmpJavaFileObject(String name, String source) {
super(URI.create("String:///" + name + Kind.SOURCE.extension), Kind.SOURCE);
this.source = source;
}

/**
* 构造用来存储字节码的JavaFileObject
* 需要传入kind,即我们想要构建一个存储什么类型文件的JavaFileObject
*/
public TmpJavaFileObject(String name, Kind kind) {
super(URI.create("String:///" + name + Kind.SOURCE.extension), kind);
this.source = null;
}

@Override
public CharSequence getCharContent(boolean ignoreEncodingErrors) throws IOException {
if (source == null) {
throw new IllegalArgumentException("source == null");
}
return source;
}

@Override
public OutputStream openOutputStream() throws IOException {
outputStream = new ByteArrayOutputStream();
return outputStream;//获取输出流对象,编译器将输出的字节码放入该输出流对象中
}

public byte[] getCompiledBytes() {
return outputStream.toByteArray();
}
}

而我们的javaFileManager则继承ForwardingJavaFileManager作出如下的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
public static class TmpJavaFileManager extends ForwardingJavaFileManager<JavaFileManager> {
protected TmpJavaFileManager(JavaFileManager fileManager) {
super(fileManager);
}

@Override
public JavaFileObject getJavaFileForInput(JavaFileManager.Location location,
String className,
JavaFileObject.Kind kind) throws IOException {
JavaFileObject javaFileObject = fileObjectMap.get(className);
if (javaFileObject == null) {
return super.getJavaFileForInput(location, className, kind);
}
return javaFileObject;
}

@Override
public JavaFileObject getJavaFileForOutput(JavaFileManager.Location location,
String className,
JavaFileObject.Kind kind,
FileObject sibling) throws IOException {
JavaFileObject javaFileObject = new TmpJavaFileObject(className, kind);
fileObjectMap.put(className, javaFileObject);
return javaFileObject;
}
}

实现编译器如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
public class StringSourceCompiler {
private static Map<String, JavaFileObject> fileObjectMap = new ConcurrentHashMap<>();

public static byte[] compile(String source) {
JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
DiagnosticCollector<JavaFileObject> collector = new DiagnosticCollector<>();
JavaFileManager javaFileManager =
new TmpJavaFileManager(compiler.getStandardFileManager(collector, null, null));

// 从源码字符串中匹配类名
Pattern CLASS_PATTERN = Pattern.compile("class\\s+([$_a-zA-Z][$_a-zA-Z0-9]*)\\s*");
Matcher matcher = CLASS_PATTERN.matcher(source);
String className;
if (matcher.find()) {
className = matcher.group(1);
} else {
throw new IllegalArgumentException("No valid class");
}

// 把源码字符串构造成JavaFileObject,供编译使用
JavaFileObject sourceJavaFileObject = new TmpJavaFileObject(className, source);

Boolean result = compiler.getTask(null, javaFileManager, collector,
null, null, Arrays.asList(sourceJavaFileObject)).call();

JavaFileObject bytesJavaFileObject = fileObjectMap.get(className);
if (result && bytesJavaFileObject != null) {
return ((TmpJavaFileObject) bytesJavaFileObject).getCompiledBytes();
}
return null;
}

/**
* 管理JavaFileObject对象的工具
*/
public static class TmpJavaFileManager extends ForwardingJavaFileManager<JavaFileManager> {
// ...
}

/**
* 用来封装表示源码与字节码的对象
*/
public static class TmpJavaFileObject extends SimpleJavaFileObject {
// ...
}
}

上述过程几个类之间的关系可以用下图描述:

未命名文件

执行程序

执行程序首先要找到用户的main方法,因此我们首先需要利用类加载器将用户定义的类加载到jvm中。

但是,考虑到我们的需求,当用户多次提交自己的运行代码的时候,如果类的名字没有改变,如果都是用系统类加载器进行加载的话,那么应用程序的类加载器会认为该类已经加载过了,就不会再加载该类,除非重启服务器,否则我们无法执行用户提交的新代码。

我们知道,两个类相等需要满足以下 3 个条件:

  • 同一个 .class 文件;
  • 被同一个虚拟机加载;
  • 被同一个类加载器加载;

我们希望破坏上面的3个条件的一个条件,如果可以破话,那么用户新提交的请求就可以被重新编译执行。

首先是第一个条件,我们可以想办法让用户提交的class类名不同,一个可能的思路是直接更改字符串中class的名字,比如在原始的类名上根据用户提交的时间添加上一个数字。但是系统可能同时有多个用户提交,所以这个思路可能还是会出现冲突。

第二个条件难以破坏,每次编译都要另起一个虚拟机实在是代价太大。

第三个条件是我们可以尝试破坏的,如果每次编译得到字节码,我们都用自己new 出来的类加载器进行加载,那么每次用户提交执行,我们就可以用自己的类加载器加载,就jvm就可以编译新生成的字节码了。

那么,如何用定义自己的类加载器呢?

Java提供了抽象类java.lang.ClassLoader,所有用户自定义的类加载器都应该继承ClassLoader类。

在定义ClassLoader的子类的时候,我们会看见两种做法:

  • 方式一:重写loadClass()方法

  • 方式二:重写findclass()方法

其中,loadClass方法的定义如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
protected synchronized Class<?> loadClass(String name, boolean resolve)
throws ClassNotFoundException
{
// First, check if the class has already been loaded
Class c = findLoadedClass(name);
if (c == null) {
try {
if (parent != null) {
c = parent.loadClass(name, false);
} else {
c = findBootstrapClass0(name);
}
} catch (ClassNotFoundException e) {
// If still not found, then invoke findClass in order
// to find the class.
c = findClass(name);
}
}
if (resolve) {
resolveClass(c);
}
return c;
}

首先,利用findLoadedClass 查看class是否被定义,如果没有被定义,那么就利用 该类加载器的父类加载器来加载对应的类。 可以看到,会一直找到启动类加载器。如果直到启动类加载器都没有找到对应的类,那么就会调用findClass函数。而findClass函数才是真正实现类加载的函数。

可以看到,相比于findClass,实际上loadClass方法主要是增加了双亲委托机制的实现。在我们的设计中,不需要破话双亲委托机制,因此只需要重写findClass方法即可。

不过,我们要看到,findClass传入的参数是name,但是实际上我们已经生成了字节码数组,所以,如果能用字节码数组字节直接转换为JVM可以识别的Class对象就好了。这里我们可以用defineClass来实现该功能,defineClass() 方法可以将byte字节流解析成JVM可以识别的Class对象。该方法是受保护的方法,只能在自定义的ClassLoader子类可以使用。我们可以对该方法做一个封装如下:

1
2
3
4
5
6
7
8
9
10
public class HotSwapClassLoader extends ClassLoader {
public HotSwapClassLoader() {
super(HotSwapClassLoader.class.getClassLoader());
}

public Class loadByte(byte[] classBytes) {
return defineClass(null, classBytes, 0, classBytes.length);
}
}

定义了自己的类加载器,我们接下来就可以用下面的方法无数次加载客户端要运行的类了:

1
2
HotSwapClassLoader classLoader = new HotSwapClassLoader();
Class clazz = classLoader.loadByte(modifyBytes);

获取到了我们的Class实例之后,我们就可以用反射方法调用用户运行类的main方法了:

1
2
Method mainMethod = clazz.getMethod("main", new Class[] { String[].class });
mainMethod.invoke(null, new String[] { null });

多用户请求设计

考虑多个用户发起请求的情况。对于一个用户的提交运行请求,我们实际上要做的事情就分为上面描述的两个部分,一个是进行动态编译,另外一个就是根据动态编译后的字节码执行用户提交的程序。

对于第一个过程,各个用户提交的代码编译时间应该是大致相同的,但是对于第二个过程,执行用户提交代码的时间可能就有所差异。因此,我们要做两件事情。一件事情是对于用户代码执行时间过长的,我们要及时终止。令一方面,我们应当利用多线程来执行用户程序。因为不同用户的程序运行时间是不同的,如果某用户程序运行时间较短,但需要串行等待运行时间较长的用户的程序运行完毕,显然是不合理的,因此我们需要设计多线程机制来执行用户的请求。

这里还要解决的一个问题是,我们主要是通过标准输出打印将用户运行的结果显示出来。但是标准输出是虚拟机全局共享的资源,如果用户可以访问到System资源,可能就出现安全问题(比如调用exit方法)。同时,由于我们需要多线程来执行程序,我们最好是实现一个线程安全的类来替代jvm的system系统。

那么,我们如何实现这个替换的过程呢?一个思路是,直接对用户发来的源代码字符串进行修改替换,另外一个思路是,直接在字节码中,将执行的类对System的符号引用替换为我们准备的HackSystem的符号引用。

所以,我们要知道System的符号引用到底在哪里,并且想办法利用程序将该符号引用找出来。

我们知道,class文件具有一定的格式,其中头8个字节是魔数和版本号,而第9个字节开始,就是常量池的入口。常量池入口的前两个字节,放置了一个u2类型的数据,标识了常量池中常量的数量。

常量池的每一项常量都通过一个表来存储。目前共有14中常量,其中CONSTANT_Utf8_info 类型常量一般用来描述类的全限定名、方法名和字段名。我们只需要修改值为java/lang/System 的 CONSTANT_Utf8_info 的常量,变成我们类的全限定名,就可以在运行时调用我们自己的system系统。

我们实现的思路是,首先取出常量池中常量的个数CPC,然后遍历常量,检查tag=1的CONSTANT_Utf8_info 常量。找到存储的常量值为java/lang/System的常量,将其替换成我们自己写的org/olexec/execute/HackSystem;

下面是我们实现替换的方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
public byte[] modifyUTF8Constant(String oldStr, String newStr) {
int cpc = getConstantPoolCount();//因为位置是固定的,因此就是将固定位置的byte转换为int而已
int offset = CONSTANT_POOL_COUNT_INDEX + u2; // 真实的常量起始位置,8字节魔数版本号,1一个u2类型
for (int i = 1; i < cpc; i++) {//注意从1开始,0空出来用来表示索引不引用任何常量池项目
int tag = ByteUtils.byte2Int(classByte, offset, u1);
if (tag == CONSTANT_UTF8_INFO) {
int len = ByteUtils.byte2Int(classByte, offset + u1, u2);
offset += u1 + u2;
String str = ByteUtils.byte2String(classByte, offset, len);
if (str.equals(oldStr)) {
byte[] strReplaceBytes = ByteUtils.string2Byte(newStr);
byte[] intReplaceBytes = ByteUtils.int2Byte(strReplaceBytes.length, u2);
// 替换新的字符串的长度
classByte = ByteUtils.byteReplace(classByte, offset - u2, u2, intReplaceBytes);
// 替换字符串本身
classByte = ByteUtils.byteReplace(classByte, offset, len, strReplaceBytes);
return classByte; // 就一个地方需要改,改完就可以返回了
} else {
offset += len;//如果没有对应字符串,继续向前查找
}
} else {
offset += CONSTANT_ITEM_LENGTH[tag];//根据tag继续往前查找
}
}
return classByte;
}

自定义hacksystem的设计

再一次厘清我们的需求:我们设计的系统,主要是希望利用System的打印函数来将字符串显示出来,在我的博文《当我们在System.out.println()的时候我们在做什么?》里面已经描述了我们调用程序向控制台输出的全过程。我们设计的hackSystem显然不是向控制台打印结果,而是返回一串结果字符串交由网页显示。这也就意味着,我们不必如同System函数一样,需要调用一个初始化方法,将静态变量out所代表的输出流指向控制台,而只需要设计对每一个线程,维护一个属于该线程的输出流,在需要获取结果的时候,将输出流转换为字符串并返回即可。

我们可以仿照System重写我们的hackSystem的大部分内容。但是要注意,因为我们不会像System那样调用一个初始化函数,所以我们需要在开始定义的时候,就要将我们的out静态变量指向我们定义的HackPrintStream对象。同时,我们需要新增一个返回String类型的方法,即将我们当前的输出对象转换为字符串交由网页处理:

1
2
3
public static String getBufferString() {
return out.toString();
}

对于System类一些和系统有关的方法,我们应当重写禁止用户调用;而对于System的一些工具类写法,直接在方法内部调用System的方法即可。

自定义HackPrintStream的设计

System.out.println实际上调用的是PrintStream.println()方法。因此,我们自定义的HackPrintStream应当继承PrintStream并重写PrintStream的公共方法。但是我们的HackPrintStream是支持多线程的。因此,与PrintStream不同,我们的HackPrintStream需要为每一个线程维护一个输出流。因此,我们想到了ThreadLocal。

所以,我们在HackPrintStream中添加如下的字段:

1
2
3
private ThreadLocal<ByteArrayOutputStream> out;
private ThreadLocal<Boolean> trouble;

让我们看一下PrintStream的println方法是如何实现的:

1
2
3
4
5
6
public void println(int x) {
synchronized (this) {
print(x);
newLine();
}
}

首先,这是一个加锁的方法。然后这里调用print代码如下:

1
2
3
public void print(int i) {
write(String.valueOf(i));
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
private void write(String s) {
try {
synchronized (this) {
ensureOpen();
textOut.write(s);
textOut.flushBuffer();
charOut.flushBuffer();
if (autoFlush && (s.indexOf('\n') >= 0))
out.flush();
}
}
catch (InterruptedIOException x) {
Thread.currentThread().interrupt();
}
catch (IOException x) {
trouble = true;
}
}

write同样是一个加锁的方法,首先确保当前有一个输出流(即输出流没有关闭)

1
2
3
4
private void ensureOpen() throws IOException {
if (out == null)
throw new IOException("Stream closed");
}

然后就是往输出流里写入内容。

而对于我们实现的HackPrintStream,我们的print方法是一样,但是这里需要对write方法做不同的实现:

1
2
3
public void print(int i) {
write(String.valueOf(i));
}

因为我们这里用了ThreadLocal来为每一个线程分配一个输出流,因此我们不对write进行加锁,而是直接获取对应的输出流,然后写入内容:

1
2
3
4
5
6
7
8
9
10
11
12
private void write(String s) {
try {
ensureOpen();
out.get().write(s.getBytes());
}
catch (InterruptedIOException x) {
Thread.currentThread().interrupt();
}
catch (IOException x) {
trouble.set(true);
}
}

由于对于每一个线程,我们都需要放入一个新的输入流,因此我们的ensureOpen函数也要重写如下所示:

1
2
3
4
5
private void ensureOpen() throws IOException {
if (out.get() == null) {
out.set(new ByteArrayOutputStream());
}
}

即如果发现当前的线程对应的ThreadLocalMap中没有输出的Stream的话,不是抛出异常,而是通过set方法向map中放置一个输出流作为value。